1pub mod llm_agent;
110pub mod react;
111pub mod types;
112
113pub use llm_agent::{LlmAgent, LlmAgentBuilder};
114pub use react::{FinishReason, ReActConfig, ReActEngine, ReActResult, ReActStep};
115
116use async_trait::async_trait;
117use llm::builder::{LLMBackend, LLMBuilder};
118use llm::chat::ChatMessage;
119use llm::LLMProvider;
120use serde::{Deserialize, Serialize};
121use types::{Message, ToolSpec};
122
123pub trait LLMResponseTrait<C: for<'de> Deserialize<'de> + Default + Send> {
132 fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self;
134
135 fn is_complete(&self) -> bool;
137
138 fn tool_calls(&self) -> Vec<ToolCall>;
140
141 fn content(&self) -> C;
143}
144
145#[derive(Debug, Clone, Serialize, Default)]
167pub struct LLMResponse<C>
168where
169 C: for<'de> Deserialize<'de> + Default + Clone + Send,
170{
171 pub content: C,
173
174 pub tool_calls: Vec<ToolCall>,
176
177 pub is_complete: bool,
179}
180
181impl<C> LLMResponseTrait<C> for LLMResponse<C>
182where
183 C: for<'de> Deserialize<'de> + Default + Clone + Send,
184{
185 fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self {
186 Self {
187 content,
188 tool_calls,
189 is_complete,
190 }
191 }
192
193 fn is_complete(&self) -> bool {
194 self.is_complete
195 }
196
197 fn tool_calls(&self) -> Vec<ToolCall> {
198 self.tool_calls.clone()
199 }
200
201 fn content(&self) -> C {
202 self.content.clone()
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ToolCall {
227 pub name: String,
229
230 pub input: serde_json::Value,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct LLMConfig {
240 pub model: String,
242 pub api_key: Option<String>,
243 pub base_url: Option<String>,
244
245 pub max_tokens: Option<u32>,
247 pub temperature: Option<f32>,
248 pub top_p: Option<f32>,
249 pub top_k: Option<u32>,
250 pub system: Option<String>,
251
252 pub timeout_seconds: Option<u64>,
254
255 pub embedding_encoding_format: Option<String>,
257 pub embedding_dimensions: Option<u32>,
258
259 pub enable_parallel_tool_use: Option<bool>,
261
262 pub reasoning: Option<bool>,
264 pub reasoning_effort: Option<String>,
265 pub reasoning_budget_tokens: Option<u32>,
266
267 pub api_version: Option<String>,
269 pub deployment_id: Option<String>,
270
271 pub voice: Option<String>,
273
274 pub xai_search_mode: Option<String>,
276 pub xai_search_source_type: Option<String>,
277 pub xai_search_excluded_websites: Option<Vec<String>>,
278 pub xai_search_max_results: Option<u32>,
279 pub xai_search_from_date: Option<String>,
280 pub xai_search_to_date: Option<String>,
281
282 pub openai_enable_web_search: Option<bool>,
284 pub openai_web_search_context_size: Option<String>,
285 pub openai_web_search_user_location_type: Option<String>,
286 pub openai_web_search_user_location_approximate_country: Option<String>,
287 pub openai_web_search_user_location_approximate_city: Option<String>,
288 pub openai_web_search_user_location_approximate_region: Option<String>,
289
290 pub resilient_enable: Option<bool>,
292 pub resilient_attempts: Option<usize>,
293 pub resilient_base_delay_ms: Option<u64>,
294 pub resilient_max_delay_ms: Option<u64>,
295 pub resilient_jitter: Option<bool>,
296}
297
298impl Default for LLMConfig {
299 fn default() -> Self {
300 Self {
301 model: String::new(),
302 api_key: None,
303 base_url: None,
304 max_tokens: Some(4096),
305 temperature: None,
306 top_p: None,
307 top_k: None,
308 system: None,
309 timeout_seconds: None,
310 embedding_encoding_format: None,
311 embedding_dimensions: None,
312 enable_parallel_tool_use: None,
313 reasoning: None,
314 reasoning_effort: None,
315 reasoning_budget_tokens: None,
316 api_version: None,
317 deployment_id: None,
318 voice: None,
319 xai_search_mode: None,
320 xai_search_source_type: None,
321 xai_search_excluded_websites: None,
322 xai_search_max_results: None,
323 xai_search_from_date: None,
324 xai_search_to_date: None,
325 openai_enable_web_search: None,
326 openai_web_search_context_size: None,
327 openai_web_search_user_location_type: None,
328 openai_web_search_user_location_approximate_country: None,
329 openai_web_search_user_location_approximate_city: None,
330 openai_web_search_user_location_approximate_region: None,
331 resilient_enable: None,
332 resilient_attempts: None,
333 resilient_base_delay_ms: None,
334 resilient_max_delay_ms: None,
335 resilient_jitter: None,
336 }
337 }
338}
339
340impl LLMConfig {
341 pub fn new(model: impl Into<String>) -> Self {
343 Self {
344 model: model.into(),
345 ..Default::default()
346 }
347 }
348
349 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
351 self.api_key = Some(api_key.into());
352 self
353 }
354
355 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
357 self.base_url = Some(base_url.into());
358 self
359 }
360
361 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
363 self.max_tokens = Some(max_tokens);
364 self
365 }
366
367 pub fn with_temperature(mut self, temperature: f32) -> Self {
369 self.temperature = Some(temperature);
370 self
371 }
372
373 pub fn with_top_p(mut self, top_p: f32) -> Self {
375 self.top_p = Some(top_p);
376 self
377 }
378
379 pub fn with_top_k(mut self, top_k: u32) -> Self {
381 self.top_k = Some(top_k);
382 self
383 }
384
385 pub fn with_system(mut self, system: impl Into<String>) -> Self {
387 self.system = Some(system.into());
388 self
389 }
390
391 pub fn with_timeout_seconds(mut self, timeout: u64) -> Self {
393 self.timeout_seconds = Some(timeout);
394 self
395 }
396
397 pub fn with_reasoning(mut self, enabled: bool) -> Self {
399 self.reasoning = Some(enabled);
400 self
401 }
402
403 pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
405 self.reasoning_effort = Some(effort.into());
406 self
407 }
408
409 pub fn with_deployment_id(mut self, deployment_id: impl Into<String>) -> Self {
411 self.deployment_id = Some(deployment_id.into());
412 self
413 }
414
415 pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
417 self.api_version = Some(api_version.into());
418 self
419 }
420
421 pub fn with_openai_web_search(mut self, enabled: bool) -> Self {
423 self.openai_enable_web_search = Some(enabled);
424 self
425 }
426
427 pub fn with_resilience(mut self, enabled: bool, attempts: usize) -> Self {
429 self.resilient_enable = Some(enabled);
430 self.resilient_attempts = Some(attempts);
431 self
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct LLMProviderConfig {
438 pub model: String,
439 pub max_tokens: u64,
440 pub api_key: Option<String>,
441 pub base_url: String,
442}
443
444impl From<LLMConfig> for LLMProviderConfig {
445 fn from(config: LLMConfig) -> Self {
446 Self {
447 model: config.model,
448 max_tokens: config.max_tokens.unwrap_or(4096) as u64,
449 api_key: config.api_key,
450 base_url: config.base_url.unwrap_or_default(),
451 }
452 }
453}
454
455fn get_api_key_env_var(provider: &str) -> Option<&'static str> {
457 match provider.to_lowercase().as_str() {
458 "ollama" => None, "anthropic" | "claude" => Some("ANTHROPIC_API_KEY"),
460 "openai" | "gpt" => Some("OPENAI_API_KEY"),
461 "deepseek" => Some("DEEPSEEK_API_KEY"),
462 "xai" | "x.ai" => Some("XAI_API_KEY"),
463 "phind" => Some("PHIND_API_KEY"),
464 "google" | "gemini" => Some("GOOGLE_API_KEY"),
465 "groq" => Some("GROQ_API_KEY"),
466 "azure" | "azureopenai" | "azure-openai" => Some("AZURE_OPENAI_API_KEY"),
467 "elevenlabs" | "11labs" => Some("ELEVENLABS_API_KEY"),
468 "cohere" => Some("COHERE_API_KEY"),
469 "mistral" => Some("MISTRAL_API_KEY"),
470 "openrouter" => Some("OPENROUTER_API_KEY"),
471 _ => None,
472 }
473}
474
475fn get_api_key_from_env(provider: &str) -> Result<Option<String>, String> {
478 match get_api_key_env_var(provider) {
479 None => Ok(None), Some(env_var) => {
481 match std::env::var(env_var) {
482 Ok(key) => Ok(Some(key)),
483 Err(_) => Err(format!(
484 "API key required for provider '{}'. Please set the {} environment variable or pass the API key explicitly.",
485 provider, env_var
486 ))
487 }
488 }
489 }
490}
491
492fn parse_provider(provider: &str) -> Result<LLMBackend, String> {
494 match provider.to_lowercase().as_str() {
495 "ollama" => Ok(LLMBackend::Ollama),
496 "anthropic" | "claude" => Ok(LLMBackend::Anthropic),
497 "openai" | "gpt" => Ok(LLMBackend::OpenAI),
498 "deepseek" => Ok(LLMBackend::DeepSeek),
499 "xai" | "x.ai" => Ok(LLMBackend::XAI),
500 "phind" => Ok(LLMBackend::Phind),
501 "google" | "gemini" => Ok(LLMBackend::Google),
502 "groq" => Ok(LLMBackend::Groq),
503 "azure" | "azureopenai" | "azure-openai" => Ok(LLMBackend::AzureOpenAI),
504 "elevenlabs" | "11labs" => Ok(LLMBackend::ElevenLabs),
505 "cohere" => Ok(LLMBackend::Cohere),
506 "mistral" => Ok(LLMBackend::Mistral),
507 "openrouter" => Ok(LLMBackend::OpenRouter),
508 _ => Err(format!("Unknown provider: {}", provider)),
509 }
510}
511
512fn build_llm_from_config(
514 config: &LLMConfig,
515 backend: LLMBackend,
516) -> Result<Box<dyn LLMProvider>, String> {
517 let mut builder = LLMBuilder::new().backend(backend.clone());
518
519 let model_name = if config.model.contains("::") {
521 config.model.split("::").nth(1).unwrap_or(&config.model)
522 } else {
523 &config.model
524 };
525
526 builder = builder.model(model_name);
527
528 if let Some(max_tokens) = config.max_tokens {
530 builder = builder.max_tokens(max_tokens);
531 }
532
533 if let Some(ref api_key) = config.api_key {
534 builder = builder.api_key(api_key);
535 }
536
537 if let Some(ref base_url) = config.base_url {
538 if !base_url.is_empty() {
539 builder = builder.base_url(base_url);
540 }
541 }
542
543 if let Some(temperature) = config.temperature {
544 builder = builder.temperature(temperature);
545 }
546
547 if let Some(top_p) = config.top_p {
548 builder = builder.top_p(top_p);
549 }
550
551 if let Some(top_k) = config.top_k {
552 builder = builder.top_k(top_k);
553 }
554
555 if let Some(ref system) = config.system {
556 builder = builder.system(system);
557 }
558
559 if let Some(timeout) = config.timeout_seconds {
560 builder = builder.timeout_seconds(timeout);
561 }
562
563 if let Some(ref format) = config.embedding_encoding_format {
564 builder = builder.embedding_encoding_format(format);
565 }
566
567 if let Some(dims) = config.embedding_dimensions {
568 builder = builder.embedding_dimensions(dims);
569 }
570
571 if let Some(enabled) = config.enable_parallel_tool_use {
572 builder = builder.enable_parallel_tool_use(enabled);
573 }
574
575 if let Some(enabled) = config.reasoning {
576 builder = builder.reasoning(enabled);
577 }
578
579 if let Some(budget) = config.reasoning_budget_tokens {
580 builder = builder.reasoning_budget_tokens(budget);
581 }
582
583 if let Some(ref api_version) = config.api_version {
585 builder = builder.api_version(api_version);
586 }
587
588 if let Some(ref deployment_id) = config.deployment_id {
589 builder = builder.deployment_id(deployment_id);
590 }
591
592 if let Some(ref voice) = config.voice {
594 builder = builder.voice(voice);
595 }
596
597 if let Some(ref mode) = config.xai_search_mode {
599 builder = builder.xai_search_mode(mode);
600 }
601
602 if let (Some(source_type), excluded) = (
604 &config.xai_search_source_type,
605 &config.xai_search_excluded_websites,
606 ) {
607 builder = builder.xai_search_source(source_type, excluded.clone());
608 }
609
610 if let Some(ref from_date) = config.xai_search_from_date {
611 builder = builder.xai_search_from_date(from_date);
612 }
613
614 if let Some(ref to_date) = config.xai_search_to_date {
615 builder = builder.xai_search_to_date(to_date);
616 }
617
618 if let Some(enabled) = config.openai_enable_web_search {
620 builder = builder.openai_enable_web_search(enabled);
621 }
622
623 if let Some(ref context_size) = config.openai_web_search_context_size {
624 builder = builder.openai_web_search_context_size(context_size);
625 }
626
627 if let Some(ref loc_type) = config.openai_web_search_user_location_type {
628 builder = builder.openai_web_search_user_location_type(loc_type);
629 }
630
631 if let Some(ref country) = config.openai_web_search_user_location_approximate_country {
632 builder = builder.openai_web_search_user_location_approximate_country(country);
633 }
634
635 if let Some(ref city) = config.openai_web_search_user_location_approximate_city {
636 builder = builder.openai_web_search_user_location_approximate_city(city);
637 }
638
639 if let Some(ref region) = config.openai_web_search_user_location_approximate_region {
640 builder = builder.openai_web_search_user_location_approximate_region(region);
641 }
642
643 if let Some(enabled) = config.resilient_enable {
645 builder = builder.resilient(enabled);
646 }
647
648 if let Some(attempts) = config.resilient_attempts {
649 builder = builder.resilient_attempts(attempts);
650 }
651
652 builder
653 .build()
654 .map_err(|e| format!("Failed to build LLM: {}", e))
655}
656
657fn estimate_tokens(text: &str) -> u64 {
659 ((text.len() as f64) / 4.0).ceil() as u64
660}
661
662fn get_model_pricing(model: &str) -> (u64, u64) {
664 match model.to_lowercase().as_str() {
665 m if m.contains("gpt-4o") => (2_500, 10_000),
666 m if m.contains("gpt-4-turbo") => (10_000, 30_000),
667 m if m.contains("gpt-4") => (30_000, 60_000),
668 m if m.contains("gpt-3.5-turbo") => (500, 1_500),
669 m if m.contains("claude-3-opus") => (15_000, 75_000),
670 m if m.contains("claude-3-5-sonnet") => (3_000, 15_000),
671 m if m.contains("claude-3-sonnet") => (3_000, 15_000),
672 m if m.contains("claude-3-haiku") => (250, 1_250),
673 _ => (500, 1_500),
674 }
675}
676
677fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> u64 {
679 let (input_price, output_price) = get_model_pricing(model);
680 (input_tokens * input_price + output_tokens * output_price) / 1000
681}
682
683pub struct UniversalLLMClient {
684 config: LLMProviderConfig,
685 llm: Box<dyn LLMProvider>,
686}
687
688#[async_trait]
689pub trait LLMClient: Send + Sync {
690 async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
692 where
693 T: LLMResponseTrait<C> + Default + Send,
694 C: for<'de> Deserialize<'de> + Default + Send + Serialize;
695}
696
697impl Clone for UniversalLLMClient {
698 fn clone(&self) -> Self {
699 let config = self.config.clone();
700 let model = config.clone().model;
701 let api_key = config.api_key.clone();
702 let base_url = config.base_url.clone();
703 let parts: Vec<&str> = model.split("::").collect();
704
705 let provider = parts[0];
706 let model = parts[1];
707
708 let backend = parse_provider(provider).unwrap_or(LLMBackend::Ollama);
709
710 let mut builder = LLMBuilder::new()
711 .backend(backend.clone())
712 .model(model)
713 .max_tokens(4096);
714
715 if let Some(api_key) = api_key {
716 builder = builder.api_key(api_key);
717 }
718
719 if !base_url.is_empty() {
720 builder = builder.base_url(base_url);
721 }
722
723 let llm = builder
724 .build()
725 .map_err(|e| format!("Failed to build LLM: {}", e))
726 .unwrap();
727
728 Self {
729 llm,
730 config: config.clone(),
731 }
732 }
733}
734
735impl UniversalLLMClient {
736 const DEFAULT_SYSTEM_PROMPT: &'static str = "You are a helpful AI assistant.";
737
738 const DEFAULT_TOOL_PROMPT: &'static str = "You have access to the following tools.\n\
740 To call ONE tool, respond EXACTLY in this format:\n\
741 USE_TOOL: tool_name\n\
742 {\"param1\": \"value1\"}\n\n\
743 To call MULTIPLE tools at once, respond in this format:\n\
744 USE_TOOLS:\n\
745 tool_name1\n\
746 {\"param1\": \"value1\"}\n\
747 ---\n\
748 tool_name2\n\
749 {\"param1\": \"value1\"}\n\n\
750 Only call tools using these exact formats. Otherwise, respond normally.";
751
752 fn generate_schema_instruction<C>(sample: &C) -> String
753 where
754 C: Serialize,
755 {
756 let sample_json = serde_json::to_string_pretty(sample).unwrap_or_else(|_| "{}".to_string());
757
758 format!(
759 "Respond with ONLY a JSON object in this exact format:\n{}\n\nProvide your response as valid JSON.",
760 sample_json
761 )
762 }
763
764 pub fn new(provider_model: &str, api_key: Option<String>) -> Result<Self, String> {
765 let parts: Vec<&str> = provider_model.split("::").collect();
766
767 if parts.len() != 2 {
768 return Err(format!(
769 "Invalid format. Use 'provider::model-name'. Got: {}",
770 provider_model
771 ));
772 }
773
774 let provider = parts[0];
775 let model = parts[1];
776
777 let final_api_key = match api_key {
779 Some(key) => Some(key),
780 None => {
781 match get_api_key_from_env(provider) {
783 Ok(env_key) => env_key,
784 Err(e) => return Err(e), }
786 }
787 };
788
789 let config = LLMProviderConfig {
790 model: provider_model.to_string(),
791 max_tokens: 4096,
792 api_key: final_api_key.clone(),
793 base_url: String::new(),
794 };
795
796 let backend = parse_provider(provider)?;
797
798 let base_url = match provider.to_lowercase().as_str() {
799 "ollama" => std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
800 _ => String::new(),
801 };
802
803 let mut builder = LLMBuilder::new()
804 .backend(backend.clone())
805 .model(model)
806 .max_tokens(4096);
807
808 if let Some(api_key) = final_api_key {
809 builder = builder.api_key(api_key);
810 }
811
812 if !base_url.is_empty() {
813 builder = builder.base_url(base_url);
814 }
815
816 let llm = builder
817 .build()
818 .map_err(|e| format!("Failed to build LLM: {}", e))?;
819
820 Ok(Self { llm, config })
821 }
822
823 pub fn new_with_config(llm_config: LLMConfig) -> Result<Self, String> {
839 let parts: Vec<&str> = llm_config.model.split("::").collect();
841
842 if parts.len() != 2 {
843 return Err(format!(
844 "Invalid format. Use 'provider::model-name'. Got: {}",
845 llm_config.model
846 ));
847 }
848
849 let provider = parts[0];
850 let backend = parse_provider(provider)?;
851
852 let mut config = llm_config.clone();
854 if config.base_url.is_none() {
855 match provider.to_lowercase().as_str() {
856 "ollama" => {
857 config.base_url = Some(
858 std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
859 );
860 }
861 _ => {}
862 }
863 }
864
865 if config.api_key.is_none() {
867 match get_api_key_from_env(provider) {
868 Ok(env_key) => config.api_key = env_key,
869 Err(e) => return Err(e), }
871 }
872
873 let llm = build_llm_from_config(&config, backend)?;
875
876 let legacy_config = LLMProviderConfig::from(config);
878
879 Ok(Self {
880 llm,
881 config: legacy_config,
882 })
883 }
884
885 fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
886 messages
887 .iter()
888 .map(|msg| match msg.role.as_str() {
889 "user" => ChatMessage::user().content(&msg.content).build(),
890 "assistant" => ChatMessage::assistant().content(&msg.content).build(),
891 "system" => ChatMessage::assistant().content(&msg.content).build(),
892 "tool" => ChatMessage::assistant()
893 .content(format!("Tool result: {}", msg.content))
894 .build(),
895 _ => ChatMessage::user().content(&msg.content).build(),
896 })
897 .collect()
898 }
899
900 fn build_tool_description(tools: &[ToolSpec]) -> String {
901 tools
902 .iter()
903 .map(|t| {
904 let params = t
905 .input_schema
906 .get("properties")
907 .and_then(|p| p.as_object())
908 .map(|o| o.keys().cloned().collect::<Vec<_>>().join(", "))
909 .unwrap_or_default();
910
911 if t.description.is_empty() {
912 format!("- {}({})", t.name, params)
913 } else {
914 format!("- {}({}): {}", t.name, params, t.description)
915 }
916 })
917 .collect::<Vec<_>>()
918 .join("\n")
919 }
920
921 fn parse_tool_calls(response_text: &str) -> Vec<ToolCall> {
923 let mut tool_calls = Vec::new();
924
925 if response_text.starts_with("USE_TOOLS:") {
927 let parts: Vec<&str> = response_text
929 .strip_prefix("USE_TOOLS:")
930 .unwrap_or("")
931 .split("---")
932 .collect();
933
934 for part in parts {
935 let lines: Vec<&str> = part.trim().lines().collect();
936 if lines.is_empty() {
937 continue;
938 }
939
940 let tool_name = lines[0].trim().to_string();
941 let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
942
943 if let Ok(input_value) = serde_json::from_str(&json_block) {
944 tool_calls.push(ToolCall {
945 name: tool_name,
946 input: input_value,
947 });
948 }
949 }
950 }
951 else if response_text.starts_with("USE_TOOL:") {
953 let lines: Vec<&str> = response_text.lines().collect();
954 let tool_name = lines[0]
955 .strip_prefix("USE_TOOL:")
956 .unwrap_or("")
957 .trim()
958 .to_string();
959
960 let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
961
962 if let Ok(input_value) = serde_json::from_str(&json_block) {
963 tool_calls.push(ToolCall {
964 name: tool_name,
965 input: input_value,
966 });
967 }
968 }
969
970 tool_calls
971 }
972}
973
974#[async_trait]
975impl LLMClient for UniversalLLMClient {
976 async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
977 where
978 T: LLMResponseTrait<C> + Default + Send,
979 C: for<'de> Deserialize<'de> + Default + Send + Serialize,
980 {
981 let mut chat_messages = vec![];
982
983 let has_user_system_prompt = messages.iter().any(|m| m.role == "system");
985 if !has_user_system_prompt {
986 chat_messages.push(
987 ChatMessage::assistant()
988 .content(Self::DEFAULT_SYSTEM_PROMPT)
989 .build(),
990 );
991 }
992
993 let user_tool_prompt = messages
995 .iter()
996 .find(|m| m.role == "system_tools")
997 .map(|m| m.content.clone());
998
999 if !tools.is_empty() {
1000 let tool_list = Self::build_tool_description(tools);
1001 let tool_prompt = user_tool_prompt.unwrap_or_else(|| {
1002 format!(
1003 "{}\n\nAvailable Tools:\n{}\n\n{}",
1004 Self::DEFAULT_TOOL_PROMPT,
1005 tool_list,
1006 "Use only the EXACT formats shown above when calling tools."
1007 )
1008 });
1009 chat_messages.push(ChatMessage::assistant().content(tool_prompt).build());
1010 }
1011
1012 let sample_c = C::default();
1014 let schema_instruction = Self::generate_schema_instruction(&sample_c);
1015
1016 chat_messages.push(ChatMessage::assistant().content(schema_instruction).build());
1017
1018 chat_messages.extend(self.convert_messages(messages));
1020
1021 let try_parse_c = |s: &str| -> C {
1023 let text = s.trim();
1024
1025 if let Ok(parsed) = serde_json::from_str::<C>(text) {
1027 return parsed;
1028 }
1029
1030 let cleaned = text
1032 .strip_prefix("```json")
1033 .unwrap_or(text)
1034 .strip_prefix("```")
1035 .unwrap_or(text)
1036 .strip_suffix("```")
1037 .unwrap_or(text)
1038 .trim();
1039
1040 if let Ok(parsed) = serde_json::from_str::<C>(cleaned) {
1041 return parsed;
1042 }
1043
1044 if let Some(start) = text.find('{') {
1046 if let Some(end) = text.rfind('}') {
1047 let json_part = &text[start..=end];
1048 if let Ok(parsed) = serde_json::from_str::<C>(json_part) {
1049 return parsed;
1050 }
1051 }
1052 }
1053
1054 if let Ok(quoted) = serde_json::to_string(text) {
1056 if let Ok(parsed) = serde_json::from_str::<C>("ed) {
1057 return parsed;
1058 }
1059 }
1060
1061 C::default()
1063 };
1064
1065 let start = std::time::Instant::now();
1067 let response = self
1068 .llm
1069 .chat(&chat_messages)
1070 .await
1071 .map_err(|e| format!("LLM error: {}", e))?;
1072 let duration = start.elapsed().as_micros() as u64;
1073
1074 let response_text = response.text().unwrap_or_default();
1075
1076 let input_text: String = messages
1078 .iter()
1079 .map(|m| m.content.as_str())
1080 .collect::<Vec<_>>()
1081 .join(" ");
1082 let input_tokens = estimate_tokens(&input_text);
1083 let output_tokens = estimate_tokens(&response_text);
1084 let total_tokens = input_tokens + output_tokens;
1085 let cost_us = calculate_cost(&self.config.model, input_tokens, output_tokens);
1086
1087 crate::metrics::metrics().record_llm_call(duration, total_tokens, cost_us);
1089
1090 let tool_calls = Self::parse_tool_calls(&response_text);
1092
1093 if !tool_calls.is_empty() {
1095 let parsed_content: C = C::default();
1096 return Ok(T::new(parsed_content, tool_calls, false));
1097 }
1098
1099 let parsed_content: C = try_parse_c(&response_text);
1101 Ok(T::new(parsed_content, vec![], true))
1102 }
1103}
1104
1105pub struct MockLLMClient {
1134 response_content: String,
1135}
1136
1137impl MockLLMClient {
1138 pub fn new(response: &str) -> Self {
1156 Self {
1157 response_content: response.to_string(),
1158 }
1159 }
1160
1161 pub fn default_hello() -> Self {
1171 Self::new("Hello! How can I help you today?")
1172 }
1173}
1174
1175#[async_trait::async_trait]
1176impl LLMClient for MockLLMClient {
1177 async fn complete<T, C>(&self, _messages: &[Message], _tools: &[ToolSpec]) -> Result<T, String>
1178 where
1179 T: LLMResponseTrait<C> + Default + Send,
1180 C: for<'de> Deserialize<'de> + Default + Send + Serialize,
1181 {
1182 let content: C = if let Ok(parsed) = serde_json::from_str(&self.response_content) {
1184 parsed
1185 } else {
1186 C::default()
1187 };
1188
1189 Ok(T::new(content, vec![], true))
1190 }
1191}
1192
1193#[cfg(test)]
1198mod tests {
1199 use super::*;
1200
1201 #[tokio::test]
1202 async fn test_mock_llm_basic() {
1203 let client = MockLLMClient::default_hello();
1204 let messages = vec![Message {
1205 role: "user".into(),
1206 content: "Say hello".into(),
1207 }];
1208
1209 let result: LLMResponse<String> = client
1210 .complete::<LLMResponse<String>, String>(&messages, &[])
1211 .await
1212 .expect("Mock LLM failed");
1213
1214 assert!(result.content.is_empty()); assert!(result.tool_calls.is_empty());
1216 assert!(result.is_complete);
1217 }
1218
1219 #[tokio::test]
1220 async fn test_mock_structured_output() {
1221 let client = MockLLMClient::new(r#"{"field1": 42.5, "flag": true}"#);
1222 let messages = vec![Message {
1223 role: "user".into(),
1224 content: "Return structured data".into(),
1225 }];
1226
1227 #[derive(Deserialize, Serialize, Default, Clone, Debug)]
1228 struct MyOutput {
1229 field1: f64,
1230 flag: bool,
1231 }
1232
1233 let result: LLMResponse<MyOutput> = client
1234 .complete::<LLMResponse<MyOutput>, MyOutput>(&messages, &[])
1235 .await
1236 .expect("Mock LLM failed");
1237
1238 assert_eq!(result.content.field1, 42.5);
1239 assert_eq!(result.content.flag, true);
1240 assert!(result.tool_calls.is_empty());
1241 assert!(result.is_complete);
1242 }
1243}