1use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20pub mod backends;
22
23pub mod builder;
25
26pub mod chat;
28
29pub mod completion;
31
32pub mod embedding;
34
35pub mod error;
37
38pub mod evaluator;
40
41#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45pub mod models;
47
48mod protocol;
49pub mod providers;
50
51pub use async_trait::async_trait;
53
54#[derive(Debug, Default, Clone)]
56pub struct NoConfig;
57
58pub trait HasConfig {
63 type Config: Default + Send + Sync + 'static;
65}
66
67pub trait LLMProvider:
70 chat::ChatProvider
71 + completion::CompletionProvider
72 + embedding::EmbeddingProvider
73 + models::ModelsProvider
74 + Send
75 + Sync
76 + 'static
77{
78}
79
80#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
83pub struct ToolCall {
84 pub id: String,
86 #[serde(rename = "type")]
88 pub call_type: String,
89 pub function: FunctionCall,
91}
92
93impl Display for ToolCall {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(
96 f,
97 "ToolCall {{ id: {}, type: {}, function: {:?} }}",
98 self.id, self.call_type, self.function
99 )
100 }
101}
102
103#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
105pub struct FunctionCall {
106 pub name: String,
108 pub arguments: String,
110}
111
112pub fn default_call_type() -> String {
114 "function".to_string()
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::HasConfig;
121 use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
122 use crate::completion::CompletionProvider;
123 use crate::embedding::EmbeddingProvider;
124 use crate::error::LLMError;
125 use async_trait::async_trait;
126 use serde_json::json;
127
128 #[test]
129 fn test_tool_call_creation() {
130 let tool_call = ToolCall {
131 id: "call_123".to_string(),
132 call_type: "function".to_string(),
133 function: FunctionCall {
134 name: "test_function".to_string(),
135 arguments: "{\"param\": \"value\"}".to_string(),
136 },
137 };
138
139 assert_eq!(tool_call.id, "call_123");
140 assert_eq!(tool_call.call_type, "function");
141 assert_eq!(tool_call.function.name, "test_function");
142 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
143 }
144
145 #[test]
146 fn test_tool_call_serialization() {
147 let tool_call = ToolCall {
148 id: "call_456".to_string(),
149 call_type: "function".to_string(),
150 function: FunctionCall {
151 name: "serialize_test".to_string(),
152 arguments: "{\"test\": true}".to_string(),
153 },
154 };
155
156 let serialized = serde_json::to_string(&tool_call).unwrap();
157 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
158
159 assert_eq!(deserialized.id, "call_456");
160 assert_eq!(deserialized.call_type, "function");
161 assert_eq!(deserialized.function.name, "serialize_test");
162 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
163 }
164
165 #[test]
166 fn test_tool_call_equality() {
167 let tool_call1 = ToolCall {
168 id: "call_1".to_string(),
169 call_type: "function".to_string(),
170 function: FunctionCall {
171 name: "equal_test".to_string(),
172 arguments: "{}".to_string(),
173 },
174 };
175
176 let tool_call2 = ToolCall {
177 id: "call_1".to_string(),
178 call_type: "function".to_string(),
179 function: FunctionCall {
180 name: "equal_test".to_string(),
181 arguments: "{}".to_string(),
182 },
183 };
184
185 let tool_call3 = ToolCall {
186 id: "call_2".to_string(),
187 call_type: "function".to_string(),
188 function: FunctionCall {
189 name: "equal_test".to_string(),
190 arguments: "{}".to_string(),
191 },
192 };
193
194 assert_eq!(tool_call1, tool_call2);
195 assert_ne!(tool_call1, tool_call3);
196 }
197
198 #[test]
199 fn test_tool_call_clone() {
200 let tool_call = ToolCall {
201 id: "clone_test".to_string(),
202 call_type: "function".to_string(),
203 function: FunctionCall {
204 name: "test_clone".to_string(),
205 arguments: "{\"clone\": true}".to_string(),
206 },
207 };
208
209 let cloned = tool_call.clone();
210 assert_eq!(tool_call, cloned);
211 assert_eq!(tool_call.id, cloned.id);
212 assert_eq!(tool_call.function.name, cloned.function.name);
213 }
214
215 #[test]
216 fn test_tool_call_debug() {
217 let tool_call = ToolCall {
218 id: "debug_test".to_string(),
219 call_type: "function".to_string(),
220 function: FunctionCall {
221 name: "debug_function".to_string(),
222 arguments: "{}".to_string(),
223 },
224 };
225
226 let debug_str = format!("{tool_call:?}");
227 assert!(debug_str.contains("ToolCall"));
228 assert!(debug_str.contains("debug_test"));
229 assert!(debug_str.contains("debug_function"));
230 }
231
232 #[test]
233 fn test_function_call_creation() {
234 let function_call = FunctionCall {
235 name: "test_function".to_string(),
236 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
237 };
238
239 assert_eq!(function_call.name, "test_function");
240 assert_eq!(
241 function_call.arguments,
242 "{\"param1\": \"value1\", \"param2\": 42}"
243 );
244 }
245
246 #[test]
247 fn test_function_call_serialization() {
248 let function_call = FunctionCall {
249 name: "serialize_function".to_string(),
250 arguments: "{\"data\": [1, 2, 3]}".to_string(),
251 };
252
253 let serialized = serde_json::to_string(&function_call).unwrap();
254 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
255
256 assert_eq!(deserialized.name, "serialize_function");
257 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
258 }
259
260 #[test]
261 fn test_function_call_equality() {
262 let func1 = FunctionCall {
263 name: "equal_func".to_string(),
264 arguments: "{}".to_string(),
265 };
266
267 let func2 = FunctionCall {
268 name: "equal_func".to_string(),
269 arguments: "{}".to_string(),
270 };
271
272 let func3 = FunctionCall {
273 name: "different_func".to_string(),
274 arguments: "{}".to_string(),
275 };
276
277 assert_eq!(func1, func2);
278 assert_ne!(func1, func3);
279 }
280
281 #[test]
282 fn test_function_call_clone() {
283 let function_call = FunctionCall {
284 name: "clone_func".to_string(),
285 arguments: "{\"clone\": \"test\"}".to_string(),
286 };
287
288 let cloned = function_call.clone();
289 assert_eq!(function_call, cloned);
290 assert_eq!(function_call.name, cloned.name);
291 assert_eq!(function_call.arguments, cloned.arguments);
292 }
293
294 #[test]
295 fn test_function_call_debug() {
296 let function_call = FunctionCall {
297 name: "debug_func".to_string(),
298 arguments: "{}".to_string(),
299 };
300
301 let debug_str = format!("{function_call:?}");
302 assert!(debug_str.contains("FunctionCall"));
303 assert!(debug_str.contains("debug_func"));
304 }
305
306 #[test]
307 fn test_tool_call_with_empty_values() {
308 let tool_call = ToolCall {
309 id: String::default(),
310 call_type: String::default(),
311 function: FunctionCall {
312 name: String::default(),
313 arguments: String::default(),
314 },
315 };
316
317 assert!(tool_call.id.is_empty());
318 assert!(tool_call.call_type.is_empty());
319 assert!(tool_call.function.name.is_empty());
320 assert!(tool_call.function.arguments.is_empty());
321 }
322
323 #[test]
324 fn test_tool_call_with_complex_arguments() {
325 let complex_args = json!({
326 "nested": {
327 "array": [1, 2, 3],
328 "object": {
329 "key": "value"
330 }
331 },
332 "simple": "string"
333 });
334
335 let tool_call = ToolCall {
336 id: "complex_call".to_string(),
337 call_type: "function".to_string(),
338 function: FunctionCall {
339 name: "complex_function".to_string(),
340 arguments: complex_args.to_string(),
341 },
342 };
343
344 let serialized = serde_json::to_string(&tool_call).unwrap();
345 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
346
347 assert_eq!(deserialized.id, "complex_call");
348 assert_eq!(deserialized.function.name, "complex_function");
349 assert!(deserialized.function.arguments.contains("nested"));
351 assert!(deserialized.function.arguments.contains("array"));
352 }
353
354 #[test]
355 fn test_tool_call_with_unicode() {
356 let tool_call = ToolCall {
357 id: "unicode_call".to_string(),
358 call_type: "function".to_string(),
359 function: FunctionCall {
360 name: "unicode_function".to_string(),
361 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
362 },
363 };
364
365 let serialized = serde_json::to_string(&tool_call).unwrap();
366 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
367
368 assert_eq!(deserialized.id, "unicode_call");
369 assert_eq!(deserialized.function.name, "unicode_function");
370 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
371 }
372
373 #[test]
374 fn test_tool_call_large_arguments() {
375 let large_arg = "x".repeat(10000);
376 let tool_call = ToolCall {
377 id: "large_call".to_string(),
378 call_type: "function".to_string(),
379 function: FunctionCall {
380 name: "large_function".to_string(),
381 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
382 },
383 };
384
385 let serialized = serde_json::to_string(&tool_call).unwrap();
386 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
387
388 assert_eq!(deserialized.id, "large_call");
389 assert_eq!(deserialized.function.name, "large_function");
390 assert!(deserialized.function.arguments.len() > 10000);
391 }
392
393 struct MockLLMProvider;
395
396 #[async_trait]
397 impl chat::ChatProvider for MockLLMProvider {
398 async fn chat(
399 &self,
400 _messages: &[ChatMessage],
401 _json_schema: Option<StructuredOutputFormat>,
402 ) -> Result<Box<dyn ChatResponse>, LLMError> {
403 Ok(Box::new(MockChatResponse {
404 text: Some("Mock response".into()),
405 }))
406 }
407
408 async fn chat_with_tools(
409 &self,
410 _messages: &[ChatMessage],
411 _tools: Option<&[Tool]>,
412 _json_schema: Option<StructuredOutputFormat>,
413 ) -> Result<Box<dyn ChatResponse>, LLMError> {
414 Ok(Box::new(MockChatResponse {
415 text: Some("Mock response".into()),
416 }))
417 }
418 }
419
420 #[async_trait]
421 impl completion::CompletionProvider for MockLLMProvider {
422 async fn complete(
423 &self,
424 _req: &completion::CompletionRequest,
425 _json_schema: Option<chat::StructuredOutputFormat>,
426 ) -> Result<completion::CompletionResponse, error::LLMError> {
427 Ok(completion::CompletionResponse {
428 text: "Mock completion".to_string(),
429 })
430 }
431 }
432
433 #[async_trait]
434 impl embedding::EmbeddingProvider for MockLLMProvider {
435 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
436 let mut embeddings = Vec::new();
437 for (i, _) in input.iter().enumerate() {
438 embeddings.push(vec![i as f32, (i + 1) as f32]);
439 }
440 Ok(embeddings)
441 }
442 }
443
444 #[async_trait]
445 impl models::ModelsProvider for MockLLMProvider {}
446
447 impl LLMProvider for MockLLMProvider {}
448
449 impl HasConfig for MockLLMProvider {
450 type Config = NoConfig;
451 }
452
453 struct MockChatResponse {
454 text: Option<String>,
455 }
456
457 impl chat::ChatResponse for MockChatResponse {
458 fn text(&self) -> Option<String> {
459 self.text.clone()
460 }
461
462 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
463 None
464 }
465 }
466
467 impl std::fmt::Debug for MockChatResponse {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 write!(f, "MockChatResponse")
470 }
471 }
472
473 impl std::fmt::Display for MockChatResponse {
474 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475 write!(f, "{}", self.text.as_deref().unwrap_or(""))
476 }
477 }
478
479 #[tokio::test]
480 async fn test_llm_provider_trait_chat() {
481 let provider = MockLLMProvider;
482 let messages = vec![chat::ChatMessage::user().content("Test").build()];
483
484 let response = provider.chat(&messages, None).await.unwrap();
485 assert_eq!(response.text(), Some("Mock response".to_string()));
486 }
487
488 #[tokio::test]
489 async fn test_llm_provider_trait_completion() {
490 let provider = MockLLMProvider;
491 let request = completion::CompletionRequest::new("Test prompt");
492
493 let response = provider.complete(&request, None).await.unwrap();
494 assert_eq!(response.text, "Mock completion");
495 }
496
497 #[tokio::test]
498 async fn test_llm_provider_trait_embedding() {
499 let provider = MockLLMProvider;
500 let input = vec!["First".to_string(), "Second".to_string()];
501
502 let embeddings = provider.embed(input).await.unwrap();
503 assert_eq!(embeddings.len(), 2);
504 assert_eq!(embeddings[0], vec![0.0, 1.0]);
505 assert_eq!(embeddings[1], vec![1.0, 2.0]);
506 }
507}