1use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20pub mod backends;
22
23pub mod builder;
25
26pub mod chat;
28
29pub mod completion;
31
32pub mod embedding;
34
35pub mod error;
37
38pub mod evaluator;
40
41#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45pub mod models;
47
48mod protocol;
49pub mod providers;
50
51pub mod pipeline;
53
54#[cfg(all(not(target_arch = "wasm32"), feature = "optim"))]
56pub mod optim;
57
58pub use async_trait::async_trait;
60
61#[derive(Debug, Default, Clone)]
63pub struct NoConfig;
64
65pub trait HasConfig {
70 type Config: Default + Send + Sync + 'static;
72}
73
74pub trait LLMProvider:
77 chat::ChatProvider
78 + completion::CompletionProvider
79 + embedding::EmbeddingProvider
80 + models::ModelsProvider
81 + Send
82 + Sync
83 + 'static
84{
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
90pub struct ToolCall {
91 pub id: String,
93 #[serde(rename = "type")]
95 pub call_type: String,
96 pub function: FunctionCall,
98}
99
100impl Display for ToolCall {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 write!(
103 f,
104 "ToolCall {{ id: {}, type: {}, function: {:?} }}",
105 self.id, self.call_type, self.function
106 )
107 }
108}
109
110#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
112pub struct FunctionCall {
113 pub name: String,
115 pub arguments: String,
117}
118
119pub fn default_call_type() -> String {
121 "function".to_string()
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::HasConfig;
128 use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
129 use crate::completion::CompletionProvider;
130 use crate::embedding::EmbeddingProvider;
131 use crate::error::LLMError;
132 use async_trait::async_trait;
133 use serde_json::json;
134
135 #[test]
136 fn test_tool_call_creation() {
137 let tool_call = ToolCall {
138 id: "call_123".to_string(),
139 call_type: "function".to_string(),
140 function: FunctionCall {
141 name: "test_function".to_string(),
142 arguments: "{\"param\": \"value\"}".to_string(),
143 },
144 };
145
146 assert_eq!(tool_call.id, "call_123");
147 assert_eq!(tool_call.call_type, "function");
148 assert_eq!(tool_call.function.name, "test_function");
149 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
150 }
151
152 #[test]
153 fn test_tool_call_serialization() {
154 let tool_call = ToolCall {
155 id: "call_456".to_string(),
156 call_type: "function".to_string(),
157 function: FunctionCall {
158 name: "serialize_test".to_string(),
159 arguments: "{\"test\": true}".to_string(),
160 },
161 };
162
163 let serialized = serde_json::to_string(&tool_call).unwrap();
164 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
165
166 assert_eq!(deserialized.id, "call_456");
167 assert_eq!(deserialized.call_type, "function");
168 assert_eq!(deserialized.function.name, "serialize_test");
169 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
170 }
171
172 #[test]
173 fn test_tool_call_equality() {
174 let tool_call1 = ToolCall {
175 id: "call_1".to_string(),
176 call_type: "function".to_string(),
177 function: FunctionCall {
178 name: "equal_test".to_string(),
179 arguments: "{}".to_string(),
180 },
181 };
182
183 let tool_call2 = ToolCall {
184 id: "call_1".to_string(),
185 call_type: "function".to_string(),
186 function: FunctionCall {
187 name: "equal_test".to_string(),
188 arguments: "{}".to_string(),
189 },
190 };
191
192 let tool_call3 = ToolCall {
193 id: "call_2".to_string(),
194 call_type: "function".to_string(),
195 function: FunctionCall {
196 name: "equal_test".to_string(),
197 arguments: "{}".to_string(),
198 },
199 };
200
201 assert_eq!(tool_call1, tool_call2);
202 assert_ne!(tool_call1, tool_call3);
203 }
204
205 #[test]
206 fn test_tool_call_clone() {
207 let tool_call = ToolCall {
208 id: "clone_test".to_string(),
209 call_type: "function".to_string(),
210 function: FunctionCall {
211 name: "test_clone".to_string(),
212 arguments: "{\"clone\": true}".to_string(),
213 },
214 };
215
216 let cloned = tool_call.clone();
217 assert_eq!(tool_call, cloned);
218 assert_eq!(tool_call.id, cloned.id);
219 assert_eq!(tool_call.function.name, cloned.function.name);
220 }
221
222 #[test]
223 fn test_tool_call_debug() {
224 let tool_call = ToolCall {
225 id: "debug_test".to_string(),
226 call_type: "function".to_string(),
227 function: FunctionCall {
228 name: "debug_function".to_string(),
229 arguments: "{}".to_string(),
230 },
231 };
232
233 let debug_str = format!("{tool_call:?}");
234 assert!(debug_str.contains("ToolCall"));
235 assert!(debug_str.contains("debug_test"));
236 assert!(debug_str.contains("debug_function"));
237 }
238
239 #[test]
240 fn test_function_call_creation() {
241 let function_call = FunctionCall {
242 name: "test_function".to_string(),
243 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
244 };
245
246 assert_eq!(function_call.name, "test_function");
247 assert_eq!(
248 function_call.arguments,
249 "{\"param1\": \"value1\", \"param2\": 42}"
250 );
251 }
252
253 #[test]
254 fn test_function_call_serialization() {
255 let function_call = FunctionCall {
256 name: "serialize_function".to_string(),
257 arguments: "{\"data\": [1, 2, 3]}".to_string(),
258 };
259
260 let serialized = serde_json::to_string(&function_call).unwrap();
261 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
262
263 assert_eq!(deserialized.name, "serialize_function");
264 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
265 }
266
267 #[test]
268 fn test_function_call_equality() {
269 let func1 = FunctionCall {
270 name: "equal_func".to_string(),
271 arguments: "{}".to_string(),
272 };
273
274 let func2 = FunctionCall {
275 name: "equal_func".to_string(),
276 arguments: "{}".to_string(),
277 };
278
279 let func3 = FunctionCall {
280 name: "different_func".to_string(),
281 arguments: "{}".to_string(),
282 };
283
284 assert_eq!(func1, func2);
285 assert_ne!(func1, func3);
286 }
287
288 #[test]
289 fn test_function_call_clone() {
290 let function_call = FunctionCall {
291 name: "clone_func".to_string(),
292 arguments: "{\"clone\": \"test\"}".to_string(),
293 };
294
295 let cloned = function_call.clone();
296 assert_eq!(function_call, cloned);
297 assert_eq!(function_call.name, cloned.name);
298 assert_eq!(function_call.arguments, cloned.arguments);
299 }
300
301 #[test]
302 fn test_function_call_debug() {
303 let function_call = FunctionCall {
304 name: "debug_func".to_string(),
305 arguments: "{}".to_string(),
306 };
307
308 let debug_str = format!("{function_call:?}");
309 assert!(debug_str.contains("FunctionCall"));
310 assert!(debug_str.contains("debug_func"));
311 }
312
313 #[test]
314 fn test_tool_call_with_empty_values() {
315 let tool_call = ToolCall {
316 id: String::default(),
317 call_type: String::default(),
318 function: FunctionCall {
319 name: String::default(),
320 arguments: String::default(),
321 },
322 };
323
324 assert!(tool_call.id.is_empty());
325 assert!(tool_call.call_type.is_empty());
326 assert!(tool_call.function.name.is_empty());
327 assert!(tool_call.function.arguments.is_empty());
328 }
329
330 #[test]
331 fn test_tool_call_with_complex_arguments() {
332 let complex_args = json!({
333 "nested": {
334 "array": [1, 2, 3],
335 "object": {
336 "key": "value"
337 }
338 },
339 "simple": "string"
340 });
341
342 let tool_call = ToolCall {
343 id: "complex_call".to_string(),
344 call_type: "function".to_string(),
345 function: FunctionCall {
346 name: "complex_function".to_string(),
347 arguments: complex_args.to_string(),
348 },
349 };
350
351 let serialized = serde_json::to_string(&tool_call).unwrap();
352 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
353
354 assert_eq!(deserialized.id, "complex_call");
355 assert_eq!(deserialized.function.name, "complex_function");
356 assert!(deserialized.function.arguments.contains("nested"));
358 assert!(deserialized.function.arguments.contains("array"));
359 }
360
361 #[test]
362 fn test_tool_call_with_unicode() {
363 let tool_call = ToolCall {
364 id: "unicode_call".to_string(),
365 call_type: "function".to_string(),
366 function: FunctionCall {
367 name: "unicode_function".to_string(),
368 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
369 },
370 };
371
372 let serialized = serde_json::to_string(&tool_call).unwrap();
373 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
374
375 assert_eq!(deserialized.id, "unicode_call");
376 assert_eq!(deserialized.function.name, "unicode_function");
377 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
378 }
379
380 #[test]
381 fn test_tool_call_large_arguments() {
382 let large_arg = "x".repeat(10000);
383 let tool_call = ToolCall {
384 id: "large_call".to_string(),
385 call_type: "function".to_string(),
386 function: FunctionCall {
387 name: "large_function".to_string(),
388 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
389 },
390 };
391
392 let serialized = serde_json::to_string(&tool_call).unwrap();
393 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
394
395 assert_eq!(deserialized.id, "large_call");
396 assert_eq!(deserialized.function.name, "large_function");
397 assert!(deserialized.function.arguments.len() > 10000);
398 }
399
400 struct MockLLMProvider;
402
403 #[async_trait]
404 impl chat::ChatProvider for MockLLMProvider {
405 async fn chat(
406 &self,
407 _messages: &[ChatMessage],
408 _json_schema: Option<StructuredOutputFormat>,
409 ) -> Result<Box<dyn ChatResponse>, LLMError> {
410 Ok(Box::new(MockChatResponse {
411 text: Some("Mock response".into()),
412 }))
413 }
414
415 async fn chat_with_tools(
416 &self,
417 _messages: &[ChatMessage],
418 _tools: Option<&[Tool]>,
419 _json_schema: Option<StructuredOutputFormat>,
420 ) -> Result<Box<dyn ChatResponse>, LLMError> {
421 Ok(Box::new(MockChatResponse {
422 text: Some("Mock response".into()),
423 }))
424 }
425 }
426
427 #[async_trait]
428 impl completion::CompletionProvider for MockLLMProvider {
429 async fn complete(
430 &self,
431 _req: &completion::CompletionRequest,
432 _json_schema: Option<chat::StructuredOutputFormat>,
433 ) -> Result<completion::CompletionResponse, error::LLMError> {
434 Ok(completion::CompletionResponse {
435 text: "Mock completion".to_string(),
436 })
437 }
438 }
439
440 #[async_trait]
441 impl embedding::EmbeddingProvider for MockLLMProvider {
442 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
443 let mut embeddings = Vec::new();
444 for (i, _) in input.iter().enumerate() {
445 embeddings.push(vec![i as f32, (i + 1) as f32]);
446 }
447 Ok(embeddings)
448 }
449 }
450
451 #[async_trait]
452 impl models::ModelsProvider for MockLLMProvider {}
453
454 impl LLMProvider for MockLLMProvider {}
455
456 impl HasConfig for MockLLMProvider {
457 type Config = NoConfig;
458 }
459
460 struct MockChatResponse {
461 text: Option<String>,
462 }
463
464 impl chat::ChatResponse for MockChatResponse {
465 fn text(&self) -> Option<String> {
466 self.text.clone()
467 }
468
469 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
470 None
471 }
472 }
473
474 impl std::fmt::Debug for MockChatResponse {
475 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
476 write!(f, "MockChatResponse")
477 }
478 }
479
480 impl std::fmt::Display for MockChatResponse {
481 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482 write!(f, "{}", self.text.as_deref().unwrap_or(""))
483 }
484 }
485
486 #[tokio::test]
487 async fn test_llm_provider_trait_chat() {
488 let provider = MockLLMProvider;
489 let messages = vec![chat::ChatMessage::user().content("Test").build()];
490
491 let response = provider.chat(&messages, None).await.unwrap();
492 assert_eq!(response.text(), Some("Mock response".to_string()));
493 }
494
495 #[tokio::test]
496 async fn test_llm_provider_trait_completion() {
497 let provider = MockLLMProvider;
498 let request = completion::CompletionRequest::new("Test prompt");
499
500 let response = provider.complete(&request, None).await.unwrap();
501 assert_eq!(response.text, "Mock completion");
502 }
503
504 #[tokio::test]
505 async fn test_llm_provider_trait_embedding() {
506 let provider = MockLLMProvider;
507 let input = vec!["First".to_string(), "Second".to_string()];
508
509 let embeddings = provider.embed(input).await.unwrap();
510 assert_eq!(embeddings.len(), 2);
511 assert_eq!(embeddings[0], vec![0.0, 1.0]);
512 assert_eq!(embeddings[1], vec![1.0, 2.0]);
513 }
514}