Skip to main content

alpine/
types.rs

1use std::time::Duration;
2
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6// ---------------------------------------------------------------------------
7// ModelId
8// ---------------------------------------------------------------------------
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct ModelId(pub String);
12
13impl ModelId {
14    pub fn new(id: impl Into<String>) -> Self {
15        Self(id.into())
16    }
17
18    pub fn as_str(&self) -> &str {
19        &self.0
20    }
21}
22
23impl Default for ModelId {
24    fn default() -> Self {
25        Self("default".into())
26    }
27}
28
29impl std::fmt::Display for ModelId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.write_str(&self.0)
32    }
33}
34
35// ---------------------------------------------------------------------------
36// Message
37// ---------------------------------------------------------------------------
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Message {
41    pub role: Role,
42    pub content: Vec<ContentBlock>,
43}
44
45impl Message {
46    /// Construct a message from a role and explicit content blocks.
47    pub fn new(role: Role, content: Vec<ContentBlock>) -> Self {
48        Self { role, content }
49    }
50
51    pub fn user(content: impl Into<String>) -> Self {
52        Self {
53            role: Role::User,
54            content: vec![ContentBlock::text(content)],
55        }
56    }
57
58    pub fn assistant(content: impl Into<String>) -> Self {
59        Self {
60            role: Role::Assistant,
61            content: vec![ContentBlock::text(content)],
62        }
63    }
64
65    pub fn system(content: impl Into<String>) -> Self {
66        Self {
67            role: Role::System,
68            content: vec![ContentBlock::text(content)],
69        }
70    }
71
72    /// A user message carrying a single tool result (the conventional way to
73    /// return tool output to the model).
74    pub fn tool_result(
75        tool_use_id: impl Into<String>,
76        content: impl Into<String>,
77        is_error: bool,
78    ) -> Self {
79        Self {
80            role: Role::User,
81            content: vec![ContentBlock::ToolResult {
82                tool_use_id: tool_use_id.into(),
83                content: content.into(),
84                is_error,
85            }],
86        }
87    }
88
89    /// Concatenate all text blocks into a single string. Non-text blocks
90    /// (tool_use / tool_result) are ignored. This is the convenience accessor
91    /// for callers that only care about textual content.
92    pub fn text(&self) -> String {
93        self.content
94            .iter()
95            .filter_map(|b| match b {
96                ContentBlock::Text { text } => Some(text.as_str()),
97                _ => None,
98            })
99            .collect::<Vec<_>>()
100            .join("")
101    }
102
103    /// All `tool_use` blocks in this message.
104    pub fn tool_uses(&self) -> impl Iterator<Item = &ContentBlock> {
105        self.content
106            .iter()
107            .filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
108    }
109}
110
111// ---------------------------------------------------------------------------
112// Content blocks
113// ---------------------------------------------------------------------------
114
115/// A single piece of a message's content. Messages are sequences of blocks so
116/// that a turn can mix text with tool-use requests (assistant) and tool
117/// results (user) — the shape required for provider tool/function calling.
118#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
119#[serde(tag = "type", rename_all = "snake_case")]
120pub enum ContentBlock {
121    Text {
122        text: String,
123    },
124    /// The model is requesting a tool invocation.
125    ToolUse {
126        id: String,
127        name: String,
128        input: serde_json::Value,
129    },
130    /// The caller is returning the result of a tool invocation.
131    ToolResult {
132        tool_use_id: String,
133        content: String,
134        #[serde(default)]
135        is_error: bool,
136    },
137}
138
139impl ContentBlock {
140    pub fn text(text: impl Into<String>) -> Self {
141        Self::Text { text: text.into() }
142    }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146#[serde(rename_all = "lowercase")]
147pub enum Role {
148    System,
149    User,
150    Assistant,
151}
152
153// ---------------------------------------------------------------------------
154// Request / Response
155// ---------------------------------------------------------------------------
156
157#[derive(Debug, Clone, Default)]
158pub struct Request {
159    pub messages: Vec<Message>,
160    pub model: ModelId,
161    pub max_tokens: Option<u32>,
162    pub temperature: Option<f32>,
163    pub system: Option<String>,
164    pub stop: Vec<String>,
165    /// Tools the model may call. Empty (the default) means no tool calling —
166    /// identical behavior to before tool support existed.
167    pub tools: Vec<ToolDefinition>,
168}
169
170#[derive(Debug, Clone)]
171pub struct Response {
172    /// Flattened text content (all text blocks concatenated).
173    pub content: String,
174    /// Any tool calls the model requested this turn. Empty when the model
175    /// returned only text.
176    pub tool_calls: Vec<ToolUse>,
177    pub usage: Usage,
178    pub model: ModelId,
179    pub finish_reason: FinishReason,
180    pub latency: Duration,
181    pub raw: serde_json::Value,
182}
183
184// ---------------------------------------------------------------------------
185// Tools
186// ---------------------------------------------------------------------------
187
188/// A provider-agnostic tool the model may call. `input_schema` is a JSON Schema
189/// describing the tool's arguments.
190#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
191pub struct ToolDefinition {
192    pub name: String,
193    pub description: String,
194    pub input_schema: serde_json::Value,
195}
196
197impl ToolDefinition {
198    pub fn new(
199        name: impl Into<String>,
200        description: impl Into<String>,
201        input_schema: serde_json::Value,
202    ) -> Self {
203        Self {
204            name: name.into(),
205            description: description.into(),
206            input_schema,
207        }
208    }
209}
210
211/// A tool call the model requested.
212#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
213pub struct ToolUse {
214    pub id: String,
215    pub name: String,
216    pub input: serde_json::Value,
217}
218
219#[derive(Debug, Clone, Default)]
220pub struct Usage {
221    pub input_tokens: u32,
222    pub output_tokens: u32,
223}
224
225#[derive(Debug, Clone, Default, PartialEq, Eq)]
226pub enum FinishReason {
227    #[default]
228    Stop,
229    MaxTokens,
230    ContentFilter,
231    /// The model stopped because it wants to call one or more tools.
232    ToolUse,
233    Other(String),
234}
235
236impl std::fmt::Display for Response {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        writeln!(f, "[{}] ({:.0?})", self.model, self.latency)?;
239        writeln!(f, "{}", self.content)?;
240        write!(
241            f,
242            "tokens: {} in / {} out | finish: {:?}",
243            self.usage.input_tokens, self.usage.output_tokens, self.finish_reason
244        )
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Streaming
250// ---------------------------------------------------------------------------
251
252#[derive(Debug, Clone)]
253pub enum StreamChunk {
254    Delta(String),
255    Done { usage: Option<Usage> },
256    Error(String),
257}
258
259/// Convenience alias used throughout the crate.
260pub type StreamResponse<'a> = BoxStream<'a, StreamChunk>;
261
262// ---------------------------------------------------------------------------
263// Embeddings
264// ---------------------------------------------------------------------------
265
266#[derive(Debug, Clone)]
267pub struct EmbedRequest {
268    pub model: ModelId,
269    pub input: Vec<String>,
270}
271
272#[derive(Debug, Clone)]
273pub struct Embedding {
274    pub vectors: Vec<Vec<f32>>,
275    pub model: ModelId,
276    pub usage: Usage,
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use std::collections::HashSet;
283    use std::time::Duration;
284
285    // -- ModelId ---------------------------------------------------------------
286
287    #[test]
288    fn model_id_new_and_as_str() {
289        let m = ModelId::new("gpt-4");
290        assert_eq!(m.as_str(), "gpt-4");
291    }
292
293    #[test]
294    fn model_id_default() {
295        assert_eq!(ModelId::default().as_str(), "default");
296    }
297
298    #[test]
299    fn model_id_display() {
300        let m = ModelId::new("claude-3");
301        assert_eq!(format!("{m}"), "claude-3");
302    }
303
304    #[test]
305    fn model_id_eq_and_hash() {
306        let a = ModelId::new("x");
307        let b = ModelId::new("x");
308        let c = ModelId::new("y");
309        assert_eq!(a, b);
310        assert_ne!(a, c);
311
312        let mut set = HashSet::new();
313        set.insert(a);
314        set.insert(b);
315        assert_eq!(set.len(), 1);
316    }
317
318    #[test]
319    fn model_id_serde_roundtrip() {
320        let m = ModelId::new("llama3.2");
321        let json = serde_json::to_string(&m).unwrap();
322        let back: ModelId = serde_json::from_str(&json).unwrap();
323        assert_eq!(m, back);
324    }
325
326    // -- Message ---------------------------------------------------------------
327
328    #[test]
329    fn message_user() {
330        let m = Message::user("hi");
331        assert_eq!(m.role, Role::User);
332        assert_eq!(m.text(), "hi");
333    }
334
335    #[test]
336    fn message_assistant() {
337        let m = Message::assistant("ok");
338        assert_eq!(m.role, Role::Assistant);
339        assert_eq!(m.text(), "ok");
340    }
341
342    #[test]
343    fn message_system() {
344        let m = Message::system("you are helpful");
345        assert_eq!(m.role, Role::System);
346        assert_eq!(m.text(), "you are helpful");
347    }
348
349    #[test]
350    fn message_text_concatenates_and_ignores_non_text() {
351        let m = Message::new(
352            Role::Assistant,
353            vec![
354                ContentBlock::text("a"),
355                ContentBlock::ToolUse {
356                    id: "t1".into(),
357                    name: "x".into(),
358                    input: serde_json::json!({}),
359                },
360                ContentBlock::text("b"),
361            ],
362        );
363        assert_eq!(m.text(), "ab");
364        assert_eq!(m.tool_uses().count(), 1);
365    }
366
367    #[test]
368    fn message_tool_result_helper() {
369        let m = Message::tool_result("tu_1", "result body", false);
370        assert_eq!(m.role, Role::User);
371        match &m.content[0] {
372            ContentBlock::ToolResult {
373                tool_use_id,
374                content,
375                is_error,
376            } => {
377                assert_eq!(tool_use_id, "tu_1");
378                assert_eq!(content, "result body");
379                assert!(!is_error);
380            }
381            other => panic!("expected ToolResult, got {other:?}"),
382        }
383    }
384
385    #[test]
386    fn content_block_serde_roundtrip() {
387        for block in [
388            ContentBlock::text("hi"),
389            ContentBlock::ToolUse {
390                id: "id".into(),
391                name: "search".into(),
392                input: serde_json::json!({"q": "x"}),
393            },
394            ContentBlock::ToolResult {
395                tool_use_id: "id".into(),
396                content: "ok".into(),
397                is_error: false,
398            },
399        ] {
400            let json = serde_json::to_string(&block).unwrap();
401            let back: ContentBlock = serde_json::from_str(&json).unwrap();
402            assert_eq!(block, back);
403        }
404    }
405
406    // -- Role ------------------------------------------------------------------
407
408    #[test]
409    fn role_serde_roundtrip() {
410        for (role, expected) in [
411            (Role::User, "\"user\""),
412            (Role::Assistant, "\"assistant\""),
413            (Role::System, "\"system\""),
414        ] {
415            let json = serde_json::to_string(&role).unwrap();
416            assert_eq!(json, expected);
417            let back: Role = serde_json::from_str(&json).unwrap();
418            assert_eq!(back, role);
419        }
420    }
421
422    // -- Request ---------------------------------------------------------------
423
424    #[test]
425    fn request_default() {
426        let r = Request::default();
427        assert!(r.messages.is_empty());
428        assert_eq!(r.model, ModelId::default());
429        assert!(r.max_tokens.is_none());
430        assert!(r.temperature.is_none());
431        assert!(r.system.is_none());
432        assert!(r.stop.is_empty());
433    }
434
435    // -- Response Display ------------------------------------------------------
436
437    #[test]
438    fn response_display() {
439        let resp = Response {
440            content: "Hello!".into(),
441            tool_calls: vec![],
442            usage: Usage {
443                input_tokens: 10,
444                output_tokens: 5,
445            },
446            model: ModelId::new("test-model"),
447            finish_reason: FinishReason::Stop,
448            latency: Duration::from_millis(1234),
449            raw: serde_json::Value::Null,
450        };
451        let s = format!("{resp}");
452        assert!(s.contains("test-model"));
453        assert!(s.contains("Hello!"));
454        assert!(s.contains("10 in"));
455        assert!(s.contains("5 out"));
456        assert!(s.contains("Stop"));
457        // latency formatted with {:.0?} — should contain "1.234s" or "1234ms"
458        assert!(s.contains("1"));
459    }
460
461    // -- FinishReason ----------------------------------------------------------
462
463    #[test]
464    fn finish_reason_default_is_stop() {
465        assert_eq!(FinishReason::default(), FinishReason::Stop);
466    }
467
468    #[test]
469    fn finish_reason_variants() {
470        assert_eq!(FinishReason::Stop, FinishReason::Stop);
471        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
472        assert_ne!(FinishReason::MaxTokens, FinishReason::ContentFilter);
473        let other = FinishReason::Other("custom".into());
474        assert_eq!(other, FinishReason::Other("custom".into()));
475        assert_ne!(other, FinishReason::Other("different".into()));
476    }
477
478    // -- Usage -----------------------------------------------------------------
479
480    #[test]
481    fn usage_default() {
482        let u = Usage::default();
483        assert_eq!(u.input_tokens, 0);
484        assert_eq!(u.output_tokens, 0);
485    }
486
487    // -- StreamChunk -----------------------------------------------------------
488
489    #[test]
490    fn stream_chunk_debug() {
491        let _ = format!("{:?}", StreamChunk::Delta("hi".into()));
492        let _ = format!("{:?}", StreamChunk::Done { usage: None });
493        let _ = format!(
494            "{:?}",
495            StreamChunk::Done {
496                usage: Some(Usage::default())
497            }
498        );
499        let _ = format!("{:?}", StreamChunk::Error("err".into()));
500    }
501
502    // -- EmbedRequest / Embedding ----------------------------------------------
503
504    #[test]
505    fn embed_request_construction() {
506        let r = EmbedRequest {
507            model: ModelId::new("nomic"),
508            input: vec!["hello".into(), "world".into()],
509        };
510        assert_eq!(r.model.as_str(), "nomic");
511        assert_eq!(r.input.len(), 2);
512    }
513
514    #[test]
515    fn embedding_construction() {
516        let e = Embedding {
517            vectors: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
518            model: ModelId::new("nomic"),
519            usage: Usage {
520                input_tokens: 4,
521                output_tokens: 0,
522            },
523        };
524        assert_eq!(e.vectors.len(), 2);
525        assert_eq!(e.vectors[0].len(), 2);
526        assert_eq!(e.model.as_str(), "nomic");
527    }
528}