1pub use crate::llm_agent::{LlmAgent, LlmAgentBuilder};
112pub use crate::react::{FinishReason, ReActConfig, ReActEngine, ReActResult, ReActStep};
113
114use crate::types::{Message, ToolSpec};
115use async_trait::async_trait;
116use llm::builder::{LLMBackend, LLMBuilder};
117use llm::chat::ChatMessage;
118use llm::LLMProvider;
119use serde::{Deserialize, Serialize};
120
121pub trait LLMResponseTrait<C: for<'de> Deserialize<'de> + Default + Send> {
130 fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self;
132
133 fn is_complete(&self) -> bool;
135
136 fn tool_calls(&self) -> Vec<ToolCall>;
138
139 fn content(&self) -> C;
141}
142
143#[derive(Debug, Clone, Serialize, Default)]
165pub struct LLMResponse<C>
166where
167 C: for<'de> Deserialize<'de> + Default + Clone + Send,
168{
169 pub content: C,
171
172 pub tool_calls: Vec<ToolCall>,
174
175 pub is_complete: bool,
177}
178
179impl<C> LLMResponseTrait<C> for LLMResponse<C>
180where
181 C: for<'de> Deserialize<'de> + Default + Clone + Send,
182{
183 fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self {
184 Self {
185 content,
186 tool_calls,
187 is_complete,
188 }
189 }
190
191 fn is_complete(&self) -> bool {
192 self.is_complete
193 }
194
195 fn tool_calls(&self) -> Vec<ToolCall> {
196 self.tool_calls.clone()
197 }
198
199 fn content(&self) -> C {
200 self.content.clone()
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ToolCall {
225 pub name: String,
227
228 pub input: serde_json::Value,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct LLMConfig {
238 pub model: String,
240 pub api_key: Option<String>,
241 pub base_url: Option<String>,
242
243 pub max_tokens: Option<u32>,
245 pub temperature: Option<f32>,
246 pub top_p: Option<f32>,
247 pub top_k: Option<u32>,
248 pub system: Option<String>,
249
250 pub timeout_seconds: Option<u64>,
252
253 pub embedding_encoding_format: Option<String>,
255 pub embedding_dimensions: Option<u32>,
256
257 pub enable_parallel_tool_use: Option<bool>,
259
260 pub reasoning: Option<bool>,
262 pub reasoning_effort: Option<String>,
263 pub reasoning_budget_tokens: Option<u32>,
264
265 pub api_version: Option<String>,
267 pub deployment_id: Option<String>,
268
269 pub voice: Option<String>,
271
272 pub xai_search_mode: Option<String>,
274 pub xai_search_source_type: Option<String>,
275 pub xai_search_excluded_websites: Option<Vec<String>>,
276 pub xai_search_max_results: Option<u32>,
277 pub xai_search_from_date: Option<String>,
278 pub xai_search_to_date: Option<String>,
279
280 pub openai_enable_web_search: Option<bool>,
282 pub openai_web_search_context_size: Option<String>,
283 pub openai_web_search_user_location_type: Option<String>,
284 pub openai_web_search_user_location_approximate_country: Option<String>,
285 pub openai_web_search_user_location_approximate_city: Option<String>,
286 pub openai_web_search_user_location_approximate_region: Option<String>,
287
288 pub resilient_enable: Option<bool>,
290 pub resilient_attempts: Option<usize>,
291 pub resilient_base_delay_ms: Option<u64>,
292 pub resilient_max_delay_ms: Option<u64>,
293 pub resilient_jitter: Option<bool>,
294}
295
296impl Default for LLMConfig {
297 fn default() -> Self {
298 Self {
299 model: String::new(),
300 api_key: None,
301 base_url: None,
302 max_tokens: Some(4096),
303 temperature: None,
304 top_p: None,
305 top_k: None,
306 system: None,
307 timeout_seconds: None,
308 embedding_encoding_format: None,
309 embedding_dimensions: None,
310 enable_parallel_tool_use: None,
311 reasoning: None,
312 reasoning_effort: None,
313 reasoning_budget_tokens: None,
314 api_version: None,
315 deployment_id: None,
316 voice: None,
317 xai_search_mode: None,
318 xai_search_source_type: None,
319 xai_search_excluded_websites: None,
320 xai_search_max_results: None,
321 xai_search_from_date: None,
322 xai_search_to_date: None,
323 openai_enable_web_search: None,
324 openai_web_search_context_size: None,
325 openai_web_search_user_location_type: None,
326 openai_web_search_user_location_approximate_country: None,
327 openai_web_search_user_location_approximate_city: None,
328 openai_web_search_user_location_approximate_region: None,
329 resilient_enable: None,
330 resilient_attempts: None,
331 resilient_base_delay_ms: None,
332 resilient_max_delay_ms: None,
333 resilient_jitter: None,
334 }
335 }
336}
337
338impl LLMConfig {
339 pub fn new(model: impl Into<String>) -> Self {
341 Self {
342 model: model.into(),
343 ..Default::default()
344 }
345 }
346
347 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
349 self.api_key = Some(api_key.into());
350 self
351 }
352
353 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
355 self.base_url = Some(base_url.into());
356 self
357 }
358
359 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
361 self.max_tokens = Some(max_tokens);
362 self
363 }
364
365 pub fn with_temperature(mut self, temperature: f32) -> Self {
367 self.temperature = Some(temperature);
368 self
369 }
370
371 pub fn with_top_p(mut self, top_p: f32) -> Self {
373 self.top_p = Some(top_p);
374 self
375 }
376
377 pub fn with_top_k(mut self, top_k: u32) -> Self {
379 self.top_k = Some(top_k);
380 self
381 }
382
383 pub fn with_system(mut self, system: impl Into<String>) -> Self {
385 self.system = Some(system.into());
386 self
387 }
388
389 pub fn with_timeout_seconds(mut self, timeout: u64) -> Self {
391 self.timeout_seconds = Some(timeout);
392 self
393 }
394
395 pub fn with_reasoning(mut self, enabled: bool) -> Self {
397 self.reasoning = Some(enabled);
398 self
399 }
400
401 pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
403 self.reasoning_effort = Some(effort.into());
404 self
405 }
406
407 pub fn with_deployment_id(mut self, deployment_id: impl Into<String>) -> Self {
409 self.deployment_id = Some(deployment_id.into());
410 self
411 }
412
413 pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
415 self.api_version = Some(api_version.into());
416 self
417 }
418
419 pub fn with_openai_web_search(mut self, enabled: bool) -> Self {
421 self.openai_enable_web_search = Some(enabled);
422 self
423 }
424
425 pub fn with_resilience(mut self, enabled: bool, attempts: usize) -> Self {
427 self.resilient_enable = Some(enabled);
428 self.resilient_attempts = Some(attempts);
429 self
430 }
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct LLMProviderConfig {
436 pub model: String,
437 pub max_tokens: u64,
438 pub api_key: Option<String>,
439 pub base_url: String,
440}
441
442impl From<LLMConfig> for LLMProviderConfig {
443 fn from(config: LLMConfig) -> Self {
444 Self {
445 model: config.model,
446 max_tokens: config.max_tokens.unwrap_or(4096) as u64,
447 api_key: config.api_key,
448 base_url: config.base_url.unwrap_or_default(),
449 }
450 }
451}
452
453fn get_api_key_env_var(provider: &str) -> Option<&'static str> {
455 match provider.to_lowercase().as_str() {
456 "ollama" => None, "anthropic" | "claude" => Some("ANTHROPIC_API_KEY"),
458 "openai" | "gpt" => Some("OPENAI_API_KEY"),
459 "deepseek" => Some("DEEPSEEK_API_KEY"),
460 "xai" | "x.ai" => Some("XAI_API_KEY"),
461 "phind" => Some("PHIND_API_KEY"),
462 "google" | "gemini" => Some("GOOGLE_API_KEY"),
463 "groq" => Some("GROQ_API_KEY"),
464 "azure" | "azureopenai" | "azure-openai" => Some("AZURE_OPENAI_API_KEY"),
465 "elevenlabs" | "11labs" => Some("ELEVENLABS_API_KEY"),
466 "cohere" => Some("COHERE_API_KEY"),
467 "mistral" => Some("MISTRAL_API_KEY"),
468 "openrouter" => Some("OPENROUTER_API_KEY"),
469 _ => None,
470 }
471}
472
473fn get_api_key_from_env(provider: &str) -> Result<Option<String>, String> {
476 match get_api_key_env_var(provider) {
477 None => Ok(None), Some(env_var) => {
479 match std::env::var(env_var) {
480 Ok(key) => Ok(Some(key)),
481 Err(_) => Err(format!(
482 "API key required for provider '{}'. Please set the {} environment variable or pass the API key explicitly.",
483 provider, env_var
484 ))
485 }
486 }
487 }
488}
489
490fn parse_provider(provider: &str) -> Result<LLMBackend, String> {
492 match provider.to_lowercase().as_str() {
493 "ollama" => Ok(LLMBackend::Ollama),
494 "anthropic" | "claude" => Ok(LLMBackend::Anthropic),
495 "openai" | "gpt" => Ok(LLMBackend::OpenAI),
496 "deepseek" => Ok(LLMBackend::DeepSeek),
497 "xai" | "x.ai" => Ok(LLMBackend::XAI),
498 "phind" => Ok(LLMBackend::Phind),
499 "google" | "gemini" => Ok(LLMBackend::Google),
500 "groq" => Ok(LLMBackend::Groq),
501 "azure" | "azureopenai" | "azure-openai" => Ok(LLMBackend::AzureOpenAI),
502 "elevenlabs" | "11labs" => Ok(LLMBackend::ElevenLabs),
503 "cohere" => Ok(LLMBackend::Cohere),
504 "mistral" => Ok(LLMBackend::Mistral),
505 "openrouter" => Ok(LLMBackend::OpenRouter),
506 _ => Err(format!("Unknown provider: {}", provider)),
507 }
508}
509
510fn build_llm_from_config(
512 config: &LLMConfig,
513 backend: LLMBackend,
514) -> Result<Box<dyn LLMProvider>, String> {
515 let mut builder = LLMBuilder::new().backend(backend.clone());
516
517 let model_name = if config.model.contains("::") {
519 config.model.split("::").nth(1).unwrap_or(&config.model)
520 } else {
521 &config.model
522 };
523
524 builder = builder.model(model_name);
525
526 if let Some(max_tokens) = config.max_tokens {
528 builder = builder.max_tokens(max_tokens);
529 }
530
531 if let Some(ref api_key) = config.api_key {
532 builder = builder.api_key(api_key);
533 }
534
535 if let Some(ref base_url) = config.base_url {
536 if !base_url.is_empty() {
537 builder = builder.base_url(base_url);
538 }
539 }
540
541 if let Some(temperature) = config.temperature {
542 builder = builder.temperature(temperature);
543 }
544
545 if let Some(top_p) = config.top_p {
546 builder = builder.top_p(top_p);
547 }
548
549 if let Some(top_k) = config.top_k {
550 builder = builder.top_k(top_k);
551 }
552
553 if let Some(ref system) = config.system {
554 builder = builder.system(system);
555 }
556
557 if let Some(timeout) = config.timeout_seconds {
558 builder = builder.timeout_seconds(timeout);
559 }
560
561 if let Some(ref format) = config.embedding_encoding_format {
562 builder = builder.embedding_encoding_format(format);
563 }
564
565 if let Some(dims) = config.embedding_dimensions {
566 builder = builder.embedding_dimensions(dims);
567 }
568
569 if let Some(enabled) = config.enable_parallel_tool_use {
570 builder = builder.enable_parallel_tool_use(enabled);
571 }
572
573 if let Some(enabled) = config.reasoning {
574 builder = builder.reasoning(enabled);
575 }
576
577 if let Some(budget) = config.reasoning_budget_tokens {
578 builder = builder.reasoning_budget_tokens(budget);
579 }
580
581 if let Some(ref api_version) = config.api_version {
583 builder = builder.api_version(api_version);
584 }
585
586 if let Some(ref deployment_id) = config.deployment_id {
587 builder = builder.deployment_id(deployment_id);
588 }
589
590 if let Some(ref voice) = config.voice {
592 builder = builder.voice(voice);
593 }
594
595 if let Some(ref mode) = config.xai_search_mode {
597 builder = builder.xai_search_mode(mode);
598 }
599
600 if let (Some(source_type), excluded) = (
602 &config.xai_search_source_type,
603 &config.xai_search_excluded_websites,
604 ) {
605 builder = builder.xai_search_source(source_type, excluded.clone());
606 }
607
608 if let Some(ref from_date) = config.xai_search_from_date {
609 builder = builder.xai_search_from_date(from_date);
610 }
611
612 if let Some(ref to_date) = config.xai_search_to_date {
613 builder = builder.xai_search_to_date(to_date);
614 }
615
616 if let Some(enabled) = config.openai_enable_web_search {
618 builder = builder.openai_enable_web_search(enabled);
619 }
620
621 if let Some(ref context_size) = config.openai_web_search_context_size {
622 builder = builder.openai_web_search_context_size(context_size);
623 }
624
625 if let Some(ref loc_type) = config.openai_web_search_user_location_type {
626 builder = builder.openai_web_search_user_location_type(loc_type);
627 }
628
629 if let Some(ref country) = config.openai_web_search_user_location_approximate_country {
630 builder = builder.openai_web_search_user_location_approximate_country(country);
631 }
632
633 if let Some(ref city) = config.openai_web_search_user_location_approximate_city {
634 builder = builder.openai_web_search_user_location_approximate_city(city);
635 }
636
637 if let Some(ref region) = config.openai_web_search_user_location_approximate_region {
638 builder = builder.openai_web_search_user_location_approximate_region(region);
639 }
640
641 if let Some(enabled) = config.resilient_enable {
643 builder = builder.resilient(enabled);
644 }
645
646 if let Some(attempts) = config.resilient_attempts {
647 builder = builder.resilient_attempts(attempts);
648 }
649
650 builder
651 .build()
652 .map_err(|e| format!("Failed to build LLM: {}", e))
653}
654
655fn estimate_tokens(text: &str) -> u64 {
657 ((text.len() as f64) / 4.0).ceil() as u64
658}
659
660fn get_model_pricing(model: &str) -> (u64, u64) {
662 match model.to_lowercase().as_str() {
663 m if m.contains("gpt-4o") => (2_500, 10_000),
664 m if m.contains("gpt-4-turbo") => (10_000, 30_000),
665 m if m.contains("gpt-4") => (30_000, 60_000),
666 m if m.contains("gpt-3.5-turbo") => (500, 1_500),
667 m if m.contains("claude-3-opus") => (15_000, 75_000),
668 m if m.contains("claude-3-5-sonnet") => (3_000, 15_000),
669 m if m.contains("claude-3-sonnet") => (3_000, 15_000),
670 m if m.contains("claude-3-haiku") => (250, 1_250),
671 _ => (500, 1_500),
672 }
673}
674
675fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> u64 {
677 let (input_price, output_price) = get_model_pricing(model);
678 (input_tokens * input_price + output_tokens * output_price) / 1000
679}
680
681pub struct UniversalLLMClient {
682 config: LLMProviderConfig,
683 llm: Box<dyn LLMProvider>,
684}
685
686#[async_trait]
687pub trait LLMClient: Send + Sync {
688 async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
690 where
691 T: LLMResponseTrait<C> + Default + Send,
692 C: for<'de> Deserialize<'de> + Default + Send + Serialize;
693}
694
695impl Clone for UniversalLLMClient {
696 fn clone(&self) -> Self {
697 let config = self.config.clone();
698 let model = config.clone().model;
699 let api_key = config.api_key.clone();
700 let base_url = config.base_url.clone();
701 let parts: Vec<&str> = model.split("::").collect();
702
703 let provider = parts[0];
704 let model = parts[1];
705
706 let backend = parse_provider(provider).unwrap_or(LLMBackend::Ollama);
707
708 let mut builder = LLMBuilder::new()
709 .backend(backend.clone())
710 .model(model)
711 .max_tokens(4096);
712
713 if let Some(api_key) = api_key {
714 builder = builder.api_key(api_key);
715 }
716
717 if !base_url.is_empty() {
718 builder = builder.base_url(base_url);
719 }
720
721 let llm = builder
722 .build()
723 .map_err(|e| format!("Failed to build LLM: {}", e))
724 .unwrap();
725
726 Self {
727 llm,
728 config: config.clone(),
729 }
730 }
731}
732
733impl UniversalLLMClient {
734 const DEFAULT_SYSTEM_PROMPT: &'static str = "You are a helpful AI assistant.";
735
736 const DEFAULT_TOOL_PROMPT: &'static str = "You have access to the following tools.\n\
738 To call ONE tool, respond EXACTLY in this format:\n\
739 USE_TOOL: tool_name\n\
740 {\"param1\": \"value1\"}\n\n\
741 To call MULTIPLE tools at once, respond in this format:\n\
742 USE_TOOLS:\n\
743 tool_name1\n\
744 {\"param1\": \"value1\"}\n\
745 ---\n\
746 tool_name2\n\
747 {\"param1\": \"value1\"}\n\n\
748 Only call tools using these exact formats. Otherwise, respond normally.";
749
750 fn generate_schema_instruction<C>(sample: &C) -> String
751 where
752 C: Serialize,
753 {
754 let sample_json = serde_json::to_string_pretty(sample).unwrap_or_else(|_| "{}".to_string());
755
756 format!(
757 "Respond with ONLY a JSON object in this exact format:\n{}\n\nProvide your response as valid JSON.",
758 sample_json
759 )
760 }
761
762 pub fn new(provider_model: &str, api_key: Option<String>) -> Result<Self, String> {
763 let parts: Vec<&str> = provider_model.split("::").collect();
764
765 if parts.len() != 2 {
766 return Err(format!(
767 "Invalid format. Use 'provider::model-name'. Got: {}",
768 provider_model
769 ));
770 }
771
772 let provider = parts[0];
773 let model = parts[1];
774
775 let final_api_key = match api_key {
777 Some(key) => Some(key),
778 None => {
779 match get_api_key_from_env(provider) {
781 Ok(env_key) => env_key,
782 Err(e) => return Err(e), }
784 }
785 };
786
787 let config = LLMProviderConfig {
788 model: provider_model.to_string(),
789 max_tokens: 4096,
790 api_key: final_api_key.clone(),
791 base_url: String::new(),
792 };
793
794 let backend = parse_provider(provider)?;
795
796 let base_url = match provider.to_lowercase().as_str() {
797 "ollama" => std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
798 _ => String::new(),
799 };
800
801 let mut builder = LLMBuilder::new()
802 .backend(backend.clone())
803 .model(model)
804 .max_tokens(4096);
805
806 if let Some(api_key) = final_api_key {
807 builder = builder.api_key(api_key);
808 }
809
810 if !base_url.is_empty() {
811 builder = builder.base_url(base_url);
812 }
813
814 let llm = builder
815 .build()
816 .map_err(|e| format!("Failed to build LLM: {}", e))?;
817
818 Ok(Self { llm, config })
819 }
820
821 pub fn new_with_config(llm_config: LLMConfig) -> Result<Self, String> {
837 let parts: Vec<&str> = llm_config.model.split("::").collect();
839
840 if parts.len() != 2 {
841 return Err(format!(
842 "Invalid format. Use 'provider::model-name'. Got: {}",
843 llm_config.model
844 ));
845 }
846
847 let provider = parts[0];
848 let backend = parse_provider(provider)?;
849
850 let mut config = llm_config.clone();
852 if config.base_url.is_none() {
853 match provider.to_lowercase().as_str() {
854 "ollama" => {
855 config.base_url = Some(
856 std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
857 );
858 }
859 _ => {}
860 }
861 }
862
863 if config.api_key.is_none() {
865 match get_api_key_from_env(provider) {
866 Ok(env_key) => config.api_key = env_key,
867 Err(e) => return Err(e), }
869 }
870
871 let llm = build_llm_from_config(&config, backend)?;
873
874 let legacy_config = LLMProviderConfig::from(config);
876
877 Ok(Self {
878 llm,
879 config: legacy_config,
880 })
881 }
882
883 fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
884 messages
885 .iter()
886 .map(|msg| match msg.role.as_str() {
887 "user" => ChatMessage::user().content(&msg.content).build(),
888 "assistant" => ChatMessage::assistant().content(&msg.content).build(),
889 "system" => ChatMessage::assistant().content(&msg.content).build(),
890 "tool" => ChatMessage::assistant()
891 .content(format!("Tool result: {}", msg.content))
892 .build(),
893 _ => ChatMessage::user().content(&msg.content).build(),
894 })
895 .collect()
896 }
897
898 fn build_tool_description(tools: &[ToolSpec]) -> String {
899 tools
900 .iter()
901 .map(|t| {
902 let params = t
903 .input_schema
904 .get("properties")
905 .and_then(|p| p.as_object())
906 .map(|o| o.keys().cloned().collect::<Vec<_>>().join(", "))
907 .unwrap_or_default();
908
909 if t.description.is_empty() {
910 format!("- {}({})", t.name, params)
911 } else {
912 format!("- {}({}): {}", t.name, params, t.description)
913 }
914 })
915 .collect::<Vec<_>>()
916 .join("\n")
917 }
918
919 fn parse_tool_calls(response_text: &str) -> Vec<ToolCall> {
921 let mut tool_calls = Vec::new();
922
923 if response_text.starts_with("USE_TOOLS:") {
925 let parts: Vec<&str> = response_text
927 .strip_prefix("USE_TOOLS:")
928 .unwrap_or("")
929 .split("---")
930 .collect();
931
932 for part in parts {
933 let lines: Vec<&str> = part.trim().lines().collect();
934 if lines.is_empty() {
935 continue;
936 }
937
938 let tool_name = lines[0].trim().to_string();
939 let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
940
941 if let Ok(input_value) = serde_json::from_str(&json_block) {
942 tool_calls.push(ToolCall {
943 name: tool_name,
944 input: input_value,
945 });
946 }
947 }
948 }
949 else if response_text.starts_with("USE_TOOL:") {
951 let lines: Vec<&str> = response_text.lines().collect();
952 let tool_name = lines[0]
953 .strip_prefix("USE_TOOL:")
954 .unwrap_or("")
955 .trim()
956 .to_string();
957
958 let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
959
960 if let Ok(input_value) = serde_json::from_str(&json_block) {
961 tool_calls.push(ToolCall {
962 name: tool_name,
963 input: input_value,
964 });
965 }
966 }
967
968 tool_calls
969 }
970}
971
972#[async_trait]
973impl LLMClient for UniversalLLMClient {
974 async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
975 where
976 T: LLMResponseTrait<C> + Default + Send,
977 C: for<'de> Deserialize<'de> + Default + Send + Serialize,
978 {
979 let mut chat_messages = vec![];
980
981 let has_user_system_prompt = messages.iter().any(|m| m.role == "system");
983 if !has_user_system_prompt {
984 chat_messages.push(
985 ChatMessage::assistant()
986 .content(Self::DEFAULT_SYSTEM_PROMPT)
987 .build(),
988 );
989 }
990
991 let user_tool_prompt = messages
993 .iter()
994 .find(|m| m.role == "system_tools")
995 .map(|m| m.content.clone());
996
997 if !tools.is_empty() {
998 let tool_list = Self::build_tool_description(tools);
999 let tool_prompt = user_tool_prompt.unwrap_or_else(|| {
1000 format!(
1001 "{}\n\nAvailable Tools:\n{}\n\n{}",
1002 Self::DEFAULT_TOOL_PROMPT,
1003 tool_list,
1004 "Use only the EXACT formats shown above when calling tools."
1005 )
1006 });
1007 chat_messages.push(ChatMessage::assistant().content(tool_prompt).build());
1008 }
1009
1010 let sample_c = C::default();
1012 let schema_instruction = Self::generate_schema_instruction(&sample_c);
1013
1014 chat_messages.push(ChatMessage::assistant().content(schema_instruction).build());
1015
1016 chat_messages.extend(self.convert_messages(messages));
1018
1019 let try_parse_c = |s: &str| -> C {
1021 let text = s.trim();
1022
1023 if let Ok(parsed) = serde_json::from_str::<C>(text) {
1025 return parsed;
1026 }
1027
1028 let cleaned = text
1030 .strip_prefix("```json")
1031 .unwrap_or(text)
1032 .strip_prefix("```")
1033 .unwrap_or(text)
1034 .strip_suffix("```")
1035 .unwrap_or(text)
1036 .trim();
1037
1038 if let Ok(parsed) = serde_json::from_str::<C>(cleaned) {
1039 return parsed;
1040 }
1041
1042 if let Some(start) = text.find('{') {
1044 if let Some(end) = text.rfind('}') {
1045 let json_part = &text[start..=end];
1046 if let Ok(parsed) = serde_json::from_str::<C>(json_part) {
1047 return parsed;
1048 }
1049 }
1050 }
1051
1052 if let Ok(quoted) = serde_json::to_string(text) {
1054 if let Ok(parsed) = serde_json::from_str::<C>("ed) {
1055 return parsed;
1056 }
1057 }
1058
1059 C::default()
1061 };
1062
1063 let start = std::time::Instant::now();
1065 let response = self
1066 .llm
1067 .chat(&chat_messages)
1068 .await
1069 .map_err(|e| format!("LLM error: {}", e))?;
1070 let duration = start.elapsed().as_micros() as u64;
1071
1072 let response_text = response.text().unwrap_or_default();
1073
1074 let input_text: String = messages
1076 .iter()
1077 .map(|m| m.content.as_str())
1078 .collect::<Vec<_>>()
1079 .join(" ");
1080 let input_tokens = estimate_tokens(&input_text);
1081 let output_tokens = estimate_tokens(&response_text);
1082 let total_tokens = input_tokens + output_tokens;
1083 let cost_us = calculate_cost(&self.config.model, input_tokens, output_tokens);
1084
1085 tracing::trace!(
1087 duration_us = duration,
1088 tokens = total_tokens,
1089 cost_us = cost_us,
1090 "llm call completed"
1091 );
1092
1093 let tool_calls = Self::parse_tool_calls(&response_text);
1095
1096 if !tool_calls.is_empty() {
1098 let parsed_content: C = C::default();
1099 return Ok(T::new(parsed_content, tool_calls, false));
1100 }
1101
1102 let parsed_content: C = try_parse_c(&response_text);
1104 Ok(T::new(parsed_content, vec![], true))
1105 }
1106}
1107
1108pub struct MockLLMClient {
1137 response_content: String,
1138}
1139
1140impl MockLLMClient {
1141 pub fn new(response: &str) -> Self {
1159 Self {
1160 response_content: response.to_string(),
1161 }
1162 }
1163
1164 pub fn default_hello() -> Self {
1174 Self::new("Hello! How can I help you today?")
1175 }
1176}
1177
1178#[async_trait::async_trait]
1179impl LLMClient for MockLLMClient {
1180 async fn complete<T, C>(&self, _messages: &[Message], _tools: &[ToolSpec]) -> Result<T, String>
1181 where
1182 T: LLMResponseTrait<C> + Default + Send,
1183 C: for<'de> Deserialize<'de> + Default + Send + Serialize,
1184 {
1185 let content: C = if let Ok(parsed) = serde_json::from_str(&self.response_content) {
1187 parsed
1188 } else {
1189 C::default()
1190 };
1191
1192 Ok(T::new(content, vec![], true))
1193 }
1194}
1195
1196#[cfg(test)]
1201mod tests {
1202 use super::*;
1203
1204 #[tokio::test]
1205 async fn test_mock_llm_basic() {
1206 let client = MockLLMClient::default_hello();
1207 let messages = vec![Message {
1208 role: "user".into(),
1209 content: "Say hello".into(),
1210 }];
1211
1212 let result: LLMResponse<String> = client
1213 .complete::<LLMResponse<String>, String>(&messages, &[])
1214 .await
1215 .expect("Mock LLM failed");
1216
1217 assert!(result.content.is_empty()); assert!(result.tool_calls.is_empty());
1219 assert!(result.is_complete);
1220 }
1221
1222 #[tokio::test]
1223 async fn test_mock_structured_output() {
1224 let client = MockLLMClient::new(r#"{"field1": 42.5, "flag": true}"#);
1225 let messages = vec![Message {
1226 role: "user".into(),
1227 content: "Return structured data".into(),
1228 }];
1229
1230 #[derive(Deserialize, Serialize, Default, Clone, Debug)]
1231 struct MyOutput {
1232 field1: f64,
1233 flag: bool,
1234 }
1235
1236 let result: LLMResponse<MyOutput> = client
1237 .complete::<LLMResponse<MyOutput>, MyOutput>(&messages, &[])
1238 .await
1239 .expect("Mock LLM failed");
1240
1241 assert_eq!(result.content.field1, 42.5);
1242 assert_eq!(result.content.flag, true);
1243 assert!(result.tool_calls.is_empty());
1244 assert!(result.is_complete);
1245 }
1246}