1use rust_decimal::Decimal;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 System,
15 User,
17 Assistant,
19}
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum ImageSource {
25 Base64 {
27 data: String,
29 },
30 Url {
32 url: String,
34 },
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum ContentPart {
41 Text {
43 text: String,
45 },
46 ToolUse {
48 id: String,
50 name: String,
52 input: serde_json::Value,
54 },
55 ToolResult {
57 tool_use_id: String,
59 content: String,
61 is_error: bool,
63 },
64 Image {
66 source: ImageSource,
68 media_type: String,
70 },
71}
72
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct ProviderMessage {
76 pub role: Role,
78 pub content: Vec<ContentPart>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ToolSchema {
85 pub name: String,
87 pub description: String,
89 pub input_schema: serde_json::Value,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ProviderRequest {
96 pub model: Option<String>,
98 pub messages: Vec<ProviderMessage>,
100 pub tools: Vec<ToolSchema>,
102 pub max_tokens: Option<u32>,
104 pub temperature: Option<f64>,
106 pub system: Option<String>,
108 #[serde(default)]
110 pub extra: serde_json::Value,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum StopReason {
117 EndTurn,
119 ToolUse,
121 MaxTokens,
123 ContentFilter,
125}
126
127#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
129pub struct TokenUsage {
130 pub input_tokens: u64,
132 pub output_tokens: u64,
134 pub cache_read_tokens: Option<u64>,
136 pub cache_creation_tokens: Option<u64>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ProviderResponse {
143 pub content: Vec<ContentPart>,
145 pub stop_reason: StopReason,
147 pub usage: TokenUsage,
149 pub model: String,
151 pub cost: Option<Decimal>,
153 pub truncated: Option<bool>,
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use serde_json::json;
161
162 #[test]
163 fn role_serde_roundtrip() {
164 for role in [Role::System, Role::User, Role::Assistant] {
165 let json = serde_json::to_string(&role).unwrap();
166 let back: Role = serde_json::from_str(&json).unwrap();
167 assert_eq!(role, back);
168 }
169 }
170
171 #[test]
172 fn content_part_text_roundtrip() {
173 let part = ContentPart::Text {
174 text: "hello".into(),
175 };
176 let json = serde_json::to_value(&part).unwrap();
177 assert_eq!(json["type"], "text");
178 let back: ContentPart = serde_json::from_value(json).unwrap();
179 assert_eq!(part, back);
180 }
181
182 #[test]
183 fn content_part_tool_use_roundtrip() {
184 let part = ContentPart::ToolUse {
185 id: "tu_1".into(),
186 name: "bash".into(),
187 input: json!({"command": "ls"}),
188 };
189 let json = serde_json::to_value(&part).unwrap();
190 assert_eq!(json["type"], "tool_use");
191 let back: ContentPart = serde_json::from_value(json).unwrap();
192 assert_eq!(part, back);
193 }
194
195 #[test]
196 fn content_part_tool_result_roundtrip() {
197 let part = ContentPart::ToolResult {
198 tool_use_id: "tu_1".into(),
199 content: "file.txt".into(),
200 is_error: false,
201 };
202 let json = serde_json::to_value(&part).unwrap();
203 assert_eq!(json["type"], "tool_result");
204 let back: ContentPart = serde_json::from_value(json).unwrap();
205 assert_eq!(part, back);
206 }
207
208 #[test]
209 fn content_part_image_roundtrip() {
210 let part = ContentPart::Image {
211 source: ImageSource::Url {
212 url: "https://example.com/img.png".into(),
213 },
214 media_type: "image/png".into(),
215 };
216 let json = serde_json::to_value(&part).unwrap();
217 assert_eq!(json["type"], "image");
218 let back: ContentPart = serde_json::from_value(json).unwrap();
219 assert_eq!(part, back);
220 }
221
222 #[test]
223 fn stop_reason_roundtrip() {
224 for reason in [
225 StopReason::EndTurn,
226 StopReason::ToolUse,
227 StopReason::MaxTokens,
228 StopReason::ContentFilter,
229 ] {
230 let json = serde_json::to_string(&reason).unwrap();
231 let back: StopReason = serde_json::from_str(&json).unwrap();
232 assert_eq!(reason, back);
233 }
234 }
235
236 #[test]
237 fn provider_message_roundtrip() {
238 let msg = ProviderMessage {
239 role: Role::User,
240 content: vec![ContentPart::Text {
241 text: "hello".into(),
242 }],
243 };
244 let json = serde_json::to_value(&msg).unwrap();
245 let back: ProviderMessage = serde_json::from_value(json).unwrap();
246 assert_eq!(msg, back);
247 }
248
249 #[test]
250 fn token_usage_default() {
251 let usage = TokenUsage::default();
252 assert_eq!(usage.input_tokens, 0);
253 assert_eq!(usage.output_tokens, 0);
254 assert!(usage.cache_read_tokens.is_none());
255 }
256
257 #[test]
258 fn token_usage_serde_roundtrip() {
259 let usage = TokenUsage {
260 input_tokens: 100,
261 output_tokens: 50,
262 cache_read_tokens: Some(10),
263 cache_creation_tokens: Some(5),
264 };
265 let json = serde_json::to_value(&usage).unwrap();
266 let back: TokenUsage = serde_json::from_value(json).unwrap();
267 assert_eq!(usage, back);
268 }
269
270 #[test]
271 fn image_source_base64_roundtrip() {
272 let source = ImageSource::Base64 {
273 data: "aGVsbG8=".into(),
274 };
275 let json = serde_json::to_value(&source).unwrap();
276 assert_eq!(json["type"], "base64");
277 let back: ImageSource = serde_json::from_value(json).unwrap();
278 assert_eq!(source, back);
279 }
280
281 #[test]
282 fn image_source_url_roundtrip() {
283 let source = ImageSource::Url {
284 url: "https://example.com/img.png".into(),
285 };
286 let json = serde_json::to_value(&source).unwrap();
287 assert_eq!(json["type"], "url");
288 let back: ImageSource = serde_json::from_value(json).unwrap();
289 assert_eq!(source, back);
290 }
291
292 #[test]
293 fn provider_request_serde_roundtrip() {
294 let request = ProviderRequest {
295 model: Some("test-model".into()),
296 messages: vec![ProviderMessage {
297 role: Role::User,
298 content: vec![ContentPart::Text {
299 text: "hello".into(),
300 }],
301 }],
302 tools: vec![ToolSchema {
303 name: "bash".into(),
304 description: "Run a command".into(),
305 input_schema: json!({"type": "object"}),
306 }],
307 max_tokens: Some(1024),
308 temperature: Some(0.7),
309 system: Some("Be helpful".into()),
310 extra: json!({"key": "value"}),
311 };
312 let json = serde_json::to_value(&request).unwrap();
313 let back: ProviderRequest = serde_json::from_value(json).unwrap();
314 assert_eq!(back.model, Some("test-model".into()));
315 assert_eq!(back.messages.len(), 1);
316 assert_eq!(back.tools.len(), 1);
317 assert_eq!(back.max_tokens, Some(1024));
318 assert_eq!(back.system, Some("Be helpful".into()));
319 }
320
321 #[test]
322 fn provider_response_serde_roundtrip() {
323 let response = ProviderResponse {
324 content: vec![ContentPart::Text {
325 text: "hello".into(),
326 }],
327 stop_reason: StopReason::EndTurn,
328 usage: TokenUsage {
329 input_tokens: 10,
330 output_tokens: 5,
331 cache_read_tokens: None,
332 cache_creation_tokens: None,
333 },
334 model: "test-model".into(),
335 cost: Some(rust_decimal::Decimal::new(1, 4)),
336 truncated: None,
337 };
338 let json = serde_json::to_value(&response).unwrap();
339 let back: ProviderResponse = serde_json::from_value(json).unwrap();
340 assert_eq!(back.model, "test-model");
341 assert_eq!(back.stop_reason, StopReason::EndTurn);
342 assert_eq!(back.content.len(), 1);
343 }
344
345 #[test]
346 fn content_part_image_base64_roundtrip() {
347 let part = ContentPart::Image {
348 source: ImageSource::Base64 {
349 data: "aGVsbG8=".into(),
350 },
351 media_type: "image/jpeg".into(),
352 };
353 let json = serde_json::to_value(&part).unwrap();
354 assert_eq!(json["type"], "image");
355 let back: ContentPart = serde_json::from_value(json).unwrap();
356 assert_eq!(part, back);
357 }
358
359 #[test]
360 fn provider_message_multi_content_roundtrip() {
361 let msg = ProviderMessage {
362 role: Role::Assistant,
363 content: vec![
364 ContentPart::Text {
365 text: "Let me help.".into(),
366 },
367 ContentPart::ToolUse {
368 id: "tu_1".into(),
369 name: "bash".into(),
370 input: json!({"cmd": "ls"}),
371 },
372 ],
373 };
374 let json = serde_json::to_value(&msg).unwrap();
375 let back: ProviderMessage = serde_json::from_value(json).unwrap();
376 assert_eq!(msg, back);
377 }
378
379 #[test]
380 fn tool_result_with_error_roundtrip() {
381 let part = ContentPart::ToolResult {
382 tool_use_id: "tu_1".into(),
383 content: "command failed".into(),
384 is_error: true,
385 };
386 let json = serde_json::to_value(&part).unwrap();
387 let back: ContentPart = serde_json::from_value(json).unwrap();
388 assert_eq!(part, back);
389 }
390}