1pub mod types;
103
104use async_trait::async_trait;
105use llm::LLMProvider;
106use llm::builder::{LLMBackend, LLMBuilder};
107use llm::chat::ChatMessage;
108use serde::{Deserialize, Serialize};
109use types::{Message, ToolSpec};
110
111pub trait LLMResponseTrait<C: for<'de> Deserialize<'de> + Default + Send> {
120 fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self;
122
123 fn is_complete(&self) -> bool;
125
126 fn tool_calls(&self) -> Vec<ToolCall>;
128
129 fn content(&self) -> C;
131}
132
133#[derive(Debug, Clone, Serialize, Default)]
155pub struct LLMResponse<C>
156where
157 C: for<'de> Deserialize<'de> + Default + Clone + Send,
158{
159 pub content: C,
161
162 pub tool_calls: Vec<ToolCall>,
164
165 pub is_complete: bool,
167}
168
169impl<C> LLMResponseTrait<C> for LLMResponse<C>
170where
171 C: for<'de> Deserialize<'de> + Default + Clone + Send,
172{
173 fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self {
174 Self {
175 content,
176 tool_calls,
177 is_complete,
178 }
179 }
180
181 fn is_complete(&self) -> bool {
182 self.is_complete
183 }
184
185 fn tool_calls(&self) -> Vec<ToolCall> {
186 self.tool_calls.clone()
187 }
188
189 fn content(&self) -> C {
190 self.content.clone()
191 }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct ToolCall {
215 pub name: String,
217
218 pub input: serde_json::Value,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct LLMConfig {
228 pub model: String,
230 pub api_key: Option<String>,
231 pub base_url: Option<String>,
232
233 pub max_tokens: Option<u32>,
235 pub temperature: Option<f32>,
236 pub top_p: Option<f32>,
237 pub top_k: Option<u32>,
238 pub system: Option<String>,
239
240 pub timeout_seconds: Option<u64>,
242
243 pub embedding_encoding_format: Option<String>,
245 pub embedding_dimensions: Option<u32>,
246
247 pub enable_parallel_tool_use: Option<bool>,
249
250 pub reasoning: Option<bool>,
252 pub reasoning_effort: Option<String>,
253 pub reasoning_budget_tokens: Option<u32>,
254
255 pub api_version: Option<String>,
257 pub deployment_id: Option<String>,
258
259 pub voice: Option<String>,
261
262 pub xai_search_mode: Option<String>,
264 pub xai_search_source_type: Option<String>,
265 pub xai_search_excluded_websites: Option<Vec<String>>,
266 pub xai_search_max_results: Option<u32>,
267 pub xai_search_from_date: Option<String>,
268 pub xai_search_to_date: Option<String>,
269
270 pub openai_enable_web_search: Option<bool>,
272 pub openai_web_search_context_size: Option<String>,
273 pub openai_web_search_user_location_type: Option<String>,
274 pub openai_web_search_user_location_approximate_country: Option<String>,
275 pub openai_web_search_user_location_approximate_city: Option<String>,
276 pub openai_web_search_user_location_approximate_region: Option<String>,
277
278 pub resilient_enable: Option<bool>,
280 pub resilient_attempts: Option<usize>,
281 pub resilient_base_delay_ms: Option<u64>,
282 pub resilient_max_delay_ms: Option<u64>,
283 pub resilient_jitter: Option<bool>,
284}
285
286impl Default for LLMConfig {
287 fn default() -> Self {
288 Self {
289 model: String::new(),
290 api_key: None,
291 base_url: None,
292 max_tokens: Some(4096),
293 temperature: None,
294 top_p: None,
295 top_k: None,
296 system: None,
297 timeout_seconds: None,
298 embedding_encoding_format: None,
299 embedding_dimensions: None,
300 enable_parallel_tool_use: None,
301 reasoning: None,
302 reasoning_effort: None,
303 reasoning_budget_tokens: None,
304 api_version: None,
305 deployment_id: None,
306 voice: None,
307 xai_search_mode: None,
308 xai_search_source_type: None,
309 xai_search_excluded_websites: None,
310 xai_search_max_results: None,
311 xai_search_from_date: None,
312 xai_search_to_date: None,
313 openai_enable_web_search: None,
314 openai_web_search_context_size: None,
315 openai_web_search_user_location_type: None,
316 openai_web_search_user_location_approximate_country: None,
317 openai_web_search_user_location_approximate_city: None,
318 openai_web_search_user_location_approximate_region: None,
319 resilient_enable: None,
320 resilient_attempts: None,
321 resilient_base_delay_ms: None,
322 resilient_max_delay_ms: None,
323 resilient_jitter: None,
324 }
325 }
326}
327
328impl LLMConfig {
329 pub fn new(model: impl Into<String>) -> Self {
331 Self {
332 model: model.into(),
333 ..Default::default()
334 }
335 }
336
337 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
339 self.api_key = Some(api_key.into());
340 self
341 }
342
343 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
345 self.base_url = Some(base_url.into());
346 self
347 }
348
349 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
351 self.max_tokens = Some(max_tokens);
352 self
353 }
354
355 pub fn with_temperature(mut self, temperature: f32) -> Self {
357 self.temperature = Some(temperature);
358 self
359 }
360
361 pub fn with_top_p(mut self, top_p: f32) -> Self {
363 self.top_p = Some(top_p);
364 self
365 }
366
367 pub fn with_top_k(mut self, top_k: u32) -> Self {
369 self.top_k = Some(top_k);
370 self
371 }
372
373 pub fn with_system(mut self, system: impl Into<String>) -> Self {
375 self.system = Some(system.into());
376 self
377 }
378
379 pub fn with_timeout_seconds(mut self, timeout: u64) -> Self {
381 self.timeout_seconds = Some(timeout);
382 self
383 }
384
385 pub fn with_reasoning(mut self, enabled: bool) -> Self {
387 self.reasoning = Some(enabled);
388 self
389 }
390
391 pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
393 self.reasoning_effort = Some(effort.into());
394 self
395 }
396
397 pub fn with_deployment_id(mut self, deployment_id: impl Into<String>) -> Self {
399 self.deployment_id = Some(deployment_id.into());
400 self
401 }
402
403 pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
405 self.api_version = Some(api_version.into());
406 self
407 }
408
409 pub fn with_openai_web_search(mut self, enabled: bool) -> Self {
411 self.openai_enable_web_search = Some(enabled);
412 self
413 }
414
415 pub fn with_resilience(mut self, enabled: bool, attempts: usize) -> Self {
417 self.resilient_enable = Some(enabled);
418 self.resilient_attempts = Some(attempts);
419 self
420 }
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct LLMProviderConfig {
426 pub model: String,
427 pub max_tokens: u64,
428 pub api_key: Option<String>,
429 pub base_url: String,
430}
431
432impl From<LLMConfig> for LLMProviderConfig {
433 fn from(config: LLMConfig) -> Self {
434 Self {
435 model: config.model,
436 max_tokens: config.max_tokens.unwrap_or(4096) as u64,
437 api_key: config.api_key,
438 base_url: config.base_url.unwrap_or_default(),
439 }
440 }
441}
442
443fn get_api_key_env_var(provider: &str) -> Option<&'static str> {
446 match provider.to_lowercase().as_str() {
447 "ollama" => None, "anthropic" | "claude" => Some("ANTHROPIC_API_KEY"),
449 "openai" | "gpt" => Some("OPENAI_API_KEY"),
450 "deepseek" => Some("DEEPSEEK_API_KEY"),
451 "xai" | "x.ai" => Some("XAI_API_KEY"),
452 "phind" => Some("PHIND_API_KEY"),
453 "google" | "gemini" => Some("GOOGLE_API_KEY"),
454 "groq" => Some("GROQ_API_KEY"),
455 "azure" | "azureopenai" | "azure-openai" => Some("AZURE_OPENAI_API_KEY"),
456 "elevenlabs" | "11labs" => Some("ELEVENLABS_API_KEY"),
457 "cohere" => Some("COHERE_API_KEY"),
458 "mistral" => Some("MISTRAL_API_KEY"),
459 "openrouter" => Some("OPENROUTER_API_KEY"),
460 _ => None,
461 }
462}
463
464fn get_api_key_from_env(provider: &str) -> Result<Option<String>, String> {
467 match get_api_key_env_var(provider) {
468 None => Ok(None), Some(env_var) => {
470 match std::env::var(env_var) {
471 Ok(key) => Ok(Some(key)),
472 Err(_) => Err(format!(
473 "API key required for provider '{}'. Please set the {} environment variable or pass the API key explicitly.",
474 provider, env_var
475 ))
476 }
477 }
478 }
479}
480
481fn parse_provider(provider: &str) -> Result<LLMBackend, String> {
482 match provider.to_lowercase().as_str() {
483 "ollama" => Ok(LLMBackend::Ollama),
484 "anthropic" | "claude" => Ok(LLMBackend::Anthropic),
485 "openai" | "gpt" => Ok(LLMBackend::OpenAI),
486 "deepseek" => Ok(LLMBackend::DeepSeek),
487 "xai" | "x.ai" => Ok(LLMBackend::XAI),
488 "phind" => Ok(LLMBackend::Phind),
489 "google" | "gemini" => Ok(LLMBackend::Google),
490 "groq" => Ok(LLMBackend::Groq),
491 "azure" | "azureopenai" | "azure-openai" => Ok(LLMBackend::AzureOpenAI),
492 "elevenlabs" | "11labs" => Ok(LLMBackend::ElevenLabs),
493 "cohere" => Ok(LLMBackend::Cohere),
494 "mistral" => Ok(LLMBackend::Mistral),
495 "openrouter" => Ok(LLMBackend::OpenRouter),
496 _ => Err(format!("Unknown provider: {}", provider)),
497 }
498}
499
500fn build_llm_from_config(config: &LLMConfig, backend: LLMBackend) -> Result<Box<dyn LLMProvider>, String> {
502 let mut builder = LLMBuilder::new().backend(backend.clone());
503
504 let model_name = if config.model.contains("::") {
506 config.model.split("::").nth(1).unwrap_or(&config.model)
507 } else {
508 &config.model
509 };
510
511 builder = builder.model(model_name);
512
513 if let Some(max_tokens) = config.max_tokens {
515 builder = builder.max_tokens(max_tokens);
516 }
517
518 if let Some(ref api_key) = config.api_key {
519 builder = builder.api_key(api_key);
520 }
521
522 if let Some(ref base_url) = config.base_url {
523 if !base_url.is_empty() {
524 builder = builder.base_url(base_url);
525 }
526 }
527
528 if let Some(temperature) = config.temperature {
529 builder = builder.temperature(temperature);
530 }
531
532 if let Some(top_p) = config.top_p {
533 builder = builder.top_p(top_p);
534 }
535
536 if let Some(top_k) = config.top_k {
537 builder = builder.top_k(top_k);
538 }
539
540 if let Some(ref system) = config.system {
541 builder = builder.system(system);
542 }
543
544 if let Some(timeout) = config.timeout_seconds {
545 builder = builder.timeout_seconds(timeout);
546 }
547
548 if let Some(ref format) = config.embedding_encoding_format {
549 builder = builder.embedding_encoding_format(format);
550 }
551
552 if let Some(dims) = config.embedding_dimensions {
553 builder = builder.embedding_dimensions(dims);
554 }
555
556 if let Some(enabled) = config.enable_parallel_tool_use {
557 builder = builder.enable_parallel_tool_use(enabled);
558 }
559
560 if let Some(enabled) = config.reasoning {
561 builder = builder.reasoning(enabled);
562 }
563
564 if let Some(budget) = config.reasoning_budget_tokens {
569 builder = builder.reasoning_budget_tokens(budget);
570 }
571
572 if let Some(ref api_version) = config.api_version {
574 builder = builder.api_version(api_version);
575 }
576
577 if let Some(ref deployment_id) = config.deployment_id {
578 builder = builder.deployment_id(deployment_id);
579 }
580
581 if let Some(ref voice) = config.voice {
583 builder = builder.voice(voice);
584 }
585
586 if let Some(ref mode) = config.xai_search_mode {
588 builder = builder.xai_search_mode(mode);
589 }
590
591 if let (Some(source_type), excluded) = (&config.xai_search_source_type, &config.xai_search_excluded_websites) {
593 builder = builder.xai_search_source(source_type, excluded.clone());
594 }
595
596 if let Some(ref from_date) = config.xai_search_from_date {
597 builder = builder.xai_search_from_date(from_date);
598 }
599
600 if let Some(ref to_date) = config.xai_search_to_date {
601 builder = builder.xai_search_to_date(to_date);
602 }
603
604 if let Some(enabled) = config.openai_enable_web_search {
606 builder = builder.openai_enable_web_search(enabled);
607 }
608
609 if let Some(ref context_size) = config.openai_web_search_context_size {
610 builder = builder.openai_web_search_context_size(context_size);
611 }
612
613 if let Some(ref loc_type) = config.openai_web_search_user_location_type {
614 builder = builder.openai_web_search_user_location_type(loc_type);
615 }
616
617 if let Some(ref country) = config.openai_web_search_user_location_approximate_country {
618 builder = builder.openai_web_search_user_location_approximate_country(country);
619 }
620
621 if let Some(ref city) = config.openai_web_search_user_location_approximate_city {
622 builder = builder.openai_web_search_user_location_approximate_city(city);
623 }
624
625 if let Some(ref region) = config.openai_web_search_user_location_approximate_region {
626 builder = builder.openai_web_search_user_location_approximate_region(region);
627 }
628
629 if let Some(enabled) = config.resilient_enable {
631 builder = builder.resilient(enabled);
632 }
633
634 if let Some(attempts) = config.resilient_attempts {
635 builder = builder.resilient_attempts(attempts);
636 }
637
638 builder
642 .build()
643 .map_err(|e| format!("Failed to build LLM: {}", e))
644}
645
646pub struct UniversalLLMClient {
647 config: LLMProviderConfig,
648 llm: Box<dyn LLMProvider>,
649}
650
651#[async_trait]
652pub trait LLMClient: Send + Sync {
653 async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
655 where
656 T: LLMResponseTrait<C> + Default + Send,
657 C: for<'de> Deserialize<'de> + Default + Send + Serialize;
658}
659
660impl Clone for UniversalLLMClient {
661 fn clone(&self) -> Self {
662 let config = self.config.clone();
663 let model = config.clone().model;
664 let api_key = config.api_key.clone();
665 let base_url = config.base_url.clone();
666 let parts: Vec<&str> = model.split("::").collect();
667
668 let provider = parts[0];
669 let model = parts[1];
670
671 let backend = parse_provider(provider).unwrap_or(LLMBackend::Ollama);
672
673 let mut builder = LLMBuilder::new()
674 .backend(backend.clone())
675 .model(model)
676 .max_tokens(4096);
677
678 if let Some(api_key) = api_key {
679 builder = builder.api_key(api_key);
680 }
681
682 if !base_url.is_empty() {
683 builder = builder.base_url(base_url);
684 }
685
686 let llm = builder
687 .build()
688 .map_err(|e| format!("Failed to build LLM: {}", e))
689 .unwrap();
690
691 Self {
692 llm,
693 config: config.clone(),
694 }
695 }
696}
697
698impl UniversalLLMClient {
699 const DEFAULT_SYSTEM_PROMPT: &'static str = "You are a helpful AI assistant.";
700
701 const DEFAULT_TOOL_PROMPT: &'static str = "You have access to the following tools.\n\
703 To call ONE tool, respond EXACTLY in this format:\n\
704 USE_TOOL: tool_name\n\
705 {\"param1\": \"value1\"}\n\n\
706 To call MULTIPLE tools at once, respond in this format:\n\
707 USE_TOOLS:\n\
708 tool_name1\n\
709 {\"param1\": \"value1\"}\n\
710 ---\n\
711 tool_name2\n\
712 {\"param1\": \"value1\"}\n\n\
713 Only call tools using these exact formats. Otherwise, respond normally.";
714
715 fn generate_schema_instruction<C>(sample: &C) -> String
716 where
717 C: Serialize,
718 {
719 let sample_json = serde_json::to_string_pretty(sample).unwrap_or_else(|_| "{}".to_string());
720
721 format!(
722 "Respond with ONLY a JSON object in this exact format:\n{}\n\nProvide your response as valid JSON.",
723 sample_json
724 )
725 }
726
727 pub fn new(provider_model: &str, api_key: Option<String>) -> Result<Self, String> {
728 let parts: Vec<&str> = provider_model.split("::").collect();
729
730 if parts.len() != 2 {
731 return Err(format!(
732 "Invalid format. Use 'provider::model-name'. Got: {}",
733 provider_model
734 ));
735 }
736
737 let provider = parts[0];
738 let model = parts[1];
739
740 let final_api_key = match api_key {
742 Some(key) => Some(key),
743 None => {
744 match get_api_key_from_env(provider) {
746 Ok(env_key) => env_key,
747 Err(e) => return Err(e), }
749 }
750 };
751
752 let config = LLMProviderConfig {
753 model: provider_model.to_string(),
754 max_tokens: 4096,
755 api_key: final_api_key.clone(),
756 base_url: String::new(),
757 };
758
759 let backend = parse_provider(provider)?;
760
761 let base_url = match provider.to_lowercase().as_str() {
762 "ollama" => std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
763 _ => String::new(),
764 };
765
766 let mut builder = LLMBuilder::new()
767 .backend(backend.clone())
768 .model(model)
769 .max_tokens(4096);
770
771 if let Some(api_key) = final_api_key {
772 builder = builder.api_key(api_key);
773 }
774
775 if !base_url.is_empty() {
776 builder = builder.base_url(base_url);
777 }
778
779 let llm = builder
780 .build()
781 .map_err(|e| format!("Failed to build LLM: {}", e))?;
782
783 Ok(Self { llm, config })
784 }
785
786 pub fn new_with_config(llm_config: LLMConfig) -> Result<Self, String> {
802 let parts: Vec<&str> = llm_config.model.split("::").collect();
804
805 if parts.len() != 2 {
806 return Err(format!(
807 "Invalid format. Use 'provider::model-name'. Got: {}",
808 llm_config.model
809 ));
810 }
811
812 let provider = parts[0];
813 let backend = parse_provider(provider)?;
814
815 let mut config = llm_config.clone();
817 if config.base_url.is_none() {
818 match provider.to_lowercase().as_str() {
819 "ollama" => {
820 config.base_url = Some(
821 std::env::var("OLLAMA_URL")
822 .unwrap_or("http://127.0.0.1:11434".to_string())
823 );
824 }
825 _ => {}
826 }
827 }
828
829 if config.api_key.is_none() {
831 match get_api_key_from_env(provider) {
832 Ok(env_key) => config.api_key = env_key,
833 Err(e) => return Err(e), }
835 }
836
837 let llm = build_llm_from_config(&config, backend)?;
839
840 let legacy_config = LLMProviderConfig::from(config);
842
843 Ok(Self {
844 llm,
845 config: legacy_config,
846 })
847 }
848
849 fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
850 messages
851 .iter()
852 .map(|msg| match msg.role.as_str() {
853 "user" => ChatMessage::user().content(&msg.content).build(),
854 "assistant" => ChatMessage::assistant().content(&msg.content).build(),
855 "system" => ChatMessage::assistant().content(&msg.content).build(),
856 "tool" => ChatMessage::assistant()
857 .content(format!("Tool result: {}", msg.content))
858 .build(),
859 _ => ChatMessage::user().content(&msg.content).build(),
860 })
861 .collect()
862 }
863
864 fn build_tool_description(tools: &[ToolSpec]) -> String {
865 tools
866 .iter()
867 .map(|t| {
868 let params = t
869 .input_schema
870 .get("properties")
871 .and_then(|p| p.as_object())
872 .map(|o| o.keys().cloned().collect::<Vec<_>>().join(", "))
873 .unwrap_or_default();
874
875 if t.description.is_empty() {
876 format!("- {}({})", t.name, params)
877 } else {
878 format!("- {}({}): {}", t.name, params, t.description)
879 }
880 })
881 .collect::<Vec<_>>()
882 .join("\n")
883 }
884
885 fn parse_tool_calls(response_text: &str) -> Vec<ToolCall> {
887 let mut tool_calls = Vec::new();
888
889 if response_text.starts_with("USE_TOOLS:") {
891 let parts: Vec<&str> = response_text
893 .strip_prefix("USE_TOOLS:")
894 .unwrap_or("")
895 .split("---")
896 .collect();
897
898 for part in parts {
899 let lines: Vec<&str> = part.trim().lines().collect();
900 if lines.is_empty() {
901 continue;
902 }
903
904 let tool_name = lines[0].trim().to_string();
905 let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
906
907 if let Ok(input_value) = serde_json::from_str(&json_block) {
908 tool_calls.push(ToolCall {
909 name: tool_name,
910 input: input_value,
911 });
912 }
913 }
914 }
915 else if response_text.starts_with("USE_TOOL:") {
917 let lines: Vec<&str> = response_text.lines().collect();
918 let tool_name = lines[0]
919 .strip_prefix("USE_TOOL:")
920 .unwrap_or("")
921 .trim()
922 .to_string();
923
924 let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
925
926 if let Ok(input_value) = serde_json::from_str(&json_block) {
927 tool_calls.push(ToolCall {
928 name: tool_name,
929 input: input_value,
930 });
931 }
932 }
933
934 tool_calls
935 }
936}
937
938#[async_trait]
939impl LLMClient for UniversalLLMClient {
940 async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
941 where
942 T: LLMResponseTrait<C> + Default + Send,
943 C: for<'de> Deserialize<'de> + Default + Send + Serialize,
944 {
945 let mut chat_messages = vec![];
946
947 let has_user_system_prompt = messages.iter().any(|m| m.role == "system");
949 if !has_user_system_prompt {
950 chat_messages.push(
951 ChatMessage::assistant()
952 .content(Self::DEFAULT_SYSTEM_PROMPT)
953 .build(),
954 );
955 }
956
957 let user_tool_prompt = messages
959 .iter()
960 .find(|m| m.role == "system_tools")
961 .map(|m| m.content.clone());
962
963 if !tools.is_empty() {
964 let tool_list = Self::build_tool_description(tools);
965 let tool_prompt = user_tool_prompt.unwrap_or_else(|| {
966 format!(
967 "{}\n\nAvailable Tools:\n{}\n\n{}",
968 Self::DEFAULT_TOOL_PROMPT,
969 tool_list,
970 "Use only the EXACT formats shown above when calling tools."
971 )
972 });
973 chat_messages.push(ChatMessage::assistant().content(tool_prompt).build());
974 }
975
976 let sample_c = C::default();
978 let schema_instruction = Self::generate_schema_instruction(&sample_c);
979
980 chat_messages.push(ChatMessage::assistant().content(schema_instruction).build());
981
982 chat_messages.extend(self.convert_messages(messages));
984
985 let try_parse_c = |s: &str| -> C {
987 let text = s.trim();
988
989 if let Ok(parsed) = serde_json::from_str::<C>(text) {
991 return parsed;
992 }
993
994 let cleaned = text
996 .strip_prefix("```json")
997 .unwrap_or(text)
998 .strip_prefix("```")
999 .unwrap_or(text)
1000 .strip_suffix("```")
1001 .unwrap_or(text)
1002 .trim();
1003
1004 if let Ok(parsed) = serde_json::from_str::<C>(cleaned) {
1005 return parsed;
1006 }
1007
1008 if let Some(start) = text.find('{') {
1010 if let Some(end) = text.rfind('}') {
1011 let json_part = &text[start..=end];
1012 if let Ok(parsed) = serde_json::from_str::<C>(json_part) {
1013 return parsed;
1014 }
1015 }
1016 }
1017
1018 if let Ok(quoted) = serde_json::to_string(text) {
1020 if let Ok(parsed) = serde_json::from_str::<C>("ed) {
1021 return parsed;
1022 }
1023 }
1024
1025 C::default()
1027 };
1028
1029 let response = self
1031 .llm
1032 .chat(&chat_messages)
1033 .await
1034 .map_err(|e| format!("LLM error: {}", e))?;
1035
1036 let response_text = response.text().unwrap_or_default();
1037
1038 let tool_calls = Self::parse_tool_calls(&response_text);
1040
1041 if !tool_calls.is_empty() {
1043 let parsed_content: C = C::default();
1044 return Ok(T::new(parsed_content, tool_calls, false));
1045 }
1046
1047 let parsed_content: C = try_parse_c(&response_text);
1049 Ok(T::new(parsed_content, vec![], true))
1050 }
1051}
1052
1053pub struct MockLLMClient {
1082 response_content: String,
1083}
1084
1085impl MockLLMClient {
1086 pub fn new(response: &str) -> Self {
1104 Self {
1105 response_content: response.to_string(),
1106 }
1107 }
1108
1109 pub fn default_hello() -> Self {
1119 Self::new("Hello! How can I help you today?")
1120 }
1121}
1122
1123#[async_trait::async_trait]
1124impl LLMClient for MockLLMClient {
1125 async fn complete<T, C>(&self, _messages: &[Message], _tools: &[ToolSpec]) -> Result<T, String>
1126 where
1127 T: LLMResponseTrait<C> + Default + Send,
1128 C: for<'de> Deserialize<'de> + Default + Send + Serialize,
1129 {
1130 let content: C = if let Ok(parsed) = serde_json::from_str(&self.response_content) {
1132 parsed
1133 } else {
1134 C::default()
1135 };
1136
1137 Ok(T::new(content, vec![], true))
1138 }
1139}
1140
1141#[cfg(test)]
1146mod tests {
1147 use super::*;
1148
1149 #[tokio::test]
1150 async fn test_mock_llm_basic() {
1151 let client = MockLLMClient::default_hello();
1152 let messages = vec![Message {
1153 role: "user".into(),
1154 content: "Say hello".into(),
1155 }];
1156
1157 let result: LLMResponse<String> = client
1158 .complete::<LLMResponse<String>, String>(&messages, &[])
1159 .await
1160 .expect("Mock LLM failed");
1161
1162 assert!(result.content.is_empty()); assert!(result.tool_calls.is_empty());
1164 assert!(result.is_complete);
1165 }
1166
1167 #[tokio::test]
1168 async fn test_mock_structured_output() {
1169 let client = MockLLMClient::new(r#"{"field1": 42.5, "flag": true}"#);
1170 let messages = vec![Message {
1171 role: "user".into(),
1172 content: "Return structured data".into(),
1173 }];
1174
1175 #[derive(Deserialize, Serialize, Default, Clone, Debug)]
1176 struct MyOutput {
1177 field1: f64,
1178 flag: bool,
1179 }
1180
1181 let result: LLMResponse<MyOutput> = client
1182 .complete::<LLMResponse<MyOutput>, MyOutput>(&messages, &[])
1183 .await
1184 .expect("Mock LLM failed");
1185
1186 assert_eq!(result.content.field1, 42.5);
1187 assert_eq!(result.content.flag, true);
1188 assert!(result.tool_calls.is_empty());
1189 assert!(result.is_complete);
1190 }
1191}