Skip to main content

alpine/
types.rs

1use std::time::Duration;
2
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6// ---------------------------------------------------------------------------
7// ModelId
8// ---------------------------------------------------------------------------
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct ModelId(pub String);
12
13impl ModelId {
14    pub fn new(id: impl Into<String>) -> Self {
15        Self(id.into())
16    }
17
18    pub fn as_str(&self) -> &str {
19        &self.0
20    }
21}
22
23impl Default for ModelId {
24    fn default() -> Self {
25        Self("default".into())
26    }
27}
28
29impl std::fmt::Display for ModelId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.write_str(&self.0)
32    }
33}
34
35// ---------------------------------------------------------------------------
36// Message
37// ---------------------------------------------------------------------------
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Message {
41    pub role: Role,
42    pub content: String,
43}
44
45impl Message {
46    pub fn user(content: impl Into<String>) -> Self {
47        Self {
48            role: Role::User,
49            content: content.into(),
50        }
51    }
52
53    pub fn assistant(content: impl Into<String>) -> Self {
54        Self {
55            role: Role::Assistant,
56            content: content.into(),
57        }
58    }
59
60    pub fn system(content: impl Into<String>) -> Self {
61        Self {
62            role: Role::System,
63            content: content.into(),
64        }
65    }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
69#[serde(rename_all = "lowercase")]
70pub enum Role {
71    System,
72    User,
73    Assistant,
74}
75
76// ---------------------------------------------------------------------------
77// Request / Response
78// ---------------------------------------------------------------------------
79
80#[derive(Debug, Clone, Default)]
81pub struct Request {
82    pub messages: Vec<Message>,
83    pub model: ModelId,
84    pub max_tokens: Option<u32>,
85    pub temperature: Option<f32>,
86    pub system: Option<String>,
87    pub stop: Vec<String>,
88}
89
90#[derive(Debug, Clone)]
91pub struct Response {
92    pub content: String,
93    pub usage: Usage,
94    pub model: ModelId,
95    pub finish_reason: FinishReason,
96    pub latency: Duration,
97    pub raw: serde_json::Value,
98}
99
100#[derive(Debug, Clone, Default)]
101pub struct Usage {
102    pub input_tokens: u32,
103    pub output_tokens: u32,
104}
105
106#[derive(Debug, Clone, Default, PartialEq, Eq)]
107pub enum FinishReason {
108    #[default]
109    Stop,
110    MaxTokens,
111    ContentFilter,
112    Other(String),
113}
114
115impl std::fmt::Display for Response {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        writeln!(f, "[{}] ({:.0?})", self.model, self.latency)?;
118        writeln!(f, "{}", self.content)?;
119        write!(
120            f,
121            "tokens: {} in / {} out | finish: {:?}",
122            self.usage.input_tokens, self.usage.output_tokens, self.finish_reason
123        )
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Streaming
129// ---------------------------------------------------------------------------
130
131#[derive(Debug, Clone)]
132pub enum StreamChunk {
133    Delta(String),
134    Done { usage: Option<Usage> },
135    Error(String),
136}
137
138/// Convenience alias used throughout the crate.
139pub type StreamResponse<'a> = BoxStream<'a, StreamChunk>;
140
141// ---------------------------------------------------------------------------
142// Embeddings
143// ---------------------------------------------------------------------------
144
145#[derive(Debug, Clone)]
146pub struct EmbedRequest {
147    pub model: ModelId,
148    pub input: Vec<String>,
149}
150
151#[derive(Debug, Clone)]
152pub struct Embedding {
153    pub vectors: Vec<Vec<f32>>,
154    pub model: ModelId,
155    pub usage: Usage,
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use std::collections::HashSet;
162    use std::time::Duration;
163
164    // -- ModelId ---------------------------------------------------------------
165
166    #[test]
167    fn model_id_new_and_as_str() {
168        let m = ModelId::new("gpt-4");
169        assert_eq!(m.as_str(), "gpt-4");
170    }
171
172    #[test]
173    fn model_id_default() {
174        assert_eq!(ModelId::default().as_str(), "default");
175    }
176
177    #[test]
178    fn model_id_display() {
179        let m = ModelId::new("claude-3");
180        assert_eq!(format!("{m}"), "claude-3");
181    }
182
183    #[test]
184    fn model_id_eq_and_hash() {
185        let a = ModelId::new("x");
186        let b = ModelId::new("x");
187        let c = ModelId::new("y");
188        assert_eq!(a, b);
189        assert_ne!(a, c);
190
191        let mut set = HashSet::new();
192        set.insert(a);
193        set.insert(b);
194        assert_eq!(set.len(), 1);
195    }
196
197    #[test]
198    fn model_id_serde_roundtrip() {
199        let m = ModelId::new("llama3.2");
200        let json = serde_json::to_string(&m).unwrap();
201        let back: ModelId = serde_json::from_str(&json).unwrap();
202        assert_eq!(m, back);
203    }
204
205    // -- Message ---------------------------------------------------------------
206
207    #[test]
208    fn message_user() {
209        let m = Message::user("hi");
210        assert_eq!(m.role, Role::User);
211        assert_eq!(m.content, "hi");
212    }
213
214    #[test]
215    fn message_assistant() {
216        let m = Message::assistant("ok");
217        assert_eq!(m.role, Role::Assistant);
218        assert_eq!(m.content, "ok");
219    }
220
221    #[test]
222    fn message_system() {
223        let m = Message::system("you are helpful");
224        assert_eq!(m.role, Role::System);
225        assert_eq!(m.content, "you are helpful");
226    }
227
228    // -- Role ------------------------------------------------------------------
229
230    #[test]
231    fn role_serde_roundtrip() {
232        for (role, expected) in [
233            (Role::User, "\"user\""),
234            (Role::Assistant, "\"assistant\""),
235            (Role::System, "\"system\""),
236        ] {
237            let json = serde_json::to_string(&role).unwrap();
238            assert_eq!(json, expected);
239            let back: Role = serde_json::from_str(&json).unwrap();
240            assert_eq!(back, role);
241        }
242    }
243
244    // -- Request ---------------------------------------------------------------
245
246    #[test]
247    fn request_default() {
248        let r = Request::default();
249        assert!(r.messages.is_empty());
250        assert_eq!(r.model, ModelId::default());
251        assert!(r.max_tokens.is_none());
252        assert!(r.temperature.is_none());
253        assert!(r.system.is_none());
254        assert!(r.stop.is_empty());
255    }
256
257    // -- Response Display ------------------------------------------------------
258
259    #[test]
260    fn response_display() {
261        let resp = Response {
262            content: "Hello!".into(),
263            usage: Usage {
264                input_tokens: 10,
265                output_tokens: 5,
266            },
267            model: ModelId::new("test-model"),
268            finish_reason: FinishReason::Stop,
269            latency: Duration::from_millis(1234),
270            raw: serde_json::Value::Null,
271        };
272        let s = format!("{resp}");
273        assert!(s.contains("test-model"));
274        assert!(s.contains("Hello!"));
275        assert!(s.contains("10 in"));
276        assert!(s.contains("5 out"));
277        assert!(s.contains("Stop"));
278        // latency formatted with {:.0?} — should contain "1.234s" or "1234ms"
279        assert!(s.contains("1"));
280    }
281
282    // -- FinishReason ----------------------------------------------------------
283
284    #[test]
285    fn finish_reason_default_is_stop() {
286        assert_eq!(FinishReason::default(), FinishReason::Stop);
287    }
288
289    #[test]
290    fn finish_reason_variants() {
291        assert_eq!(FinishReason::Stop, FinishReason::Stop);
292        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
293        assert_ne!(FinishReason::MaxTokens, FinishReason::ContentFilter);
294        let other = FinishReason::Other("custom".into());
295        assert_eq!(other, FinishReason::Other("custom".into()));
296        assert_ne!(other, FinishReason::Other("different".into()));
297    }
298
299    // -- Usage -----------------------------------------------------------------
300
301    #[test]
302    fn usage_default() {
303        let u = Usage::default();
304        assert_eq!(u.input_tokens, 0);
305        assert_eq!(u.output_tokens, 0);
306    }
307
308    // -- StreamChunk -----------------------------------------------------------
309
310    #[test]
311    fn stream_chunk_debug() {
312        let _ = format!("{:?}", StreamChunk::Delta("hi".into()));
313        let _ = format!("{:?}", StreamChunk::Done { usage: None });
314        let _ = format!(
315            "{:?}",
316            StreamChunk::Done {
317                usage: Some(Usage::default())
318            }
319        );
320        let _ = format!("{:?}", StreamChunk::Error("err".into()));
321    }
322
323    // -- EmbedRequest / Embedding ----------------------------------------------
324
325    #[test]
326    fn embed_request_construction() {
327        let r = EmbedRequest {
328            model: ModelId::new("nomic"),
329            input: vec!["hello".into(), "world".into()],
330        };
331        assert_eq!(r.model.as_str(), "nomic");
332        assert_eq!(r.input.len(), 2);
333    }
334
335    #[test]
336    fn embedding_construction() {
337        let e = Embedding {
338            vectors: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
339            model: ModelId::new("nomic"),
340            usage: Usage {
341                input_tokens: 4,
342                output_tokens: 0,
343            },
344        };
345        assert_eq!(e.vectors.len(), 2);
346        assert_eq!(e.vectors[0].len(), 2);
347        assert_eq!(e.model.as_str(), "nomic");
348    }
349}