1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::OnceCell;
12use crate::core::types::{
14 ChatMessage, ChatRequest, ChatResponse, RequestContext, Tool, ToolChoice,
15};
16
17use crate::core::providers::{Provider, ProviderRegistry, ProviderType};
19use crate::utils::error::{GatewayError, Result};
20use tracing::debug;
21
22pub async fn completion(
25 model: &str,
26 messages: Vec<Message>,
27 options: Option<CompletionOptions>,
28) -> Result<CompletionResponse> {
29 let router = get_global_router().await;
30 router
31 .complete(model, messages, options.unwrap_or_default())
32 .await
33}
34
35pub async fn acompletion(
37 model: &str,
38 messages: Vec<Message>,
39 options: Option<CompletionOptions>,
40) -> Result<CompletionResponse> {
41 completion(model, messages, options).await
42}
43
44pub async fn completion_stream(
46 _model: &str,
47 _messages: Vec<Message>,
48 _options: Option<CompletionOptions>,
49) -> Result<CompletionStream> {
50 todo!("Streaming completion not yet implemented")
52}
53
54pub type Message = ChatMessage;
56
57pub use crate::core::types::{MessageContent, MessageRole};
59
60pub use crate::core::types::ContentPart;
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ToolCall {
66 pub id: String,
67 pub r#type: String,
68 pub function: FunctionCall,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct FunctionCall {
74 pub name: String,
75 pub arguments: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, Default)]
80pub struct CompletionOptions {
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub temperature: Option<f32>,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 pub max_tokens: Option<u32>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub top_p: Option<f32>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub frequency_penalty: Option<f32>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub presence_penalty: Option<f32>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub stop: Option<Vec<String>>,
93 #[serde(default)]
94 pub stream: bool,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub tools: Option<Vec<Tool>>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub tool_choice: Option<ToolChoice>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub user: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 pub seed: Option<i32>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub n: Option<u32>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub logprobs: Option<bool>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub top_logprobs: Option<u32>,
109
110 #[serde(skip_serializing_if = "Option::is_none")]
113 pub api_base: Option<String>,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub api_key: Option<String>,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub organization: Option<String>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub api_version: Option<String>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub headers: Option<HashMap<String, String>>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub timeout: Option<u64>,
134
135 #[serde(flatten)]
136 pub extra_params: HashMap<String, serde_json::Value>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct CompletionResponse {
142 pub id: String,
143 pub object: String,
144 pub created: i64,
145 pub model: String,
146 pub choices: Vec<Choice>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub usage: Option<Usage>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Choice {
154 pub index: u32,
155 pub message: Message,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub finish_reason: Option<FinishReason>,
158}
159
160pub type Usage = crate::core::types::responses::Usage;
162
163pub type FinishReason = crate::core::types::responses::FinishReason;
165
166pub type CompletionStream =
168 Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin + 'static>;
169
170pub type LiteLLMError = GatewayError;
172
173#[async_trait]
175pub trait Router: Send + Sync {
176 async fn complete(
177 &self,
178 model: &str,
179 messages: Vec<Message>,
180 options: CompletionOptions,
181 ) -> Result<CompletionResponse>;
182
183 async fn complete_stream(
184 &self,
185 model: &str,
186 messages: Vec<Message>,
187 options: CompletionOptions,
188 ) -> Result<CompletionStream>;
189}
190
191pub struct DefaultRouter {
193 provider_registry: Arc<ProviderRegistry>,
194}
195
196impl DefaultRouter {
197 fn select_provider_by_name<'a>(
199 providers: &'a [&'a crate::core::providers::Provider],
200 provider_name: &str,
201 original_model: &str,
202 prefix: &str,
203 chat_request: &ChatRequest,
204 ) -> Option<(&'a crate::core::providers::Provider, ChatRequest)> {
205 if !original_model.starts_with(prefix) {
206 return None;
207 }
208
209 let actual_model = original_model.strip_prefix(prefix).unwrap_or(original_model);
210
211 debug!(
212 provider = provider_name,
213 model = %actual_model,
214 "Using static {} provider", provider_name
215 );
216
217 for provider in providers.iter() {
218 if provider.name() == provider_name {
219 let mut updated_request = chat_request.clone();
220 updated_request.model = actual_model.to_string();
221 return Some((provider, updated_request));
222 }
223 }
224
225 None
226 }
227
228 pub async fn new() -> Result<Self> {
229 let mut provider_registry = ProviderRegistry::new();
230
231 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
233 use crate::core::providers::base::BaseConfig;
234 use crate::core::providers::openai::OpenAIProvider;
235 use crate::core::providers::openai::config::OpenAIConfig;
236
237 let config = OpenAIConfig {
239 base: BaseConfig {
240 api_key: Some(api_key),
241 api_base: Some("https://api.openai.com/v1".to_string()),
242 timeout: 60,
243 max_retries: 3,
244 headers: Default::default(),
245 organization: std::env::var("OPENAI_ORGANIZATION").ok(),
246 api_version: None,
247 },
248 organization: std::env::var("OPENAI_ORGANIZATION").ok(),
249 project: None,
250 model_mappings: Default::default(),
251 features: Default::default(),
252 };
253
254 if let Ok(openai_provider) = OpenAIProvider::new(config).await {
256 provider_registry.register(Provider::OpenAI(openai_provider));
257 }
258 }
259
260 if let Ok(api_key) = std::env::var("OPENROUTER_API_KEY") {
262 use crate::core::providers::openrouter::{OpenRouterConfig, OpenRouterProvider};
263
264 let api_key = api_key.trim().to_string();
266
267 let config = OpenRouterConfig {
269 api_key,
270 base_url: "https://openrouter.ai/api/v1".to_string(),
271 site_url: std::env::var("OPENROUTER_HTTP_REFERER").ok(),
272 site_name: std::env::var("OPENROUTER_X_TITLE").ok(),
273 timeout_seconds: 60,
274 max_retries: 3,
275 extra_params: Default::default(),
276 };
277
278 if let Ok(openrouter_provider) = OpenRouterProvider::new(config) {
280 provider_registry.register(Provider::OpenRouter(openrouter_provider));
281 }
282 }
283
284 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
286 use crate::core::providers::anthropic::{AnthropicProvider, AnthropicConfig};
287
288 let config = AnthropicConfig::new(api_key)
289 .with_base_url("https://api.anthropic.com")
290 .with_experimental(false);
291
292 let anthropic_provider = AnthropicProvider::new(config)?;
293 provider_registry.register(Provider::Anthropic(anthropic_provider));
294 }
295
296 if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() {
301 use crate::core::providers::vertex_ai::{
302 VertexAIProvider, VertexAIProviderConfig, VertexCredentials,
303 };
304
305 let config = VertexAIProviderConfig {
306 project_id: std::env::var("GOOGLE_PROJECT_ID")
307 .unwrap_or_else(|_| "default-project".to_string()),
308 location: std::env::var("GOOGLE_LOCATION")
309 .unwrap_or_else(|_| "us-central1".to_string()),
310 api_version: "v1".to_string(),
311 credentials: VertexCredentials::ApplicationDefault,
312 api_base: None,
313 timeout_seconds: 60,
314 max_retries: 3,
315 enable_experimental: false,
316 };
317
318 if let Ok(vertex_provider) = VertexAIProvider::new(config).await {
319 provider_registry.register(Provider::VertexAI(vertex_provider));
320 }
321 }
322
323 if let Ok(_api_key) = std::env::var("DEEPSEEK_API_KEY") {
325 use crate::core::providers::deepseek::{DeepSeekConfig, DeepSeekProvider};
326
327 let config = DeepSeekConfig::from_env();
328
329 if let Ok(deepseek_provider) = DeepSeekProvider::new(config) {
330 provider_registry.register(Provider::DeepSeek(deepseek_provider));
331 }
332 }
333
334 Ok(Self {
335 provider_registry: Arc::new(provider_registry),
336 })
337 }
338
339 async fn try_dynamic_provider_creation(
342 &self,
343 chat_request: &ChatRequest,
344 context: RequestContext,
345 options: &CompletionOptions,
346 ) -> Result<Option<CompletionResponse>> {
347 let model = &chat_request.model;
348
349 let api_key = match &options.api_key {
351 Some(key) => key.clone(),
352 None => return Ok(None), };
354
355 let (provider_type, actual_model, api_base) = if model.starts_with("openrouter/") {
357 let actual_model = model.strip_prefix("openrouter/").unwrap_or(model);
358 let api_base = options
359 .api_base
360 .clone()
361 .unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
362 ("openrouter", actual_model, api_base)
363 } else if model.starts_with("anthropic/") {
364 let actual_model = model.strip_prefix("anthropic/").unwrap_or(model);
365 let api_base = options
366 .api_base
367 .clone()
368 .unwrap_or_else(|| "https://api.anthropic.com".to_string());
369 ("anthropic", actual_model, api_base)
370 } else if model.starts_with("deepseek/") {
371 let actual_model = model.strip_prefix("deepseek/").unwrap_or(model);
372 let api_base = options
373 .api_base
374 .clone()
375 .unwrap_or_else(|| "https://api.deepseek.com".to_string());
376 ("deepseek", actual_model, api_base)
377 } else if model.starts_with("azure_ai/") || model.starts_with("azure-ai/") {
378 let actual_model = model.strip_prefix("azure_ai/")
379 .or_else(|| model.strip_prefix("azure-ai/"))
380 .unwrap_or(model);
381 let api_base = options
382 .api_base
383 .clone()
384 .or_else(|| std::env::var("AZURE_AI_API_BASE").ok())
385 .unwrap_or_else(|| "https://api.azure.com".to_string());
386 ("azure_ai", actual_model, api_base)
387 } else if model.starts_with("openai/") {
388 let actual_model = model.strip_prefix("openai/").unwrap_or(model);
389 let api_base = options
390 .api_base
391 .clone()
392 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
393 ("openai", actual_model, api_base)
394 } else {
395 if let Some(api_base) = &options.api_base {
397 ("openai-compatible", model.as_str(), api_base.clone())
398 } else {
399 return Ok(None); }
401 };
402
403 debug!(
404 provider_type = %provider_type,
405 model = %actual_model,
406 "Creating dynamic provider for model"
407 );
408
409 let response = match provider_type {
411 "openrouter" => {
412 self.create_dynamic_openrouter(
413 actual_model,
414 &api_key,
415 &api_base,
416 chat_request,
417 context,
418 )
419 .await?
420 }
421 "anthropic" => {
422 self.create_dynamic_anthropic(
423 actual_model,
424 &api_key,
425 &api_base,
426 chat_request,
427 context,
428 )
429 .await?
430 }
431 "deepseek" => {
432 self.create_dynamic_openai_compatible(
433 actual_model,
434 &api_key,
435 &api_base,
436 chat_request,
437 context,
438 "DeepSeek",
439 )
440 .await?
441 }
442 "azure_ai" => {
443 self.create_dynamic_azure_ai(
444 actual_model,
445 &api_key,
446 &api_base,
447 chat_request,
448 context,
449 )
450 .await?
451 }
452 "openai" => {
453 self.create_dynamic_openai_compatible(
454 actual_model,
455 &api_key,
456 &api_base,
457 chat_request,
458 context,
459 "OpenAI",
460 )
461 .await?
462 }
463 "openai-compatible" => {
464 self.create_dynamic_openai_compatible(
465 actual_model,
466 &api_key,
467 &api_base,
468 chat_request,
469 context,
470 "OpenAI-Compatible",
471 )
472 .await?
473 }
474 _ => return Ok(None),
475 };
476
477 Ok(Some(response))
478 }
479
480 async fn create_dynamic_openrouter(
482 &self,
483 model: &str,
484 api_key: &str,
485 api_base: &str,
486 chat_request: &ChatRequest,
487 context: RequestContext,
488 ) -> Result<CompletionResponse> {
489 use crate::core::providers::openrouter::{OpenRouterConfig, OpenRouterProvider};
490 use crate::core::traits::LLMProvider;
491
492 let config = OpenRouterConfig {
493 api_key: api_key.to_string(),
494 base_url: api_base.to_string(),
495 site_url: None, site_name: None,
497 timeout_seconds: 60,
498 max_retries: 3,
499 extra_params: Default::default(),
500 };
501
502 let provider = OpenRouterProvider::new(config).map_err(|e| {
503 GatewayError::internal(format!(
504 "Failed to create dynamic OpenRouter provider: {}",
505 e
506 ))
507 })?;
508
509 let mut updated_request = chat_request.clone();
510 updated_request.model = model.to_string();
511
512 let response = provider
513 .chat_completion(updated_request, context)
514 .await
515 .map_err(|e| {
516 GatewayError::internal(format!("Dynamic OpenRouter provider error: {}", e))
517 })?;
518
519 convert_from_chat_completion_response(response)
520 }
521
522 async fn create_dynamic_anthropic(
524 &self,
525 model: &str,
526 api_key: &str,
527 api_base: &str,
528 chat_request: &ChatRequest,
529 context: RequestContext,
530 ) -> Result<CompletionResponse> {
531 use crate::core::providers::anthropic::{AnthropicProvider, AnthropicConfig};
532 use crate::core::traits::LLMProvider;
533
534 let config = AnthropicConfig::new(api_key)
535 .with_base_url(api_base)
536 .with_experimental(false);
537
538 let provider = AnthropicProvider::new(config)?;
539
540 let mut updated_request = chat_request.clone();
541 updated_request.model = model.to_string();
542
543 let response = LLMProvider::chat_completion(&provider, updated_request, context)
544 .await
545 .map_err(|e| {
546 GatewayError::internal(format!("Dynamic Anthropic provider error: {}", e))
547 })?;
548
549 convert_from_chat_completion_response(response)
550 }
551
552 async fn create_dynamic_openai_compatible(
554 &self,
555 model: &str,
556 api_key: &str,
557 api_base: &str,
558 chat_request: &ChatRequest,
559 context: RequestContext,
560 provider_name: &str,
561 ) -> Result<CompletionResponse> {
562 use crate::core::providers::base::BaseConfig;
563 use crate::core::providers::openai::OpenAIProvider;
564 use crate::core::providers::openai::config::OpenAIConfig;
565 use crate::core::traits::LLMProvider;
566
567 let config = OpenAIConfig {
568 base: BaseConfig {
569 api_key: Some(api_key.to_string()),
570 api_base: Some(api_base.to_string()),
571 timeout: 60,
572 max_retries: 3,
573 headers: Default::default(),
574 organization: None,
575 api_version: None,
576 },
577 organization: None,
578 project: None,
579 model_mappings: Default::default(),
580 features: Default::default(),
581 };
582
583 let provider = OpenAIProvider::new(config).await.map_err(|e| {
584 GatewayError::internal(format!(
585 "Failed to create dynamic {} provider: {}",
586 provider_name, e
587 ))
588 })?;
589
590 let mut updated_request = chat_request.clone();
591 updated_request.model = model.to_string();
592
593 let response = provider
594 .chat_completion(updated_request, context)
595 .await
596 .map_err(|e| {
597 GatewayError::internal(format!("Dynamic {} provider error: {}", provider_name, e))
598 })?;
599
600 convert_from_chat_completion_response(response)
601 }
602
603 async fn create_dynamic_azure_ai(
605 &self,
606 model: &str,
607 api_key: &str,
608 api_base: &str,
609 chat_request: &ChatRequest,
610 context: RequestContext,
611 ) -> Result<CompletionResponse> {
612 use crate::core::providers::azure_ai::{AzureAIConfig, AzureAIProvider};
613 use crate::core::traits::LLMProvider;
614
615 let mut config = AzureAIConfig::new("azure_ai");
616 config.base.api_key = Some(api_key.to_string());
617 config.base.api_base = Some(api_base.to_string());
618
619 if config.base.api_key.is_none() {
621 if let Ok(key) = std::env::var("AZURE_AI_API_KEY") {
622 config.base.api_key = Some(key);
623 }
624 }
625 if config.base.api_base.is_none() {
626 if let Ok(base) = std::env::var("AZURE_AI_API_BASE") {
627 config.base.api_base = Some(base);
628 }
629 }
630
631 let provider = AzureAIProvider::new(config).map_err(|e| {
632 GatewayError::internal(format!(
633 "Failed to create dynamic Azure AI provider: {}",
634 e
635 ))
636 })?;
637
638 let mut updated_request = chat_request.clone();
639 updated_request.model = model.to_string();
640
641 let response = provider
642 .chat_completion(updated_request, context)
643 .await
644 .map_err(|e| {
645 GatewayError::internal(format!("Dynamic Azure AI provider error: {}", e))
646 })?;
647
648 convert_from_chat_completion_response(response)
649 }
650}
651
652#[async_trait]
653impl Router for DefaultRouter {
654 async fn complete(
655 &self,
656 model: &str,
657 messages: Vec<Message>,
658 options: CompletionOptions,
659 ) -> Result<CompletionResponse> {
660 let chat_messages = convert_messages_to_chat_messages(messages);
662 let chat_request =
663 convert_to_chat_completion_request(model, chat_messages, options.clone())?;
664
665 let mut context = RequestContext::new();
667
668 if let Some(api_base) = &options.api_base {
670 context.metadata.insert(
671 "api_base_override".to_string(),
672 serde_json::Value::String(api_base.clone()),
673 );
674 }
675
676 if let Some(api_key) = &options.api_key {
677 context.metadata.insert(
678 "api_key_override".to_string(),
679 serde_json::Value::String(api_key.clone()),
680 );
681 }
682
683 if let Some(organization) = &options.organization {
684 context.metadata.insert(
685 "organization_override".to_string(),
686 serde_json::Value::String(organization.clone()),
687 );
688 }
689
690 if let Some(api_version) = &options.api_version {
691 context.metadata.insert(
692 "api_version_override".to_string(),
693 serde_json::Value::String(api_version.clone()),
694 );
695 }
696
697 if let Some(headers) = &options.headers {
698 context.metadata.insert(
699 "headers_override".to_string(),
700 serde_json::to_value(headers).unwrap_or_default(),
701 );
702 }
703
704 if let Some(timeout) = options.timeout {
705 context.metadata.insert(
706 "timeout_override".to_string(),
707 serde_json::Value::Number(serde_json::Number::from(timeout)),
708 );
709 }
710
711 if let Some(api_base) = &options.api_base {
713 use crate::core::providers::base::BaseConfig;
716 use crate::core::providers::openai::OpenAIProvider;
717 use crate::core::providers::openai::config::OpenAIConfig;
718 use crate::core::traits::LLMProvider;
719
720 let api_key = options
721 .api_key
722 .clone()
723 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
724 .unwrap_or_else(|| "dummy-key-for-local".to_string());
725
726 let config = OpenAIConfig {
727 base: BaseConfig {
728 api_key: Some(api_key),
729 api_base: Some(api_base.clone()),
730 timeout: options.timeout.unwrap_or(60),
731 max_retries: 3,
732 headers: options.headers.clone().unwrap_or_default(),
733 organization: options.organization.clone(),
734 api_version: None,
735 },
736 organization: options.organization.clone(),
737 project: None,
738 model_mappings: Default::default(),
739 features: Default::default(),
740 };
741
742 match OpenAIProvider::new(config).await {
744 Ok(temp_provider) => {
745 let response = temp_provider
747 .chat_completion(chat_request, context)
748 .await
749 .map_err(|e| GatewayError::internal(format!("Provider error: {}", e)))?;
750 return convert_from_chat_completion_response(response);
751 }
752 Err(e) => {
753 return Err(GatewayError::internal(format!(
754 "Failed to create provider with custom api_base: {}",
755 e
756 )));
757 }
758 }
759 }
760
761 if let Some(response) = self
764 .try_dynamic_provider_creation(&chat_request, context.clone(), &options)
765 .await?
766 {
767 return Ok(response);
768 }
769
770 let providers = self.provider_registry.all();
772
773 let mut selected_provider = Self::select_provider_by_name(&providers, "openrouter", model, "openrouter/", &chat_request)
775 .or_else(|| Self::select_provider_by_name(&providers, "deepseek", model, "deepseek/", &chat_request))
776 .or_else(|| Self::select_provider_by_name(&providers, "anthropic", model, "anthropic/", &chat_request))
777 .or_else(|| Self::select_provider_by_name(&providers, "azure_ai", model, "azure_ai/", &chat_request));
778
779 if selected_provider.is_none() {
781 if model.starts_with("openai/") || model.starts_with("azure/") {
782 for provider in providers.iter() {
783 if provider.provider_type() == ProviderType::OpenAI
784 && provider.supports_model(model)
785 {
786 selected_provider = Some((provider, chat_request.clone()));
787 break;
788 }
789 }
790 } else {
791 for provider in providers.iter() {
793 if provider.supports_model(model) {
794 selected_provider = Some((provider, chat_request.clone()));
795 break;
796 }
797 }
798 }
799 }
800
801 if let Some((provider, request)) = selected_provider {
803 let response = provider.chat_completion(request, context).await?;
804 return convert_from_chat_completion_response(response);
805 }
806
807 Err(GatewayError::internal(
808 "No suitable provider found for model",
809 ))
810 }
811
812 async fn complete_stream(
813 &self,
814 _model: &str,
815 _messages: Vec<Message>,
816 _options: CompletionOptions,
817 ) -> Result<CompletionStream> {
818 todo!("Streaming not yet implemented")
820 }
821}
822
823pub struct ErrorRouter {
825 error: String,
826}
827
828#[async_trait]
829impl Router for ErrorRouter {
830 async fn complete(
831 &self,
832 _model: &str,
833 _messages: Vec<Message>,
834 _options: CompletionOptions,
835 ) -> Result<CompletionResponse> {
836 Err(GatewayError::internal(format!(
837 "Router initialization failed: {}",
838 self.error
839 )))
840 }
841
842 async fn complete_stream(
843 &self,
844 _model: &str,
845 _messages: Vec<Message>,
846 _options: CompletionOptions,
847 ) -> Result<CompletionStream> {
848 Err(GatewayError::internal(format!(
849 "Router initialization failed: {}",
850 self.error
851 )))
852 }
853}
854
855static GLOBAL_ROUTER: OnceCell<Box<dyn Router>> = OnceCell::const_new();
857
858async fn get_global_router() -> &'static Box<dyn Router> {
860 GLOBAL_ROUTER
861 .get_or_init(|| async {
862 match DefaultRouter::new().await {
863 Ok(router) => Box::new(router) as Box<dyn Router>,
864 Err(e) => Box::new(ErrorRouter {
865 error: e.to_string(),
866 }) as Box<dyn Router>,
867 }
868 })
869 .await
870}
871
872pub fn user_message(content: impl Into<String>) -> Message {
874 use crate::core::types::{MessageContent, MessageRole};
875 ChatMessage {
876 role: MessageRole::User,
877 content: Some(MessageContent::Text(content.into())),
878 name: None,
879 tool_calls: None,
880 tool_call_id: None,
881 function_call: None,
882 }
883}
884
885pub fn system_message(content: impl Into<String>) -> Message {
887 use crate::core::types::{MessageContent, MessageRole};
888 ChatMessage {
889 role: MessageRole::System,
890 content: Some(MessageContent::Text(content.into())),
891 name: None,
892 tool_calls: None,
893 tool_call_id: None,
894 function_call: None,
895 }
896}
897
898pub fn assistant_message(content: impl Into<String>) -> Message {
900 use crate::core::types::{MessageContent, MessageRole};
901 ChatMessage {
902 role: MessageRole::Assistant,
903 content: Some(MessageContent::Text(content.into())),
904 name: None,
905 tool_calls: None,
906 tool_call_id: None,
907 function_call: None,
908 }
909}
910
911fn convert_messages_to_chat_messages(messages: Vec<Message>) -> Vec<ChatMessage> {
914 messages
916}
917
918fn convert_to_chat_completion_request(
919 model: &str,
920 messages: Vec<ChatMessage>,
921 options: CompletionOptions,
922) -> Result<ChatRequest> {
923 Ok(ChatRequest {
924 model: model.to_string(),
925 messages,
926 temperature: options.temperature,
927 max_tokens: options.max_tokens,
928 max_completion_tokens: None,
929 top_p: options.top_p,
930 frequency_penalty: options.frequency_penalty,
931 presence_penalty: options.presence_penalty,
932 stop: options.stop,
933 stream: options.stream,
934 tools: None, tool_choice: None, parallel_tool_calls: None,
937 response_format: None,
938 user: options.user,
939 seed: options.seed,
940 n: options.n,
941 logit_bias: None,
942 functions: None,
943 function_call: None,
944 logprobs: options.logprobs,
945 top_logprobs: options.top_logprobs,
946 extra_params: options.extra_params,
947 })
948}
949
950fn convert_from_chat_completion_response(response: ChatResponse) -> Result<CompletionResponse> {
951 let choices = response
952 .choices
953 .into_iter()
954 .map(|choice| Choice {
955 index: choice.index,
956 message: choice.message, finish_reason: choice.finish_reason, })
959 .collect();
960
961 Ok(CompletionResponse {
962 id: response.id,
963 object: response.object,
964 created: response.created,
965 model: response.model,
966 choices,
967 usage: response.usage, })
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974
975 #[test]
976 fn test_message_creation() {
977 let msg = user_message("Hello, world!");
978 assert_eq!(msg.role, MessageRole::User);
979 if let Some(MessageContent::Text(content)) = msg.content {
980 assert_eq!(content, "Hello, world!");
981 } else {
982 panic!("Expected text content");
983 }
984 }
985
986 #[test]
987 fn test_completion_options_default() {
988 let options = CompletionOptions::default();
989 assert!(!options.stream);
990 assert_eq!(options.extra_params.len(), 0);
991 }
992}