1use crate::api::ChatCompletionChunk;
2use crate::api::ChatProvider;
3use crate::config::ConnectionOptions;
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::model::{ModelResolution, ModelResolutionSource, ModelResolver};
6use crate::rate_limiter::BackpressureController;
7use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
8use futures::stream::Stream;
9use std::sync::Arc;
10
11use super::builder::AiClientBuilder;
12use super::helpers;
13use super::metadata::{metadata_from_provider, ClientMetadata};
14use super::model_options::ModelOptions;
15use super::provider::Provider;
16use super::stream::CancelHandle;
17use super::{batch, request, stream, ProviderFactory};
18
19pub struct AiClient {
68 pub(crate) chat_provider: Box<dyn ChatProvider>,
69 pub(crate) metadata: ClientMetadata,
70 pub(crate) metrics: Arc<dyn Metrics>,
71 pub(crate) model_resolver: Arc<ModelResolver>,
72 pub(crate) connection_options: Option<ConnectionOptions>,
73 #[cfg(feature = "interceptors")]
74 pub(crate) interceptor_pipeline: Option<crate::interceptors::InterceptorPipeline>,
75 pub(crate) custom_default_chat_model: Option<String>,
77 pub(crate) custom_default_multimodal_model: Option<String>,
78 pub(crate) backpressure: Option<Arc<BackpressureController>>,
80}
81
82impl AiClient {
83 pub fn default_chat_model(&self) -> String {
85 self.custom_default_chat_model
86 .clone()
87 .or_else(|| self.metadata.default_chat_model().map(|s| s.to_string()))
88 .expect("AiClient metadata missing default chat model")
89 }
90
91 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
93 AiClientBuilder::new(provider).build()
94 }
95
96 pub fn builder(provider: Provider) -> AiClientBuilder {
98 AiClientBuilder::new(provider)
99 }
100
101 pub fn new_with_metrics(
103 provider: Provider,
104 metrics: Arc<dyn Metrics>,
105 ) -> Result<Self, AiLibError> {
106 AiClientBuilder::new(provider).with_metrics(metrics).build()
107 }
108
109 pub fn with_options(provider: Provider, opts: ConnectionOptions) -> Result<Self, AiLibError> {
115 let opts = opts.hydrate_with_env(provider.env_prefix());
117
118 let resolved_base_url = super::builder::resolve_base_url(provider, opts.base_url.clone())?;
119
120 let effective_proxy = if opts.disable_proxy {
122 None
123 } else {
124 opts.proxy.clone()
125 };
126
127 let transport = if effective_proxy.is_some() || opts.timeout.is_some() {
128 let transport_config = crate::transport::HttpTransportConfig {
129 timeout: opts.timeout.unwrap_or(std::time::Duration::from_secs(30)),
130 proxy: effective_proxy,
131 pool_max_idle_per_host: None,
132 pool_idle_timeout: None,
133 };
134 Some(crate::transport::HttpTransport::new_with_config(transport_config)?.boxed())
135 } else {
136 None
137 };
138
139 let chat_provider = ProviderFactory::create_adapter(
140 provider,
141 opts.api_key.clone(),
142 Some(resolved_base_url.clone()),
143 transport,
144 )?;
145
146 let metadata = metadata_from_provider(
147 provider,
148 chat_provider.name().to_string(),
149 Some(resolved_base_url),
150 None,
151 None,
152 );
153
154 Ok(AiClient {
155 chat_provider,
156 metadata,
157 metrics: Arc::new(NoopMetrics::new()),
158 model_resolver: Arc::new(ModelResolver::new()),
159 connection_options: Some(opts),
160 custom_default_chat_model: None,
161 custom_default_multimodal_model: None,
162 backpressure: None,
163 #[cfg(feature = "interceptors")]
164 interceptor_pipeline: None,
165 })
166 }
167
168 pub fn connection_options(&self) -> Option<&ConnectionOptions> {
169 self.connection_options.as_ref()
170 }
171
172 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
174 self.metrics = metrics;
175 self
176 }
177
178 pub fn model_resolver(&self) -> Arc<ModelResolver> {
180 self.model_resolver.clone()
181 }
182
183 pub async fn chat_completion(
185 &self,
186 request: ChatCompletionRequest,
187 ) -> Result<ChatCompletionResponse, AiLibError> {
188 request::chat_completion(self, request).await
189 }
190
191 #[cfg(feature = "response_parser")]
192 pub async fn chat_completion_parsed<P>(
198 &self,
199 request: ChatCompletionRequest,
200 parser: P,
201 ) -> Result<P::Output, AiLibError>
202 where
203 P: crate::response_parser::ResponseParser + Send + Sync,
204 {
205 let response = self.chat_completion(request).await?;
206 let content = response
207 .choices
208 .first()
209 .map(|c| c.message.content.as_text())
210 .ok_or_else(|| {
211 AiLibError::InvalidModelResponse(
212 "Response contained no text content to parse".to_string(),
213 )
214 })?;
215
216 parser.parse(&content).await.map_err(|e| {
217 AiLibError::InvalidModelResponse(format!("Failed to parse response: {}", e))
218 })
219 }
220 pub async fn chat_completion_stream(
222 &self,
223 request: ChatCompletionRequest,
224 ) -> Result<
225 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
226 AiLibError,
227 > {
228 stream::chat_completion_stream(self, request).await
229 }
230
231 pub async fn chat_completion_stream_with_cancel(
233 &self,
234 request: ChatCompletionRequest,
235 ) -> Result<
236 (
237 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
238 CancelHandle,
239 ),
240 AiLibError,
241 > {
242 stream::chat_completion_stream_with_cancel(self, request).await
243 }
244
245 pub async fn chat_completion_batch(
247 &self,
248 requests: Vec<ChatCompletionRequest>,
249 concurrency_limit: Option<usize>,
250 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
251 batch::chat_completion_batch(self, requests, concurrency_limit).await
252 }
253
254 pub async fn chat_completion_batch_smart(
256 &self,
257 requests: Vec<ChatCompletionRequest>,
258 ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
259 batch::chat_completion_batch_smart(self, requests).await
260 }
261
262 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
264 helpers::list_models(self).await
265 }
266
267 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
269 helpers::switch_provider(self, provider)
270 }
271
272 pub fn provider_name(&self) -> &str {
274 self.metadata.provider_name()
275 }
276
277 pub fn provider(&self) -> Provider {
279 self.metadata.provider()
280 }
281
282 pub fn build_simple_request<S: Into<String>>(&self, prompt: S) -> ChatCompletionRequest {
284 helpers::build_simple_request(self, prompt)
285 }
286
287 pub fn build_simple_request_with_model<S: Into<String>>(
289 &self,
290 prompt: S,
291 model: S,
292 ) -> ChatCompletionRequest {
293 helpers::build_simple_request_with_model(self, prompt, model)
294 }
295
296 pub fn build_multimodal_request<S: Into<String>>(
298 &self,
299 prompt: S,
300 ) -> Result<ChatCompletionRequest, AiLibError> {
301 helpers::build_multimodal_request(self, prompt)
302 }
303
304 pub fn build_multimodal_request_with_model<S: Into<String>>(
306 &self,
307 prompt: S,
308 model: S,
309 ) -> ChatCompletionRequest {
310 helpers::build_multimodal_request_with_model(self, prompt, model)
311 }
312
313 pub async fn quick_chat_text<P: Into<String>>(
315 provider: Provider,
316 prompt: P,
317 ) -> Result<String, AiLibError> {
318 helpers::quick_chat_text(provider, prompt).await
319 }
320
321 pub async fn quick_chat_text_with_model<P: Into<String>, M: Into<String>>(
323 provider: Provider,
324 prompt: P,
325 model: M,
326 ) -> Result<String, AiLibError> {
327 helpers::quick_chat_text_with_model(provider, prompt, model).await
328 }
329
330 pub async fn quick_multimodal_text<P: Into<String>>(
332 provider: Provider,
333 prompt: P,
334 ) -> Result<String, AiLibError> {
335 helpers::quick_multimodal_text(provider, prompt).await
336 }
337
338 pub async fn quick_multimodal_text_with_model<P: Into<String>, M: Into<String>>(
340 provider: Provider,
341 prompt: P,
342 model: M,
343 ) -> Result<String, AiLibError> {
344 helpers::quick_multimodal_text_with_model(provider, prompt, model).await
345 }
346
347 pub async fn quick_chat_text_with_options<P: Into<String>>(
349 provider: Provider,
350 prompt: P,
351 options: ModelOptions,
352 ) -> Result<String, AiLibError> {
353 helpers::quick_chat_text_with_options(provider, prompt, options).await
354 }
355
356 pub async fn upload_file(&self, path: &str) -> Result<String, AiLibError> {
358 helpers::upload_file(self, path).await
359 }
360
361 pub(crate) fn provider_id(&self) -> Provider {
362 self.metadata.provider()
363 }
364
365 pub(crate) fn prepare_chat_request(
366 &self,
367 mut request: ChatCompletionRequest,
368 ) -> ChatCompletionRequest {
369 if should_use_auto_token(&request.model) {
370 if let Some(custom) = &self.custom_default_chat_model {
371 request.model = custom.clone();
372 } else {
373 let resolution = self
374 .model_resolver
375 .resolve_chat_model(self.provider_id(), None);
376 request.model = resolution.model;
377 }
378 }
379 request
380 }
381
382 pub(crate) fn fallback_model_after_invalid(
383 &self,
384 failed_model: &str,
385 ) -> Option<ModelResolution> {
386 if let Some(custom) = &self.custom_default_chat_model {
387 if !custom.eq_ignore_ascii_case(failed_model) {
388 return Some(ModelResolution::new(
389 custom.clone(),
390 ModelResolutionSource::CustomDefault,
391 self.model_resolver.doc_url(self.provider_id()),
392 ));
393 }
394 }
395
396 self.model_resolver
397 .fallback_after_invalid(self.provider_id(), failed_model)
398 }
399}
400
401fn should_use_auto_token(model: &str) -> bool {
402 let trimmed = model.trim();
403 trimmed.is_empty()
404 || trimmed.eq_ignore_ascii_case("auto")
405 || trimmed.eq_ignore_ascii_case("default")
406 || trimmed.eq_ignore_ascii_case("provider_default")
407}