Skip to main content

autoagents_llm/
protocol.rs

1use crate::chat::{CompletionTokensDetails, ImageMime, PromptTokensDetails, StreamChunk, Usage};
2use crate::{FunctionCall, ToolCall};
3use autoagents_protocol as protocol;
4
5impl From<protocol::ImageMime> for ImageMime {
6    fn from(value: protocol::ImageMime) -> Self {
7        match value {
8            protocol::ImageMime::JPEG => ImageMime::JPEG,
9            protocol::ImageMime::PNG => ImageMime::PNG,
10            protocol::ImageMime::GIF => ImageMime::GIF,
11            protocol::ImageMime::WEBP => ImageMime::WEBP,
12            _ => ImageMime::PNG,
13        }
14    }
15}
16
17impl From<ImageMime> for protocol::ImageMime {
18    fn from(value: ImageMime) -> Self {
19        match value {
20            ImageMime::JPEG => protocol::ImageMime::JPEG,
21            ImageMime::PNG => protocol::ImageMime::PNG,
22            ImageMime::GIF => protocol::ImageMime::GIF,
23            ImageMime::WEBP => protocol::ImageMime::WEBP,
24        }
25    }
26}
27
28impl From<protocol::FunctionCall> for FunctionCall {
29    fn from(value: protocol::FunctionCall) -> Self {
30        Self {
31            name: value.name,
32            arguments: value.arguments,
33        }
34    }
35}
36
37impl From<FunctionCall> for protocol::FunctionCall {
38    fn from(value: FunctionCall) -> Self {
39        Self {
40            name: value.name,
41            arguments: value.arguments,
42        }
43    }
44}
45
46impl From<protocol::ToolCall> for ToolCall {
47    fn from(value: protocol::ToolCall) -> Self {
48        Self {
49            id: value.id,
50            call_type: value.call_type,
51            function: value.function.into(),
52        }
53    }
54}
55
56impl From<ToolCall> for protocol::ToolCall {
57    fn from(value: ToolCall) -> Self {
58        Self {
59            id: value.id,
60            call_type: value.call_type,
61            function: value.function.into(),
62        }
63    }
64}
65
66impl From<protocol::CompletionTokensDetails> for CompletionTokensDetails {
67    fn from(value: protocol::CompletionTokensDetails) -> Self {
68        Self {
69            reasoning_tokens: value.reasoning_tokens,
70            audio_tokens: value.audio_tokens,
71        }
72    }
73}
74
75impl From<CompletionTokensDetails> for protocol::CompletionTokensDetails {
76    fn from(value: CompletionTokensDetails) -> Self {
77        Self {
78            reasoning_tokens: value.reasoning_tokens,
79            audio_tokens: value.audio_tokens,
80        }
81    }
82}
83
84impl From<protocol::PromptTokensDetails> for PromptTokensDetails {
85    fn from(value: protocol::PromptTokensDetails) -> Self {
86        Self {
87            cached_tokens: value.cached_tokens,
88            audio_tokens: value.audio_tokens,
89        }
90    }
91}
92
93impl From<PromptTokensDetails> for protocol::PromptTokensDetails {
94    fn from(value: PromptTokensDetails) -> Self {
95        Self {
96            cached_tokens: value.cached_tokens,
97            audio_tokens: value.audio_tokens,
98        }
99    }
100}
101
102impl From<protocol::Usage> for Usage {
103    fn from(value: protocol::Usage) -> Self {
104        Self {
105            prompt_tokens: value.prompt_tokens,
106            completion_tokens: value.completion_tokens,
107            total_tokens: value.total_tokens,
108            completion_tokens_details: value
109                .completion_tokens_details
110                .map(CompletionTokensDetails::from),
111            prompt_tokens_details: value.prompt_tokens_details.map(PromptTokensDetails::from),
112        }
113    }
114}
115
116impl From<Usage> for protocol::Usage {
117    fn from(value: Usage) -> Self {
118        Self {
119            prompt_tokens: value.prompt_tokens,
120            completion_tokens: value.completion_tokens,
121            total_tokens: value.total_tokens,
122            completion_tokens_details: value
123                .completion_tokens_details
124                .map(protocol::CompletionTokensDetails::from),
125            prompt_tokens_details: value
126                .prompt_tokens_details
127                .map(protocol::PromptTokensDetails::from),
128        }
129    }
130}
131
132impl From<protocol::StreamChunk> for StreamChunk {
133    fn from(value: protocol::StreamChunk) -> Self {
134        match value {
135            protocol::StreamChunk::Text(text) => StreamChunk::Text(text),
136            protocol::StreamChunk::ToolUseStart { index, id, name } => {
137                StreamChunk::ToolUseStart { index, id, name }
138            }
139            protocol::StreamChunk::ToolUseInputDelta {
140                index,
141                partial_json,
142            } => StreamChunk::ToolUseInputDelta {
143                index,
144                partial_json,
145            },
146            protocol::StreamChunk::ToolUseComplete { index, tool_call } => {
147                StreamChunk::ToolUseComplete {
148                    index,
149                    tool_call: tool_call.into(),
150                }
151            }
152            protocol::StreamChunk::Done { stop_reason } => StreamChunk::Done { stop_reason },
153            protocol::StreamChunk::Usage(usage) => StreamChunk::Usage(usage.into()),
154        }
155    }
156}
157
158impl From<StreamChunk> for protocol::StreamChunk {
159    fn from(value: StreamChunk) -> Self {
160        match value {
161            StreamChunk::Text(text) => protocol::StreamChunk::Text(text),
162            StreamChunk::ToolUseStart { index, id, name } => {
163                protocol::StreamChunk::ToolUseStart { index, id, name }
164            }
165            StreamChunk::ToolUseInputDelta {
166                index,
167                partial_json,
168            } => protocol::StreamChunk::ToolUseInputDelta {
169                index,
170                partial_json,
171            },
172            StreamChunk::ToolUseComplete { index, tool_call } => {
173                protocol::StreamChunk::ToolUseComplete {
174                    index,
175                    tool_call: tool_call.into(),
176                }
177            }
178            StreamChunk::Done { stop_reason } => protocol::StreamChunk::Done { stop_reason },
179            StreamChunk::Usage(usage) => protocol::StreamChunk::Usage(usage.into()),
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn converts_stream_chunk_roundtrip() {
190        let chunk = StreamChunk::ToolUseStart {
191            index: 1,
192            id: "tool_1".to_string(),
193            name: "search".to_string(),
194        };
195        let protocol_chunk: protocol::StreamChunk = chunk.clone().into();
196        let roundtrip: StreamChunk = protocol_chunk.into();
197        assert_eq!(format!("{chunk:?}"), format!("{roundtrip:?}"));
198    }
199
200    #[test]
201    fn converts_usage_roundtrip() {
202        let usage = Usage {
203            prompt_tokens: 1,
204            completion_tokens: 2,
205            total_tokens: 3,
206            completion_tokens_details: Some(CompletionTokensDetails {
207                reasoning_tokens: Some(4),
208                audio_tokens: None,
209            }),
210            prompt_tokens_details: Some(PromptTokensDetails {
211                cached_tokens: Some(5),
212                audio_tokens: None,
213            }),
214        };
215        let protocol_usage: protocol::Usage = usage.clone().into();
216        let roundtrip: Usage = protocol_usage.into();
217        assert_eq!(usage, roundtrip);
218    }
219
220    #[test]
221    fn converts_image_mime_roundtrip() {
222        for mime in [
223            ImageMime::JPEG,
224            ImageMime::PNG,
225            ImageMime::GIF,
226            ImageMime::WEBP,
227        ] {
228            let proto: protocol::ImageMime = mime.into();
229            let back: ImageMime = proto.into();
230            assert_eq!(mime, back);
231        }
232    }
233
234    #[test]
235    fn converts_function_call_roundtrip() {
236        let fc = FunctionCall {
237            name: "search".to_string(),
238            arguments: r#"{"q":"test"}"#.to_string(),
239        };
240        let proto: protocol::FunctionCall = fc.clone().into();
241        let back: FunctionCall = proto.into();
242        assert_eq!(back.name, fc.name);
243        assert_eq!(back.arguments, fc.arguments);
244    }
245
246    #[test]
247    fn converts_tool_call_roundtrip() {
248        let tc = ToolCall {
249            id: "tc1".to_string(),
250            call_type: "function".to_string(),
251            function: FunctionCall {
252                name: "tool".to_string(),
253                arguments: "{}".to_string(),
254            },
255        };
256        let proto: protocol::ToolCall = tc.clone().into();
257        let back: ToolCall = proto.into();
258        assert_eq!(back.id, tc.id);
259        assert_eq!(back.call_type, tc.call_type);
260        assert_eq!(back.function.name, tc.function.name);
261    }
262
263    #[test]
264    fn converts_stream_chunk_text_roundtrip() {
265        let chunk = StreamChunk::Text("hello".to_string());
266        let proto: protocol::StreamChunk = chunk.into();
267        let back: StreamChunk = proto.into();
268        assert!(matches!(back, StreamChunk::Text(ref s) if s == "hello"));
269    }
270
271    #[test]
272    fn converts_stream_chunk_tool_use_input_delta() {
273        let chunk = StreamChunk::ToolUseInputDelta {
274            index: 0,
275            partial_json: r#"{"ke"#.to_string(),
276        };
277        let proto: protocol::StreamChunk = chunk.into();
278        let back: StreamChunk = proto.into();
279        assert!(matches!(
280            back,
281            StreamChunk::ToolUseInputDelta { index: 0, .. }
282        ));
283    }
284
285    #[test]
286    fn converts_stream_chunk_tool_use_complete() {
287        let tc = ToolCall {
288            id: "tc1".to_string(),
289            call_type: "function".to_string(),
290            function: FunctionCall {
291                name: "tool".to_string(),
292                arguments: "{}".to_string(),
293            },
294        };
295        let chunk = StreamChunk::ToolUseComplete {
296            index: 0,
297            tool_call: tc,
298        };
299        let proto: protocol::StreamChunk = chunk.into();
300        let back: StreamChunk = proto.into();
301        assert!(matches!(
302            back,
303            StreamChunk::ToolUseComplete { index: 0, .. }
304        ));
305    }
306
307    #[test]
308    fn converts_stream_chunk_done() {
309        let chunk = StreamChunk::Done {
310            stop_reason: "end_turn".to_string(),
311        };
312        let proto: protocol::StreamChunk = chunk.into();
313        let back: StreamChunk = proto.into();
314        assert!(matches!(back, StreamChunk::Done { ref stop_reason } if stop_reason == "end_turn"));
315    }
316
317    #[test]
318    fn converts_stream_chunk_usage() {
319        let usage = Usage {
320            prompt_tokens: 10,
321            completion_tokens: 20,
322            total_tokens: 30,
323            completion_tokens_details: None,
324            prompt_tokens_details: None,
325        };
326        let chunk = StreamChunk::Usage(usage);
327        let proto: protocol::StreamChunk = chunk.into();
328        let back: StreamChunk = proto.into();
329        assert!(matches!(back, StreamChunk::Usage(_)));
330    }
331
332    #[test]
333    fn converts_completion_tokens_details_roundtrip() {
334        let details = CompletionTokensDetails {
335            reasoning_tokens: Some(10),
336            audio_tokens: Some(5),
337        };
338        let proto: protocol::CompletionTokensDetails = details.clone().into();
339        let back: CompletionTokensDetails = proto.into();
340        assert_eq!(details, back);
341    }
342
343    #[test]
344    fn converts_prompt_tokens_details_roundtrip() {
345        let details = PromptTokensDetails {
346            cached_tokens: Some(100),
347            audio_tokens: None,
348        };
349        let proto: protocol::PromptTokensDetails = details.clone().into();
350        let back: PromptTokensDetails = proto.into();
351        assert_eq!(details, back);
352    }
353}