litellm_rs/core/
completion.rs

1//! Python LiteLLM compatible completion API
2//!
3//! This module provides a Python LiteLLM-style API for making completion requests.
4//! It serves as the main entry point for the library, providing a unified interface
5//! to call 100+ LLM APIs using OpenAI format.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::OnceCell;
12// Import core types from our unified type system
13use crate::core::types::{
14    ChatMessage, ChatRequest, ChatResponse, RequestContext, Tool, ToolChoice,
15};
16
17// Import provider system
18use crate::core::providers::{Provider, ProviderRegistry, ProviderType};
19use crate::utils::error::{GatewayError, Result};
20use tracing::debug;
21
22/// Core completion function - the main entry point for all LLM calls
23/// Mimics Python LiteLLM's completion function signature
24pub async fn completion(
25    model: &str,
26    messages: Vec<Message>,
27    options: Option<CompletionOptions>,
28) -> Result<CompletionResponse> {
29    let router = get_global_router().await;
30    router
31        .complete(model, messages, options.unwrap_or_default())
32        .await
33}
34
35/// Async version of completion (though all is async in Rust)
36pub async fn acompletion(
37    model: &str,
38    messages: Vec<Message>,
39    options: Option<CompletionOptions>,
40) -> Result<CompletionResponse> {
41    completion(model, messages, options).await
42}
43
44/// Streaming completion function
45pub async fn completion_stream(
46    _model: &str,
47    _messages: Vec<Message>,
48    _options: Option<CompletionOptions>,
49) -> Result<CompletionStream> {
50    // TODO: Implement streaming
51    todo!("Streaming completion not yet implemented")
52}
53
54/// Unified message format (OpenAI compatible) - just re-export the core type
55pub type Message = ChatMessage;
56
57// Re-export types with proper paths (no duplicate imports)
58pub use crate::core::types::{MessageContent, MessageRole};
59
60/// Content part for multimodal messages (re-export from core types)
61pub use crate::core::types::ContentPart;
62
63/// Tool call structure
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ToolCall {
66    pub id: String,
67    pub r#type: String,
68    pub function: FunctionCall,
69}
70
71/// Function call structure
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct FunctionCall {
74    pub name: String,
75    pub arguments: String,
76}
77
78/// Completion options - Python LiteLLM compatible
79#[derive(Debug, Clone, Serialize, Deserialize, Default)]
80pub struct CompletionOptions {
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub temperature: Option<f32>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub max_tokens: Option<u32>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub top_p: Option<f32>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub frequency_penalty: Option<f32>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub presence_penalty: Option<f32>,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub stop: Option<Vec<String>>,
93    #[serde(default)]
94    pub stream: bool,
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub tools: Option<Vec<Tool>>,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub tool_choice: Option<ToolChoice>,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub user: Option<String>,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub seed: Option<i32>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub n: Option<u32>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub logprobs: Option<bool>,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub top_logprobs: Option<u32>,
109
110    // Python LiteLLM compatibility fields
111    /// Custom API base URL - overrides provider's default base URL
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub api_base: Option<String>,
114
115    /// Custom API key - overrides provider's default API key
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub api_key: Option<String>,
118
119    /// Custom organization ID (for OpenAI)
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub organization: Option<String>,
122
123    /// Custom API version (for Azure)
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub api_version: Option<String>,
126
127    /// Custom headers to add to the request
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub headers: Option<HashMap<String, String>>,
130
131    /// Timeout in seconds for the request
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub timeout: Option<u64>,
134
135    #[serde(flatten)]
136    pub extra_params: HashMap<String, serde_json::Value>,
137}
138
139/// Completion response
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct CompletionResponse {
142    pub id: String,
143    pub object: String,
144    pub created: i64,
145    pub model: String,
146    pub choices: Vec<Choice>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub usage: Option<Usage>,
149}
150
151/// Choice in completion response
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Choice {
154    pub index: u32,
155    pub message: Message,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub finish_reason: Option<FinishReason>,
158}
159
160/// Usage statistics (re-export from core types)
161pub type Usage = crate::core::types::responses::Usage;
162
163/// Finish reason enumeration (re-export from core types)
164pub type FinishReason = crate::core::types::responses::FinishReason;
165
166/// Streaming response type (placeholder for now)
167pub type CompletionStream =
168    Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin + 'static>;
169
170/// LiteLLM Error type
171pub type LiteLLMError = GatewayError;
172
173/// Router trait for handling completion requests
174#[async_trait]
175pub trait Router: Send + Sync {
176    async fn complete(
177        &self,
178        model: &str,
179        messages: Vec<Message>,
180        options: CompletionOptions,
181    ) -> Result<CompletionResponse>;
182
183    async fn complete_stream(
184        &self,
185        model: &str,
186        messages: Vec<Message>,
187        options: CompletionOptions,
188    ) -> Result<CompletionStream>;
189}
190
191/// Default router implementation using the provider registry
192pub struct DefaultRouter {
193    provider_registry: Arc<ProviderRegistry>,
194}
195
196impl DefaultRouter {
197    /// Helper function to find and select a provider by name with model prefix stripping
198    fn select_provider_by_name<'a>(
199        providers: &'a [&'a crate::core::providers::Provider],
200        provider_name: &str,
201        original_model: &str,
202        prefix: &str,
203        chat_request: &ChatRequest,
204    ) -> Option<(&'a crate::core::providers::Provider, ChatRequest)> {
205        if !original_model.starts_with(prefix) {
206            return None;
207        }
208
209        let actual_model = original_model.strip_prefix(prefix).unwrap_or(original_model);
210        
211        debug!(
212            provider = provider_name,
213            model = %actual_model,
214            "Using static {} provider", provider_name
215        );
216
217        for provider in providers.iter() {
218            if provider.name() == provider_name {
219                let mut updated_request = chat_request.clone();
220                updated_request.model = actual_model.to_string();
221                return Some((provider, updated_request));
222            }
223        }
224        
225        None
226    }
227
228    pub async fn new() -> Result<Self> {
229        let mut provider_registry = ProviderRegistry::new();
230
231        // Add OpenAI provider if API key is available
232        if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
233            use crate::core::providers::base::BaseConfig;
234            use crate::core::providers::openai::OpenAIProvider;
235            use crate::core::providers::openai::config::OpenAIConfig;
236
237            // Create OpenAI provider config
238            let config = OpenAIConfig {
239                base: BaseConfig {
240                    api_key: Some(api_key),
241                    api_base: Some("https://api.openai.com/v1".to_string()),
242                    timeout: 60,
243                    max_retries: 3,
244                    headers: Default::default(),
245                    organization: std::env::var("OPENAI_ORGANIZATION").ok(),
246                    api_version: None,
247                },
248                organization: std::env::var("OPENAI_ORGANIZATION").ok(),
249                project: None,
250                model_mappings: Default::default(),
251                features: Default::default(),
252            };
253
254            // Create and register OpenAI provider
255            if let Ok(openai_provider) = OpenAIProvider::new(config).await {
256                provider_registry.register(Provider::OpenAI(openai_provider));
257            }
258        }
259
260        // Add OpenRouter provider if API key is available
261        if let Ok(api_key) = std::env::var("OPENROUTER_API_KEY") {
262            use crate::core::providers::openrouter::{OpenRouterConfig, OpenRouterProvider};
263
264            // Clean the API key to remove any whitespace or newlines
265            let api_key = api_key.trim().to_string();
266
267            // Create OpenRouter provider config
268            let config = OpenRouterConfig {
269                api_key,
270                base_url: "https://openrouter.ai/api/v1".to_string(),
271                site_url: std::env::var("OPENROUTER_HTTP_REFERER").ok(),
272                site_name: std::env::var("OPENROUTER_X_TITLE").ok(),
273                timeout_seconds: 60,
274                max_retries: 3,
275                extra_params: Default::default(),
276            };
277
278            // Create and register OpenRouter provider
279            if let Ok(openrouter_provider) = OpenRouterProvider::new(config) {
280                provider_registry.register(Provider::OpenRouter(openrouter_provider));
281            }
282        }
283
284        // Add Anthropic provider if API key is available
285        if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
286            use crate::core::providers::anthropic::{AnthropicProvider, AnthropicConfig};
287
288            let config = AnthropicConfig::new(api_key)
289                .with_base_url("https://api.anthropic.com")
290                .with_experimental(false);
291
292            let anthropic_provider = AnthropicProvider::new(config)?;
293            provider_registry.register(Provider::Anthropic(anthropic_provider));
294        }
295
296        // Azure provider registration temporarily disabled - needs migration to new system
297        // TODO: Re-enable once Azure provider is fully migrated from BaseLLM to LLMProvider
298
299        // Add VertexAI provider if service account is available
300        if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() {
301            use crate::core::providers::vertex_ai::{
302                VertexAIProvider, VertexAIProviderConfig, VertexCredentials,
303            };
304
305            let config = VertexAIProviderConfig {
306                project_id: std::env::var("GOOGLE_PROJECT_ID")
307                    .unwrap_or_else(|_| "default-project".to_string()),
308                location: std::env::var("GOOGLE_LOCATION")
309                    .unwrap_or_else(|_| "us-central1".to_string()),
310                api_version: "v1".to_string(),
311                credentials: VertexCredentials::ApplicationDefault,
312                api_base: None,
313                timeout_seconds: 60,
314                max_retries: 3,
315                enable_experimental: false,
316            };
317
318            if let Ok(vertex_provider) = VertexAIProvider::new(config).await {
319                provider_registry.register(Provider::VertexAI(vertex_provider));
320            }
321        }
322
323        // Add DeepSeek provider if API key is available
324        if let Ok(_api_key) = std::env::var("DEEPSEEK_API_KEY") {
325            use crate::core::providers::deepseek::{DeepSeekConfig, DeepSeekProvider};
326
327            let config = DeepSeekConfig::from_env();
328
329            if let Ok(deepseek_provider) = DeepSeekProvider::new(config) {
330                provider_registry.register(Provider::DeepSeek(deepseek_provider));
331            }
332        }
333
334        Ok(Self {
335            provider_registry: Arc::new(provider_registry),
336        })
337    }
338
339    /// Dynamic provider creation (Python LiteLLM style)
340    /// Creates providers on-demand based on model name and provided options
341    async fn try_dynamic_provider_creation(
342        &self,
343        chat_request: &ChatRequest,
344        context: RequestContext,
345        options: &CompletionOptions,
346    ) -> Result<Option<CompletionResponse>> {
347        let model = &chat_request.model;
348
349        // Only proceed if user provided an API key
350        let api_key = match &options.api_key {
351            Some(key) => key.clone(),
352            None => return Ok(None), // No dynamic creation without API key
353        };
354
355        // Determine provider type from model name
356        let (provider_type, actual_model, api_base) = if model.starts_with("openrouter/") {
357            let actual_model = model.strip_prefix("openrouter/").unwrap_or(model);
358            let api_base = options
359                .api_base
360                .clone()
361                .unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
362            ("openrouter", actual_model, api_base)
363        } else if model.starts_with("anthropic/") {
364            let actual_model = model.strip_prefix("anthropic/").unwrap_or(model);
365            let api_base = options
366                .api_base
367                .clone()
368                .unwrap_or_else(|| "https://api.anthropic.com".to_string());
369            ("anthropic", actual_model, api_base)
370        } else if model.starts_with("deepseek/") {
371            let actual_model = model.strip_prefix("deepseek/").unwrap_or(model);
372            let api_base = options
373                .api_base
374                .clone()
375                .unwrap_or_else(|| "https://api.deepseek.com".to_string());
376            ("deepseek", actual_model, api_base)
377        } else if model.starts_with("azure_ai/") || model.starts_with("azure-ai/") {
378            let actual_model = model.strip_prefix("azure_ai/")
379                .or_else(|| model.strip_prefix("azure-ai/"))
380                .unwrap_or(model);
381            let api_base = options
382                .api_base
383                .clone()
384                .or_else(|| std::env::var("AZURE_AI_API_BASE").ok())
385                .unwrap_or_else(|| "https://api.azure.com".to_string());
386            ("azure_ai", actual_model, api_base)
387        } else if model.starts_with("openai/") {
388            let actual_model = model.strip_prefix("openai/").unwrap_or(model);
389            let api_base = options
390                .api_base
391                .clone()
392                .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
393            ("openai", actual_model, api_base)
394        } else {
395            // For models without provider prefix, try to infer or use custom api_base
396            if let Some(api_base) = &options.api_base {
397                ("openai-compatible", model.as_str(), api_base.clone())
398            } else {
399                return Ok(None); // Can't create dynamic provider without api_base
400            }
401        };
402
403        debug!(
404            provider_type = %provider_type,
405            model = %actual_model,
406            "Creating dynamic provider for model"
407        );
408
409        // Create dynamic provider based on type
410        let response = match provider_type {
411            "openrouter" => {
412                self.create_dynamic_openrouter(
413                    actual_model,
414                    &api_key,
415                    &api_base,
416                    chat_request,
417                    context,
418                )
419                .await?
420            }
421            "anthropic" => {
422                self.create_dynamic_anthropic(
423                    actual_model,
424                    &api_key,
425                    &api_base,
426                    chat_request,
427                    context,
428                )
429                .await?
430            }
431            "deepseek" => {
432                self.create_dynamic_openai_compatible(
433                    actual_model,
434                    &api_key,
435                    &api_base,
436                    chat_request,
437                    context,
438                    "DeepSeek",
439                )
440                .await?
441            }
442            "azure_ai" => {
443                self.create_dynamic_azure_ai(
444                    actual_model,
445                    &api_key,
446                    &api_base,
447                    chat_request,
448                    context,
449                )
450                .await?
451            }
452            "openai" => {
453                self.create_dynamic_openai_compatible(
454                    actual_model,
455                    &api_key,
456                    &api_base,
457                    chat_request,
458                    context,
459                    "OpenAI",
460                )
461                .await?
462            }
463            "openai-compatible" => {
464                self.create_dynamic_openai_compatible(
465                    actual_model,
466                    &api_key,
467                    &api_base,
468                    chat_request,
469                    context,
470                    "OpenAI-Compatible",
471                )
472                .await?
473            }
474            _ => return Ok(None),
475        };
476
477        Ok(Some(response))
478    }
479
480    /// Create dynamic OpenRouter provider
481    async fn create_dynamic_openrouter(
482        &self,
483        model: &str,
484        api_key: &str,
485        api_base: &str,
486        chat_request: &ChatRequest,
487        context: RequestContext,
488    ) -> Result<CompletionResponse> {
489        use crate::core::providers::openrouter::{OpenRouterConfig, OpenRouterProvider};
490        use crate::core::traits::LLMProvider;
491
492        let config = OpenRouterConfig {
493            api_key: api_key.to_string(),
494            base_url: api_base.to_string(),
495            site_url: None, // Could be extracted from options if needed
496            site_name: None,
497            timeout_seconds: 60,
498            max_retries: 3,
499            extra_params: Default::default(),
500        };
501
502        let provider = OpenRouterProvider::new(config).map_err(|e| {
503            GatewayError::internal(format!(
504                "Failed to create dynamic OpenRouter provider: {}",
505                e
506            ))
507        })?;
508
509        let mut updated_request = chat_request.clone();
510        updated_request.model = model.to_string();
511
512        let response = provider
513            .chat_completion(updated_request, context)
514            .await
515            .map_err(|e| {
516                GatewayError::internal(format!("Dynamic OpenRouter provider error: {}", e))
517            })?;
518
519        convert_from_chat_completion_response(response)
520    }
521
522    /// Create dynamic Anthropic provider
523    async fn create_dynamic_anthropic(
524        &self,
525        model: &str,
526        api_key: &str,
527        api_base: &str,
528        chat_request: &ChatRequest,
529        context: RequestContext,
530    ) -> Result<CompletionResponse> {
531        use crate::core::providers::anthropic::{AnthropicProvider, AnthropicConfig};
532        use crate::core::traits::LLMProvider;
533
534        let config = AnthropicConfig::new(api_key)
535            .with_base_url(api_base)
536            .with_experimental(false);
537
538        let provider = AnthropicProvider::new(config)?;
539
540        let mut updated_request = chat_request.clone();
541        updated_request.model = model.to_string();
542
543        let response = LLMProvider::chat_completion(&provider, updated_request, context)
544            .await
545            .map_err(|e| {
546                GatewayError::internal(format!("Dynamic Anthropic provider error: {}", e))
547            })?;
548
549        convert_from_chat_completion_response(response)
550    }
551
552    /// Create dynamic OpenAI-compatible provider (works for OpenAI, DeepSeek, and other compatible APIs)
553    async fn create_dynamic_openai_compatible(
554        &self,
555        model: &str,
556        api_key: &str,
557        api_base: &str,
558        chat_request: &ChatRequest,
559        context: RequestContext,
560        provider_name: &str,
561    ) -> Result<CompletionResponse> {
562        use crate::core::providers::base::BaseConfig;
563        use crate::core::providers::openai::OpenAIProvider;
564        use crate::core::providers::openai::config::OpenAIConfig;
565        use crate::core::traits::LLMProvider;
566
567        let config = OpenAIConfig {
568            base: BaseConfig {
569                api_key: Some(api_key.to_string()),
570                api_base: Some(api_base.to_string()),
571                timeout: 60,
572                max_retries: 3,
573                headers: Default::default(),
574                organization: None,
575                api_version: None,
576            },
577            organization: None,
578            project: None,
579            model_mappings: Default::default(),
580            features: Default::default(),
581        };
582
583        let provider = OpenAIProvider::new(config).await.map_err(|e| {
584            GatewayError::internal(format!(
585                "Failed to create dynamic {} provider: {}",
586                provider_name, e
587            ))
588        })?;
589
590        let mut updated_request = chat_request.clone();
591        updated_request.model = model.to_string();
592
593        let response = provider
594            .chat_completion(updated_request, context)
595            .await
596            .map_err(|e| {
597                GatewayError::internal(format!("Dynamic {} provider error: {}", provider_name, e))
598            })?;
599
600        convert_from_chat_completion_response(response)
601    }
602
603    /// Create dynamic Azure AI provider
604    async fn create_dynamic_azure_ai(
605        &self,
606        model: &str,
607        api_key: &str,
608        api_base: &str,
609        chat_request: &ChatRequest,
610        context: RequestContext,
611    ) -> Result<CompletionResponse> {
612        use crate::core::providers::azure_ai::{AzureAIConfig, AzureAIProvider};
613        use crate::core::traits::LLMProvider;
614
615        let mut config = AzureAIConfig::new("azure_ai");
616        config.base.api_key = Some(api_key.to_string());
617        config.base.api_base = Some(api_base.to_string());
618        
619        // Also check environment variables
620        if config.base.api_key.is_none() {
621            if let Ok(key) = std::env::var("AZURE_AI_API_KEY") {
622                config.base.api_key = Some(key);
623            }
624        }
625        if config.base.api_base.is_none() {
626            if let Ok(base) = std::env::var("AZURE_AI_API_BASE") {
627                config.base.api_base = Some(base);
628            }
629        }
630
631        let provider = AzureAIProvider::new(config).map_err(|e| {
632            GatewayError::internal(format!(
633                "Failed to create dynamic Azure AI provider: {}",
634                e
635            ))
636        })?;
637
638        let mut updated_request = chat_request.clone();
639        updated_request.model = model.to_string();
640
641        let response = provider
642            .chat_completion(updated_request, context)
643            .await
644            .map_err(|e| {
645                GatewayError::internal(format!("Dynamic Azure AI provider error: {}", e))
646            })?;
647
648        convert_from_chat_completion_response(response)
649    }
650}
651
652#[async_trait]
653impl Router for DefaultRouter {
654    async fn complete(
655        &self,
656        model: &str,
657        messages: Vec<Message>,
658        options: CompletionOptions,
659    ) -> Result<CompletionResponse> {
660        // Convert to internal types
661        let chat_messages = convert_messages_to_chat_messages(messages);
662        let chat_request =
663            convert_to_chat_completion_request(model, chat_messages, options.clone())?;
664
665        // Create request context with override parameters from options
666        let mut context = RequestContext::new();
667
668        // Check for dynamic provider configuration overrides
669        if let Some(api_base) = &options.api_base {
670            context.metadata.insert(
671                "api_base_override".to_string(),
672                serde_json::Value::String(api_base.clone()),
673            );
674        }
675
676        if let Some(api_key) = &options.api_key {
677            context.metadata.insert(
678                "api_key_override".to_string(),
679                serde_json::Value::String(api_key.clone()),
680            );
681        }
682
683        if let Some(organization) = &options.organization {
684            context.metadata.insert(
685                "organization_override".to_string(),
686                serde_json::Value::String(organization.clone()),
687            );
688        }
689
690        if let Some(api_version) = &options.api_version {
691            context.metadata.insert(
692                "api_version_override".to_string(),
693                serde_json::Value::String(api_version.clone()),
694            );
695        }
696
697        if let Some(headers) = &options.headers {
698            context.metadata.insert(
699                "headers_override".to_string(),
700                serde_json::to_value(headers).unwrap_or_default(),
701            );
702        }
703
704        if let Some(timeout) = options.timeout {
705            context.metadata.insert(
706                "timeout_override".to_string(),
707                serde_json::Value::Number(serde_json::Number::from(timeout)),
708            );
709        }
710
711        // Check if user provided custom api_base (Python LiteLLM compatibility)
712        if let Some(api_base) = &options.api_base {
713            // When api_base is provided, create a temporary OpenAI-compatible provider
714            // This matches Python LiteLLM behavior for custom endpoints
715            use crate::core::providers::base::BaseConfig;
716            use crate::core::providers::openai::OpenAIProvider;
717            use crate::core::providers::openai::config::OpenAIConfig;
718            use crate::core::traits::LLMProvider;
719
720            let api_key = options
721                .api_key
722                .clone()
723                .or_else(|| std::env::var("OPENAI_API_KEY").ok())
724                .unwrap_or_else(|| "dummy-key-for-local".to_string());
725
726            let config = OpenAIConfig {
727                base: BaseConfig {
728                    api_key: Some(api_key),
729                    api_base: Some(api_base.clone()),
730                    timeout: options.timeout.unwrap_or(60),
731                    max_retries: 3,
732                    headers: options.headers.clone().unwrap_or_default(),
733                    organization: options.organization.clone(),
734                    api_version: None,
735                },
736                organization: options.organization.clone(),
737                project: None,
738                model_mappings: Default::default(),
739                features: Default::default(),
740            };
741
742            // Create temporary provider with custom base URL
743            match OpenAIProvider::new(config).await {
744                Ok(temp_provider) => {
745                    // Use the temporary provider directly
746                    let response = temp_provider
747                        .chat_completion(chat_request, context)
748                        .await
749                        .map_err(|e| GatewayError::internal(format!("Provider error: {}", e)))?;
750                    return convert_from_chat_completion_response(response);
751                }
752                Err(e) => {
753                    return Err(GatewayError::internal(format!(
754                        "Failed to create provider with custom api_base: {}",
755                        e
756                    )));
757                }
758            }
759        }
760
761        // Dynamic provider creation (Python LiteLLM style)
762        // Try dynamic creation first, fallback to static registry
763        if let Some(response) = self
764            .try_dynamic_provider_creation(&chat_request, context.clone(), &options)
765            .await?
766        {
767            return Ok(response);
768        }
769
770        // Fallback to static provider registry
771        let providers = self.provider_registry.all();
772
773        // Check if model explicitly specifies a provider - using helper function
774        let mut selected_provider = Self::select_provider_by_name(&providers, "openrouter", model, "openrouter/", &chat_request)
775            .or_else(|| Self::select_provider_by_name(&providers, "deepseek", model, "deepseek/", &chat_request))
776            .or_else(|| Self::select_provider_by_name(&providers, "anthropic", model, "anthropic/", &chat_request))
777            .or_else(|| Self::select_provider_by_name(&providers, "azure_ai", model, "azure_ai/", &chat_request));
778        
779        // Handle special cases that don't follow the standard pattern
780        if selected_provider.is_none() {
781            if model.starts_with("openai/") || model.starts_with("azure/") {
782                for provider in providers.iter() {
783                    if provider.provider_type() == ProviderType::OpenAI
784                        && provider.supports_model(model)
785                    {
786                        selected_provider = Some((provider, chat_request.clone()));
787                        break;
788                    }
789                }
790            } else {
791                // No explicit provider, try to find one that supports the model
792                for provider in providers.iter() {
793                    if provider.supports_model(model) {
794                        selected_provider = Some((provider, chat_request.clone()));
795                        break;
796                    }
797                }
798            }
799        }
800
801        // Use static provider if found
802        if let Some((provider, request)) = selected_provider {
803            let response = provider.chat_completion(request, context).await?;
804            return convert_from_chat_completion_response(response);
805        }
806
807        Err(GatewayError::internal(
808            "No suitable provider found for model",
809        ))
810    }
811
812    async fn complete_stream(
813        &self,
814        _model: &str,
815        _messages: Vec<Message>,
816        _options: CompletionOptions,
817    ) -> Result<CompletionStream> {
818        // TODO: Implement streaming
819        todo!("Streaming not yet implemented")
820    }
821}
822
823/// Fallback router for when initialization fails
824pub struct ErrorRouter {
825    error: String,
826}
827
828#[async_trait]
829impl Router for ErrorRouter {
830    async fn complete(
831        &self,
832        _model: &str,
833        _messages: Vec<Message>,
834        _options: CompletionOptions,
835    ) -> Result<CompletionResponse> {
836        Err(GatewayError::internal(format!(
837            "Router initialization failed: {}",
838            self.error
839        )))
840    }
841
842    async fn complete_stream(
843        &self,
844        _model: &str,
845        _messages: Vec<Message>,
846        _options: CompletionOptions,
847    ) -> Result<CompletionStream> {
848        Err(GatewayError::internal(format!(
849            "Router initialization failed: {}",
850            self.error
851        )))
852    }
853}
854
855/// Global router instance
856static GLOBAL_ROUTER: OnceCell<Box<dyn Router>> = OnceCell::const_new();
857
858/// Get or initialize the global router
859async fn get_global_router() -> &'static Box<dyn Router> {
860    GLOBAL_ROUTER
861        .get_or_init(|| async {
862            match DefaultRouter::new().await {
863                Ok(router) => Box::new(router) as Box<dyn Router>,
864                Err(e) => Box::new(ErrorRouter {
865                    error: e.to_string(),
866                }) as Box<dyn Router>,
867            }
868        })
869        .await
870}
871
872/// Helper function to create user message
873pub fn user_message(content: impl Into<String>) -> Message {
874    use crate::core::types::{MessageContent, MessageRole};
875    ChatMessage {
876        role: MessageRole::User,
877        content: Some(MessageContent::Text(content.into())),
878        name: None,
879        tool_calls: None,
880        tool_call_id: None,
881        function_call: None,
882    }
883}
884
885/// Helper function to create system message
886pub fn system_message(content: impl Into<String>) -> Message {
887    use crate::core::types::{MessageContent, MessageRole};
888    ChatMessage {
889        role: MessageRole::System,
890        content: Some(MessageContent::Text(content.into())),
891        name: None,
892        tool_calls: None,
893        tool_call_id: None,
894        function_call: None,
895    }
896}
897
898/// Helper function to create assistant message
899pub fn assistant_message(content: impl Into<String>) -> Message {
900    use crate::core::types::{MessageContent, MessageRole};
901    ChatMessage {
902        role: MessageRole::Assistant,
903        content: Some(MessageContent::Text(content.into())),
904        name: None,
905        tool_calls: None,
906        tool_call_id: None,
907        function_call: None,
908    }
909}
910
911// Internal conversion functions
912
913fn convert_messages_to_chat_messages(messages: Vec<Message>) -> Vec<ChatMessage> {
914    // Since Message is now an alias for ChatMessage, this is just a no-op
915    messages
916}
917
918fn convert_to_chat_completion_request(
919    model: &str,
920    messages: Vec<ChatMessage>,
921    options: CompletionOptions,
922) -> Result<ChatRequest> {
923    Ok(ChatRequest {
924        model: model.to_string(),
925        messages,
926        temperature: options.temperature,
927        max_tokens: options.max_tokens,
928        max_completion_tokens: None,
929        top_p: options.top_p,
930        frequency_penalty: options.frequency_penalty,
931        presence_penalty: options.presence_penalty,
932        stop: options.stop,
933        stream: options.stream,
934        tools: None,       // Will implement tool conversion later
935        tool_choice: None, // Will implement tool choice conversion later
936        parallel_tool_calls: None,
937        response_format: None,
938        user: options.user,
939        seed: options.seed,
940        n: options.n,
941        logit_bias: None,
942        functions: None,
943        function_call: None,
944        logprobs: options.logprobs,
945        top_logprobs: options.top_logprobs,
946        extra_params: options.extra_params,
947    })
948}
949
950fn convert_from_chat_completion_response(response: ChatResponse) -> Result<CompletionResponse> {
951    let choices = response
952        .choices
953        .into_iter()
954        .map(|choice| Choice {
955            index: choice.index,
956            message: choice.message,             // Same type already
957            finish_reason: choice.finish_reason, // Same type already
958        })
959        .collect();
960
961    Ok(CompletionResponse {
962        id: response.id,
963        object: response.object,
964        created: response.created,
965        model: response.model,
966        choices,
967        usage: response.usage, // Same type already
968    })
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974
975    #[test]
976    fn test_message_creation() {
977        let msg = user_message("Hello, world!");
978        assert_eq!(msg.role, MessageRole::User);
979        if let Some(MessageContent::Text(content)) = msg.content {
980            assert_eq!(content, "Hello, world!");
981        } else {
982            panic!("Expected text content");
983        }
984    }
985
986    #[test]
987    fn test_completion_options_default() {
988        let options = CompletionOptions::default();
989        assert!(!options.stream);
990        assert_eq!(options.extra_params.len(), 0);
991    }
992}