Skip to main content

autoagents_llm/
protocol.rs

1use crate::chat::{CompletionTokensDetails, ImageMime, PromptTokensDetails, StreamChunk, Usage};
2use crate::{FunctionCall, ToolCall};
3use autoagents_protocol as protocol;
4
5impl From<protocol::ImageMime> for ImageMime {
6    fn from(value: protocol::ImageMime) -> Self {
7        match value {
8            protocol::ImageMime::JPEG => ImageMime::JPEG,
9            protocol::ImageMime::PNG => ImageMime::PNG,
10            protocol::ImageMime::GIF => ImageMime::GIF,
11            protocol::ImageMime::WEBP => ImageMime::WEBP,
12            _ => ImageMime::PNG,
13        }
14    }
15}
16
17impl From<ImageMime> for protocol::ImageMime {
18    fn from(value: ImageMime) -> Self {
19        match value {
20            ImageMime::JPEG => protocol::ImageMime::JPEG,
21            ImageMime::PNG => protocol::ImageMime::PNG,
22            ImageMime::GIF => protocol::ImageMime::GIF,
23            ImageMime::WEBP => protocol::ImageMime::WEBP,
24        }
25    }
26}
27
28impl From<protocol::FunctionCall> for FunctionCall {
29    fn from(value: protocol::FunctionCall) -> Self {
30        Self {
31            name: value.name,
32            arguments: value.arguments,
33        }
34    }
35}
36
37impl From<FunctionCall> for protocol::FunctionCall {
38    fn from(value: FunctionCall) -> Self {
39        Self {
40            name: value.name,
41            arguments: value.arguments,
42        }
43    }
44}
45
46impl From<protocol::ToolCall> for ToolCall {
47    fn from(value: protocol::ToolCall) -> Self {
48        Self {
49            id: value.id,
50            call_type: value.call_type,
51            function: value.function.into(),
52        }
53    }
54}
55
56impl From<ToolCall> for protocol::ToolCall {
57    fn from(value: ToolCall) -> Self {
58        Self {
59            id: value.id,
60            call_type: value.call_type,
61            function: value.function.into(),
62        }
63    }
64}
65
66impl From<protocol::CompletionTokensDetails> for CompletionTokensDetails {
67    fn from(value: protocol::CompletionTokensDetails) -> Self {
68        Self {
69            reasoning_tokens: value.reasoning_tokens,
70            audio_tokens: value.audio_tokens,
71        }
72    }
73}
74
75impl From<CompletionTokensDetails> for protocol::CompletionTokensDetails {
76    fn from(value: CompletionTokensDetails) -> Self {
77        Self {
78            reasoning_tokens: value.reasoning_tokens,
79            audio_tokens: value.audio_tokens,
80        }
81    }
82}
83
84impl From<protocol::PromptTokensDetails> for PromptTokensDetails {
85    fn from(value: protocol::PromptTokensDetails) -> Self {
86        Self {
87            cached_tokens: value.cached_tokens,
88            audio_tokens: value.audio_tokens,
89        }
90    }
91}
92
93impl From<PromptTokensDetails> for protocol::PromptTokensDetails {
94    fn from(value: PromptTokensDetails) -> Self {
95        Self {
96            cached_tokens: value.cached_tokens,
97            audio_tokens: value.audio_tokens,
98        }
99    }
100}
101
102impl From<protocol::Usage> for Usage {
103    fn from(value: protocol::Usage) -> Self {
104        Self {
105            prompt_tokens: value.prompt_tokens,
106            completion_tokens: value.completion_tokens,
107            total_tokens: value.total_tokens,
108            completion_tokens_details: value
109                .completion_tokens_details
110                .map(CompletionTokensDetails::from),
111            prompt_tokens_details: value.prompt_tokens_details.map(PromptTokensDetails::from),
112        }
113    }
114}
115
116impl From<Usage> for protocol::Usage {
117    fn from(value: Usage) -> Self {
118        Self {
119            prompt_tokens: value.prompt_tokens,
120            completion_tokens: value.completion_tokens,
121            total_tokens: value.total_tokens,
122            completion_tokens_details: value
123                .completion_tokens_details
124                .map(protocol::CompletionTokensDetails::from),
125            prompt_tokens_details: value
126                .prompt_tokens_details
127                .map(protocol::PromptTokensDetails::from),
128        }
129    }
130}
131
132impl From<protocol::StreamChunk> for StreamChunk {
133    fn from(value: protocol::StreamChunk) -> Self {
134        match value {
135            protocol::StreamChunk::Text(text) => StreamChunk::Text(text),
136            protocol::StreamChunk::ReasoningContent(text) => StreamChunk::ReasoningContent(text),
137            protocol::StreamChunk::ToolUseStart { index, id, name } => {
138                StreamChunk::ToolUseStart { index, id, name }
139            }
140            protocol::StreamChunk::ToolUseInputDelta {
141                index,
142                partial_json,
143            } => StreamChunk::ToolUseInputDelta {
144                index,
145                partial_json,
146            },
147            protocol::StreamChunk::ToolUseComplete { index, tool_call } => {
148                StreamChunk::ToolUseComplete {
149                    index,
150                    tool_call: tool_call.into(),
151                }
152            }
153            protocol::StreamChunk::Done { stop_reason } => StreamChunk::Done { stop_reason },
154            protocol::StreamChunk::Usage(usage) => StreamChunk::Usage(usage.into()),
155        }
156    }
157}
158
159impl From<StreamChunk> for protocol::StreamChunk {
160    fn from(value: StreamChunk) -> Self {
161        match value {
162            StreamChunk::Text(text) => protocol::StreamChunk::Text(text),
163            StreamChunk::ReasoningContent(text) => protocol::StreamChunk::ReasoningContent(text),
164            StreamChunk::ToolUseStart { index, id, name } => {
165                protocol::StreamChunk::ToolUseStart { index, id, name }
166            }
167            StreamChunk::ToolUseInputDelta {
168                index,
169                partial_json,
170            } => protocol::StreamChunk::ToolUseInputDelta {
171                index,
172                partial_json,
173            },
174            StreamChunk::ToolUseComplete { index, tool_call } => {
175                protocol::StreamChunk::ToolUseComplete {
176                    index,
177                    tool_call: tool_call.into(),
178                }
179            }
180            StreamChunk::Done { stop_reason } => protocol::StreamChunk::Done { stop_reason },
181            StreamChunk::Usage(usage) => protocol::StreamChunk::Usage(usage.into()),
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn converts_stream_chunk_roundtrip() {
192        let chunk = StreamChunk::ToolUseStart {
193            index: 1,
194            id: "tool_1".to_string(),
195            name: "search".to_string(),
196        };
197        let protocol_chunk: protocol::StreamChunk = chunk.clone().into();
198        let roundtrip: StreamChunk = protocol_chunk.into();
199        assert_eq!(format!("{chunk:?}"), format!("{roundtrip:?}"));
200    }
201
202    #[test]
203    fn converts_usage_roundtrip() {
204        let usage = Usage {
205            prompt_tokens: 1,
206            completion_tokens: 2,
207            total_tokens: 3,
208            completion_tokens_details: Some(CompletionTokensDetails {
209                reasoning_tokens: Some(4),
210                audio_tokens: None,
211            }),
212            prompt_tokens_details: Some(PromptTokensDetails {
213                cached_tokens: Some(5),
214                audio_tokens: None,
215            }),
216        };
217        let protocol_usage: protocol::Usage = usage.clone().into();
218        let roundtrip: Usage = protocol_usage.into();
219        assert_eq!(usage, roundtrip);
220    }
221
222    #[test]
223    fn converts_image_mime_roundtrip() {
224        for mime in [
225            ImageMime::JPEG,
226            ImageMime::PNG,
227            ImageMime::GIF,
228            ImageMime::WEBP,
229        ] {
230            let proto: protocol::ImageMime = mime.into();
231            let back: ImageMime = proto.into();
232            assert_eq!(mime, back);
233        }
234    }
235
236    #[test]
237    fn converts_function_call_roundtrip() {
238        let fc = FunctionCall {
239            name: "search".to_string(),
240            arguments: r#"{"q":"test"}"#.to_string(),
241        };
242        let proto: protocol::FunctionCall = fc.clone().into();
243        let back: FunctionCall = proto.into();
244        assert_eq!(back.name, fc.name);
245        assert_eq!(back.arguments, fc.arguments);
246    }
247
248    #[test]
249    fn converts_tool_call_roundtrip() {
250        let tc = ToolCall {
251            id: "tc1".to_string(),
252            call_type: "function".to_string(),
253            function: FunctionCall {
254                name: "tool".to_string(),
255                arguments: "{}".to_string(),
256            },
257        };
258        let proto: protocol::ToolCall = tc.clone().into();
259        let back: ToolCall = proto.into();
260        assert_eq!(back.id, tc.id);
261        assert_eq!(back.call_type, tc.call_type);
262        assert_eq!(back.function.name, tc.function.name);
263    }
264
265    #[test]
266    fn converts_stream_chunk_text_roundtrip() {
267        let chunk = StreamChunk::Text("hello".to_string());
268        let proto: protocol::StreamChunk = chunk.into();
269        let back: StreamChunk = proto.into();
270        assert!(matches!(back, StreamChunk::Text(ref s) if s == "hello"));
271    }
272
273    #[test]
274    fn converts_stream_chunk_tool_use_input_delta() {
275        let chunk = StreamChunk::ToolUseInputDelta {
276            index: 0,
277            partial_json: r#"{"ke"#.to_string(),
278        };
279        let proto: protocol::StreamChunk = chunk.into();
280        let back: StreamChunk = proto.into();
281        assert!(matches!(
282            back,
283            StreamChunk::ToolUseInputDelta { index: 0, .. }
284        ));
285    }
286
287    #[test]
288    fn converts_stream_chunk_tool_use_complete() {
289        let tc = ToolCall {
290            id: "tc1".to_string(),
291            call_type: "function".to_string(),
292            function: FunctionCall {
293                name: "tool".to_string(),
294                arguments: "{}".to_string(),
295            },
296        };
297        let chunk = StreamChunk::ToolUseComplete {
298            index: 0,
299            tool_call: tc,
300        };
301        let proto: protocol::StreamChunk = chunk.into();
302        let back: StreamChunk = proto.into();
303        assert!(matches!(
304            back,
305            StreamChunk::ToolUseComplete { index: 0, .. }
306        ));
307    }
308
309    #[test]
310    fn converts_stream_chunk_done() {
311        let chunk = StreamChunk::Done {
312            stop_reason: "end_turn".to_string(),
313        };
314        let proto: protocol::StreamChunk = chunk.into();
315        let back: StreamChunk = proto.into();
316        assert!(matches!(back, StreamChunk::Done { ref stop_reason } if stop_reason == "end_turn"));
317    }
318
319    #[test]
320    fn converts_stream_chunk_usage() {
321        let usage = Usage {
322            prompt_tokens: 10,
323            completion_tokens: 20,
324            total_tokens: 30,
325            completion_tokens_details: None,
326            prompt_tokens_details: None,
327        };
328        let chunk = StreamChunk::Usage(usage);
329        let proto: protocol::StreamChunk = chunk.into();
330        let back: StreamChunk = proto.into();
331        assert!(matches!(back, StreamChunk::Usage(_)));
332    }
333
334    #[test]
335    fn converts_completion_tokens_details_roundtrip() {
336        let details = CompletionTokensDetails {
337            reasoning_tokens: Some(10),
338            audio_tokens: Some(5),
339        };
340        let proto: protocol::CompletionTokensDetails = details.clone().into();
341        let back: CompletionTokensDetails = proto.into();
342        assert_eq!(details, back);
343    }
344
345    #[test]
346    fn converts_prompt_tokens_details_roundtrip() {
347        let details = PromptTokensDetails {
348            cached_tokens: Some(100),
349            audio_tokens: None,
350        };
351        let proto: protocol::PromptTokensDetails = details.clone().into();
352        let back: PromptTokensDetails = proto.into();
353        assert_eq!(details, back);
354    }
355}