1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::config::{ConnectionOptions, ResilienceConfig};
3use crate::metrics::{Metrics, NoopMetrics};
4use crate::provider::{
5 classification::ProviderClassification, CohereAdapter, GeminiAdapter, GenericAdapter,
6 MistralAdapter, OpenAiAdapter, ProviderConfigs,
7};
8use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
9use futures::stream::Stream;
10use futures::Future;
11use std::sync::Arc;
12use tokio::sync::oneshot;
13use crate::rate_limiter::{BackpressureController, BackpressurePermit};
14
15#[derive(Debug, Clone)]
17pub struct ModelOptions {
18 pub chat_model: Option<String>,
19 pub multimodal_model: Option<String>,
20 pub fallback_models: Vec<String>,
21 pub auto_discovery: bool,
22}
23
24impl Default for ModelOptions {
25 fn default() -> Self {
26 Self {
27 chat_model: None,
28 multimodal_model: None,
29 fallback_models: Vec::new(),
30 auto_discovery: true,
31 }
32 }
33}
34
35impl ModelOptions {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn with_chat_model(mut self, model: &str) -> Self {
43 self.chat_model = Some(model.to_string());
44 self
45 }
46
47 pub fn with_multimodal_model(mut self, model: &str) -> Self {
49 self.multimodal_model = Some(model.to_string());
50 self
51 }
52
53 pub fn with_fallback_models(mut self, models: Vec<&str>) -> Self {
55 self.fallback_models = models.into_iter().map(|s| s.to_string()).collect();
56 self
57 }
58
59 pub fn with_auto_discovery(mut self, enabled: bool) -> Self {
61 self.auto_discovery = enabled;
62 self
63 }
64}
65
66fn create_generic_adapter(
68 config: crate::provider::config::ProviderConfig,
69 transport: Option<crate::transport::DynHttpTransportRef>,
70) -> Result<Box<dyn ChatApi>, AiLibError> {
71 if let Some(custom_transport) = transport {
72 Ok(Box::new(GenericAdapter::with_transport_ref(
73 config,
74 custom_transport,
75 )?))
76 } else {
77 Ok(Box::new(GenericAdapter::new(config)?))
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq)]
85pub enum Provider {
86 Groq,
88 XaiGrok,
89 Ollama,
90 DeepSeek,
91 Anthropic,
92 AzureOpenAI,
93 HuggingFace,
94 TogetherAI,
95 BaiduWenxin,
97 TencentHunyuan,
98 IflytekSpark,
99 Moonshot,
100 OpenAI,
102 Qwen,
103 Gemini,
104 Mistral,
105 Cohere,
106 }
108
109impl Provider {
110 pub fn default_chat_model(&self) -> &'static str {
113 match self {
114 Provider::Groq => "llama-3.1-8b-instant",
115 Provider::XaiGrok => "grok-beta",
116 Provider::Ollama => "llama3-8b",
117 Provider::DeepSeek => "deepseek-chat",
118 Provider::Anthropic => "claude-3-5-sonnet-20241022",
119 Provider::AzureOpenAI => "gpt-35-turbo",
120 Provider::HuggingFace => "microsoft/DialoGPT-medium",
121 Provider::TogetherAI => "meta-llama/Llama-3-8b-chat-hf",
122 Provider::BaiduWenxin => "ernie-3.5",
123 Provider::TencentHunyuan => "hunyuan-standard",
124 Provider::IflytekSpark => "spark-v3.0",
125 Provider::Moonshot => "moonshot-v1-8k",
126 Provider::OpenAI => "gpt-3.5-turbo",
127 Provider::Qwen => "qwen-turbo",
128 Provider::Gemini => "gemini-1.5-flash", Provider::Mistral => "mistral-small", Provider::Cohere => "command-r", }
132 }
133
134 pub fn default_multimodal_model(&self) -> Option<&'static str> {
136 match self {
137 Provider::OpenAI => Some("gpt-4o"),
138 Provider::AzureOpenAI => Some("gpt-4o"),
139 Provider::Anthropic => Some("claude-3-5-sonnet-20241022"),
140 Provider::Groq => None, Provider::Gemini => Some("gemini-1.5-flash"),
142 Provider::Cohere => Some("command-r-plus"),
143 _ => None,
145 }
146 }
147}
148
149pub struct AiClient {
193 provider: Provider,
194 adapter: Box<dyn ChatApi>,
195 metrics: Arc<dyn Metrics>,
196 connection_options: Option<ConnectionOptions>,
197 #[cfg(feature = "interceptors")]
198 interceptor_pipeline: Option<crate::interceptors::InterceptorPipeline>,
199 custom_default_chat_model: Option<String>,
201 custom_default_multimodal_model: Option<String>,
202 backpressure: Option<Arc<BackpressureController>>,
204 #[cfg(feature = "routing_mvp")]
205 routing_array: Option<crate::provider::models::ModelArray>,
206}
207
208impl AiClient {
209 pub fn default_chat_model(&self) -> String {
211 self.custom_default_chat_model
212 .clone()
213 .unwrap_or_else(|| self.provider.default_chat_model().to_string())
214 }
215 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
231 let mut c = AiClientBuilder::new(provider).build()?;
233 c.connection_options = None;
234 Ok(c)
235 }
236
237 #[cfg(feature = "routing_mvp")]
238 pub fn with_routing_array(mut self, array: crate::provider::models::ModelArray) -> Self {
240 self.routing_array = Some(array);
241 self
242 }
243
244 pub fn with_options(provider: Provider, opts: ConnectionOptions) -> Result<Self, AiLibError> {
257 let config_driven = provider.is_config_driven();
258 let need_builder = config_driven
259 && (opts.base_url.is_some()
260 || opts.proxy.is_some()
261 || opts.timeout.is_some()
262 || opts.disable_proxy);
263 if need_builder {
264 let mut b = AiClient::builder(provider);
265 if let Some(ref base) = opts.base_url {
266 b = b.with_base_url(base);
267 }
268 if opts.disable_proxy {
269 b = b.without_proxy();
270 } else if let Some(ref proxy) = opts.proxy {
271 if proxy.is_empty() {
272 b = b.without_proxy();
273 } else {
274 b = b.with_proxy(Some(proxy));
275 }
276 }
277 if let Some(t) = opts.timeout {
278 b = b.with_timeout(t);
279 }
280 let mut client = b.build()?;
281 if opts.api_key.is_some() {
283 let new_adapter: Option<Box<dyn ChatApi>> = match provider {
285 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
286 ProviderConfigs::groq(),
287 opts.api_key.clone(),
288 )?)),
289 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
290 ProviderConfigs::xai_grok(),
291 opts.api_key.clone(),
292 )?)),
293 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
294 ProviderConfigs::ollama(),
295 opts.api_key.clone(),
296 )?)),
297 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
298 ProviderConfigs::deepseek(),
299 opts.api_key.clone(),
300 )?)),
301 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
302 ProviderConfigs::qwen(),
303 opts.api_key.clone(),
304 )?)),
305 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
306 ProviderConfigs::baidu_wenxin(),
307 opts.api_key.clone(),
308 )?)),
309 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
310 ProviderConfigs::tencent_hunyuan(),
311 opts.api_key.clone(),
312 )?)),
313 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
314 ProviderConfigs::iflytek_spark(),
315 opts.api_key.clone(),
316 )?)),
317 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
318 ProviderConfigs::moonshot(),
319 opts.api_key.clone(),
320 )?)),
321 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
322 ProviderConfigs::anthropic(),
323 opts.api_key.clone(),
324 )?)),
325 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
326 ProviderConfigs::azure_openai(),
327 opts.api_key.clone(),
328 )?)),
329 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
330 ProviderConfigs::huggingface(),
331 opts.api_key.clone(),
332 )?)),
333 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
334 ProviderConfigs::together_ai(),
335 opts.api_key.clone(),
336 )?)),
337 _ => None,
338 };
339 if let Some(a) = new_adapter {
340 client.adapter = a;
341 }
342 }
343 client.connection_options = Some(opts);
344 return Ok(client);
345 }
346
347 if provider.is_independent() {
349 let adapter: Box<dyn ChatApi> = match provider {
350 Provider::OpenAI => {
351 if let Some(ref k) = opts.api_key {
352 let inner =
353 OpenAiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
354 Box::new(inner)
355 } else {
356 let inner = OpenAiAdapter::new()?;
357 Box::new(inner)
358 }
359 }
360 Provider::Gemini => {
361 if let Some(ref k) = opts.api_key {
362 let inner =
363 GeminiAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
364 Box::new(inner)
365 } else {
366 let inner = GeminiAdapter::new()?;
367 Box::new(inner)
368 }
369 }
370 Provider::Mistral => {
371 if opts.api_key.is_some() || opts.base_url.is_some() {
372 let inner = MistralAdapter::new_with_overrides(
373 opts.api_key.clone(),
374 opts.base_url.clone(),
375 )?;
376 Box::new(inner)
377 } else {
378 let inner = MistralAdapter::new()?;
379 Box::new(inner)
380 }
381 }
382 Provider::Cohere => {
383 if let Some(ref k) = opts.api_key {
384 let inner =
385 CohereAdapter::new_with_overrides(k.clone(), opts.base_url.clone())?;
386 Box::new(inner)
387 } else {
388 let inner = CohereAdapter::new()?;
389 Box::new(inner)
390 }
391 }
392 _ => unreachable!(),
393 };
394 return Ok(AiClient {
395 provider,
396 adapter,
397 metrics: Arc::new(NoopMetrics::new()),
398 connection_options: Some(opts),
399 custom_default_chat_model: None,
400 custom_default_multimodal_model: None,
401 backpressure: None,
402 #[cfg(feature = "routing_mvp")]
403 routing_array: None,
404 #[cfg(feature = "interceptors")]
405 interceptor_pipeline: None,
406 });
407 }
408
409 let mut client = AiClient::new(provider)?;
411 if let Some(ref k) = opts.api_key {
412 let override_adapter: Option<Box<dyn ChatApi>> = match provider {
413 Provider::Groq => Some(Box::new(GenericAdapter::new_with_api_key(
414 ProviderConfigs::groq(),
415 Some(k.clone()),
416 )?)),
417 Provider::XaiGrok => Some(Box::new(GenericAdapter::new_with_api_key(
418 ProviderConfigs::xai_grok(),
419 Some(k.clone()),
420 )?)),
421 Provider::Ollama => Some(Box::new(GenericAdapter::new_with_api_key(
422 ProviderConfigs::ollama(),
423 Some(k.clone()),
424 )?)),
425 Provider::DeepSeek => Some(Box::new(GenericAdapter::new_with_api_key(
426 ProviderConfigs::deepseek(),
427 Some(k.clone()),
428 )?)),
429 Provider::Qwen => Some(Box::new(GenericAdapter::new_with_api_key(
430 ProviderConfigs::qwen(),
431 Some(k.clone()),
432 )?)),
433 Provider::BaiduWenxin => Some(Box::new(GenericAdapter::new_with_api_key(
434 ProviderConfigs::baidu_wenxin(),
435 Some(k.clone()),
436 )?)),
437 Provider::TencentHunyuan => Some(Box::new(GenericAdapter::new_with_api_key(
438 ProviderConfigs::tencent_hunyuan(),
439 Some(k.clone()),
440 )?)),
441 Provider::IflytekSpark => Some(Box::new(GenericAdapter::new_with_api_key(
442 ProviderConfigs::iflytek_spark(),
443 Some(k.clone()),
444 )?)),
445 Provider::Moonshot => Some(Box::new(GenericAdapter::new_with_api_key(
446 ProviderConfigs::moonshot(),
447 Some(k.clone()),
448 )?)),
449 Provider::Anthropic => Some(Box::new(GenericAdapter::new_with_api_key(
450 ProviderConfigs::anthropic(),
451 Some(k.clone()),
452 )?)),
453 Provider::AzureOpenAI => Some(Box::new(GenericAdapter::new_with_api_key(
454 ProviderConfigs::azure_openai(),
455 Some(k.clone()),
456 )?)),
457 Provider::HuggingFace => Some(Box::new(GenericAdapter::new_with_api_key(
458 ProviderConfigs::huggingface(),
459 Some(k.clone()),
460 )?)),
461 Provider::TogetherAI => Some(Box::new(GenericAdapter::new_with_api_key(
462 ProviderConfigs::together_ai(),
463 Some(k.clone()),
464 )?)),
465 _ => None,
466 };
467 if let Some(a) = override_adapter {
468 client.adapter = a;
469 }
470 }
471 client.connection_options = Some(opts);
472 Ok(client)
473 }
474
475 pub fn connection_options(&self) -> Option<&ConnectionOptions> {
476 self.connection_options.as_ref()
477 }
478
479 pub fn builder(provider: Provider) -> AiClientBuilder {
507 AiClientBuilder::new(provider)
508 }
509
510 pub fn new_with_metrics(
512 provider: Provider,
513 metrics: Arc<dyn Metrics>,
514 ) -> Result<Self, AiLibError> {
515 let adapter: Box<dyn ChatApi> = match provider {
516 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
517 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
518 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
519 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
520 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
521 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
522 Provider::BaiduWenxin => {
523 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
524 }
525 Provider::TencentHunyuan => {
526 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
527 }
528 Provider::IflytekSpark => {
529 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
530 }
531 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
532 Provider::AzureOpenAI => {
533 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
534 }
535 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
536 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
537 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
538 Provider::Gemini => Box::new(GeminiAdapter::new()?),
539 Provider::Mistral => Box::new(MistralAdapter::new()?),
540 Provider::Cohere => Box::new(CohereAdapter::new()?),
541 };
542
543 Ok(Self {
544 provider,
545 adapter,
546 metrics,
547 connection_options: None,
548 custom_default_chat_model: None,
549 custom_default_multimodal_model: None,
550 backpressure: None,
551 #[cfg(feature = "routing_mvp")]
552 routing_array: None,
553 #[cfg(feature = "interceptors")]
554 interceptor_pipeline: None,
555 })
556 }
557
558 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
560 self.metrics = metrics;
561 self
562 }
563
564 pub async fn chat_completion(
572 &self,
573 request: ChatCompletionRequest,
574 ) -> Result<ChatCompletionResponse, AiLibError> {
575 let _bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &self.backpressure {
577 match ctrl.acquire_permit().await {
578 Ok(p) => Some(p),
579 Err(_) => {
580 return Err(AiLibError::RateLimitExceeded(
581 "Backpressure: no permits available".to_string(),
582 ))
583 }
584 }
585 } else {
586 None
587 };
588 #[cfg(feature = "routing_mvp")]
589 {
590 if request.model == "__route__" {
592 let _ = self.metrics.incr_counter("routing_mvp.request", 1).await;
593 let mut chosen = self.provider.default_chat_model().to_string();
594 if let Some(arr) = &self.routing_array {
595 let mut arr_clone = arr.clone();
596 if let Some(ep) = arr_clone.select_endpoint() {
597 match crate::provider::utils::health_check(&ep.url).await {
598 Ok(()) => {
599 let _ = self.metrics.incr_counter("routing_mvp.selected", 1).await;
600 chosen = ep.model_name.clone();
601 }
602 Err(_) => {
603 let _ = self
604 .metrics
605 .incr_counter("routing_mvp.health_fail", 1)
606 .await;
607 chosen = self.provider.default_chat_model().to_string();
608 let _ = self
609 .metrics
610 .incr_counter("routing_mvp.fallback_default", 1)
611 .await;
612 }
613 }
614 } else {
615 let _ = self
616 .metrics
617 .incr_counter("routing_mvp.no_endpoint", 1)
618 .await;
619 }
620 } else {
621 let _ = self
622 .metrics
623 .incr_counter("routing_mvp.missing_array", 1)
624 .await;
625 }
626 let mut req2 = request;
627 req2.model = chosen;
628 return self.adapter.chat_completion(req2).await;
629 }
630 }
631 #[cfg(feature = "interceptors")]
632 if let Some(p) = &self.interceptor_pipeline {
633 let ctx = crate::interceptors::RequestContext {
634 provider: format!("{:?}", self.provider).to_lowercase(),
635 model: request.model.clone(),
636 };
637 return p
638 .execute(&ctx, &request, || self.adapter.chat_completion(request.clone()))
639 .await;
640 }
641
642 self.adapter.chat_completion(request).await
643 }
644
645 pub async fn chat_completion_stream(
653 &self,
654 mut request: ChatCompletionRequest,
655 ) -> Result<
656 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
657 AiLibError,
658 > {
659 request.stream = Some(true);
660 let bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &self.backpressure {
662 match ctrl.acquire_permit().await {
663 Ok(p) => Some(p),
664 Err(_) => {
665 return Err(AiLibError::RateLimitExceeded(
666 "Backpressure: no permits available".to_string(),
667 ))
668 }
669 }
670 } else {
671 None
672 };
673 #[cfg(feature = "interceptors")]
674 if let Some(p) = &self.interceptor_pipeline {
675 let ctx = crate::interceptors::RequestContext {
676 provider: format!("{:?}", self.provider).to_lowercase(),
677 model: request.model.clone(),
678 };
679 for ic in &p.interceptors {
682 ic.on_request(&ctx, &request).await;
683 }
684 let inner = self.adapter.chat_completion_stream(request).await?;
685 let cs = ControlledStream::new_with_bp(inner, None, bp_permit);
686 return Ok(Box::new(cs));
687 }
688 let inner = self.adapter.chat_completion_stream(request).await?;
689 let cs = ControlledStream::new_with_bp(inner, None, bp_permit);
690 Ok(Box::new(cs))
691 }
692
693 pub async fn chat_completion_stream_with_cancel(
701 &self,
702 mut request: ChatCompletionRequest,
703 ) -> Result<
704 (
705 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
706 CancelHandle,
707 ),
708 AiLibError,
709 > {
710 request.stream = Some(true);
711 let bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &self.backpressure {
713 match ctrl.acquire_permit().await {
714 Ok(p) => Some(p),
715 Err(_) => {
716 return Err(AiLibError::RateLimitExceeded(
717 "Backpressure: no permits available".to_string(),
718 ))
719 }
720 }
721 } else {
722 None
723 };
724 let stream = self.adapter.chat_completion_stream(request).await?;
725 let (cancel_tx, cancel_rx) = oneshot::channel();
726 let cancel_handle = CancelHandle {
727 sender: Some(cancel_tx),
728 };
729
730 let controlled_stream = ControlledStream::new_with_bp(stream, Some(cancel_rx), bp_permit);
731 Ok((Box::new(controlled_stream), cancel_handle))
732 }
733
734 pub async fn chat_completion_batch(
785 &self,
786 requests: Vec<ChatCompletionRequest>,
787 concurrency_limit: Option<usize>,
788 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
789 self.adapter
790 .chat_completion_batch(requests, concurrency_limit)
791 .await
792 }
793
794 pub async fn chat_completion_batch_smart(
802 &self,
803 requests: Vec<ChatCompletionRequest>,
804 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
805 let concurrency_limit = if requests.len() <= 3 { None } else { Some(10) };
807 self.chat_completion_batch(requests, concurrency_limit)
808 .await
809 }
810
811 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
867 self.adapter.list_models().await
868 }
869
870 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
888 let new_adapter: Box<dyn ChatApi> = match provider {
889 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
890 Provider::XaiGrok => Box::new(GenericAdapter::new(ProviderConfigs::xai_grok())?),
891 Provider::Ollama => Box::new(GenericAdapter::new(ProviderConfigs::ollama())?),
892 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
893 Provider::Qwen => Box::new(GenericAdapter::new(ProviderConfigs::qwen())?),
894 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
895 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
896 Provider::BaiduWenxin => {
897 Box::new(GenericAdapter::new(ProviderConfigs::baidu_wenxin())?)
898 }
899 Provider::TencentHunyuan => {
900 Box::new(GenericAdapter::new(ProviderConfigs::tencent_hunyuan())?)
901 }
902 Provider::IflytekSpark => {
903 Box::new(GenericAdapter::new(ProviderConfigs::iflytek_spark())?)
904 }
905 Provider::Moonshot => Box::new(GenericAdapter::new(ProviderConfigs::moonshot())?),
906 Provider::Gemini => Box::new(GeminiAdapter::new()?),
907 Provider::AzureOpenAI => {
908 Box::new(GenericAdapter::new(ProviderConfigs::azure_openai())?)
909 }
910 Provider::HuggingFace => Box::new(GenericAdapter::new(ProviderConfigs::huggingface())?),
911 Provider::TogetherAI => Box::new(GenericAdapter::new(ProviderConfigs::together_ai())?),
912 Provider::Mistral => Box::new(MistralAdapter::new()?),
913 Provider::Cohere => Box::new(CohereAdapter::new()?),
914 };
916
917 self.provider = provider;
918 self.adapter = new_adapter;
919 Ok(())
920 }
921
922 pub fn current_provider(&self) -> Provider {
924 self.provider
925 }
926
927 pub fn build_simple_request<S: Into<String>>(&self, prompt: S) -> ChatCompletionRequest {
931 let model = self
932 .custom_default_chat_model
933 .clone()
934 .unwrap_or_else(|| self.provider.default_chat_model().to_string());
935
936 ChatCompletionRequest::new(
937 model,
938 vec![crate::types::Message {
939 role: crate::types::Role::User,
940 content: crate::types::common::Content::Text(prompt.into()),
941 function_call: None,
942 }],
943 )
944 }
945
946 pub fn build_simple_request_with_model<S: Into<String>>(
949 &self,
950 prompt: S,
951 model: S,
952 ) -> ChatCompletionRequest {
953 ChatCompletionRequest::new(
954 model.into(),
955 vec![crate::types::Message {
956 role: crate::types::Role::User,
957 content: crate::types::common::Content::Text(prompt.into()),
958 function_call: None,
959 }],
960 )
961 }
962
963 pub fn build_multimodal_request<S: Into<String>>(
967 &self,
968 prompt: S,
969 ) -> Result<ChatCompletionRequest, AiLibError> {
970 let model = self
971 .custom_default_multimodal_model
972 .clone()
973 .or_else(|| {
974 self.provider
975 .default_multimodal_model()
976 .map(|s| s.to_string())
977 })
978 .ok_or_else(|| {
979 AiLibError::ConfigurationError(format!(
980 "No multimodal model available for provider {:?}",
981 self.provider
982 ))
983 })?;
984
985 Ok(ChatCompletionRequest::new(
986 model,
987 vec![crate::types::Message {
988 role: crate::types::Role::User,
989 content: crate::types::common::Content::Text(prompt.into()),
990 function_call: None,
991 }],
992 ))
993 }
994
995 pub fn build_multimodal_request_with_model<S: Into<String>>(
998 &self,
999 prompt: S,
1000 model: S,
1001 ) -> ChatCompletionRequest {
1002 ChatCompletionRequest::new(
1003 model.into(),
1004 vec![crate::types::Message {
1005 role: crate::types::Role::User,
1006 content: crate::types::common::Content::Text(prompt.into()),
1007 function_call: None,
1008 }],
1009 )
1010 }
1011
1012 pub async fn quick_chat_text<P: Into<String>>(
1015 provider: Provider,
1016 prompt: P,
1017 ) -> Result<String, AiLibError> {
1018 let client = AiClient::new(provider)?;
1019 let req = client.build_simple_request(prompt.into());
1020 let resp = client.chat_completion(req).await?;
1021 resp.first_text().map(|s| s.to_string())
1022 }
1023
1024 pub async fn quick_chat_text_with_model<P: Into<String>, M: Into<String>>(
1027 provider: Provider,
1028 prompt: P,
1029 model: M,
1030 ) -> Result<String, AiLibError> {
1031 let client = AiClient::new(provider)?;
1032 let req = client.build_simple_request_with_model(prompt.into(), model.into());
1033 let resp = client.chat_completion(req).await?;
1034 resp.first_text().map(|s| s.to_string())
1035 }
1036
1037 pub async fn quick_multimodal_text<P: Into<String>>(
1040 provider: Provider,
1041 prompt: P,
1042 ) -> Result<String, AiLibError> {
1043 let client = AiClient::new(provider)?;
1044 let req = client.build_multimodal_request(prompt.into())?;
1045 let resp = client.chat_completion(req).await?;
1046 resp.first_text().map(|s| s.to_string())
1047 }
1048
1049 pub async fn quick_multimodal_text_with_model<P: Into<String>, M: Into<String>>(
1052 provider: Provider,
1053 prompt: P,
1054 model: M,
1055 ) -> Result<String, AiLibError> {
1056 let client = AiClient::new(provider)?;
1057 let req = client.build_multimodal_request_with_model(prompt.into(), model.into());
1058 let resp = client.chat_completion(req).await?;
1059 resp.first_text().map(|s| s.to_string())
1060 }
1061
1062 pub async fn quick_chat_text_with_options<P: Into<String>>(
1065 provider: Provider,
1066 prompt: P,
1067 options: ModelOptions,
1068 ) -> Result<String, AiLibError> {
1069 let client = AiClient::new(provider)?;
1070
1071 let model = if let Some(chat_model) = options.chat_model {
1073 chat_model
1074 } else {
1075 provider.default_chat_model().to_string()
1076 };
1077
1078 let req = client.build_simple_request_with_model(prompt.into(), model);
1079 let resp = client.chat_completion(req).await?;
1080 resp.first_text().map(|s| s.to_string())
1081 }
1082
1083 pub async fn upload_file(&self, path: &str) -> Result<String, AiLibError> {
1090 let base_url = if let Some(opts) = &self.connection_options {
1092 if let Some(b) = &opts.base_url {
1093 b.clone()
1094 } else {
1095 self.provider_default_base_url()?
1096 }
1097 } else {
1098 self.provider_default_base_url()?
1099 };
1100
1101 let endpoint: Option<String> = if self.provider.is_config_driven() {
1103 let cfg = self.provider.get_default_config()?;
1105 cfg.upload_endpoint.clone()
1106 } else {
1107 match self.provider {
1108 Provider::OpenAI => Some("/files".to_string()),
1109 _ => None,
1110 }
1111 };
1112
1113 let endpoint = endpoint.ok_or_else(|| {
1114 AiLibError::UnsupportedFeature(format!(
1115 "Provider {:?} does not expose an upload endpoint in OSS",
1116 self.provider
1117 ))
1118 })?;
1119
1120 let upload_url = if base_url.ends_with('/') {
1122 format!("{}{}", base_url.trim_end_matches('/'), endpoint)
1123 } else {
1124 format!("{}{}", base_url, endpoint)
1125 };
1126
1127 crate::provider::utils::upload_file_with_transport(None, &upload_url, path, "file").await
1129 }
1130
1131 fn provider_default_base_url(&self) -> Result<String, AiLibError> {
1132 if self.provider.is_config_driven() {
1133 Ok(self.provider.get_default_config()?.base_url)
1134 } else {
1135 match self.provider {
1136 Provider::OpenAI => Ok("https://api.openai.com/v1".to_string()),
1137 Provider::Gemini => {
1138 Ok("https://generativelanguage.googleapis.com/v1beta".to_string())
1139 }
1140 Provider::Mistral => Ok("https://api.mistral.ai".to_string()),
1141 Provider::Cohere => Ok("https://api.cohere.ai".to_string()),
1142 _ => Err(AiLibError::ConfigurationError(
1143 "No default base URL for provider".to_string(),
1144 )),
1145 }
1146 }
1147 }
1148}
1149
1150pub struct CancelHandle {
1152 sender: Option<oneshot::Sender<()>>,
1153}
1154
1155impl CancelHandle {
1156 pub fn cancel(mut self) {
1158 if let Some(sender) = self.sender.take() {
1159 let _ = sender.send(());
1160 }
1161 }
1162}
1163
1164pub struct AiClientBuilder {
1189 provider: Provider,
1190 base_url: Option<String>,
1191 proxy_url: Option<String>,
1192 timeout: Option<std::time::Duration>,
1193 pool_max_idle: Option<usize>,
1194 pool_idle_timeout: Option<std::time::Duration>,
1195 metrics: Option<Arc<dyn Metrics>>,
1196 default_chat_model: Option<String>,
1198 default_multimodal_model: Option<String>,
1199 resilience_config: ResilienceConfig,
1201 #[cfg(feature = "routing_mvp")]
1202 routing_array: Option<crate::provider::models::ModelArray>,
1203 #[cfg(feature = "interceptors")]
1204 interceptor_pipeline: Option<crate::interceptors::InterceptorPipeline>,
1205}
1206
1207impl AiClientBuilder {
1208 pub fn new(provider: Provider) -> Self {
1216 Self {
1217 provider,
1218 base_url: None,
1219 proxy_url: None,
1220 timeout: None,
1221 pool_max_idle: None,
1222 pool_idle_timeout: None,
1223 metrics: None,
1224 default_chat_model: None,
1225 default_multimodal_model: None,
1226 resilience_config: ResilienceConfig::default(),
1227 #[cfg(feature = "routing_mvp")]
1228 routing_array: None,
1229 #[cfg(feature = "interceptors")]
1230 interceptor_pipeline: None,
1231 }
1232 }
1233
1234 fn is_config_driven_provider(provider: Provider) -> bool {
1236 provider.is_config_driven()
1237 }
1238
1239 pub fn with_base_url(mut self, base_url: &str) -> Self {
1247 self.base_url = Some(base_url.to_string());
1248 self
1249 }
1250
1251 pub fn with_proxy(mut self, proxy_url: Option<&str>) -> Self {
1275 self.proxy_url = proxy_url.map(|s| s.to_string());
1276 self
1277 }
1278
1279 pub fn without_proxy(mut self) -> Self {
1295 self.proxy_url = Some("".to_string());
1296 self
1297 }
1298
1299 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
1307 self.timeout = Some(timeout);
1308 self
1309 }
1310
1311 pub fn with_pool_config(mut self, max_idle: usize, idle_timeout: std::time::Duration) -> Self {
1320 self.pool_max_idle = Some(max_idle);
1321 self.pool_idle_timeout = Some(idle_timeout);
1322 self
1323 }
1324
1325 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
1333 self.metrics = Some(metrics);
1334 self
1335 }
1336
1337 #[cfg(feature = "interceptors")]
1338 pub fn with_interceptor_pipeline(
1339 mut self,
1340 pipeline: crate::interceptors::InterceptorPipeline,
1341 ) -> Self {
1342 self.interceptor_pipeline = Some(pipeline);
1343 self
1344 }
1345
1346 #[cfg(feature = "interceptors")]
1347 pub fn enable_default_interceptors(mut self) -> Self {
1348 let p = crate::interceptors::create_default_interceptors();
1349 self.interceptor_pipeline = Some(p);
1350 self
1351 }
1352
1353 #[cfg(feature = "interceptors")]
1354 pub fn enable_minimal_interceptors(mut self) -> Self {
1355 let p = crate::interceptors::default::DefaultInterceptorsBuilder::new()
1356 .enable_circuit_breaker(false)
1357 .enable_rate_limit(false)
1358 .build();
1359 self.interceptor_pipeline = Some(p);
1360 self
1361 }
1362
1363 pub fn with_default_chat_model(mut self, model: &str) -> Self {
1381 self.default_chat_model = Some(model.to_string());
1382 self
1383 }
1384
1385 pub fn with_default_multimodal_model(mut self, model: &str) -> Self {
1403 self.default_multimodal_model = Some(model.to_string());
1404 self
1405 }
1406
1407 pub fn with_smart_defaults(mut self) -> Self {
1425 self.resilience_config = ResilienceConfig::smart_defaults();
1426 self
1427 }
1428
1429 pub fn for_production(mut self) -> Self {
1447 self.resilience_config = ResilienceConfig::production();
1448 self
1449 }
1450
1451 pub fn for_development(mut self) -> Self {
1469 self.resilience_config = ResilienceConfig::development();
1470 self
1471 }
1472
1473 pub fn with_max_concurrency(mut self, max_concurrent_requests: usize) -> Self {
1478 let mut cfg = self.resilience_config.clone();
1479 cfg.backpressure = Some(crate::config::BackpressureConfig { max_concurrent_requests });
1480 self.resilience_config = cfg;
1481 self
1482 }
1483
1484 pub fn with_resilience_config(mut self, config: ResilienceConfig) -> Self {
1492 self.resilience_config = config;
1493 self
1494 }
1495
1496 #[cfg(feature = "routing_mvp")]
1497 pub fn with_routing_array(mut self, array: crate::provider::models::ModelArray) -> Self {
1499 self.routing_array = Some(array);
1500 self
1501 }
1502
1503 pub fn build(self) -> Result<AiClient, AiLibError> {
1513 let base_url = self.determine_base_url()?;
1515
1516 let proxy_url = self.determine_proxy_url();
1518
1519 let timeout = self
1521 .timeout
1522 .unwrap_or_else(|| std::time::Duration::from_secs(30));
1523
1524 let adapter: Box<dyn ChatApi> = if Self::is_config_driven_provider(self.provider) {
1526 let config = self.create_custom_config(base_url)?;
1529 let transport = self.create_custom_transport(proxy_url.clone(), timeout)?;
1531 create_generic_adapter(config, transport)?
1532 } else {
1533 match self.provider {
1535 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
1536 Provider::Gemini => Box::new(GeminiAdapter::new()?),
1537 Provider::Mistral => Box::new(MistralAdapter::new()?),
1538 Provider::Cohere => Box::new(CohereAdapter::new()?),
1539 _ => unreachable!("All providers should be handled by now"),
1540 }
1541 };
1542
1543 let bp_ctrl: Option<Arc<BackpressureController>> = self
1545 .resilience_config
1546 .backpressure
1547 .as_ref()
1548 .map(|cfg| Arc::new(BackpressureController::new(cfg.max_concurrent_requests)));
1549
1550 let client = AiClient {
1552 provider: self.provider,
1553 adapter,
1554 metrics: self.metrics.unwrap_or_else(|| Arc::new(NoopMetrics::new())),
1555 connection_options: None,
1556 custom_default_chat_model: self.default_chat_model,
1557 custom_default_multimodal_model: self.default_multimodal_model,
1558 backpressure: bp_ctrl,
1559 #[cfg(feature = "routing_mvp")]
1560 routing_array: self.routing_array,
1561 #[cfg(feature = "interceptors")]
1562 interceptor_pipeline: self.interceptor_pipeline,
1563 };
1564
1565 Ok(client)
1566 }
1567
1568 fn determine_base_url(&self) -> Result<String, AiLibError> {
1570 if let Some(ref base_url) = self.base_url {
1572 return Ok(base_url.clone());
1573 }
1574
1575 let env_var_name = self.get_base_url_env_var_name();
1577 if let Ok(base_url) = std::env::var(&env_var_name) {
1578 return Ok(base_url);
1579 }
1580
1581 if Self::is_config_driven_provider(self.provider) {
1583 let default_config = self.get_default_provider_config()?;
1584 Ok(default_config.base_url)
1585 } else {
1586 match self.provider {
1588 Provider::OpenAI => Ok("https://api.openai.com".to_string()),
1589 Provider::Gemini => Ok("https://generativelanguage.googleapis.com".to_string()),
1590 Provider::Mistral => Ok("https://api.mistral.ai".to_string()),
1591 Provider::Cohere => Ok("https://api.cohere.ai".to_string()),
1592 _ => Err(AiLibError::ConfigurationError(
1593 "Unknown provider for base URL determination".to_string(),
1594 )),
1595 }
1596 }
1597 }
1598
1599 fn determine_proxy_url(&self) -> Option<String> {
1601 if let Some(ref proxy_url) = self.proxy_url {
1603 if proxy_url.is_empty() {
1605 return None;
1606 }
1607 return Some(proxy_url.clone());
1608 }
1609
1610 std::env::var("AI_PROXY_URL").ok()
1612 }
1613
1614 fn get_base_url_env_var_name(&self) -> String {
1616 match self.provider {
1617 Provider::Groq => "GROQ_BASE_URL".to_string(),
1618 Provider::XaiGrok => "GROK_BASE_URL".to_string(),
1619 Provider::Ollama => "OLLAMA_BASE_URL".to_string(),
1620 Provider::DeepSeek => "DEEPSEEK_BASE_URL".to_string(),
1621 Provider::Qwen => "DASHSCOPE_BASE_URL".to_string(),
1622 Provider::BaiduWenxin => "BAIDU_WENXIN_BASE_URL".to_string(),
1623 Provider::TencentHunyuan => "TENCENT_HUNYUAN_BASE_URL".to_string(),
1624 Provider::IflytekSpark => "IFLYTEK_BASE_URL".to_string(),
1625 Provider::Moonshot => "MOONSHOT_BASE_URL".to_string(),
1626 Provider::Anthropic => "ANTHROPIC_BASE_URL".to_string(),
1627 Provider::AzureOpenAI => "AZURE_OPENAI_BASE_URL".to_string(),
1628 Provider::HuggingFace => "HUGGINGFACE_BASE_URL".to_string(),
1629 Provider::TogetherAI => "TOGETHER_BASE_URL".to_string(),
1630 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
1632 "".to_string()
1633 }
1634 }
1635 }
1636
1637 fn get_default_provider_config(
1639 &self,
1640 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1641 self.provider.get_default_config()
1642 }
1643
1644 fn create_custom_config(
1646 &self,
1647 base_url: String,
1648 ) -> Result<crate::provider::config::ProviderConfig, AiLibError> {
1649 let mut config = self.get_default_provider_config()?;
1650 config.base_url = base_url;
1651 Ok(config)
1652 }
1653
1654 fn create_custom_transport(
1656 &self,
1657 proxy_url: Option<String>,
1658 timeout: std::time::Duration,
1659 ) -> Result<Option<crate::transport::DynHttpTransportRef>, AiLibError> {
1660 if proxy_url.is_none() && self.pool_max_idle.is_none() && self.pool_idle_timeout.is_none() {
1662 return Ok(None);
1663 }
1664
1665 let transport_config = crate::transport::HttpTransportConfig {
1667 timeout,
1668 proxy: proxy_url,
1669 pool_max_idle_per_host: self.pool_max_idle,
1670 pool_idle_timeout: self.pool_idle_timeout,
1671 };
1672
1673 let transport = crate::transport::HttpTransport::new_with_config(transport_config)?;
1675 Ok(Some(transport.boxed()))
1676 }
1677}
1678
1679struct ControlledStream {
1681 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1682 cancel_rx: Option<oneshot::Receiver<()>>,
1683 _bp_permit: Option<BackpressurePermit>,
1685}
1686
1687impl ControlledStream {
1688 fn new_with_bp(
1689 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
1690 cancel_rx: Option<oneshot::Receiver<()>>,
1691 bp_permit: Option<BackpressurePermit>,
1692 ) -> Self {
1693 Self { inner, cancel_rx, _bp_permit: bp_permit }
1694 }
1695}
1696
1697impl Stream for ControlledStream {
1698 type Item = Result<ChatCompletionChunk, AiLibError>;
1699
1700 fn poll_next(
1701 mut self: std::pin::Pin<&mut Self>,
1702 cx: &mut std::task::Context<'_>,
1703 ) -> std::task::Poll<Option<Self::Item>> {
1704 use futures::stream::StreamExt;
1705 use std::task::Poll;
1706
1707 if let Some(ref mut cancel_rx) = self.cancel_rx {
1709 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
1710 Poll::Ready(_) => {
1711 self.cancel_rx = None;
1712 return Poll::Ready(Some(Err(AiLibError::ProviderError(
1713 "Stream cancelled".to_string(),
1714 ))));
1715 }
1716 Poll::Pending => {}
1717 }
1718 }
1719
1720 self.inner.poll_next_unpin(cx)
1722 }
1723}