1use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20pub mod backends;
22
23pub mod builder;
25
26pub mod chat;
28
29pub mod completion;
31
32pub mod embedding;
34
35pub mod error;
37
38pub mod evaluator;
40
41#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45pub mod models;
47
48mod protocol;
49pub mod providers;
50
51pub use async_trait::async_trait;
53
54pub trait LLMProvider:
57 chat::ChatProvider
58 + completion::CompletionProvider
59 + embedding::EmbeddingProvider
60 + models::ModelsProvider
61 + Send
62 + Sync
63 + 'static
64{
65}
66
67#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
70pub struct ToolCall {
71 pub id: String,
73 #[serde(rename = "type")]
75 pub call_type: String,
76 pub function: FunctionCall,
78}
79
80impl Display for ToolCall {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(
83 f,
84 "ToolCall {{ id: {}, type: {}, function: {:?} }}",
85 self.id, self.call_type, self.function
86 )
87 }
88}
89
90#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
92pub struct FunctionCall {
93 pub name: String,
95 pub arguments: String,
97}
98
99pub fn default_call_type() -> String {
101 "function".to_string()
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
108 use crate::completion::CompletionProvider;
109 use crate::embedding::EmbeddingProvider;
110 use crate::error::LLMError;
111 use async_trait::async_trait;
112 use serde_json::json;
113
114 #[test]
115 fn test_tool_call_creation() {
116 let tool_call = ToolCall {
117 id: "call_123".to_string(),
118 call_type: "function".to_string(),
119 function: FunctionCall {
120 name: "test_function".to_string(),
121 arguments: "{\"param\": \"value\"}".to_string(),
122 },
123 };
124
125 assert_eq!(tool_call.id, "call_123");
126 assert_eq!(tool_call.call_type, "function");
127 assert_eq!(tool_call.function.name, "test_function");
128 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
129 }
130
131 #[test]
132 fn test_tool_call_serialization() {
133 let tool_call = ToolCall {
134 id: "call_456".to_string(),
135 call_type: "function".to_string(),
136 function: FunctionCall {
137 name: "serialize_test".to_string(),
138 arguments: "{\"test\": true}".to_string(),
139 },
140 };
141
142 let serialized = serde_json::to_string(&tool_call).unwrap();
143 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
144
145 assert_eq!(deserialized.id, "call_456");
146 assert_eq!(deserialized.call_type, "function");
147 assert_eq!(deserialized.function.name, "serialize_test");
148 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
149 }
150
151 #[test]
152 fn test_tool_call_equality() {
153 let tool_call1 = ToolCall {
154 id: "call_1".to_string(),
155 call_type: "function".to_string(),
156 function: FunctionCall {
157 name: "equal_test".to_string(),
158 arguments: "{}".to_string(),
159 },
160 };
161
162 let tool_call2 = ToolCall {
163 id: "call_1".to_string(),
164 call_type: "function".to_string(),
165 function: FunctionCall {
166 name: "equal_test".to_string(),
167 arguments: "{}".to_string(),
168 },
169 };
170
171 let tool_call3 = ToolCall {
172 id: "call_2".to_string(),
173 call_type: "function".to_string(),
174 function: FunctionCall {
175 name: "equal_test".to_string(),
176 arguments: "{}".to_string(),
177 },
178 };
179
180 assert_eq!(tool_call1, tool_call2);
181 assert_ne!(tool_call1, tool_call3);
182 }
183
184 #[test]
185 fn test_tool_call_clone() {
186 let tool_call = ToolCall {
187 id: "clone_test".to_string(),
188 call_type: "function".to_string(),
189 function: FunctionCall {
190 name: "test_clone".to_string(),
191 arguments: "{\"clone\": true}".to_string(),
192 },
193 };
194
195 let cloned = tool_call.clone();
196 assert_eq!(tool_call, cloned);
197 assert_eq!(tool_call.id, cloned.id);
198 assert_eq!(tool_call.function.name, cloned.function.name);
199 }
200
201 #[test]
202 fn test_tool_call_debug() {
203 let tool_call = ToolCall {
204 id: "debug_test".to_string(),
205 call_type: "function".to_string(),
206 function: FunctionCall {
207 name: "debug_function".to_string(),
208 arguments: "{}".to_string(),
209 },
210 };
211
212 let debug_str = format!("{tool_call:?}");
213 assert!(debug_str.contains("ToolCall"));
214 assert!(debug_str.contains("debug_test"));
215 assert!(debug_str.contains("debug_function"));
216 }
217
218 #[test]
219 fn test_function_call_creation() {
220 let function_call = FunctionCall {
221 name: "test_function".to_string(),
222 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
223 };
224
225 assert_eq!(function_call.name, "test_function");
226 assert_eq!(
227 function_call.arguments,
228 "{\"param1\": \"value1\", \"param2\": 42}"
229 );
230 }
231
232 #[test]
233 fn test_function_call_serialization() {
234 let function_call = FunctionCall {
235 name: "serialize_function".to_string(),
236 arguments: "{\"data\": [1, 2, 3]}".to_string(),
237 };
238
239 let serialized = serde_json::to_string(&function_call).unwrap();
240 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
241
242 assert_eq!(deserialized.name, "serialize_function");
243 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
244 }
245
246 #[test]
247 fn test_function_call_equality() {
248 let func1 = FunctionCall {
249 name: "equal_func".to_string(),
250 arguments: "{}".to_string(),
251 };
252
253 let func2 = FunctionCall {
254 name: "equal_func".to_string(),
255 arguments: "{}".to_string(),
256 };
257
258 let func3 = FunctionCall {
259 name: "different_func".to_string(),
260 arguments: "{}".to_string(),
261 };
262
263 assert_eq!(func1, func2);
264 assert_ne!(func1, func3);
265 }
266
267 #[test]
268 fn test_function_call_clone() {
269 let function_call = FunctionCall {
270 name: "clone_func".to_string(),
271 arguments: "{\"clone\": \"test\"}".to_string(),
272 };
273
274 let cloned = function_call.clone();
275 assert_eq!(function_call, cloned);
276 assert_eq!(function_call.name, cloned.name);
277 assert_eq!(function_call.arguments, cloned.arguments);
278 }
279
280 #[test]
281 fn test_function_call_debug() {
282 let function_call = FunctionCall {
283 name: "debug_func".to_string(),
284 arguments: "{}".to_string(),
285 };
286
287 let debug_str = format!("{function_call:?}");
288 assert!(debug_str.contains("FunctionCall"));
289 assert!(debug_str.contains("debug_func"));
290 }
291
292 #[test]
293 fn test_tool_call_with_empty_values() {
294 let tool_call = ToolCall {
295 id: String::new(),
296 call_type: String::new(),
297 function: FunctionCall {
298 name: String::new(),
299 arguments: String::new(),
300 },
301 };
302
303 assert!(tool_call.id.is_empty());
304 assert!(tool_call.call_type.is_empty());
305 assert!(tool_call.function.name.is_empty());
306 assert!(tool_call.function.arguments.is_empty());
307 }
308
309 #[test]
310 fn test_tool_call_with_complex_arguments() {
311 let complex_args = json!({
312 "nested": {
313 "array": [1, 2, 3],
314 "object": {
315 "key": "value"
316 }
317 },
318 "simple": "string"
319 });
320
321 let tool_call = ToolCall {
322 id: "complex_call".to_string(),
323 call_type: "function".to_string(),
324 function: FunctionCall {
325 name: "complex_function".to_string(),
326 arguments: complex_args.to_string(),
327 },
328 };
329
330 let serialized = serde_json::to_string(&tool_call).unwrap();
331 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
332
333 assert_eq!(deserialized.id, "complex_call");
334 assert_eq!(deserialized.function.name, "complex_function");
335 assert!(deserialized.function.arguments.contains("nested"));
337 assert!(deserialized.function.arguments.contains("array"));
338 }
339
340 #[test]
341 fn test_tool_call_with_unicode() {
342 let tool_call = ToolCall {
343 id: "unicode_call".to_string(),
344 call_type: "function".to_string(),
345 function: FunctionCall {
346 name: "unicode_function".to_string(),
347 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
348 },
349 };
350
351 let serialized = serde_json::to_string(&tool_call).unwrap();
352 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
353
354 assert_eq!(deserialized.id, "unicode_call");
355 assert_eq!(deserialized.function.name, "unicode_function");
356 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
357 }
358
359 #[test]
360 fn test_tool_call_large_arguments() {
361 let large_arg = "x".repeat(10000);
362 let tool_call = ToolCall {
363 id: "large_call".to_string(),
364 call_type: "function".to_string(),
365 function: FunctionCall {
366 name: "large_function".to_string(),
367 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
368 },
369 };
370
371 let serialized = serde_json::to_string(&tool_call).unwrap();
372 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
373
374 assert_eq!(deserialized.id, "large_call");
375 assert_eq!(deserialized.function.name, "large_function");
376 assert!(deserialized.function.arguments.len() > 10000);
377 }
378
379 struct MockLLMProvider;
381
382 #[async_trait]
383 impl chat::ChatProvider for MockLLMProvider {
384 async fn chat(
385 &self,
386 _messages: &[ChatMessage],
387 _json_schema: Option<StructuredOutputFormat>,
388 ) -> Result<Box<dyn ChatResponse>, LLMError> {
389 Ok(Box::new(MockChatResponse {
390 text: Some("Mock response".into()),
391 }))
392 }
393
394 async fn chat_with_tools(
395 &self,
396 _messages: &[ChatMessage],
397 _tools: Option<&[Tool]>,
398 _json_schema: Option<StructuredOutputFormat>,
399 ) -> Result<Box<dyn ChatResponse>, LLMError> {
400 Ok(Box::new(MockChatResponse {
401 text: Some("Mock response".into()),
402 }))
403 }
404 }
405
406 #[async_trait]
407 impl completion::CompletionProvider for MockLLMProvider {
408 async fn complete(
409 &self,
410 _req: &completion::CompletionRequest,
411 _json_schema: Option<chat::StructuredOutputFormat>,
412 ) -> Result<completion::CompletionResponse, error::LLMError> {
413 Ok(completion::CompletionResponse {
414 text: "Mock completion".to_string(),
415 })
416 }
417 }
418
419 #[async_trait]
420 impl embedding::EmbeddingProvider for MockLLMProvider {
421 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
422 let mut embeddings = Vec::new();
423 for (i, _) in input.iter().enumerate() {
424 embeddings.push(vec![i as f32, (i + 1) as f32]);
425 }
426 Ok(embeddings)
427 }
428 }
429
430 #[async_trait]
431 impl models::ModelsProvider for MockLLMProvider {}
432
433 impl LLMProvider for MockLLMProvider {}
434
435 struct MockChatResponse {
436 text: Option<String>,
437 }
438
439 impl chat::ChatResponse for MockChatResponse {
440 fn text(&self) -> Option<String> {
441 self.text.clone()
442 }
443
444 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
445 None
446 }
447 }
448
449 impl std::fmt::Debug for MockChatResponse {
450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451 write!(f, "MockChatResponse")
452 }
453 }
454
455 impl std::fmt::Display for MockChatResponse {
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 write!(f, "{}", self.text.as_deref().unwrap_or(""))
458 }
459 }
460
461 #[tokio::test]
462 async fn test_llm_provider_trait_chat() {
463 let provider = MockLLMProvider;
464 let messages = vec![chat::ChatMessage::user().content("Test").build()];
465
466 let response = provider.chat(&messages, None).await.unwrap();
467 assert_eq!(response.text(), Some("Mock response".to_string()));
468 }
469
470 #[tokio::test]
471 async fn test_llm_provider_trait_completion() {
472 let provider = MockLLMProvider;
473 let request = completion::CompletionRequest::new("Test prompt");
474
475 let response = provider.complete(&request, None).await.unwrap();
476 assert_eq!(response.text, "Mock completion");
477 }
478
479 #[tokio::test]
480 async fn test_llm_provider_trait_embedding() {
481 let provider = MockLLMProvider;
482 let input = vec!["First".to_string(), "Second".to_string()];
483
484 let embeddings = provider.embed(input).await.unwrap();
485 assert_eq!(embeddings.len(), 2);
486 assert_eq!(embeddings[0], vec![0.0, 1.0]);
487 assert_eq!(embeddings[1], vec![1.0, 2.0]);
488 }
489}