1use std::time::Duration;
2
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct ModelId(pub String);
12
13impl ModelId {
14 pub fn new(id: impl Into<String>) -> Self {
15 Self(id.into())
16 }
17
18 pub fn as_str(&self) -> &str {
19 &self.0
20 }
21}
22
23impl Default for ModelId {
24 fn default() -> Self {
25 Self("default".into())
26 }
27}
28
29impl std::fmt::Display for ModelId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.write_str(&self.0)
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Message {
41 pub role: Role,
42 pub content: Vec<ContentBlock>,
43}
44
45impl Message {
46 pub fn new(role: Role, content: Vec<ContentBlock>) -> Self {
48 Self { role, content }
49 }
50
51 pub fn user(content: impl Into<String>) -> Self {
52 Self {
53 role: Role::User,
54 content: vec![ContentBlock::text(content)],
55 }
56 }
57
58 pub fn assistant(content: impl Into<String>) -> Self {
59 Self {
60 role: Role::Assistant,
61 content: vec![ContentBlock::text(content)],
62 }
63 }
64
65 pub fn system(content: impl Into<String>) -> Self {
66 Self {
67 role: Role::System,
68 content: vec![ContentBlock::text(content)],
69 }
70 }
71
72 pub fn tool_result(
75 tool_use_id: impl Into<String>,
76 content: impl Into<String>,
77 is_error: bool,
78 ) -> Self {
79 Self {
80 role: Role::User,
81 content: vec![ContentBlock::ToolResult {
82 tool_use_id: tool_use_id.into(),
83 content: content.into(),
84 is_error,
85 }],
86 }
87 }
88
89 pub fn text(&self) -> String {
93 self.content
94 .iter()
95 .filter_map(|b| match b {
96 ContentBlock::Text { text } => Some(text.as_str()),
97 _ => None,
98 })
99 .collect::<Vec<_>>()
100 .join("")
101 }
102
103 pub fn tool_uses(&self) -> impl Iterator<Item = &ContentBlock> {
105 self.content
106 .iter()
107 .filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
119#[serde(tag = "type", rename_all = "snake_case")]
120pub enum ContentBlock {
121 Text {
122 text: String,
123 },
124 ToolUse {
126 id: String,
127 name: String,
128 input: serde_json::Value,
129 },
130 ToolResult {
132 tool_use_id: String,
133 content: String,
134 #[serde(default)]
135 is_error: bool,
136 },
137}
138
139impl ContentBlock {
140 pub fn text(text: impl Into<String>) -> Self {
141 Self::Text { text: text.into() }
142 }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146#[serde(rename_all = "lowercase")]
147pub enum Role {
148 System,
149 User,
150 Assistant,
151}
152
153#[derive(Debug, Clone, Default)]
158pub struct Request {
159 pub messages: Vec<Message>,
160 pub model: ModelId,
161 pub max_tokens: Option<u32>,
162 pub temperature: Option<f32>,
163 pub system: Option<String>,
164 pub stop: Vec<String>,
165 pub tools: Vec<ToolDefinition>,
168}
169
170#[derive(Debug, Clone)]
171pub struct Response {
172 pub content: String,
174 pub tool_calls: Vec<ToolUse>,
177 pub usage: Usage,
178 pub model: ModelId,
179 pub finish_reason: FinishReason,
180 pub latency: Duration,
181 pub raw: serde_json::Value,
182}
183
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
191pub struct ToolDefinition {
192 pub name: String,
193 pub description: String,
194 pub input_schema: serde_json::Value,
195}
196
197impl ToolDefinition {
198 pub fn new(
199 name: impl Into<String>,
200 description: impl Into<String>,
201 input_schema: serde_json::Value,
202 ) -> Self {
203 Self {
204 name: name.into(),
205 description: description.into(),
206 input_schema,
207 }
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
213pub struct ToolUse {
214 pub id: String,
215 pub name: String,
216 pub input: serde_json::Value,
217}
218
219#[derive(Debug, Clone, Default)]
220pub struct Usage {
221 pub input_tokens: u32,
222 pub output_tokens: u32,
223}
224
225#[derive(Debug, Clone, Default, PartialEq, Eq)]
226pub enum FinishReason {
227 #[default]
228 Stop,
229 MaxTokens,
230 ContentFilter,
231 ToolUse,
233 Other(String),
234}
235
236impl std::fmt::Display for Response {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 writeln!(f, "[{}] ({:.0?})", self.model, self.latency)?;
239 writeln!(f, "{}", self.content)?;
240 write!(
241 f,
242 "tokens: {} in / {} out | finish: {:?}",
243 self.usage.input_tokens, self.usage.output_tokens, self.finish_reason
244 )
245 }
246}
247
248#[derive(Debug, Clone)]
253pub enum StreamChunk {
254 Delta(String),
255 Done { usage: Option<Usage> },
256 Error(String),
257}
258
259pub type StreamResponse<'a> = BoxStream<'a, StreamChunk>;
261
262#[derive(Debug, Clone)]
267pub struct EmbedRequest {
268 pub model: ModelId,
269 pub input: Vec<String>,
270}
271
272#[derive(Debug, Clone)]
273pub struct Embedding {
274 pub vectors: Vec<Vec<f32>>,
275 pub model: ModelId,
276 pub usage: Usage,
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::collections::HashSet;
283 use std::time::Duration;
284
285 #[test]
288 fn model_id_new_and_as_str() {
289 let m = ModelId::new("gpt-4");
290 assert_eq!(m.as_str(), "gpt-4");
291 }
292
293 #[test]
294 fn model_id_default() {
295 assert_eq!(ModelId::default().as_str(), "default");
296 }
297
298 #[test]
299 fn model_id_display() {
300 let m = ModelId::new("claude-3");
301 assert_eq!(format!("{m}"), "claude-3");
302 }
303
304 #[test]
305 fn model_id_eq_and_hash() {
306 let a = ModelId::new("x");
307 let b = ModelId::new("x");
308 let c = ModelId::new("y");
309 assert_eq!(a, b);
310 assert_ne!(a, c);
311
312 let mut set = HashSet::new();
313 set.insert(a);
314 set.insert(b);
315 assert_eq!(set.len(), 1);
316 }
317
318 #[test]
319 fn model_id_serde_roundtrip() {
320 let m = ModelId::new("llama3.2");
321 let json = serde_json::to_string(&m).unwrap();
322 let back: ModelId = serde_json::from_str(&json).unwrap();
323 assert_eq!(m, back);
324 }
325
326 #[test]
329 fn message_user() {
330 let m = Message::user("hi");
331 assert_eq!(m.role, Role::User);
332 assert_eq!(m.text(), "hi");
333 }
334
335 #[test]
336 fn message_assistant() {
337 let m = Message::assistant("ok");
338 assert_eq!(m.role, Role::Assistant);
339 assert_eq!(m.text(), "ok");
340 }
341
342 #[test]
343 fn message_system() {
344 let m = Message::system("you are helpful");
345 assert_eq!(m.role, Role::System);
346 assert_eq!(m.text(), "you are helpful");
347 }
348
349 #[test]
350 fn message_text_concatenates_and_ignores_non_text() {
351 let m = Message::new(
352 Role::Assistant,
353 vec![
354 ContentBlock::text("a"),
355 ContentBlock::ToolUse {
356 id: "t1".into(),
357 name: "x".into(),
358 input: serde_json::json!({}),
359 },
360 ContentBlock::text("b"),
361 ],
362 );
363 assert_eq!(m.text(), "ab");
364 assert_eq!(m.tool_uses().count(), 1);
365 }
366
367 #[test]
368 fn message_tool_result_helper() {
369 let m = Message::tool_result("tu_1", "result body", false);
370 assert_eq!(m.role, Role::User);
371 match &m.content[0] {
372 ContentBlock::ToolResult {
373 tool_use_id,
374 content,
375 is_error,
376 } => {
377 assert_eq!(tool_use_id, "tu_1");
378 assert_eq!(content, "result body");
379 assert!(!is_error);
380 }
381 other => panic!("expected ToolResult, got {other:?}"),
382 }
383 }
384
385 #[test]
386 fn content_block_serde_roundtrip() {
387 for block in [
388 ContentBlock::text("hi"),
389 ContentBlock::ToolUse {
390 id: "id".into(),
391 name: "search".into(),
392 input: serde_json::json!({"q": "x"}),
393 },
394 ContentBlock::ToolResult {
395 tool_use_id: "id".into(),
396 content: "ok".into(),
397 is_error: false,
398 },
399 ] {
400 let json = serde_json::to_string(&block).unwrap();
401 let back: ContentBlock = serde_json::from_str(&json).unwrap();
402 assert_eq!(block, back);
403 }
404 }
405
406 #[test]
409 fn role_serde_roundtrip() {
410 for (role, expected) in [
411 (Role::User, "\"user\""),
412 (Role::Assistant, "\"assistant\""),
413 (Role::System, "\"system\""),
414 ] {
415 let json = serde_json::to_string(&role).unwrap();
416 assert_eq!(json, expected);
417 let back: Role = serde_json::from_str(&json).unwrap();
418 assert_eq!(back, role);
419 }
420 }
421
422 #[test]
425 fn request_default() {
426 let r = Request::default();
427 assert!(r.messages.is_empty());
428 assert_eq!(r.model, ModelId::default());
429 assert!(r.max_tokens.is_none());
430 assert!(r.temperature.is_none());
431 assert!(r.system.is_none());
432 assert!(r.stop.is_empty());
433 }
434
435 #[test]
438 fn response_display() {
439 let resp = Response {
440 content: "Hello!".into(),
441 tool_calls: vec![],
442 usage: Usage {
443 input_tokens: 10,
444 output_tokens: 5,
445 },
446 model: ModelId::new("test-model"),
447 finish_reason: FinishReason::Stop,
448 latency: Duration::from_millis(1234),
449 raw: serde_json::Value::Null,
450 };
451 let s = format!("{resp}");
452 assert!(s.contains("test-model"));
453 assert!(s.contains("Hello!"));
454 assert!(s.contains("10 in"));
455 assert!(s.contains("5 out"));
456 assert!(s.contains("Stop"));
457 assert!(s.contains("1"));
459 }
460
461 #[test]
464 fn finish_reason_default_is_stop() {
465 assert_eq!(FinishReason::default(), FinishReason::Stop);
466 }
467
468 #[test]
469 fn finish_reason_variants() {
470 assert_eq!(FinishReason::Stop, FinishReason::Stop);
471 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
472 assert_ne!(FinishReason::MaxTokens, FinishReason::ContentFilter);
473 let other = FinishReason::Other("custom".into());
474 assert_eq!(other, FinishReason::Other("custom".into()));
475 assert_ne!(other, FinishReason::Other("different".into()));
476 }
477
478 #[test]
481 fn usage_default() {
482 let u = Usage::default();
483 assert_eq!(u.input_tokens, 0);
484 assert_eq!(u.output_tokens, 0);
485 }
486
487 #[test]
490 fn stream_chunk_debug() {
491 let _ = format!("{:?}", StreamChunk::Delta("hi".into()));
492 let _ = format!("{:?}", StreamChunk::Done { usage: None });
493 let _ = format!(
494 "{:?}",
495 StreamChunk::Done {
496 usage: Some(Usage::default())
497 }
498 );
499 let _ = format!("{:?}", StreamChunk::Error("err".into()));
500 }
501
502 #[test]
505 fn embed_request_construction() {
506 let r = EmbedRequest {
507 model: ModelId::new("nomic"),
508 input: vec!["hello".into(), "world".into()],
509 };
510 assert_eq!(r.model.as_str(), "nomic");
511 assert_eq!(r.input.len(), 2);
512 }
513
514 #[test]
515 fn embedding_construction() {
516 let e = Embedding {
517 vectors: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
518 model: ModelId::new("nomic"),
519 usage: Usage {
520 input_tokens: 4,
521 output_tokens: 0,
522 },
523 };
524 assert_eq!(e.vectors.len(), 2);
525 assert_eq!(e.vectors[0].len(), 2);
526 assert_eq!(e.model.as_str(), "nomic");
527 }
528}