1use serde::{Deserialize, Serialize};
17
18pub mod backends;
20
21pub mod builder;
23
24pub mod chat;
26
27pub mod completion;
29
30pub mod embedding;
32
33pub mod error;
35
36pub mod evaluator;
38
39#[cfg(not(target_arch = "wasm32"))]
41pub mod secret_store;
42
43pub mod models;
45
46pub trait LLMProvider:
49 chat::ChatProvider
50 + completion::CompletionProvider
51 + embedding::EmbeddingProvider
52 + models::ModelsProvider
53 + Send
54 + Sync
55 + 'static
56{
57}
58
59#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
62pub struct ToolCall {
63 pub id: String,
65 #[serde(rename = "type")]
67 pub call_type: String,
68 pub function: FunctionCall,
70}
71
72#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
74pub struct FunctionCall {
75 pub name: String,
77 pub arguments: String,
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
85 use crate::completion::CompletionProvider;
86 use crate::embedding::EmbeddingProvider;
87 use crate::error::LLMError;
88 use async_trait::async_trait;
89 use serde_json::json;
90
91 #[test]
92 fn test_tool_call_creation() {
93 let tool_call = ToolCall {
94 id: "call_123".to_string(),
95 call_type: "function".to_string(),
96 function: FunctionCall {
97 name: "test_function".to_string(),
98 arguments: "{\"param\": \"value\"}".to_string(),
99 },
100 };
101
102 assert_eq!(tool_call.id, "call_123");
103 assert_eq!(tool_call.call_type, "function");
104 assert_eq!(tool_call.function.name, "test_function");
105 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
106 }
107
108 #[test]
109 fn test_tool_call_serialization() {
110 let tool_call = ToolCall {
111 id: "call_456".to_string(),
112 call_type: "function".to_string(),
113 function: FunctionCall {
114 name: "serialize_test".to_string(),
115 arguments: "{\"test\": true}".to_string(),
116 },
117 };
118
119 let serialized = serde_json::to_string(&tool_call).unwrap();
120 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
121
122 assert_eq!(deserialized.id, "call_456");
123 assert_eq!(deserialized.call_type, "function");
124 assert_eq!(deserialized.function.name, "serialize_test");
125 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
126 }
127
128 #[test]
129 fn test_tool_call_equality() {
130 let tool_call1 = ToolCall {
131 id: "call_1".to_string(),
132 call_type: "function".to_string(),
133 function: FunctionCall {
134 name: "equal_test".to_string(),
135 arguments: "{}".to_string(),
136 },
137 };
138
139 let tool_call2 = ToolCall {
140 id: "call_1".to_string(),
141 call_type: "function".to_string(),
142 function: FunctionCall {
143 name: "equal_test".to_string(),
144 arguments: "{}".to_string(),
145 },
146 };
147
148 let tool_call3 = ToolCall {
149 id: "call_2".to_string(),
150 call_type: "function".to_string(),
151 function: FunctionCall {
152 name: "equal_test".to_string(),
153 arguments: "{}".to_string(),
154 },
155 };
156
157 assert_eq!(tool_call1, tool_call2);
158 assert_ne!(tool_call1, tool_call3);
159 }
160
161 #[test]
162 fn test_tool_call_clone() {
163 let tool_call = ToolCall {
164 id: "clone_test".to_string(),
165 call_type: "function".to_string(),
166 function: FunctionCall {
167 name: "test_clone".to_string(),
168 arguments: "{\"clone\": true}".to_string(),
169 },
170 };
171
172 let cloned = tool_call.clone();
173 assert_eq!(tool_call, cloned);
174 assert_eq!(tool_call.id, cloned.id);
175 assert_eq!(tool_call.function.name, cloned.function.name);
176 }
177
178 #[test]
179 fn test_tool_call_debug() {
180 let tool_call = ToolCall {
181 id: "debug_test".to_string(),
182 call_type: "function".to_string(),
183 function: FunctionCall {
184 name: "debug_function".to_string(),
185 arguments: "{}".to_string(),
186 },
187 };
188
189 let debug_str = format!("{tool_call:?}");
190 assert!(debug_str.contains("ToolCall"));
191 assert!(debug_str.contains("debug_test"));
192 assert!(debug_str.contains("debug_function"));
193 }
194
195 #[test]
196 fn test_function_call_creation() {
197 let function_call = FunctionCall {
198 name: "test_function".to_string(),
199 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
200 };
201
202 assert_eq!(function_call.name, "test_function");
203 assert_eq!(
204 function_call.arguments,
205 "{\"param1\": \"value1\", \"param2\": 42}"
206 );
207 }
208
209 #[test]
210 fn test_function_call_serialization() {
211 let function_call = FunctionCall {
212 name: "serialize_function".to_string(),
213 arguments: "{\"data\": [1, 2, 3]}".to_string(),
214 };
215
216 let serialized = serde_json::to_string(&function_call).unwrap();
217 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
218
219 assert_eq!(deserialized.name, "serialize_function");
220 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
221 }
222
223 #[test]
224 fn test_function_call_equality() {
225 let func1 = FunctionCall {
226 name: "equal_func".to_string(),
227 arguments: "{}".to_string(),
228 };
229
230 let func2 = FunctionCall {
231 name: "equal_func".to_string(),
232 arguments: "{}".to_string(),
233 };
234
235 let func3 = FunctionCall {
236 name: "different_func".to_string(),
237 arguments: "{}".to_string(),
238 };
239
240 assert_eq!(func1, func2);
241 assert_ne!(func1, func3);
242 }
243
244 #[test]
245 fn test_function_call_clone() {
246 let function_call = FunctionCall {
247 name: "clone_func".to_string(),
248 arguments: "{\"clone\": \"test\"}".to_string(),
249 };
250
251 let cloned = function_call.clone();
252 assert_eq!(function_call, cloned);
253 assert_eq!(function_call.name, cloned.name);
254 assert_eq!(function_call.arguments, cloned.arguments);
255 }
256
257 #[test]
258 fn test_function_call_debug() {
259 let function_call = FunctionCall {
260 name: "debug_func".to_string(),
261 arguments: "{}".to_string(),
262 };
263
264 let debug_str = format!("{function_call:?}");
265 assert!(debug_str.contains("FunctionCall"));
266 assert!(debug_str.contains("debug_func"));
267 }
268
269 #[test]
270 fn test_tool_call_with_empty_values() {
271 let tool_call = ToolCall {
272 id: String::new(),
273 call_type: String::new(),
274 function: FunctionCall {
275 name: String::new(),
276 arguments: String::new(),
277 },
278 };
279
280 assert!(tool_call.id.is_empty());
281 assert!(tool_call.call_type.is_empty());
282 assert!(tool_call.function.name.is_empty());
283 assert!(tool_call.function.arguments.is_empty());
284 }
285
286 #[test]
287 fn test_tool_call_with_complex_arguments() {
288 let complex_args = json!({
289 "nested": {
290 "array": [1, 2, 3],
291 "object": {
292 "key": "value"
293 }
294 },
295 "simple": "string"
296 });
297
298 let tool_call = ToolCall {
299 id: "complex_call".to_string(),
300 call_type: "function".to_string(),
301 function: FunctionCall {
302 name: "complex_function".to_string(),
303 arguments: complex_args.to_string(),
304 },
305 };
306
307 let serialized = serde_json::to_string(&tool_call).unwrap();
308 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
309
310 assert_eq!(deserialized.id, "complex_call");
311 assert_eq!(deserialized.function.name, "complex_function");
312 assert!(deserialized.function.arguments.contains("nested"));
314 assert!(deserialized.function.arguments.contains("array"));
315 }
316
317 #[test]
318 fn test_tool_call_with_unicode() {
319 let tool_call = ToolCall {
320 id: "unicode_call".to_string(),
321 call_type: "function".to_string(),
322 function: FunctionCall {
323 name: "unicode_function".to_string(),
324 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
325 },
326 };
327
328 let serialized = serde_json::to_string(&tool_call).unwrap();
329 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
330
331 assert_eq!(deserialized.id, "unicode_call");
332 assert_eq!(deserialized.function.name, "unicode_function");
333 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
334 }
335
336 #[test]
337 fn test_tool_call_large_arguments() {
338 let large_arg = "x".repeat(10000);
339 let tool_call = ToolCall {
340 id: "large_call".to_string(),
341 call_type: "function".to_string(),
342 function: FunctionCall {
343 name: "large_function".to_string(),
344 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
345 },
346 };
347
348 let serialized = serde_json::to_string(&tool_call).unwrap();
349 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
350
351 assert_eq!(deserialized.id, "large_call");
352 assert_eq!(deserialized.function.name, "large_function");
353 assert!(deserialized.function.arguments.len() > 10000);
354 }
355
356 struct MockLLMProvider;
358
359 #[async_trait]
360 impl chat::ChatProvider for MockLLMProvider {
361 async fn chat(
362 &self,
363 _messages: &[ChatMessage],
364 _tools: Option<&[Tool]>,
365 _json_schema: Option<StructuredOutputFormat>,
366 ) -> Result<Box<dyn ChatResponse>, LLMError> {
367 Ok(Box::new(MockChatResponse {
368 text: Some("Mock response".into()),
369 }))
370 }
371 }
372
373 #[async_trait]
374 impl completion::CompletionProvider for MockLLMProvider {
375 async fn complete(
376 &self,
377 _req: &completion::CompletionRequest,
378 _json_schema: Option<chat::StructuredOutputFormat>,
379 ) -> Result<completion::CompletionResponse, error::LLMError> {
380 Ok(completion::CompletionResponse {
381 text: "Mock completion".to_string(),
382 })
383 }
384 }
385
386 #[async_trait]
387 impl embedding::EmbeddingProvider for MockLLMProvider {
388 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
389 let mut embeddings = Vec::new();
390 for (i, _) in input.iter().enumerate() {
391 embeddings.push(vec![i as f32, (i + 1) as f32]);
392 }
393 Ok(embeddings)
394 }
395 }
396
397 #[async_trait]
398 impl models::ModelsProvider for MockLLMProvider {}
399
400 impl LLMProvider for MockLLMProvider {}
401
402 struct MockChatResponse {
403 text: Option<String>,
404 }
405
406 impl chat::ChatResponse for MockChatResponse {
407 fn text(&self) -> Option<String> {
408 self.text.clone()
409 }
410
411 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
412 None
413 }
414 }
415
416 impl std::fmt::Debug for MockChatResponse {
417 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418 write!(f, "MockChatResponse")
419 }
420 }
421
422 impl std::fmt::Display for MockChatResponse {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 write!(f, "{}", self.text.as_deref().unwrap_or(""))
425 }
426 }
427
428 #[tokio::test]
429 async fn test_llm_provider_trait_chat() {
430 let provider = MockLLMProvider;
431 let messages = vec![chat::ChatMessage::user().content("Test").build()];
432
433 let response = provider.chat(&messages, None, None).await.unwrap();
434 assert_eq!(response.text(), Some("Mock response".to_string()));
435 }
436
437 #[tokio::test]
438 async fn test_llm_provider_trait_completion() {
439 let provider = MockLLMProvider;
440 let request = completion::CompletionRequest::new("Test prompt");
441
442 let response = provider.complete(&request, None).await.unwrap();
443 assert_eq!(response.text, "Mock completion");
444 }
445
446 #[tokio::test]
447 async fn test_llm_provider_trait_embedding() {
448 let provider = MockLLMProvider;
449 let input = vec!["First".to_string(), "Second".to_string()];
450
451 let embeddings = provider.embed(input).await.unwrap();
452 assert_eq!(embeddings.len(), 2);
453 assert_eq!(embeddings[0], vec![0.0, 1.0]);
454 assert_eq!(embeddings[1], vec![1.0, 2.0]);
455 }
456}