1use serde::{Deserialize, Serialize};
17
18pub mod backends;
20
21pub mod builder;
23
24pub mod chat;
26
27pub mod completion;
29
30pub mod embedding;
32
33pub mod error;
35
36pub mod evaluator;
38
39#[cfg(not(target_arch = "wasm32"))]
41pub mod secret_store;
42
43pub mod models;
45
46pub mod providers;
47
48pub use async_trait::async_trait;
50
51pub trait LLMProvider:
54 chat::ChatProvider
55 + completion::CompletionProvider
56 + embedding::EmbeddingProvider
57 + models::ModelsProvider
58 + Send
59 + Sync
60 + 'static
61{
62}
63
64#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
67pub struct ToolCall {
68 pub id: String,
70 #[serde(rename = "type")]
72 pub call_type: String,
73 pub function: FunctionCall,
75}
76
77#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
79pub struct FunctionCall {
80 pub name: String,
82 pub arguments: String,
84}
85
86pub fn default_call_type() -> String {
88 "function".to_string()
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
95 use crate::completion::CompletionProvider;
96 use crate::embedding::EmbeddingProvider;
97 use crate::error::LLMError;
98 use async_trait::async_trait;
99 use serde_json::json;
100
101 #[test]
102 fn test_tool_call_creation() {
103 let tool_call = ToolCall {
104 id: "call_123".to_string(),
105 call_type: "function".to_string(),
106 function: FunctionCall {
107 name: "test_function".to_string(),
108 arguments: "{\"param\": \"value\"}".to_string(),
109 },
110 };
111
112 assert_eq!(tool_call.id, "call_123");
113 assert_eq!(tool_call.call_type, "function");
114 assert_eq!(tool_call.function.name, "test_function");
115 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
116 }
117
118 #[test]
119 fn test_tool_call_serialization() {
120 let tool_call = ToolCall {
121 id: "call_456".to_string(),
122 call_type: "function".to_string(),
123 function: FunctionCall {
124 name: "serialize_test".to_string(),
125 arguments: "{\"test\": true}".to_string(),
126 },
127 };
128
129 let serialized = serde_json::to_string(&tool_call).unwrap();
130 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
131
132 assert_eq!(deserialized.id, "call_456");
133 assert_eq!(deserialized.call_type, "function");
134 assert_eq!(deserialized.function.name, "serialize_test");
135 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
136 }
137
138 #[test]
139 fn test_tool_call_equality() {
140 let tool_call1 = ToolCall {
141 id: "call_1".to_string(),
142 call_type: "function".to_string(),
143 function: FunctionCall {
144 name: "equal_test".to_string(),
145 arguments: "{}".to_string(),
146 },
147 };
148
149 let tool_call2 = ToolCall {
150 id: "call_1".to_string(),
151 call_type: "function".to_string(),
152 function: FunctionCall {
153 name: "equal_test".to_string(),
154 arguments: "{}".to_string(),
155 },
156 };
157
158 let tool_call3 = ToolCall {
159 id: "call_2".to_string(),
160 call_type: "function".to_string(),
161 function: FunctionCall {
162 name: "equal_test".to_string(),
163 arguments: "{}".to_string(),
164 },
165 };
166
167 assert_eq!(tool_call1, tool_call2);
168 assert_ne!(tool_call1, tool_call3);
169 }
170
171 #[test]
172 fn test_tool_call_clone() {
173 let tool_call = ToolCall {
174 id: "clone_test".to_string(),
175 call_type: "function".to_string(),
176 function: FunctionCall {
177 name: "test_clone".to_string(),
178 arguments: "{\"clone\": true}".to_string(),
179 },
180 };
181
182 let cloned = tool_call.clone();
183 assert_eq!(tool_call, cloned);
184 assert_eq!(tool_call.id, cloned.id);
185 assert_eq!(tool_call.function.name, cloned.function.name);
186 }
187
188 #[test]
189 fn test_tool_call_debug() {
190 let tool_call = ToolCall {
191 id: "debug_test".to_string(),
192 call_type: "function".to_string(),
193 function: FunctionCall {
194 name: "debug_function".to_string(),
195 arguments: "{}".to_string(),
196 },
197 };
198
199 let debug_str = format!("{tool_call:?}");
200 assert!(debug_str.contains("ToolCall"));
201 assert!(debug_str.contains("debug_test"));
202 assert!(debug_str.contains("debug_function"));
203 }
204
205 #[test]
206 fn test_function_call_creation() {
207 let function_call = FunctionCall {
208 name: "test_function".to_string(),
209 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
210 };
211
212 assert_eq!(function_call.name, "test_function");
213 assert_eq!(
214 function_call.arguments,
215 "{\"param1\": \"value1\", \"param2\": 42}"
216 );
217 }
218
219 #[test]
220 fn test_function_call_serialization() {
221 let function_call = FunctionCall {
222 name: "serialize_function".to_string(),
223 arguments: "{\"data\": [1, 2, 3]}".to_string(),
224 };
225
226 let serialized = serde_json::to_string(&function_call).unwrap();
227 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
228
229 assert_eq!(deserialized.name, "serialize_function");
230 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
231 }
232
233 #[test]
234 fn test_function_call_equality() {
235 let func1 = FunctionCall {
236 name: "equal_func".to_string(),
237 arguments: "{}".to_string(),
238 };
239
240 let func2 = FunctionCall {
241 name: "equal_func".to_string(),
242 arguments: "{}".to_string(),
243 };
244
245 let func3 = FunctionCall {
246 name: "different_func".to_string(),
247 arguments: "{}".to_string(),
248 };
249
250 assert_eq!(func1, func2);
251 assert_ne!(func1, func3);
252 }
253
254 #[test]
255 fn test_function_call_clone() {
256 let function_call = FunctionCall {
257 name: "clone_func".to_string(),
258 arguments: "{\"clone\": \"test\"}".to_string(),
259 };
260
261 let cloned = function_call.clone();
262 assert_eq!(function_call, cloned);
263 assert_eq!(function_call.name, cloned.name);
264 assert_eq!(function_call.arguments, cloned.arguments);
265 }
266
267 #[test]
268 fn test_function_call_debug() {
269 let function_call = FunctionCall {
270 name: "debug_func".to_string(),
271 arguments: "{}".to_string(),
272 };
273
274 let debug_str = format!("{function_call:?}");
275 assert!(debug_str.contains("FunctionCall"));
276 assert!(debug_str.contains("debug_func"));
277 }
278
279 #[test]
280 fn test_tool_call_with_empty_values() {
281 let tool_call = ToolCall {
282 id: String::new(),
283 call_type: String::new(),
284 function: FunctionCall {
285 name: String::new(),
286 arguments: String::new(),
287 },
288 };
289
290 assert!(tool_call.id.is_empty());
291 assert!(tool_call.call_type.is_empty());
292 assert!(tool_call.function.name.is_empty());
293 assert!(tool_call.function.arguments.is_empty());
294 }
295
296 #[test]
297 fn test_tool_call_with_complex_arguments() {
298 let complex_args = json!({
299 "nested": {
300 "array": [1, 2, 3],
301 "object": {
302 "key": "value"
303 }
304 },
305 "simple": "string"
306 });
307
308 let tool_call = ToolCall {
309 id: "complex_call".to_string(),
310 call_type: "function".to_string(),
311 function: FunctionCall {
312 name: "complex_function".to_string(),
313 arguments: complex_args.to_string(),
314 },
315 };
316
317 let serialized = serde_json::to_string(&tool_call).unwrap();
318 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
319
320 assert_eq!(deserialized.id, "complex_call");
321 assert_eq!(deserialized.function.name, "complex_function");
322 assert!(deserialized.function.arguments.contains("nested"));
324 assert!(deserialized.function.arguments.contains("array"));
325 }
326
327 #[test]
328 fn test_tool_call_with_unicode() {
329 let tool_call = ToolCall {
330 id: "unicode_call".to_string(),
331 call_type: "function".to_string(),
332 function: FunctionCall {
333 name: "unicode_function".to_string(),
334 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
335 },
336 };
337
338 let serialized = serde_json::to_string(&tool_call).unwrap();
339 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
340
341 assert_eq!(deserialized.id, "unicode_call");
342 assert_eq!(deserialized.function.name, "unicode_function");
343 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
344 }
345
346 #[test]
347 fn test_tool_call_large_arguments() {
348 let large_arg = "x".repeat(10000);
349 let tool_call = ToolCall {
350 id: "large_call".to_string(),
351 call_type: "function".to_string(),
352 function: FunctionCall {
353 name: "large_function".to_string(),
354 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
355 },
356 };
357
358 let serialized = serde_json::to_string(&tool_call).unwrap();
359 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
360
361 assert_eq!(deserialized.id, "large_call");
362 assert_eq!(deserialized.function.name, "large_function");
363 assert!(deserialized.function.arguments.len() > 10000);
364 }
365
366 struct MockLLMProvider;
368
369 #[async_trait]
370 impl chat::ChatProvider for MockLLMProvider {
371 async fn chat(
372 &self,
373 _messages: &[ChatMessage],
374 _tools: Option<&[Tool]>,
375 _json_schema: Option<StructuredOutputFormat>,
376 ) -> Result<Box<dyn ChatResponse>, LLMError> {
377 Ok(Box::new(MockChatResponse {
378 text: Some("Mock response".into()),
379 }))
380 }
381 }
382
383 #[async_trait]
384 impl completion::CompletionProvider for MockLLMProvider {
385 async fn complete(
386 &self,
387 _req: &completion::CompletionRequest,
388 _json_schema: Option<chat::StructuredOutputFormat>,
389 ) -> Result<completion::CompletionResponse, error::LLMError> {
390 Ok(completion::CompletionResponse {
391 text: "Mock completion".to_string(),
392 })
393 }
394 }
395
396 #[async_trait]
397 impl embedding::EmbeddingProvider for MockLLMProvider {
398 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
399 let mut embeddings = Vec::new();
400 for (i, _) in input.iter().enumerate() {
401 embeddings.push(vec![i as f32, (i + 1) as f32]);
402 }
403 Ok(embeddings)
404 }
405 }
406
407 #[async_trait]
408 impl models::ModelsProvider for MockLLMProvider {}
409
410 impl LLMProvider for MockLLMProvider {}
411
412 struct MockChatResponse {
413 text: Option<String>,
414 }
415
416 impl chat::ChatResponse for MockChatResponse {
417 fn text(&self) -> Option<String> {
418 self.text.clone()
419 }
420
421 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
422 None
423 }
424 }
425
426 impl std::fmt::Debug for MockChatResponse {
427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 write!(f, "MockChatResponse")
429 }
430 }
431
432 impl std::fmt::Display for MockChatResponse {
433 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434 write!(f, "{}", self.text.as_deref().unwrap_or(""))
435 }
436 }
437
438 #[tokio::test]
439 async fn test_llm_provider_trait_chat() {
440 let provider = MockLLMProvider;
441 let messages = vec![chat::ChatMessage::user().content("Test").build()];
442
443 let response = provider.chat(&messages, None, None).await.unwrap();
444 assert_eq!(response.text(), Some("Mock response".to_string()));
445 }
446
447 #[tokio::test]
448 async fn test_llm_provider_trait_completion() {
449 let provider = MockLLMProvider;
450 let request = completion::CompletionRequest::new("Test prompt");
451
452 let response = provider.complete(&request, None).await.unwrap();
453 assert_eq!(response.text, "Mock completion");
454 }
455
456 #[tokio::test]
457 async fn test_llm_provider_trait_embedding() {
458 let provider = MockLLMProvider;
459 let input = vec!["First".to_string(), "Second".to_string()];
460
461 let embeddings = provider.embed(input).await.unwrap();
462 assert_eq!(embeddings.len(), 2);
463 assert_eq!(embeddings[0], vec![0.0, 1.0]);
464 assert_eq!(embeddings[1], vec![1.0, 2.0]);
465 }
466}