ceylon_runtime/llm/
mod.rs

1//! LLM (Large Language Model) integration layer.
2//!
3//! This module provides a unified interface for interacting with various LLM providers
4//! including OpenAI, Anthropic, Ollama, Google, and many others.
5//!
6//! # Supported Providers
7//!
8//! Ceylon supports 13+ LLM providers through the [`UniversalLLMClient`]:
9//!
10//! - **OpenAI** - GPT-4, GPT-3.5-turbo, etc.
11//! - **Anthropic** - Claude 3 Opus, Sonnet, Haiku
12//! - **Ollama** - Local models (Llama, Mistral, Gemma, etc.)
13//! - **Google** - Gemini Pro
14//! - **DeepSeek** - DeepSeek Chat, DeepSeek Coder
15//! - **X.AI** - Grok
16//! - **Groq** - High-speed inference
17//! - **Azure OpenAI** - Enterprise OpenAI deployment
18//! - **Cohere** - Command models
19//! - **Mistral** - Mistral AI models
20//! - **Phind** - CodeLlama variants
21//! - **OpenRouter** - Multi-provider routing
22//! - **ElevenLabs** - Voice/audio generation
23//!
24//! # Configuration
25//!
26//! Use [`LLMConfig`] for comprehensive configuration:
27//!
28//! ```rust,no_run
29//! use runtime::llm::LLMConfig;
30//!
31//! let config = LLMConfig::new("openai::gpt-4")
32//!     .with_api_key("sk-...")
33//!     .with_temperature(0.7)
34//!     .with_max_tokens(2048)
35//!     .with_resilience(true, 3);
36//! ```
37//!
38//! # API Key Detection
39//!
40//! Ceylon automatically detects API keys from environment variables:
41//! - `OPENAI_API_KEY`
42//! - `ANTHROPIC_API_KEY`
43//! - `GOOGLE_API_KEY`
44//! - `DEEPSEEK_API_KEY`
45//! - `XAI_API_KEY`
46//! - `GROQ_API_KEY`
47//! - `MISTRAL_API_KEY`
48//! - `COHERE_API_KEY`
49//! - `PHIND_API_KEY`
50//! - `OPENROUTER_API_KEY`
51//! - And more...
52//!
53//! # Tool Calling
54//!
55//! Ceylon supports native tool calling for compatible models and falls back to
56//! text-based tool invocation for others.
57//!
58//! # Examples
59//!
60//! ## Basic Usage
61//!
62//! ```rust,no_run
63//! use runtime::llm::{UniversalLLMClient, LLMClient, LLMResponse};
64//! use runtime::llm::types::Message;
65//!
66//! # async fn example() -> Result<(), String> {
67//! let client = UniversalLLMClient::new("openai::gpt-4", None)?;
68//! let messages = vec![Message {
69//!     role: "user".to_string(),
70//!     content: "Hello!".to_string(),
71//! }];
72//!
73//! let response: LLMResponse<String> = client
74//!     .complete::<LLMResponse<String>, String>(&messages, &[])
75//!     .await?;
76//! # Ok(())
77//! # }
78//! ```
79//!
80//! ## Advanced Configuration
81//!
82//! ```rust,no_run
83//! use runtime::llm::{LLMConfig, UniversalLLMClient};
84//!
85//! # fn example() -> Result<(), String> {
86//! let llm_config = LLMConfig::new("anthropic::claude-3-opus-20240229")
87//!     .with_api_key(std::env::var("ANTHROPIC_API_KEY").unwrap())
88//!     .with_temperature(0.8)
89//!     .with_max_tokens(4096)
90//!     .with_reasoning(true);
91//!
92//! let client = UniversalLLMClient::new_with_config(llm_config)?;
93//! # Ok(())
94//! # }
95//! ```
96//!
97//! ## Local Models with Ollama
98//!
99//! ```rust,no_run
100//! use runtime::llm::UniversalLLMClient;
101//!
102//! # fn example() -> Result<(), String> {
103//! // No API key needed for local models
104//! let client = UniversalLLMClient::new("ollama::llama2", None)?;
105//! # Ok(())
106//! # }
107//! ```
108
109pub mod llm_agent;
110pub mod react;
111pub mod types;
112
113pub use llm_agent::{LlmAgent, LlmAgentBuilder};
114pub use react::{FinishReason, ReActConfig, ReActEngine, ReActResult, ReActStep};
115
116use async_trait::async_trait;
117use llm::builder::{LLMBackend, LLMBuilder};
118use llm::chat::ChatMessage;
119use llm::LLMProvider;
120use serde::{Deserialize, Serialize};
121use types::{Message, ToolSpec};
122
123// ============================================================================
124// LLM CLIENT - Talks to actual LLM API
125// ============================================================================
126
127/// Trait for LLM response types with tool calling support.
128///
129/// This trait defines the interface for LLM responses, supporting both
130/// content generation and tool calling capabilities.
131pub trait LLMResponseTrait<C: for<'de> Deserialize<'de> + Default + Send> {
132    /// Creates a new LLM response.
133    fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self;
134
135    /// Returns whether the response is complete.
136    fn is_complete(&self) -> bool;
137
138    /// Returns the tool calls requested by the LLM.
139    fn tool_calls(&self) -> Vec<ToolCall>;
140
141    /// Returns the content of the response.
142    fn content(&self) -> C;
143}
144
145/// Response from an LLM including generated content and tool calls.
146///
147/// This struct represents a complete response from an LLM, which may include
148/// generated text content and/or requests to call tools.
149///
150/// # Examples
151///
152/// ```rust,no_run
153/// use runtime::llm::{LLMResponse, ToolCall};
154///
155/// let response = LLMResponse {
156///     content: "Let me calculate that for you".to_string(),
157///     tool_calls: vec![
158///         ToolCall {
159///             name: "calculator".to_string(),
160///             input: serde_json::json!({"operation": "add", "a": 2, "b": 2}),
161///         }
162///     ],
163///     is_complete: false,
164/// };
165/// ```
166#[derive(Debug, Clone, Serialize, Default)]
167pub struct LLMResponse<C>
168where
169    C: for<'de> Deserialize<'de> + Default + Clone + Send,
170{
171    /// The generated content from the LLM
172    pub content: C,
173
174    /// Tool calls requested by the LLM (supports multiple calls)
175    pub tool_calls: Vec<ToolCall>,
176
177    /// Whether the response is complete (false if tool calls need to be executed)
178    pub is_complete: bool,
179}
180
181impl<C> LLMResponseTrait<C> for LLMResponse<C>
182where
183    C: for<'de> Deserialize<'de> + Default + Clone + Send,
184{
185    fn new(content: C, tool_calls: Vec<ToolCall>, is_complete: bool) -> Self {
186        Self {
187            content,
188            tool_calls,
189            is_complete,
190        }
191    }
192
193    fn is_complete(&self) -> bool {
194        self.is_complete
195    }
196
197    fn tool_calls(&self) -> Vec<ToolCall> {
198        self.tool_calls.clone()
199    }
200
201    fn content(&self) -> C {
202        self.content.clone()
203    }
204}
205
206/// A request from the LLM to call a tool.
207///
208/// When an LLM wants to use a tool to perform an action, it returns a `ToolCall`
209/// specifying which tool to invoke and with what parameters.
210///
211/// # Examples
212///
213/// ```rust
214/// use runtime::llm::ToolCall;
215/// use serde_json::json;
216///
217/// let tool_call = ToolCall {
218///     name: "search_database".to_string(),
219///     input: json!({
220///         "query": "users with age > 30",
221///         "limit": 10
222///     }),
223/// };
224/// ```
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ToolCall {
227    /// The name of the tool to call
228    pub name: String,
229
230    /// The input parameters for the tool as a JSON value
231    pub input: serde_json::Value,
232}
233
234/// Configuration for LLM providers with all builder options.
235///
236/// This struct provides comprehensive configuration for any LLM provider,
237/// matching all options available in the LLMBuilder.
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct LLMConfig {
240    // Basic configuration
241    pub model: String,
242    pub api_key: Option<String>,
243    pub base_url: Option<String>,
244
245    // Generation parameters
246    pub max_tokens: Option<u32>,
247    pub temperature: Option<f32>,
248    pub top_p: Option<f32>,
249    pub top_k: Option<u32>,
250    pub system: Option<String>,
251
252    // Timeout and retry
253    pub timeout_seconds: Option<u64>,
254
255    // Embeddings
256    pub embedding_encoding_format: Option<String>,
257    pub embedding_dimensions: Option<u32>,
258
259    // Tools and function calling
260    pub enable_parallel_tool_use: Option<bool>,
261
262    // Reasoning (for providers that support it)
263    pub reasoning: Option<bool>,
264    pub reasoning_effort: Option<String>,
265    pub reasoning_budget_tokens: Option<u32>,
266
267    // Provider-specific: Azure
268    pub api_version: Option<String>,
269    pub deployment_id: Option<String>,
270
271    // Provider-specific: Voice/Audio
272    pub voice: Option<String>,
273
274    // Provider-specific: XAI search
275    pub xai_search_mode: Option<String>,
276    pub xai_search_source_type: Option<String>,
277    pub xai_search_excluded_websites: Option<Vec<String>>,
278    pub xai_search_max_results: Option<u32>,
279    pub xai_search_from_date: Option<String>,
280    pub xai_search_to_date: Option<String>,
281
282    // Provider-specific: OpenAI web search
283    pub openai_enable_web_search: Option<bool>,
284    pub openai_web_search_context_size: Option<String>,
285    pub openai_web_search_user_location_type: Option<String>,
286    pub openai_web_search_user_location_approximate_country: Option<String>,
287    pub openai_web_search_user_location_approximate_city: Option<String>,
288    pub openai_web_search_user_location_approximate_region: Option<String>,
289
290    // Resilience
291    pub resilient_enable: Option<bool>,
292    pub resilient_attempts: Option<usize>,
293    pub resilient_base_delay_ms: Option<u64>,
294    pub resilient_max_delay_ms: Option<u64>,
295    pub resilient_jitter: Option<bool>,
296}
297
298impl Default for LLMConfig {
299    fn default() -> Self {
300        Self {
301            model: String::new(),
302            api_key: None,
303            base_url: None,
304            max_tokens: Some(4096),
305            temperature: None,
306            top_p: None,
307            top_k: None,
308            system: None,
309            timeout_seconds: None,
310            embedding_encoding_format: None,
311            embedding_dimensions: None,
312            enable_parallel_tool_use: None,
313            reasoning: None,
314            reasoning_effort: None,
315            reasoning_budget_tokens: None,
316            api_version: None,
317            deployment_id: None,
318            voice: None,
319            xai_search_mode: None,
320            xai_search_source_type: None,
321            xai_search_excluded_websites: None,
322            xai_search_max_results: None,
323            xai_search_from_date: None,
324            xai_search_to_date: None,
325            openai_enable_web_search: None,
326            openai_web_search_context_size: None,
327            openai_web_search_user_location_type: None,
328            openai_web_search_user_location_approximate_country: None,
329            openai_web_search_user_location_approximate_city: None,
330            openai_web_search_user_location_approximate_region: None,
331            resilient_enable: None,
332            resilient_attempts: None,
333            resilient_base_delay_ms: None,
334            resilient_max_delay_ms: None,
335            resilient_jitter: None,
336        }
337    }
338}
339
340impl LLMConfig {
341    /// Create a new LLMConfig with just the model name
342    pub fn new(model: impl Into<String>) -> Self {
343        Self {
344            model: model.into(),
345            ..Default::default()
346        }
347    }
348
349    /// Set API key
350    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
351        self.api_key = Some(api_key.into());
352        self
353    }
354
355    /// Set base URL
356    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
357        self.base_url = Some(base_url.into());
358        self
359    }
360
361    /// Set max tokens
362    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
363        self.max_tokens = Some(max_tokens);
364        self
365    }
366
367    /// Set temperature
368    pub fn with_temperature(mut self, temperature: f32) -> Self {
369        self.temperature = Some(temperature);
370        self
371    }
372
373    /// Set top_p
374    pub fn with_top_p(mut self, top_p: f32) -> Self {
375        self.top_p = Some(top_p);
376        self
377    }
378
379    /// Set top_k
380    pub fn with_top_k(mut self, top_k: u32) -> Self {
381        self.top_k = Some(top_k);
382        self
383    }
384
385    /// Set system prompt
386    pub fn with_system(mut self, system: impl Into<String>) -> Self {
387        self.system = Some(system.into());
388        self
389    }
390
391    /// Set timeout in seconds
392    pub fn with_timeout_seconds(mut self, timeout: u64) -> Self {
393        self.timeout_seconds = Some(timeout);
394        self
395    }
396
397    /// Enable reasoning (for supported providers)
398    pub fn with_reasoning(mut self, enabled: bool) -> Self {
399        self.reasoning = Some(enabled);
400        self
401    }
402
403    /// Set reasoning effort
404    pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
405        self.reasoning_effort = Some(effort.into());
406        self
407    }
408
409    /// Set Azure deployment ID
410    pub fn with_deployment_id(mut self, deployment_id: impl Into<String>) -> Self {
411        self.deployment_id = Some(deployment_id.into());
412        self
413    }
414
415    /// Set Azure API version
416    pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
417        self.api_version = Some(api_version.into());
418        self
419    }
420
421    /// Enable OpenAI web search
422    pub fn with_openai_web_search(mut self, enabled: bool) -> Self {
423        self.openai_enable_web_search = Some(enabled);
424        self
425    }
426
427    /// Enable resilience with retry/backoff
428    pub fn with_resilience(mut self, enabled: bool, attempts: usize) -> Self {
429        self.resilient_enable = Some(enabled);
430        self.resilient_attempts = Some(attempts);
431        self
432    }
433}
434
435/// Legacy config for backward compatibility
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct LLMProviderConfig {
438    pub model: String,
439    pub max_tokens: u64,
440    pub api_key: Option<String>,
441    pub base_url: String,
442}
443
444impl From<LLMConfig> for LLMProviderConfig {
445    fn from(config: LLMConfig) -> Self {
446        Self {
447            model: config.model,
448            max_tokens: config.max_tokens.unwrap_or(4096) as u64,
449            api_key: config.api_key,
450            base_url: config.base_url.unwrap_or_default(),
451        }
452    }
453}
454
455/// Get the environment variable name for a provider's API key
456fn get_api_key_env_var(provider: &str) -> Option<&'static str> {
457    match provider.to_lowercase().as_str() {
458        "ollama" => None, // Ollama doesn't require API key
459        "anthropic" | "claude" => Some("ANTHROPIC_API_KEY"),
460        "openai" | "gpt" => Some("OPENAI_API_KEY"),
461        "deepseek" => Some("DEEPSEEK_API_KEY"),
462        "xai" | "x.ai" => Some("XAI_API_KEY"),
463        "phind" => Some("PHIND_API_KEY"),
464        "google" | "gemini" => Some("GOOGLE_API_KEY"),
465        "groq" => Some("GROQ_API_KEY"),
466        "azure" | "azureopenai" | "azure-openai" => Some("AZURE_OPENAI_API_KEY"),
467        "elevenlabs" | "11labs" => Some("ELEVENLABS_API_KEY"),
468        "cohere" => Some("COHERE_API_KEY"),
469        "mistral" => Some("MISTRAL_API_KEY"),
470        "openrouter" => Some("OPENROUTER_API_KEY"),
471        _ => None,
472    }
473}
474
475/// Attempt to get API key from environment variable for the provider
476/// Returns Ok(Some(key)) if found, Ok(None) if provider doesn't need key, Err if required but not found
477fn get_api_key_from_env(provider: &str) -> Result<Option<String>, String> {
478    match get_api_key_env_var(provider) {
479        None => Ok(None), // Provider doesn't need API key
480        Some(env_var) => {
481            match std::env::var(env_var) {
482                Ok(key) => Ok(Some(key)),
483                Err(_) => Err(format!(
484                    "API key required for provider '{}'. Please set the {} environment variable or pass the API key explicitly.",
485                    provider, env_var
486                ))
487            }
488        }
489    }
490}
491
492/// Helper function to parse provider string to LLMBackend
493fn parse_provider(provider: &str) -> Result<LLMBackend, String> {
494    match provider.to_lowercase().as_str() {
495        "ollama" => Ok(LLMBackend::Ollama),
496        "anthropic" | "claude" => Ok(LLMBackend::Anthropic),
497        "openai" | "gpt" => Ok(LLMBackend::OpenAI),
498        "deepseek" => Ok(LLMBackend::DeepSeek),
499        "xai" | "x.ai" => Ok(LLMBackend::XAI),
500        "phind" => Ok(LLMBackend::Phind),
501        "google" | "gemini" => Ok(LLMBackend::Google),
502        "groq" => Ok(LLMBackend::Groq),
503        "azure" | "azureopenai" | "azure-openai" => Ok(LLMBackend::AzureOpenAI),
504        "elevenlabs" | "11labs" => Ok(LLMBackend::ElevenLabs),
505        "cohere" => Ok(LLMBackend::Cohere),
506        "mistral" => Ok(LLMBackend::Mistral),
507        "openrouter" => Ok(LLMBackend::OpenRouter),
508        _ => Err(format!("Unknown provider: {}", provider)),
509    }
510}
511
512/// Helper function to build LLM from config
513fn build_llm_from_config(
514    config: &LLMConfig,
515    backend: LLMBackend,
516) -> Result<Box<dyn LLMProvider>, String> {
517    let mut builder = LLMBuilder::new().backend(backend.clone());
518
519    // Parse model name from "provider::model" format
520    let model_name = if config.model.contains("::") {
521        config.model.split("::").nth(1).unwrap_or(&config.model)
522    } else {
523        &config.model
524    };
525
526    builder = builder.model(model_name);
527
528    // Apply all configuration options
529    if let Some(max_tokens) = config.max_tokens {
530        builder = builder.max_tokens(max_tokens);
531    }
532
533    if let Some(ref api_key) = config.api_key {
534        builder = builder.api_key(api_key);
535    }
536
537    if let Some(ref base_url) = config.base_url {
538        if !base_url.is_empty() {
539            builder = builder.base_url(base_url);
540        }
541    }
542
543    if let Some(temperature) = config.temperature {
544        builder = builder.temperature(temperature);
545    }
546
547    if let Some(top_p) = config.top_p {
548        builder = builder.top_p(top_p);
549    }
550
551    if let Some(top_k) = config.top_k {
552        builder = builder.top_k(top_k);
553    }
554
555    if let Some(ref system) = config.system {
556        builder = builder.system(system);
557    }
558
559    if let Some(timeout) = config.timeout_seconds {
560        builder = builder.timeout_seconds(timeout);
561    }
562
563    if let Some(ref format) = config.embedding_encoding_format {
564        builder = builder.embedding_encoding_format(format);
565    }
566
567    if let Some(dims) = config.embedding_dimensions {
568        builder = builder.embedding_dimensions(dims);
569    }
570
571    if let Some(enabled) = config.enable_parallel_tool_use {
572        builder = builder.enable_parallel_tool_use(enabled);
573    }
574
575    if let Some(enabled) = config.reasoning {
576        builder = builder.reasoning(enabled);
577    }
578
579    if let Some(budget) = config.reasoning_budget_tokens {
580        builder = builder.reasoning_budget_tokens(budget);
581    }
582
583    // Azure-specific
584    if let Some(ref api_version) = config.api_version {
585        builder = builder.api_version(api_version);
586    }
587
588    if let Some(ref deployment_id) = config.deployment_id {
589        builder = builder.deployment_id(deployment_id);
590    }
591
592    // Voice
593    if let Some(ref voice) = config.voice {
594        builder = builder.voice(voice);
595    }
596
597    // XAI search parameters
598    if let Some(ref mode) = config.xai_search_mode {
599        builder = builder.xai_search_mode(mode);
600    }
601
602    // XAI search source uses a combined method
603    if let (Some(source_type), excluded) = (
604        &config.xai_search_source_type,
605        &config.xai_search_excluded_websites,
606    ) {
607        builder = builder.xai_search_source(source_type, excluded.clone());
608    }
609
610    if let Some(ref from_date) = config.xai_search_from_date {
611        builder = builder.xai_search_from_date(from_date);
612    }
613
614    if let Some(ref to_date) = config.xai_search_to_date {
615        builder = builder.xai_search_to_date(to_date);
616    }
617
618    // OpenAI web search
619    if let Some(enabled) = config.openai_enable_web_search {
620        builder = builder.openai_enable_web_search(enabled);
621    }
622
623    if let Some(ref context_size) = config.openai_web_search_context_size {
624        builder = builder.openai_web_search_context_size(context_size);
625    }
626
627    if let Some(ref loc_type) = config.openai_web_search_user_location_type {
628        builder = builder.openai_web_search_user_location_type(loc_type);
629    }
630
631    if let Some(ref country) = config.openai_web_search_user_location_approximate_country {
632        builder = builder.openai_web_search_user_location_approximate_country(country);
633    }
634
635    if let Some(ref city) = config.openai_web_search_user_location_approximate_city {
636        builder = builder.openai_web_search_user_location_approximate_city(city);
637    }
638
639    if let Some(ref region) = config.openai_web_search_user_location_approximate_region {
640        builder = builder.openai_web_search_user_location_approximate_region(region);
641    }
642
643    // Resilience
644    if let Some(enabled) = config.resilient_enable {
645        builder = builder.resilient(enabled);
646    }
647
648    if let Some(attempts) = config.resilient_attempts {
649        builder = builder.resilient_attempts(attempts);
650    }
651
652    builder
653        .build()
654        .map_err(|e| format!("Failed to build LLM: {}", e))
655}
656
657/// Estimate token count based on text length (~4 chars per token)
658fn estimate_tokens(text: &str) -> u64 {
659    ((text.len() as f64) / 4.0).ceil() as u64
660}
661
662/// Get model pricing in micro-dollars per 1K tokens (input, output)
663fn get_model_pricing(model: &str) -> (u64, u64) {
664    match model.to_lowercase().as_str() {
665        m if m.contains("gpt-4o") => (2_500, 10_000),
666        m if m.contains("gpt-4-turbo") => (10_000, 30_000),
667        m if m.contains("gpt-4") => (30_000, 60_000),
668        m if m.contains("gpt-3.5-turbo") => (500, 1_500),
669        m if m.contains("claude-3-opus") => (15_000, 75_000),
670        m if m.contains("claude-3-5-sonnet") => (3_000, 15_000),
671        m if m.contains("claude-3-sonnet") => (3_000, 15_000),
672        m if m.contains("claude-3-haiku") => (250, 1_250),
673        _ => (500, 1_500),
674    }
675}
676
677/// Calculate cost in micro-dollars
678fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> u64 {
679    let (input_price, output_price) = get_model_pricing(model);
680    (input_tokens * input_price + output_tokens * output_price) / 1000
681}
682
683pub struct UniversalLLMClient {
684    config: LLMProviderConfig,
685    llm: Box<dyn LLMProvider>,
686}
687
688#[async_trait]
689pub trait LLMClient: Send + Sync {
690    /// Send messages to LLM and get a response
691    async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
692    where
693        T: LLMResponseTrait<C> + Default + Send,
694        C: for<'de> Deserialize<'de> + Default + Send + Serialize;
695}
696
697impl Clone for UniversalLLMClient {
698    fn clone(&self) -> Self {
699        let config = self.config.clone();
700        let model = config.clone().model;
701        let api_key = config.api_key.clone();
702        let base_url = config.base_url.clone();
703        let parts: Vec<&str> = model.split("::").collect();
704
705        let provider = parts[0];
706        let model = parts[1];
707
708        let backend = parse_provider(provider).unwrap_or(LLMBackend::Ollama);
709
710        let mut builder = LLMBuilder::new()
711            .backend(backend.clone())
712            .model(model)
713            .max_tokens(4096);
714
715        if let Some(api_key) = api_key {
716            builder = builder.api_key(api_key);
717        }
718
719        if !base_url.is_empty() {
720            builder = builder.base_url(base_url);
721        }
722
723        let llm = builder
724            .build()
725            .map_err(|e| format!("Failed to build LLM: {}", e))
726            .unwrap();
727
728        Self {
729            llm,
730            config: config.clone(),
731        }
732    }
733}
734
735impl UniversalLLMClient {
736    const DEFAULT_SYSTEM_PROMPT: &'static str = "You are a helpful AI assistant.";
737
738    // Updated: Now supports multiple tool calls
739    const DEFAULT_TOOL_PROMPT: &'static str = "You have access to the following tools.\n\
740         To call ONE tool, respond EXACTLY in this format:\n\
741         USE_TOOL: tool_name\n\
742         {\"param1\": \"value1\"}\n\n\
743         To call MULTIPLE tools at once, respond in this format:\n\
744         USE_TOOLS:\n\
745         tool_name1\n\
746         {\"param1\": \"value1\"}\n\
747         ---\n\
748         tool_name2\n\
749         {\"param1\": \"value1\"}\n\n\
750         Only call tools using these exact formats. Otherwise, respond normally.";
751
752    fn generate_schema_instruction<C>(sample: &C) -> String
753    where
754        C: Serialize,
755    {
756        let sample_json = serde_json::to_string_pretty(sample).unwrap_or_else(|_| "{}".to_string());
757
758        format!(
759            "Respond with ONLY a JSON object in this exact format:\n{}\n\nProvide your response as valid JSON.",
760            sample_json
761        )
762    }
763
764    pub fn new(provider_model: &str, api_key: Option<String>) -> Result<Self, String> {
765        let parts: Vec<&str> = provider_model.split("::").collect();
766
767        if parts.len() != 2 {
768            return Err(format!(
769                "Invalid format. Use 'provider::model-name'. Got: {}",
770                provider_model
771            ));
772        }
773
774        let provider = parts[0];
775        let model = parts[1];
776
777        // Determine the API key to use: provided > environment variable > error if required
778        let final_api_key = match api_key {
779            Some(key) => Some(key),
780            None => {
781                // Try to get from environment variable
782                match get_api_key_from_env(provider) {
783                    Ok(env_key) => env_key,
784                    Err(e) => return Err(e), // Required but not found
785                }
786            }
787        };
788
789        let config = LLMProviderConfig {
790            model: provider_model.to_string(),
791            max_tokens: 4096,
792            api_key: final_api_key.clone(),
793            base_url: String::new(),
794        };
795
796        let backend = parse_provider(provider)?;
797
798        let base_url = match provider.to_lowercase().as_str() {
799            "ollama" => std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
800            _ => String::new(),
801        };
802
803        let mut builder = LLMBuilder::new()
804            .backend(backend.clone())
805            .model(model)
806            .max_tokens(4096);
807
808        if let Some(api_key) = final_api_key {
809            builder = builder.api_key(api_key);
810        }
811
812        if !base_url.is_empty() {
813            builder = builder.base_url(base_url);
814        }
815
816        let llm = builder
817            .build()
818            .map_err(|e| format!("Failed to build LLM: {}", e))?;
819
820        Ok(Self { llm, config })
821    }
822
823    /// Create a new UniversalLLMClient with comprehensive LLMConfig
824    ///
825    /// # Examples
826    ///
827    /// ```rust,no_run
828    /// use runtime::llm::LLMConfig;
829    /// use runtime::llm::UniversalLLMClient;
830    ///
831    /// let config = LLMConfig::new("openai::gpt-4")
832    ///     .with_api_key("your-api-key")
833    ///     .with_temperature(0.7)
834    ///     .with_max_tokens(2048);
835    ///
836    /// let client = UniversalLLMClient::new_with_config(config).unwrap();
837    /// ```
838    pub fn new_with_config(llm_config: LLMConfig) -> Result<Self, String> {
839        // Parse provider from model string
840        let parts: Vec<&str> = llm_config.model.split("::").collect();
841
842        if parts.len() != 2 {
843            return Err(format!(
844                "Invalid format. Use 'provider::model-name'. Got: {}",
845                llm_config.model
846            ));
847        }
848
849        let provider = parts[0];
850        let backend = parse_provider(provider)?;
851
852        // Set default base_url for certain providers if not specified
853        let mut config = llm_config.clone();
854        if config.base_url.is_none() {
855            match provider.to_lowercase().as_str() {
856                "ollama" => {
857                    config.base_url = Some(
858                        std::env::var("OLLAMA_URL").unwrap_or("http://127.0.0.1:11434".to_string()),
859                    );
860                }
861                _ => {}
862            }
863        }
864
865        // Check for API key: provided > environment variable > error if required
866        if config.api_key.is_none() {
867            match get_api_key_from_env(provider) {
868                Ok(env_key) => config.api_key = env_key,
869                Err(e) => return Err(e), // Required but not found
870            }
871        }
872
873        // Build LLM using the comprehensive config
874        let llm = build_llm_from_config(&config, backend)?;
875
876        // Convert to legacy config for internal storage
877        let legacy_config = LLMProviderConfig::from(config);
878
879        Ok(Self {
880            llm,
881            config: legacy_config,
882        })
883    }
884
885    fn convert_messages(&self, messages: &[Message]) -> Vec<ChatMessage> {
886        messages
887            .iter()
888            .map(|msg| match msg.role.as_str() {
889                "user" => ChatMessage::user().content(&msg.content).build(),
890                "assistant" => ChatMessage::assistant().content(&msg.content).build(),
891                "system" => ChatMessage::assistant().content(&msg.content).build(),
892                "tool" => ChatMessage::assistant()
893                    .content(format!("Tool result: {}", msg.content))
894                    .build(),
895                _ => ChatMessage::user().content(&msg.content).build(),
896            })
897            .collect()
898    }
899
900    fn build_tool_description(tools: &[ToolSpec]) -> String {
901        tools
902            .iter()
903            .map(|t| {
904                let params = t
905                    .input_schema
906                    .get("properties")
907                    .and_then(|p| p.as_object())
908                    .map(|o| o.keys().cloned().collect::<Vec<_>>().join(", "))
909                    .unwrap_or_default();
910
911                if t.description.is_empty() {
912                    format!("- {}({})", t.name, params)
913                } else {
914                    format!("- {}({}): {}", t.name, params, t.description)
915                }
916            })
917            .collect::<Vec<_>>()
918            .join("\n")
919    }
920
921    // New helper: Parse multiple tool calls from response
922    fn parse_tool_calls(response_text: &str) -> Vec<ToolCall> {
923        let mut tool_calls = Vec::new();
924
925        // Check for multiple tools format
926        if response_text.starts_with("USE_TOOLS:") {
927            // Split by "---" to get individual tool calls
928            let parts: Vec<&str> = response_text
929                .strip_prefix("USE_TOOLS:")
930                .unwrap_or("")
931                .split("---")
932                .collect();
933
934            for part in parts {
935                let lines: Vec<&str> = part.trim().lines().collect();
936                if lines.is_empty() {
937                    continue;
938                }
939
940                let tool_name = lines[0].trim().to_string();
941                let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
942
943                if let Ok(input_value) = serde_json::from_str(&json_block) {
944                    tool_calls.push(ToolCall {
945                        name: tool_name,
946                        input: input_value,
947                    });
948                }
949            }
950        }
951        // Check for single tool format
952        else if response_text.starts_with("USE_TOOL:") {
953            let lines: Vec<&str> = response_text.lines().collect();
954            let tool_name = lines[0]
955                .strip_prefix("USE_TOOL:")
956                .unwrap_or("")
957                .trim()
958                .to_string();
959
960            let json_block = lines.get(1..).unwrap_or(&[]).join("\n");
961
962            if let Ok(input_value) = serde_json::from_str(&json_block) {
963                tool_calls.push(ToolCall {
964                    name: tool_name,
965                    input: input_value,
966                });
967            }
968        }
969
970        tool_calls
971    }
972}
973
974#[async_trait]
975impl LLMClient for UniversalLLMClient {
976    async fn complete<T, C>(&self, messages: &[Message], tools: &[ToolSpec]) -> Result<T, String>
977    where
978        T: LLMResponseTrait<C> + Default + Send,
979        C: for<'de> Deserialize<'de> + Default + Send + Serialize,
980    {
981        let mut chat_messages = vec![];
982
983        // 1) Add system prompt if not provided by user
984        let has_user_system_prompt = messages.iter().any(|m| m.role == "system");
985        if !has_user_system_prompt {
986            chat_messages.push(
987                ChatMessage::assistant()
988                    .content(Self::DEFAULT_SYSTEM_PROMPT)
989                    .build(),
990            );
991        }
992
993        // 2) Add tool prompt if tools are provided
994        let user_tool_prompt = messages
995            .iter()
996            .find(|m| m.role == "system_tools")
997            .map(|m| m.content.clone());
998
999        if !tools.is_empty() {
1000            let tool_list = Self::build_tool_description(tools);
1001            let tool_prompt = user_tool_prompt.unwrap_or_else(|| {
1002                format!(
1003                    "{}\n\nAvailable Tools:\n{}\n\n{}",
1004                    Self::DEFAULT_TOOL_PROMPT,
1005                    tool_list,
1006                    "Use only the EXACT formats shown above when calling tools."
1007                )
1008            });
1009            chat_messages.push(ChatMessage::assistant().content(tool_prompt).build());
1010        }
1011
1012        // 3) AUTO-GENERATE SCHEMA INSTRUCTION
1013        let sample_c = C::default();
1014        let schema_instruction = Self::generate_schema_instruction(&sample_c);
1015
1016        chat_messages.push(ChatMessage::assistant().content(schema_instruction).build());
1017
1018        // 4) Add user messages
1019        chat_messages.extend(self.convert_messages(messages));
1020
1021        // Helper: try to parse into C
1022        let try_parse_c = |s: &str| -> C {
1023            let text = s.trim();
1024
1025            // Try direct JSON parse
1026            if let Ok(parsed) = serde_json::from_str::<C>(text) {
1027                return parsed;
1028            }
1029
1030            // Remove markdown code blocks
1031            let cleaned = text
1032                .strip_prefix("```json")
1033                .unwrap_or(text)
1034                .strip_prefix("```")
1035                .unwrap_or(text)
1036                .strip_suffix("```")
1037                .unwrap_or(text)
1038                .trim();
1039
1040            if let Ok(parsed) = serde_json::from_str::<C>(cleaned) {
1041                return parsed;
1042            }
1043
1044            // Find JSON object in text
1045            if let Some(start) = text.find('{') {
1046                if let Some(end) = text.rfind('}') {
1047                    let json_part = &text[start..=end];
1048                    if let Ok(parsed) = serde_json::from_str::<C>(json_part) {
1049                        return parsed;
1050                    }
1051                }
1052            }
1053
1054            // Try quoted string
1055            if let Ok(quoted) = serde_json::to_string(text) {
1056                if let Ok(parsed) = serde_json::from_str::<C>(&quoted) {
1057                    return parsed;
1058                }
1059            }
1060
1061            // Fallback to default
1062            C::default()
1063        };
1064
1065        // 5) Send to LLM
1066        let start = std::time::Instant::now();
1067        let response = self
1068            .llm
1069            .chat(&chat_messages)
1070            .await
1071            .map_err(|e| format!("LLM error: {}", e))?;
1072        let duration = start.elapsed().as_micros() as u64;
1073
1074        let response_text = response.text().unwrap_or_default();
1075
1076        // Estimate tokens and cost
1077        let input_text: String = messages
1078            .iter()
1079            .map(|m| m.content.as_str())
1080            .collect::<Vec<_>>()
1081            .join(" ");
1082        let input_tokens = estimate_tokens(&input_text);
1083        let output_tokens = estimate_tokens(&response_text);
1084        let total_tokens = input_tokens + output_tokens;
1085        let cost_us = calculate_cost(&self.config.model, input_tokens, output_tokens);
1086
1087        // Record metrics with estimates
1088        crate::metrics::metrics().record_llm_call(duration, total_tokens, cost_us);
1089
1090        // 6) Parse tool calls (handles both single and multiple)
1091        let tool_calls = Self::parse_tool_calls(&response_text);
1092
1093        // If we have tool calls, return them
1094        if !tool_calls.is_empty() {
1095            let parsed_content: C = C::default();
1096            return Ok(T::new(parsed_content, tool_calls, false));
1097        }
1098
1099        // 7) Normal response - parse into C
1100        let parsed_content: C = try_parse_c(&response_text);
1101        Ok(T::new(parsed_content, vec![], true))
1102    }
1103}
1104
1105// ============================================================================
1106// MOCK LLM CLIENT FOR TESTING
1107// ============================================================================
1108
1109/// Mock LLM client for testing - doesn't make real API calls.
1110///
1111/// This is useful for writing fast unit tests without requiring actual LLM API access.
1112///
1113/// # Examples
1114///
1115/// ```rust
1116/// use runtime::llm::{MockLLMClient, LLMClient, LLMResponse};
1117/// use runtime::llm::types::Message;
1118///
1119/// # #[tokio::main]
1120/// # async fn main() {
1121/// let client = MockLLMClient::new("Hello from mock!");
1122/// let messages = vec![Message {
1123///     role: "user".to_string(),
1124///     content: "Say hello".to_string(),
1125/// }];
1126///
1127/// let result: LLMResponse<String> = client
1128///     .complete::<LLMResponse<String>, String>(&messages, &[])
1129///     .await
1130///     .expect("Mock LLM failed");
1131/// # }
1132/// ```
1133pub struct MockLLMClient {
1134    response_content: String,
1135}
1136
1137impl MockLLMClient {
1138    /// Creates a new mock LLM client that returns the specified response.
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `response` - A string that will be parsed as the response content
1143    ///
1144    /// # Examples
1145    ///
1146    /// ```rust
1147    /// use runtime::llm::MockLLMClient;
1148    ///
1149    /// // For structured responses, provide JSON
1150    /// let client = MockLLMClient::new(r#"{"field": "value"}"#);
1151    ///
1152    /// // For simple text responses
1153    /// let client = MockLLMClient::new("Hello, world!");
1154    /// ```
1155    pub fn new(response: &str) -> Self {
1156        Self {
1157            response_content: response.to_string(),
1158        }
1159    }
1160
1161    /// Creates a mock LLM client with a default "Hello" response.
1162    ///
1163    /// # Examples
1164    ///
1165    /// ```rust
1166    /// use runtime::llm::MockLLMClient;
1167    ///
1168    /// let client = MockLLMClient::default_hello();
1169    /// ```
1170    pub fn default_hello() -> Self {
1171        Self::new("Hello! How can I help you today?")
1172    }
1173}
1174
1175#[async_trait::async_trait]
1176impl LLMClient for MockLLMClient {
1177    async fn complete<T, C>(&self, _messages: &[Message], _tools: &[ToolSpec]) -> Result<T, String>
1178    where
1179        T: LLMResponseTrait<C> + Default + Send,
1180        C: for<'de> Deserialize<'de> + Default + Send + Serialize,
1181    {
1182        // Try to parse the response into the expected content type
1183        let content: C = if let Ok(parsed) = serde_json::from_str(&self.response_content) {
1184            parsed
1185        } else {
1186            C::default()
1187        };
1188
1189        Ok(T::new(content, vec![], true))
1190    }
1191}
1192
1193// ============================================================================
1194// TESTS
1195// ============================================================================
1196
1197#[cfg(test)]
1198mod tests {
1199    use super::*;
1200
1201    #[tokio::test]
1202    async fn test_mock_llm_basic() {
1203        let client = MockLLMClient::default_hello();
1204        let messages = vec![Message {
1205            role: "user".into(),
1206            content: "Say hello".into(),
1207        }];
1208
1209        let result: LLMResponse<String> = client
1210            .complete::<LLMResponse<String>, String>(&messages, &[])
1211            .await
1212            .expect("Mock LLM failed");
1213
1214        assert!(result.content.is_empty()); // Default String is empty
1215        assert!(result.tool_calls.is_empty());
1216        assert!(result.is_complete);
1217    }
1218
1219    #[tokio::test]
1220    async fn test_mock_structured_output() {
1221        let client = MockLLMClient::new(r#"{"field1": 42.5, "flag": true}"#);
1222        let messages = vec![Message {
1223            role: "user".into(),
1224            content: "Return structured data".into(),
1225        }];
1226
1227        #[derive(Deserialize, Serialize, Default, Clone, Debug)]
1228        struct MyOutput {
1229            field1: f64,
1230            flag: bool,
1231        }
1232
1233        let result: LLMResponse<MyOutput> = client
1234            .complete::<LLMResponse<MyOutput>, MyOutput>(&messages, &[])
1235            .await
1236            .expect("Mock LLM failed");
1237
1238        assert_eq!(result.content.field1, 42.5);
1239        assert_eq!(result.content.flag, true);
1240        assert!(result.tool_calls.is_empty());
1241        assert!(result.is_complete);
1242    }
1243}