1use std::time::Duration;
2
3use futures::stream::BoxStream;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct ModelId(pub String);
12
13impl ModelId {
14 pub fn new(id: impl Into<String>) -> Self {
15 Self(id.into())
16 }
17
18 pub fn as_str(&self) -> &str {
19 &self.0
20 }
21}
22
23impl Default for ModelId {
24 fn default() -> Self {
25 Self("default".into())
26 }
27}
28
29impl std::fmt::Display for ModelId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.write_str(&self.0)
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Message {
41 pub role: Role,
42 pub content: String,
43}
44
45impl Message {
46 pub fn user(content: impl Into<String>) -> Self {
47 Self {
48 role: Role::User,
49 content: content.into(),
50 }
51 }
52
53 pub fn assistant(content: impl Into<String>) -> Self {
54 Self {
55 role: Role::Assistant,
56 content: content.into(),
57 }
58 }
59
60 pub fn system(content: impl Into<String>) -> Self {
61 Self {
62 role: Role::System,
63 content: content.into(),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
69#[serde(rename_all = "lowercase")]
70pub enum Role {
71 System,
72 User,
73 Assistant,
74}
75
76#[derive(Debug, Clone, Default)]
81pub struct Request {
82 pub messages: Vec<Message>,
83 pub model: ModelId,
84 pub max_tokens: Option<u32>,
85 pub temperature: Option<f32>,
86 pub system: Option<String>,
87 pub stop: Vec<String>,
88}
89
90#[derive(Debug, Clone)]
91pub struct Response {
92 pub content: String,
93 pub usage: Usage,
94 pub model: ModelId,
95 pub finish_reason: FinishReason,
96 pub latency: Duration,
97 pub raw: serde_json::Value,
98}
99
100#[derive(Debug, Clone, Default)]
101pub struct Usage {
102 pub input_tokens: u32,
103 pub output_tokens: u32,
104}
105
106#[derive(Debug, Clone, Default, PartialEq, Eq)]
107pub enum FinishReason {
108 #[default]
109 Stop,
110 MaxTokens,
111 ContentFilter,
112 Other(String),
113}
114
115impl std::fmt::Display for Response {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 writeln!(f, "[{}] ({:.0?})", self.model, self.latency)?;
118 writeln!(f, "{}", self.content)?;
119 write!(
120 f,
121 "tokens: {} in / {} out | finish: {:?}",
122 self.usage.input_tokens, self.usage.output_tokens, self.finish_reason
123 )
124 }
125}
126
127#[derive(Debug, Clone)]
132pub enum StreamChunk {
133 Delta(String),
134 Done { usage: Option<Usage> },
135 Error(String),
136}
137
138pub type StreamResponse<'a> = BoxStream<'a, StreamChunk>;
140
141#[derive(Debug, Clone)]
146pub struct EmbedRequest {
147 pub model: ModelId,
148 pub input: Vec<String>,
149}
150
151#[derive(Debug, Clone)]
152pub struct Embedding {
153 pub vectors: Vec<Vec<f32>>,
154 pub model: ModelId,
155 pub usage: Usage,
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use std::collections::HashSet;
162 use std::time::Duration;
163
164 #[test]
167 fn model_id_new_and_as_str() {
168 let m = ModelId::new("gpt-4");
169 assert_eq!(m.as_str(), "gpt-4");
170 }
171
172 #[test]
173 fn model_id_default() {
174 assert_eq!(ModelId::default().as_str(), "default");
175 }
176
177 #[test]
178 fn model_id_display() {
179 let m = ModelId::new("claude-3");
180 assert_eq!(format!("{m}"), "claude-3");
181 }
182
183 #[test]
184 fn model_id_eq_and_hash() {
185 let a = ModelId::new("x");
186 let b = ModelId::new("x");
187 let c = ModelId::new("y");
188 assert_eq!(a, b);
189 assert_ne!(a, c);
190
191 let mut set = HashSet::new();
192 set.insert(a);
193 set.insert(b);
194 assert_eq!(set.len(), 1);
195 }
196
197 #[test]
198 fn model_id_serde_roundtrip() {
199 let m = ModelId::new("llama3.2");
200 let json = serde_json::to_string(&m).unwrap();
201 let back: ModelId = serde_json::from_str(&json).unwrap();
202 assert_eq!(m, back);
203 }
204
205 #[test]
208 fn message_user() {
209 let m = Message::user("hi");
210 assert_eq!(m.role, Role::User);
211 assert_eq!(m.content, "hi");
212 }
213
214 #[test]
215 fn message_assistant() {
216 let m = Message::assistant("ok");
217 assert_eq!(m.role, Role::Assistant);
218 assert_eq!(m.content, "ok");
219 }
220
221 #[test]
222 fn message_system() {
223 let m = Message::system("you are helpful");
224 assert_eq!(m.role, Role::System);
225 assert_eq!(m.content, "you are helpful");
226 }
227
228 #[test]
231 fn role_serde_roundtrip() {
232 for (role, expected) in [
233 (Role::User, "\"user\""),
234 (Role::Assistant, "\"assistant\""),
235 (Role::System, "\"system\""),
236 ] {
237 let json = serde_json::to_string(&role).unwrap();
238 assert_eq!(json, expected);
239 let back: Role = serde_json::from_str(&json).unwrap();
240 assert_eq!(back, role);
241 }
242 }
243
244 #[test]
247 fn request_default() {
248 let r = Request::default();
249 assert!(r.messages.is_empty());
250 assert_eq!(r.model, ModelId::default());
251 assert!(r.max_tokens.is_none());
252 assert!(r.temperature.is_none());
253 assert!(r.system.is_none());
254 assert!(r.stop.is_empty());
255 }
256
257 #[test]
260 fn response_display() {
261 let resp = Response {
262 content: "Hello!".into(),
263 usage: Usage {
264 input_tokens: 10,
265 output_tokens: 5,
266 },
267 model: ModelId::new("test-model"),
268 finish_reason: FinishReason::Stop,
269 latency: Duration::from_millis(1234),
270 raw: serde_json::Value::Null,
271 };
272 let s = format!("{resp}");
273 assert!(s.contains("test-model"));
274 assert!(s.contains("Hello!"));
275 assert!(s.contains("10 in"));
276 assert!(s.contains("5 out"));
277 assert!(s.contains("Stop"));
278 assert!(s.contains("1"));
280 }
281
282 #[test]
285 fn finish_reason_default_is_stop() {
286 assert_eq!(FinishReason::default(), FinishReason::Stop);
287 }
288
289 #[test]
290 fn finish_reason_variants() {
291 assert_eq!(FinishReason::Stop, FinishReason::Stop);
292 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
293 assert_ne!(FinishReason::MaxTokens, FinishReason::ContentFilter);
294 let other = FinishReason::Other("custom".into());
295 assert_eq!(other, FinishReason::Other("custom".into()));
296 assert_ne!(other, FinishReason::Other("different".into()));
297 }
298
299 #[test]
302 fn usage_default() {
303 let u = Usage::default();
304 assert_eq!(u.input_tokens, 0);
305 assert_eq!(u.output_tokens, 0);
306 }
307
308 #[test]
311 fn stream_chunk_debug() {
312 let _ = format!("{:?}", StreamChunk::Delta("hi".into()));
313 let _ = format!("{:?}", StreamChunk::Done { usage: None });
314 let _ = format!(
315 "{:?}",
316 StreamChunk::Done {
317 usage: Some(Usage::default())
318 }
319 );
320 let _ = format!("{:?}", StreamChunk::Error("err".into()));
321 }
322
323 #[test]
326 fn embed_request_construction() {
327 let r = EmbedRequest {
328 model: ModelId::new("nomic"),
329 input: vec!["hello".into(), "world".into()],
330 };
331 assert_eq!(r.model.as_str(), "nomic");
332 assert_eq!(r.input.len(), 2);
333 }
334
335 #[test]
336 fn embedding_construction() {
337 let e = Embedding {
338 vectors: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
339 model: ModelId::new("nomic"),
340 usage: Usage {
341 input_tokens: 4,
342 output_tokens: 0,
343 },
344 };
345 assert_eq!(e.vectors.len(), 2);
346 assert_eq!(e.vectors[0].len(), 2);
347 assert_eq!(e.model.as_str(), "nomic");
348 }
349}