ceylon_llm/
client.rs

1//! LLM (Large Language Model) integration layer.
2//!
3//! This module provides a unified interface for interacting with various LLM providers
4//! including OpenAI, Anthropic, Ollama, Google, and many others.
5//!
6//! # Supported Providers
7//!
8//! Ceylon supports 13+ LLM providers through the [`UniversalLLMClient`]:
9//!
10//! - **OpenAI** - GPT-4, GPT-3.5-turbo, etc.
11//! - **Anthropic** - Claude 3 Opus, Sonnet, Haiku
12//! - **Ollama** - Local models (Llama, Mistral, Gemma, etc.)
13//! - **Google** - Gemini Pro
14//! - **DeepSeek** - DeepSeek Chat, DeepSeek Coder
15//! - **X.AI** - Grok
16//! - **Groq** - High-speed inference
17//! - **Azure OpenAI** - Enterprise OpenAI deployment
18//! - **Cohere** - Command models
19//! - **Mistral** - Mistral AI models
20//! - **Phind** - CodeLlama variants
21//! - **OpenRouter** - Multi-provider routing
22//! - **ElevenLabs** - Voice/audio generation
23//!
24//! # Configuration
25//!
26//! Use [`LLMConfig`] for comprehensive configuration:
27//!
28//! ```rust,no_run
29//! use runtime::llm::LLMConfig;
30//!
31//! let config = LLMConfig::new("openai::gpt-4")
32//!     .with_api_key("sk-...")
33//!     .with_temperature(0.7)
34//!     .with_max_tokens(2048)
35//!     .with_resilience(true, 3);
36//! ```
37//!
38//! # API Key Detection
39//!
40//! Ceylon automatically detects API keys from environment variables:
41//! - `OPENAI_API_KEY`
42//! - `ANTHROPIC_API_KEY`
43//! - `GOOGLE_API_KEY`
44//! - `DEEPSEEK_API_KEY`
45//! - `XAI_API_KEY`
46//! - `GROQ_API_KEY`
47//! - `MISTRAL_API_KEY`
48//! - `COHERE_API_KEY`
49//! - `PHIND_API_KEY`
50//! - `OPENROUTER_API_KEY`
51//! - And more...
52//!
53//! # Tool Calling
54//!
55//! Ceylon supports native tool calling for compatible models and falls back to
56//! text-based tool invocation for others.
57//!
58//! # Examples
59//!
60//! ## Basic Usage
61//!
62//! ```rust,no_run
63//! use runtime::llm::{UniversalLLMClient, LLMClient, LLMResponse};
64//! use runtime::llm::types::Message;
65//!
66//! # async fn example() -> Result<(), String> {
67//! let client = UniversalLLMClient::new("openai::gpt-4", None)?;
68//! let messages = vec![Message {
69//!     role: "user".to_string(),
70//!     content: "Hello!".to_string(),
71//! }];
72//!
73//! let response: LLMResponse<String> = client
74//!     .complete::<LLMResponse<String>, String>(&messages, &[])
75//!     .await?;
76//! # Ok(())
77//! # }
78//! ```
79//!
80//! ## Advanced Configuration
81//!
82//! ```rust,no_run
83//! use runtime::llm::{LLMConfig, UniversalLLMClient};
84//!
85//! # fn example() -> Result<(), String> {
86//! let llm_config = LLMConfig::new("anthropic::claude-3-opus-20240229")
87//!     .with_api_key(std::env::var("ANTHROPIC_API_KEY").unwrap())
88//!     .with_temperature(0.8)
89//!     .with_max_tokens(4096)
90//!     .with_reasoning(true);
91//!
92//! let client = UniversalLLMClient::new_with_config(llm_config)?;
93//! # Ok(())
94//! # }
95//! ```
96//!
97//! ## Local Models with Ollama
98//!
99//! ```rust,no_run
100//! use runtime::llm::UniversalLLMClient;
101//!
102//! # fn example() -> Result<(), String> {
103//! // No API key needed for local models
104//! let client = UniversalLLMClient::new("ollama::llama2", None)?;
105//! # Ok(())
106//! # }
107//! ```
108
109// Removed pub mod declarations - modules are now separate files managed by lib.rs
110
111pub use crate::llm_agent::{LlmAgent, LlmAgentBuilder};
112pub use crate::react::{FinishReason, ReActConfig, ReActEngine, ReActResult, ReActStep};
113
114use crate::types::{Message, ToolSpec};
115use async_trait::async_trait;
116use llm::builder::{LLMBackend, LLMBuilder};
117use llm::chat::ChatMessage;
118use llm::LLMProvider;
119use serde::{Deserialize, Serialize};
120
121// ============================================================================
122// LLM CLIENT - Talks to actual LLM API
123// ============================================================================
124
125/// Trait for LLM response types with tool calling support.
126///
127/// This trait defines the interface for LLM responses, supporting both
128/// content generation and tool calling capabilities.
129pub trait LLMResponseTrait<C: for<'de> Deserialize<'de> + Default + Send> {
130    /// Creates a new LLM response.
131    fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self;
132
133    /// Returns whether the response is complete.
134    fn is_complete(&self) -> bool;
135
136    /// Returns the tool calls requested by the LLM.
137    fn tool_calls(&self) -> Vec<ToolCall>;
138
139    /// Returns the content of the response.
140    fn content(&self) -> C;
141}
142
143/// Response from an LLM including generated content and tool calls.
144///
145/// This struct represents a complete response from an LLM, which may include
146/// generated text content and/or requests to call tools.
147///
148/// # Examples
149///
150/// ```rust,no_run
151/// use runtime::llm::{LLMResponse, ToolCall};
152///
153/// let response = LLMResponse {
154///     content: "Let me calculate that for you".to_string(),
155///     tool_calls: vec![
156///         ToolCall {
157///             name: "calculator".to_string(),
158///             input: serde_json::json!({"operation": "add", "a": 2, "b": 2}),
159///         }
160///     ],
161///     is_complete: false,
162/// };
163/// ```
164#[derive(Debug, Clone, Serialize, Default)]
165pub struct LLMResponse<C>
166where
167    C: for<'de> Deserialize<'de> + Default + Clone + Send,
168{
169    /// The generated content from the LLM
170    pub content: C,
171
172    /// Tool calls requested by the LLM (supports multiple calls)
173    pub tool_calls: Vec<ToolCall>,
174
175    /// Whether the response is complete (false if tool calls need to be executed)
176    pub is_complete: bool,
177}
178
179impl<C> LLMResponseTrait<C> for LLMResponse<C>
180where
181    C: for<'de> Deserialize<'de> + Default + Clone + Send,
182{
183    fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self {
184        Self {
185            content,
186            tool_calls,
187            is_complete,
188        }
189    }
190
191    fn is_complete(&self) -> bool {
192        self.is_complete
193    }
194
195    fn tool_calls(&self) -> Vec<ToolCall> {
196        self.tool_calls.clone()
197    }
198
199    fn content(&self) -> C {
200        self.content.clone()
201    }
202}
203
204/// A request from the LLM to call a tool.
205///
206/// When an LLM wants to use a tool to perform an action, it returns a `ToolCall`
207/// specifying which tool to invoke and with what parameters.
208///
209/// # Examples
210///
211/// ```rust
212/// use runtime::llm::ToolCall;
213/// use serde_json::json;
214///
215/// let tool_call = ToolCall {
216///     name: "search_database".to_string(),
217///     input: json!({
218///         "query": "users with age > 30",
219///         "limit": 10
220///     }),
221/// };
222/// ```
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ToolCall {
225    /// The name of the tool to call
226    pub name: String,
227
228    /// The input parameters for the tool as a JSON value
229    pub input: serde_json::Value,
230}
231
232/// Configuration for LLM providers with all builder options.
233///
234/// This struct provides comprehensive configuration for any LLM provider,
235/// matching all options available in the LLMBuilder.
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct LLMConfig {
238    // Basic configuration
239    pub model: String,
240    pub api_key: Option<String>,
241    pub base_url: Option<String>,
242
243    // Generation parameters
244    pub max_tokens: Option<u32>,
245    pub temperature: Option<f32>,
246    pub top_p: Option<f32>,
247    pub top_k: Option<u32>,
248    pub system: Option<String>,
249
250    // Timeout and retry
251    pub timeout_seconds: Option<u64>,
252
253    // Embeddings
254    pub embedding_encoding_format: Option<String>,
255    pub embedding_dimensions: Option<u32>,
256
257    // Tools and function calling
258    pub enable_parallel_tool_use: Option<bool>,
259
260    // Reasoning (for providers that support it)
261    pub reasoning: Option<bool>,
262    pub reasoning_effort: Option<String>,
263    pub reasoning_budget_tokens: Option<u32>,
264
265    // Provider-specific: Azure
266    pub api_version: Option<String>,
267    pub deployment_id: Option<String>,
268
269    // Provider-specific: Voice/Audio
270    pub voice: Option<String>,
271
272    // Provider-specific: XAI search
273    pub xai_search_mode: Option<String>,
274    pub xai_search_source_type: Option<String>,
275    pub xai_search_excluded_websites: Option<Vec<String>>,
276    pub xai_search_max_results: Option<u32>,
277    pub xai_search_from_date: Option<String>,
278    pub xai_search_to_date: Option<String>,
279
280    // Provider-specific: OpenAI web search
281    pub openai_enable_web_search: Option<bool>,
282    pub openai_web_search_context_size: Option<String>,
283    pub openai_web_search_user_location_type: Option<String>,
284    pub openai_web_search_user_location_approximate_country: Option<String>,
285    pub openai_web_search_user_location_approximate_city: Option<String>,
286    pub openai_web_search_user_location_approximate_region: Option<String>,
287
288    // Resilience
289    pub resilient_enable: Option<bool>,
290    pub resilient_attempts: Option<usize>,
291    pub resilient_base_delay_ms: Option<u64>,
292    pub resilient_max_delay_ms: Option<u64>,
293    pub resilient_jitter: Option<bool>,
294}
295
296impl Default for LLMConfig {
297    fn default() -> Self {
298        Self {
299            model: String::new(),
300            api_key: None,
301            base_url: None,
302            max_tokens: Some(4096),
303            temperature: None,
304            top_p: None,
305            top_k: None,
306            system: None,
307            timeout_seconds: None,
308            embedding_encoding_format: None,
309            embedding_dimensions: None,
310            enable_parallel_tool_use: None,
311            reasoning: None,
312            reasoning_effort: None,
313            reasoning_budget_tokens: None,
314            api_version: None,
315            deployment_id: None,
316            voice: None,
317            xai_search_mode: None,
318            xai_search_source_type: None,
319            xai_search_excluded_websites: None,
320            xai_search_max_results: None,
321            xai_search_from_date: None,
322            xai_search_to_date: None,
323            openai_enable_web_search: None,
324            openai_web_search_context_size: None,
325            openai_web_search_user_location_type: None,
326            openai_web_search_user_location_approximate_country: None,
327            openai_web_search_user_location_approximate_city: None,
328            openai_web_search_user_location_approximate_region: None,
329            resilient_enable: None,
330            resilient_attempts: None,
331            resilient_base_delay_ms: None,
332            resilient_max_delay_ms: None,
333            resilient_jitter: None,
334        }
335    }
336}
337
338impl LLMConfig {
339    /// Create a new LLMConfig with just the model name
340    pub fn new(model: impl Into<String>) -> Self {
341        Self {
342            model: model.into(),
343            ..Default::default()
344        }
345    }
346
347    /// Set API key
348    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
349        self.api_key = Some(api_key.into());
350        self
351    }
352
353    /// Set base URL
354    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
355        self.base_url = Some(base_url.into());
356        self
357    }
358
359    /// Set max tokens
360    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
361        self.max_tokens = Some(max_tokens);
362        self
363    }
364
365    /// Set temperature
366    pub fn with_temperature(mut self, temperature: f32) -> Self {
367        self.temperature = Some(temperature);
368        self
369    }
370
371    /// Set top_p
372    pub fn with_top_p(mut self, top_p: f32) -> Self {
373        self.top_p = Some(top_p);
374        self
375    }
376
377    /// Set top_k
378    pub fn with_top_k(mut self, top_k: u32) -> Self {
379        self.top_k = Some(top_k);
380        self
381    }
382
383    /// Set system prompt
384    pub fn with_system(mut self, system: impl Into<String>) -> Self {
385        self.system = Some(system.into());
386        self
387    }
388
389    /// Set timeout in seconds
390    pub fn with_timeout_seconds(mut self, timeout: u64) -> Self {
391        self.timeout_seconds = Some(timeout);
392        self
393    }
394
395    /// Enable reasoning (for supported providers)
396    pub fn with_reasoning(mut self, enabled: bool) -> Self {
397        self.reasoning = Some(enabled);
398        self
399    }
400
401    /// Set reasoning effort
402    pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
403        self.reasoning_effort = Some(effort.into());
404        self
405    }
406
407    /// Set Azure deployment ID
408    pub fn with_deployment_id(mut self, deployment_id: impl Into<String>) -> Self {
409        self.deployment_id = Some(deployment_id.into());
410        self
411    }
412
413    /// Set Azure API version
414    pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
415        self.api_version = Some(api_version.into());
416        self
417    }
418
419    /// Enable OpenAI web search
420    pub fn with_openai_web_search(mut self, enabled: bool) -> Self {
421        self.openai_enable_web_search = Some(enabled);
422        self
423    }
424
425    /// Enable resilience with retry/backoff
426    pub fn with_resilience(mut self, enabled: bool, attempts: usize) -> Self {
427        self.resilient_enable = Some(enabled);
428        self.resilient_attempts = Some(attempts);
429        self
430    }
431}
432
433/// Legacy config for backward compatibility
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct LLMProviderConfig {
436    pub model: String,
437    pub max_tokens: u64,
438    pub api_key: Option<String>,
439    pub base_url: String,
440}
441
442impl From<LLMConfig> for LLMProviderConfig {
443    fn from(config: LLMConfig) -> Self {
444        Self {
445            model: config.model,
446            max_tokens: config.max_tokens.unwrap_or(4096) as u64,
447            api_key: config.api_key,
448            base_url: config.base_url.unwrap_or_default(),
449        }
450    }
451}
452
453/// Get the environment variable name for a provider's API key
454fn get_api_key_env_var(provider: &str) -> Option<&'static str> {
455    match provider.to_lowercase().as_str() {
456        "ollama" => None, // Ollama doesn't require API key
457        "anthropic" | "claude" => Some("ANTHROPIC_API_KEY"),
458        "openai" | "gpt" => Some("OPENAI_API_KEY"),
459        "deepseek" => Some("DEEPSEEK_API_KEY"),
460        "xai" | "x.ai" => Some("XAI_API_KEY"),
461        "phind" => Some("PHIND_API_KEY"),
462        "google" | "gemini" => Some("GOOGLE_API_KEY"),
463        "groq" => Some("GROQ_API_KEY"),
464        "azure" | "azureopenai" | "azure-openai" => Some("AZURE_OPENAI_API_KEY"),
465        "elevenlabs" | "11labs" => Some("ELEVENLABS_API_KEY"),
466        "cohere" => Some("COHERE_API_KEY"),
467        "mistral" => Some("MISTRAL_API_KEY"),
468        "openrouter" => Some("OPENROUTER_API_KEY"),
469        _ => None,
470    }
471}
472
473/// Attempt to get API key from environment variable for the provider
474/// Returns Ok(Some(key)) if found, Ok(None) if provider doesn't need key, Err if required but not found
475fn get_api_key_from_env(provider: &str) -> Result<Option<String>, String> {
476    match get_api_key_env_var(provider) {
477        None => Ok(None), // Provider doesn't need API key
478        Some(env_var) => {
479            match std::env::var(env_var) {
480                Ok(key) => Ok(Some(key)),
481                Err(_) => Err(format!(
482                    "API key required for provider '{}'. Please set the {} environment variable or pass the API key explicitly.",
483                    provider, env_var
484                ))
485            }
486        }
487    }
488}
489
490/// Helper function to parse provider string to LLMBackend
491fn parse_provider(provider: &str) -> Result<LLMBackend, String> {
492    match provider.to_lowercase().as_str() {
493        "ollama" => Ok(LLMBackend::Ollama),
494        "anthropic" | "claude" => Ok(LLMBackend::Anthropic),
495        "openai" | "gpt" => Ok(LLMBackend::OpenAI),
496        "deepseek" => Ok(LLMBackend::DeepSeek),
497        "xai" | "x.ai" => Ok(LLMBackend::XAI),
498        "phind" => Ok(LLMBackend::Phind),
499        "google" | "gemini" => Ok(LLMBackend::Google),
500        "groq" => Ok(LLMBackend::Groq),
501        "azure" | "azureopenai" | "azure-openai" => Ok(LLMBackend::AzureOpenAI),
502        "elevenlabs" | "11labs" => Ok(LLMBackend::ElevenLabs),
503        "cohere" => Ok(LLMBackend::Cohere),
504        "mistral" => Ok(LLMBackend::Mistral),
505        "openrouter" => Ok(LLMBackend::OpenRouter),
506        _ => Err(format!("Unknown provider: {}", provider)),
507    }
508}
509
510/// Helper function to build LLM from config
511fn build_llm_from_config(
512    config: &LLMConfig,
513    backend: LLMBackend,
514) -> Result<Box<dyn LLMProvider>, String> {
515    let mut builder = LLMBuilder::new().backend(backend.clone());
516
517    // Parse model name from "provider::model" format
518    let model_name = if config.model.contains("::") {
519        config.model.split("::").nth(1).unwrap_or(&config.model)
520    } else {
521        &config.model
522    };
523
524    builder = builder.model(model_name);
525
526    // Apply all configuration options
527    if let Some(max_tokens) = config.max_tokens {
528        builder = builder.max_tokens(max_tokens);
529    }
530
531    if let Some(ref api_key) = config.api_key {
532        builder = builder.api_key(api_key);
533    }
534
535    if let Some(ref base_url) = config.base_url {
536        if !base_url.is_empty() {
537            builder = builder.base_url(base_url);
538        }
539    }
540
541    if let Some(temperature) = config.temperature {
542        builder = builder.temperature(temperature);
543    }
544
545    if let Some(top_p) = config.top_p {
546        builder = builder.top_p(top_p);
547    }
548
549    if let Some(top_k) = config.top_k {
550        builder = builder.top_k(top_k);
551    }
552
553    if let Some(ref system) = config.system {
554        builder = builder.system(system);
555    }
556
557    if let Some(timeout) = config.timeout_seconds {
558        builder = builder.timeout_seconds(timeout);
559    }
560
561    if let Some(ref format) = config.embedding_encoding_format {
562        builder = builder.embedding_encoding_format(format);
563    }
564
565    if let Some(dims) = config.embedding_dimensions {
566        builder = builder.embedding_dimensions(dims);
567    }
568
569    if let Some(enabled) = config.enable_parallel_tool_use {
570        builder = builder.enable_parallel_tool_use(enabled);
571    }
572
573    if let Some(enabled) = config.reasoning {
574        builder = builder.reasoning(enabled);
575    }
576
577    if let Some(budget) = config.reasoning_budget_tokens {
578        builder = builder.reasoning_budget_tokens(budget);
579    }
580
581    // Azure-specific
582    if let Some(ref api_version) = config.api_version {
583        builder = builder.api_version(api_version);
584    }
585
586    if let Some(ref deployment_id) = config.deployment_id {
587        builder = builder.deployment_id(deployment_id);
588    }
589
590    // Voice
591    if let Some(ref voice) = config.voice {
592        builder = builder.voice(voice);
593    }
594
595    // XAI search parameters
596    if let Some(ref mode) = config.xai_search_mode {
597        builder = builder.xai_search_mode(mode);
598    }
599
600    // XAI search source uses a combined method
601    if let (Some(source_type), excluded) = (
602        &config.xai_search_source_type,
603        &config.xai_search_excluded_websites,
604    ) {
605        builder = builder.xai_search_source(source_type, excluded.clone());
606    }
607
608    if let Some(ref from_date) = config.xai_search_from_date {
609        builder = builder.xai_search_from_date(from_date);
610    }
611
612    if let Some(ref to_date) = config.xai_search_to_date {
613        builder = builder.xai_search_to_date(to_date);
614    }
615
616    // OpenAI web search
617    if let Some(enabled) = config.openai_enable_web_search {
618        builder = builder.openai_enable_web_search(enabled);
619    }
620
621    if let Some(ref context_size) = config.openai_web_search_context_size {
622        builder = builder.openai_web_search_context_size(context_size);
623    }
624
625    if let Some(ref loc_type) = config.openai_web_search_user_location_type {
626        builder = builder.openai_web_search_user_location_type(loc_type);
627    }
628
629    if let Some(ref country) = config.openai_web_search_user_location_approximate_country {
630        builder = builder.openai_web_search_user_location_approximate_country(country);
631    }
632
633    if let Some(ref city) = config.openai_web_search_user_location_approximate_city {
634        builder = builder.openai_web_search_user_location_approximate_city(city);
635    }
636
637    if let Some(ref region) = config.openai_web_search_user_location_approximate_region {
638        builder = builder.openai_web_search_user_location_approximate_region(region);
639    }
640
641    // Resilience
642    if let Some(enabled) = config.resilient_enable {
643        builder = builder.resilient(enabled);
644    }
645
646    if let Some(attempts) = config.resilient_attempts {
647        builder = builder.resilient_attempts(attempts);
648    }
649
650    builder
651        .build()
652        .map_err(|e| format!("Failed to build LLM: {}", e))
653}
654
655/// Estimate token count based on text length (~4 chars per token)
656fn estimate_tokens(text: &str) -> u64 {
657    ((text.len() as f64) / 4.0).ceil() as u64
658}
659
660/// Get model pricing in micro-dollars per 1K tokens (input, output)
661fn get_model_pricing(model: &str) -> (u64, u64) {
662    match model.to_lowercase().as_str() {
663        m if m.contains("gpt-4o") => (2_500, 10_000),
664        m if m.contains("gpt-4-turbo") => (10_000, 30_000),
665        m if m.contains("gpt-4") => (30_000, 60_000),
666        m if m.contains("gpt-3.5-turbo") => (500, 1_500),
667        m if m.contains("claude-3-opus") => (15_000, 75_000),
668        m if m.contains("claude-3-5-sonnet") => (3_000, 15_000),
669        m if m.contains("claude-3-sonnet") => (3_000, 15_000),
670        m if m.contains("claude-3-haiku") => (250, 1_250),
671        _ => (500, 1_500),
672    }
673}
674
675/// Calculate cost in micro-dollars
676fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> u64 {
677    let (input_price, output_price) = get_model_pricing(model);
678    (input_tokens * input_price + output_tokens * output_price) / 1000
679}
680
681pub struct UniversalLLMClient {
682    config: LLMProviderConfig,
683    llm: Box<dyn LLMProvider>,
684}
685
686#[async_trait]
687pub trait LLMClient: Send + Sync {
688    /// Send messages to LLM and get a response
689    async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
690    where
691        T: LLMResponseTrait<C> + Default + Send,
692        C: for<'de> Deserialize<'de> + Default + Send + Serialize;
693}
694
695impl Clone for UniversalLLMClient {
696    fn clone(&self) -> Self {
697        let config = self.config.clone();
698        let model = config.clone().model;
699        let api_key = config.api_key.clone();
700        let base_url = config.base_url.clone();
701        let parts: Vec<&str> = model.split("::").collect();
702
703        let provider = parts[0];
704        let model = parts[1];
705
706        let backend = parse_provider(provider).unwrap_or(LLMBackend::Ollama);
707
708        let mut builder = LLMBuilder::new()
709            .backend(backend.clone())
710            .model(model)
711            .max_tokens(4096);
712
713        if let Some(api_key) = api_key {
714            builder = builder.api_key(api_key);
715        }
716
717        if !base_url.is_empty() {
718            builder = builder.base_url(base_url);
719        }
720
721        let llm = builder
722            .build()
723            .map_err(|e| format!("Failed to build LLM: {}", e))
724            .unwrap();
725
726        Self {
727            llm,
728            config: config.clone(),
729        }
730    }
731}
732
733impl UniversalLLMClient {
734    const DEFAULT_SYSTEM_PROMPT: &'static str = "You are a helpful AI assistant.";
735
736    // Updated: Now supports multiple tool calls
737    const DEFAULT_TOOL_PROMPT: &'static str = "You have access to the following tools.\n\
738         To call ONE tool, respond EXACTLY in this format:\n\
739         USE_TOOL: tool_name\n\
740         {\"param1\": \"value1\"}\n\n\
741         To call MULTIPLE tools at once, respond in this format:\n\
742         USE_TOOLS:\n\
743         tool_name1\n\
744         {\"param1\": \"value1\"}\n\
745         ---\n\
746         tool_name2\n\
747         {\"param1\": \"value1\"}\n\n\
748         Only call tools using these exact formats. Otherwise, respond normally.";
749
750    fn generate_schema_instruction<C>(sample: &C) -> String
751    where
752        C: Serialize,
753    {
754        let sample_json = serde_json::to_string_pretty(sample).unwrap_or_else(|_| "{}".to_string());
755
756        format!(
757            "Respond with ONLY a JSON object in this exact format:\n{}\n\nProvide your response as valid JSON.",
758            sample_json
759        )
760    }
761
762    pub fn new(provider_model: &str, api_key: Option<String>) -> Result<Self, String> {
763        let parts: Vec<&str> = provider_model.split("::").collect();
764
765        if parts.len() != 2 {
766            return Err(format!(
767                "Invalid format. Use 'provider::model-name'. Got: {}",
768                provider_model
769            ));
770        }
771
772        let provider = parts[0];
773        let model = parts[1];
774
775        // Determine the API key to use: provided > environment variable > error if required
776        let final_api_key = match api_key {
777            Some(key) => Some(key),
778            None => {
779                // Try to get from environment variable
780                match get_api_key_from_env(provider) {
781                    Ok(env_key) => env_key,
782                    Err(e) => return Err(e), // Required but not found
783                }
784            }
785        };
786
787        let config = LLMProviderConfig {
788            model: provider_model.to_string(),
789            max_tokens: 4096,
790            api_key: final_api_key.clone(),
791            base_url: String::new(),
792        };
793
794        let backend = parse_provider(provider)?;
795
796        let base_url = match provider.to_lowercase().as_str() {
797            "ollama" => std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
798            _ => String::new(),
799        };
800
801        let mut builder = LLMBuilder::new()
802            .backend(backend.clone())
803            .model(model)
804            .max_tokens(4096);
805
806        if let Some(api_key) = final_api_key {
807            builder = builder.api_key(api_key);
808        }
809
810        if !base_url.is_empty() {
811            builder = builder.base_url(base_url);
812        }
813
814        let llm = builder
815            .build()
816            .map_err(|e| format!("Failed to build LLM: {}", e))?;
817
818        Ok(Self { llm, config })
819    }
820
821    /// Create a new UniversalLLMClient with comprehensive LLMConfig
822    ///
823    /// # Examples
824    ///
825    /// ```rust,no_run
826    /// use runtime::llm::LLMConfig;
827    /// use runtime::llm::UniversalLLMClient;
828    ///
829    /// let config = LLMConfig::new("openai::gpt-4")
830    ///     .with_api_key("your-api-key")
831    ///     .with_temperature(0.7)
832    ///     .with_max_tokens(2048);
833    ///
834    /// let client = UniversalLLMClient::new_with_config(config).unwrap();
835    /// ```
836    pub fn new_with_config(llm_config: LLMConfig) -> Result<Self, String> {
837        // Parse provider from model string
838        let parts: Vec<&str> = llm_config.model.split("::").collect();
839
840        if parts.len() != 2 {
841            return Err(format!(
842                "Invalid format. Use 'provider::model-name'. Got: {}",
843                llm_config.model
844            ));
845        }
846
847        let provider = parts[0];
848        let backend = parse_provider(provider)?;
849
850        // Set default base_url for certain providers if not specified
851        let mut config = llm_config.clone();
852        if config.base_url.is_none() {
853            match provider.to_lowercase().as_str() {
854                "ollama" => {
855                    config.base_url = Some(
856                        std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
857                    );
858                }
859                _ => {}
860            }
861        }
862
863        // Check for API key: provided > environment variable > error if required
864        if config.api_key.is_none() {
865            match get_api_key_from_env(provider) {
866                Ok(env_key) => config.api_key = env_key,
867                Err(e) => return Err(e), // Required but not found
868            }
869        }
870
871        // Build LLM using the comprehensive config
872        let llm = build_llm_from_config(&config, backend)?;
873
874        // Convert to legacy config for internal storage
875        let legacy_config = LLMProviderConfig::from(config);
876
877        Ok(Self {
878            llm,
879            config: legacy_config,
880        })
881    }
882
883    fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
884        messages
885            .iter()
886            .map(|msg| match msg.role.as_str() {
887                "user" => ChatMessage::user().content(&msg.content).build(),
888                "assistant" => ChatMessage::assistant().content(&msg.content).build(),
889                "system" => ChatMessage::assistant().content(&msg.content).build(),
890                "tool" => ChatMessage::assistant()
891                    .content(format!("Tool result: {}", msg.content))
892                    .build(),
893                _ => ChatMessage::user().content(&msg.content).build(),
894            })
895            .collect()
896    }
897
898    fn build_tool_description(tools: &[ToolSpec]) -> String {
899        tools
900            .iter()
901            .map(|t| {
902                let params = t
903                    .input_schema
904                    .get("properties")
905                    .and_then(|p| p.as_object())
906                    .map(|o| o.keys().cloned().collect::<Vec<_>>().join(", "))
907                    .unwrap_or_default();
908
909                if t.description.is_empty() {
910                    format!("- {}({})", t.name, params)
911                } else {
912                    format!("- {}({}): {}", t.name, params, t.description)
913                }
914            })
915            .collect::<Vec<_>>()
916            .join("\n")
917    }
918
919    // New helper: Parse multiple tool calls from response
920    fn parse_tool_calls(response_text: &str) -> Vec<ToolCall> {
921        let mut tool_calls = Vec::new();
922
923        // Check for multiple tools format
924        if response_text.starts_with("USE_TOOLS:") {
925            // Split by "---" to get individual tool calls
926            let parts: Vec<&str> = response_text
927                .strip_prefix("USE_TOOLS:")
928                .unwrap_or("")
929                .split("---")
930                .collect();
931
932            for part in parts {
933                let lines: Vec<&str> = part.trim().lines().collect();
934                if lines.is_empty() {
935                    continue;
936                }
937
938                let tool_name = lines[0].trim().to_string();
939                let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
940
941                if let Ok(input_value) = serde_json::from_str(&json_block) {
942                    tool_calls.push(ToolCall {
943                        name: tool_name,
944                        input: input_value,
945                    });
946                }
947            }
948        }
949        // Check for single tool format
950        else if response_text.starts_with("USE_TOOL:") {
951            let lines: Vec<&str> = response_text.lines().collect();
952            let tool_name = lines[0]
953                .strip_prefix("USE_TOOL:")
954                .unwrap_or("")
955                .trim()
956                .to_string();
957
958            let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
959
960            if let Ok(input_value) = serde_json::from_str(&json_block) {
961                tool_calls.push(ToolCall {
962                    name: tool_name,
963                    input: input_value,
964                });
965            }
966        }
967
968        tool_calls
969    }
970}
971
972#[async_trait]
973impl LLMClient for UniversalLLMClient {
974    async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
975    where
976        T: LLMResponseTrait<C> + Default + Send,
977        C: for<'de> Deserialize<'de> + Default + Send + Serialize,
978    {
979        let mut chat_messages = vec![];
980
981        // 1) Add system prompt if not provided by user
982        let has_user_system_prompt = messages.iter().any(|m| m.role == "system");
983        if !has_user_system_prompt {
984            chat_messages.push(
985                ChatMessage::assistant()
986                    .content(Self::DEFAULT_SYSTEM_PROMPT)
987                    .build(),
988            );
989        }
990
991        // 2) Add tool prompt if tools are provided
992        let user_tool_prompt = messages
993            .iter()
994            .find(|m| m.role == "system_tools")
995            .map(|m| m.content.clone());
996
997        if !tools.is_empty() {
998            let tool_list = Self::build_tool_description(tools);
999            let tool_prompt = user_tool_prompt.unwrap_or_else(|| {
1000                format!(
1001                    "{}\n\nAvailable Tools:\n{}\n\n{}",
1002                    Self::DEFAULT_TOOL_PROMPT,
1003                    tool_list,
1004                    "Use only the EXACT formats shown above when calling tools."
1005                )
1006            });
1007            chat_messages.push(ChatMessage::assistant().content(tool_prompt).build());
1008        }
1009
1010        // 3) AUTO-GENERATE SCHEMA INSTRUCTION
1011        let sample_c = C::default();
1012        let schema_instruction = Self::generate_schema_instruction(&sample_c);
1013
1014        chat_messages.push(ChatMessage::assistant().content(schema_instruction).build());
1015
1016        // 4) Add user messages
1017        chat_messages.extend(self.convert_messages(messages));
1018
1019        // Helper: try to parse into C
1020        let try_parse_c = |s: &str| -> C {
1021            let text = s.trim();
1022
1023            // Try direct JSON parse
1024            if let Ok(parsed) = serde_json::from_str::<C>(text) {
1025                return parsed;
1026            }
1027
1028            // Remove markdown code blocks
1029            let cleaned = text
1030                .strip_prefix("```json")
1031                .unwrap_or(text)
1032                .strip_prefix("```")
1033                .unwrap_or(text)
1034                .strip_suffix("```")
1035                .unwrap_or(text)
1036                .trim();
1037
1038            if let Ok(parsed) = serde_json::from_str::<C>(cleaned) {
1039                return parsed;
1040            }
1041
1042            // Find JSON object in text
1043            if let Some(start) = text.find('{') {
1044                if let Some(end) = text.rfind('}') {
1045                    let json_part = &text[start..=end];
1046                    if let Ok(parsed) = serde_json::from_str::<C>(json_part) {
1047                        return parsed;
1048                    }
1049                }
1050            }
1051
1052            // Try quoted string
1053            if let Ok(quoted) = serde_json::to_string(text) {
1054                if let Ok(parsed) = serde_json::from_str::<C>(&quoted) {
1055                    return parsed;
1056                }
1057            }
1058
1059            // Fallback to default
1060            C::default()
1061        };
1062
1063        // 5) Send to LLM
1064        let start = std::time::Instant::now();
1065        let response = self
1066            .llm
1067            .chat(&chat_messages)
1068            .await
1069            .map_err(|e| format!("LLM error: {}", e))?;
1070        let duration = start.elapsed().as_micros() as u64;
1071
1072        let response_text = response.text().unwrap_or_default();
1073
1074        // Estimate tokens and cost
1075        let input_text: String = messages
1076            .iter()
1077            .map(|m| m.content.as_str())
1078            .collect::<Vec<_>>()
1079            .join(" ");
1080        let input_tokens = estimate_tokens(&input_text);
1081        let output_tokens = estimate_tokens(&response_text);
1082        let total_tokens = input_tokens + output_tokens;
1083        let cost_us = calculate_cost(&self.config.model, input_tokens, output_tokens);
1084
1085        // Record metrics with estimates
1086        tracing::trace!(
1087            duration_us = duration,
1088            tokens = total_tokens,
1089            cost_us = cost_us,
1090            "llm call completed"
1091        );
1092
1093        // 6) Parse tool calls (handles both single and multiple)
1094        let tool_calls = Self::parse_tool_calls(&response_text);
1095
1096        // If we have tool calls, return them
1097        if !tool_calls.is_empty() {
1098            let parsed_content: C = C::default();
1099            return Ok(T::new(parsed_content, tool_calls, false));
1100        }
1101
1102        // 7) Normal response - parse into C
1103        let parsed_content: C = try_parse_c(&response_text);
1104        Ok(T::new(parsed_content, vec![], true))
1105    }
1106}
1107
1108// ============================================================================
1109// MOCK LLM CLIENT FOR TESTING
1110// ============================================================================
1111
1112/// Mock LLM client for testing - doesn't make real API calls.
1113///
1114/// This is useful for writing fast unit tests without requiring actual LLM API access.
1115///
1116/// # Examples
1117///
1118/// ```rust
1119/// use runtime::llm::{MockLLMClient, LLMClient, LLMResponse};
1120/// use runtime::llm::types::Message;
1121///
1122/// # #[tokio::main]
1123/// # async fn main() {
1124/// let client = MockLLMClient::new("Hello from mock!");
1125/// let messages = vec![Message {
1126///     role: "user".to_string(),
1127///     content: "Say hello".to_string(),
1128/// }];
1129///
1130/// let result: LLMResponse<String> = client
1131///     .complete::<LLMResponse<String>, String>(&messages, &[])
1132///     .await
1133///     .expect("Mock LLM failed");
1134/// # }
1135/// ```
1136pub struct MockLLMClient {
1137    response_content: String,
1138}
1139
1140impl MockLLMClient {
1141    /// Creates a new mock LLM client that returns the specified response.
1142    ///
1143    /// # Arguments
1144    ///
1145    /// * `response` - A string that will be parsed as the response content
1146    ///
1147    /// # Examples
1148    ///
1149    /// ```rust
1150    /// use runtime::llm::MockLLMClient;
1151    ///
1152    /// // For structured responses, provide JSON
1153    /// let client = MockLLMClient::new(r#"{"field": "value"}"#);
1154    ///
1155    /// // For simple text responses
1156    /// let client = MockLLMClient::new("Hello, world!");
1157    /// ```
1158    pub fn new(response: &str) -> Self {
1159        Self {
1160            response_content: response.to_string(),
1161        }
1162    }
1163
1164    /// Creates a mock LLM client with a default "Hello" response.
1165    ///
1166    /// # Examples
1167    ///
1168    /// ```rust
1169    /// use runtime::llm::MockLLMClient;
1170    ///
1171    /// let client = MockLLMClient::default_hello();
1172    /// ```
1173    pub fn default_hello() -> Self {
1174        Self::new("Hello! How can I help you today?")
1175    }
1176}
1177
1178#[async_trait::async_trait]
1179impl LLMClient for MockLLMClient {
1180    async fn complete<T, C>(&self, _messages: &[Message], _tools: &[ToolSpec]) -> Result<T, String>
1181    where
1182        T: LLMResponseTrait<C> + Default + Send,
1183        C: for<'de> Deserialize<'de> + Default + Send + Serialize,
1184    {
1185        // Try to parse the response into the expected content type
1186        let content: C = if let Ok(parsed) = serde_json::from_str(&self.response_content) {
1187            parsed
1188        } else {
1189            C::default()
1190        };
1191
1192        Ok(T::new(content, vec![], true))
1193    }
1194}
1195
1196// ============================================================================
1197// TESTS
1198// ============================================================================
1199
1200#[cfg(test)]
1201mod tests {
1202    use super::*;
1203
1204    #[tokio::test]
1205    async fn test_mock_llm_basic() {
1206        let client = MockLLMClient::default_hello();
1207        let messages = vec![Message {
1208            role: "user".into(),
1209            content: "Say hello".into(),
1210        }];
1211
1212        let result: LLMResponse<String> = client
1213            .complete::<LLMResponse<String>, String>(&messages, &[])
1214            .await
1215            .expect("Mock LLM failed");
1216
1217        assert!(result.content.is_empty()); // Default String is empty
1218        assert!(result.tool_calls.is_empty());
1219        assert!(result.is_complete);
1220    }
1221
1222    #[tokio::test]
1223    async fn test_mock_structured_output() {
1224        let client = MockLLMClient::new(r#"{"field1": 42.5, "flag": true}"#);
1225        let messages = vec![Message {
1226            role: "user".into(),
1227            content: "Return structured data".into(),
1228        }];
1229
1230        #[derive(Deserialize, Serialize, Default, Clone, Debug)]
1231        struct MyOutput {
1232            field1: f64,
1233            flag: bool,
1234        }
1235
1236        let result: LLMResponse<MyOutput> = client
1237            .complete::<LLMResponse<MyOutput>, MyOutput>(&messages, &[])
1238            .await
1239            .expect("Mock LLM failed");
1240
1241        assert_eq!(result.content.field1, 42.5);
1242        assert_eq!(result.content.flag, true);
1243        assert!(result.tool_calls.is_empty());
1244        assert!(result.is_complete);
1245    }
1246}