ai_lib/client/
client_impl.rs

1use crate::api::ChatCompletionChunk;
2use crate::api::ChatProvider;
3use crate::config::ConnectionOptions;
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::model::{ModelResolution, ModelResolutionSource, ModelResolver};
6use crate::rate_limiter::BackpressureController;
7use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
8use futures::stream::Stream;
9use std::sync::Arc;
10
11use super::builder::AiClientBuilder;
12use super::helpers;
13use super::metadata::{metadata_from_provider, ClientMetadata};
14use super::model_options::ModelOptions;
15use super::provider::Provider;
16use super::stream::CancelHandle;
17use super::{batch, request, stream, ProviderFactory};
18
19/// 统一的AI客户端,提供跨厂商的AI服务访问接口
20///
21/// Unified AI client
22///
23/// Usage example:
24/// ```rust
25/// use ai_lib::{AiClient, Provider, ChatCompletionRequest, Message, Role};
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29///     // Switch model provider by changing Provider value
30///     let client = AiClient::new(Provider::Groq)?;
31///     
32///     let request = ChatCompletionRequest::new(
33///         "test-model".to_string(),
34///         vec![Message {
35///             role: Role::User,
36///             content: ai_lib::types::common::Content::Text("Hello".to_string()),
37///             function_call: None,
38///         }],
39///     );
40///     
41///     // Note: Set GROQ_API_KEY environment variable for actual API calls
42///     // Optional: Set AI_PROXY_URL environment variable to use proxy server
43///     // let response = client.chat_completion(request).await?;
44///     
45///     println!(
46///         "Client created successfully with provider: {}",
47///         client.provider_name()
48///     );
49///     println!("Request prepared for model: {}", request.model);
50///     
51///     Ok(())
52/// }
53/// ```
54///
55/// # Proxy Configuration
56///
57/// Configure proxy server by setting the `AI_PROXY_URL` environment variable:
58///
59/// ```bash
60/// export AI_PROXY_URL=http://proxy.example.com:8080
61/// ```
62///
63/// Supported proxy formats:
64/// - HTTP proxy: `http://proxy.example.com:8080`
65/// - HTTPS proxy: `https://proxy.example.com:8080`  
66/// - With authentication: `http://user:pass@proxy.example.com:8080`
67pub struct AiClient {
68    pub(crate) chat_provider: Box<dyn ChatProvider>,
69    pub(crate) metadata: ClientMetadata,
70    pub(crate) metrics: Arc<dyn Metrics>,
71    pub(crate) model_resolver: Arc<ModelResolver>,
72    pub(crate) connection_options: Option<ConnectionOptions>,
73    #[cfg(feature = "interceptors")]
74    pub(crate) interceptor_pipeline: Option<crate::interceptors::InterceptorPipeline>,
75    // Custom default models (override provider defaults)
76    pub(crate) custom_default_chat_model: Option<String>,
77    pub(crate) custom_default_multimodal_model: Option<String>,
78    // Optional backpressure controller
79    pub(crate) backpressure: Option<Arc<BackpressureController>>,
80}
81
82impl AiClient {
83    /// Get the effective default chat model for this client (honors custom override)
84    pub fn default_chat_model(&self) -> String {
85        self.custom_default_chat_model
86            .clone()
87            .or_else(|| self.metadata.default_chat_model().map(|s| s.to_string()))
88            .expect("AiClient metadata missing default chat model")
89    }
90
91    /// Create a new AI client
92    pub fn new(provider: Provider) -> Result<Self, AiLibError> {
93        AiClientBuilder::new(provider).build()
94    }
95
96    /// Create a new AI client builder
97    pub fn builder(provider: Provider) -> AiClientBuilder {
98        AiClientBuilder::new(provider)
99    }
100
101    /// Create AiClient with injected metrics implementation
102    pub fn new_with_metrics(
103        provider: Provider,
104        metrics: Arc<dyn Metrics>,
105    ) -> Result<Self, AiLibError> {
106        AiClientBuilder::new(provider).with_metrics(metrics).build()
107    }
108
109    /// Create client with minimal explicit options (base_url/proxy/timeout).
110    ///
111    /// Fields left as `None` in `ConnectionOptions` will fall back to environment
112    /// variables (e.g., `OPENAI_API_KEY`, `AI_PROXY_URL`, `AI_TIMEOUT_SECS`).
113    /// Set `disable_proxy: true` to prevent automatic proxy detection from `AI_PROXY_URL`.
114    pub fn with_options(provider: Provider, opts: ConnectionOptions) -> Result<Self, AiLibError> {
115        // Hydrate unset fields from environment variables
116        let opts = opts.hydrate_with_env(provider.env_prefix());
117
118        let resolved_base_url = super::builder::resolve_base_url(provider, opts.base_url.clone())?;
119
120        // Determine effective proxy: None if disable_proxy is true, otherwise use hydrated value
121        let effective_proxy = if opts.disable_proxy {
122            None
123        } else {
124            opts.proxy.clone()
125        };
126
127        let transport = if effective_proxy.is_some() || opts.timeout.is_some() {
128            let transport_config = crate::transport::HttpTransportConfig {
129                timeout: opts.timeout.unwrap_or(std::time::Duration::from_secs(30)),
130                proxy: effective_proxy,
131                pool_max_idle_per_host: None,
132                pool_idle_timeout: None,
133            };
134            Some(crate::transport::HttpTransport::new_with_config(transport_config)?.boxed())
135        } else {
136            None
137        };
138
139        let chat_provider = ProviderFactory::create_adapter(
140            provider,
141            opts.api_key.clone(),
142            Some(resolved_base_url.clone()),
143            transport,
144        )?;
145
146        let metadata = metadata_from_provider(
147            provider,
148            chat_provider.name().to_string(),
149            Some(resolved_base_url),
150            None,
151            None,
152        );
153
154        Ok(AiClient {
155            chat_provider,
156            metadata,
157            metrics: Arc::new(NoopMetrics::new()),
158            model_resolver: Arc::new(ModelResolver::new()),
159            connection_options: Some(opts),
160            custom_default_chat_model: None,
161            custom_default_multimodal_model: None,
162            backpressure: None,
163            #[cfg(feature = "interceptors")]
164            interceptor_pipeline: None,
165        })
166    }
167
168    pub fn connection_options(&self) -> Option<&ConnectionOptions> {
169        self.connection_options.as_ref()
170    }
171
172    /// Set metrics implementation on client
173    pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
174        self.metrics = metrics;
175        self
176    }
177
178    /// Access the model resolver (advanced customization).
179    pub fn model_resolver(&self) -> Arc<ModelResolver> {
180        self.model_resolver.clone()
181    }
182
183    /// Send chat completion request
184    pub async fn chat_completion(
185        &self,
186        request: ChatCompletionRequest,
187    ) -> Result<ChatCompletionResponse, AiLibError> {
188        request::chat_completion(self, request).await
189    }
190
191    #[cfg(feature = "response_parser")]
192    /// Send a chat completion request and parse the response using the provided parser.
193    ///
194    /// This is a convenience method that pairs `chat_completion` with a `ResponseParser`.
195    /// It ensures the response contains text content and then extracts structured data
196    /// defined by the parser (e.g., JSON, Markdown sections).
197    pub async fn chat_completion_parsed<P>(
198        &self,
199        request: ChatCompletionRequest,
200        parser: P,
201    ) -> Result<P::Output, AiLibError>
202    where
203        P: crate::response_parser::ResponseParser + Send + Sync,
204    {
205        let response = self.chat_completion(request).await?;
206        let content = response
207            .choices
208            .first()
209            .map(|c| c.message.content.as_text())
210            .ok_or_else(|| {
211                AiLibError::InvalidModelResponse(
212                    "Response contained no text content to parse".to_string(),
213                )
214            })?;
215
216        parser.parse(&content).await.map_err(|e| {
217            AiLibError::InvalidModelResponse(format!("Failed to parse response: {}", e))
218        })
219    }
220    /// Streaming chat completion request
221    pub async fn chat_completion_stream(
222        &self,
223        request: ChatCompletionRequest,
224    ) -> Result<
225        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
226        AiLibError,
227    > {
228        stream::chat_completion_stream(self, request).await
229    }
230
231    /// Streaming chat completion request with cancel control
232    pub async fn chat_completion_stream_with_cancel(
233        &self,
234        request: ChatCompletionRequest,
235    ) -> Result<
236        (
237            Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
238            CancelHandle,
239        ),
240        AiLibError,
241    > {
242        stream::chat_completion_stream_with_cancel(self, request).await
243    }
244
245    /// Batch chat completion requests
246    pub async fn chat_completion_batch(
247        &self,
248        requests: Vec<ChatCompletionRequest>,
249        concurrency_limit: Option<usize>,
250    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
251        batch::chat_completion_batch(self, requests, concurrency_limit).await
252    }
253
254    /// Smart batch processing
255    pub async fn chat_completion_batch_smart(
256        &self,
257        requests: Vec<ChatCompletionRequest>,
258    ) -> Result<Vec<Result<ChatCompletionResponse, AiLibError>>, AiLibError> {
259        batch::chat_completion_batch_smart(self, requests).await
260    }
261
262    /// Get list of supported models
263    pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
264        helpers::list_models(self).await
265    }
266
267    /// Switch AI model provider
268    pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
269        helpers::switch_provider(self, provider)
270    }
271
272    /// Get the active provider name reported by the underlying strategy.
273    pub fn provider_name(&self) -> &str {
274        self.metadata.provider_name()
275    }
276
277    /// Get the current provider enum value
278    pub fn provider(&self) -> Provider {
279        self.metadata.provider()
280    }
281
282    /// Convenience helper: construct a request with the provider's default chat model.
283    pub fn build_simple_request<S: Into<String>>(&self, prompt: S) -> ChatCompletionRequest {
284        helpers::build_simple_request(self, prompt)
285    }
286
287    /// Convenience helper: construct a request with an explicitly specified chat model.
288    pub fn build_simple_request_with_model<S: Into<String>>(
289        &self,
290        prompt: S,
291        model: S,
292    ) -> ChatCompletionRequest {
293        helpers::build_simple_request_with_model(self, prompt, model)
294    }
295
296    /// Convenience helper: construct a request with the provider's default multimodal model.
297    pub fn build_multimodal_request<S: Into<String>>(
298        &self,
299        prompt: S,
300    ) -> Result<ChatCompletionRequest, AiLibError> {
301        helpers::build_multimodal_request(self, prompt)
302    }
303
304    /// Convenience helper: construct a request with an explicitly specified multimodal model.
305    pub fn build_multimodal_request_with_model<S: Into<String>>(
306        &self,
307        prompt: S,
308        model: S,
309    ) -> ChatCompletionRequest {
310        helpers::build_multimodal_request_with_model(self, prompt, model)
311    }
312
313    /// One-shot helper: create a client for `provider`, send a single user prompt
314    pub async fn quick_chat_text<P: Into<String>>(
315        provider: Provider,
316        prompt: P,
317    ) -> Result<String, AiLibError> {
318        helpers::quick_chat_text(provider, prompt).await
319    }
320
321    /// One-shot helper with model
322    pub async fn quick_chat_text_with_model<P: Into<String>, M: Into<String>>(
323        provider: Provider,
324        prompt: P,
325        model: M,
326    ) -> Result<String, AiLibError> {
327        helpers::quick_chat_text_with_model(provider, prompt, model).await
328    }
329
330    /// One-shot helper multimodal
331    pub async fn quick_multimodal_text<P: Into<String>>(
332        provider: Provider,
333        prompt: P,
334    ) -> Result<String, AiLibError> {
335        helpers::quick_multimodal_text(provider, prompt).await
336    }
337
338    /// One-shot helper multimodal with model
339    pub async fn quick_multimodal_text_with_model<P: Into<String>, M: Into<String>>(
340        provider: Provider,
341        prompt: P,
342        model: M,
343    ) -> Result<String, AiLibError> {
344        helpers::quick_multimodal_text_with_model(provider, prompt, model).await
345    }
346
347    /// One-shot helper with model options
348    pub async fn quick_chat_text_with_options<P: Into<String>>(
349        provider: Provider,
350        prompt: P,
351        options: ModelOptions,
352    ) -> Result<String, AiLibError> {
353        helpers::quick_chat_text_with_options(provider, prompt, options).await
354    }
355
356    /// Upload a local file
357    pub async fn upload_file(&self, path: &str) -> Result<String, AiLibError> {
358        helpers::upload_file(self, path).await
359    }
360
361    pub(crate) fn provider_id(&self) -> Provider {
362        self.metadata.provider()
363    }
364
365    pub(crate) fn prepare_chat_request(
366        &self,
367        mut request: ChatCompletionRequest,
368    ) -> ChatCompletionRequest {
369        if should_use_auto_token(&request.model) {
370            if let Some(custom) = &self.custom_default_chat_model {
371                request.model = custom.clone();
372            } else {
373                let resolution = self
374                    .model_resolver
375                    .resolve_chat_model(self.provider_id(), None);
376                request.model = resolution.model;
377            }
378        }
379        request
380    }
381
382    pub(crate) fn fallback_model_after_invalid(
383        &self,
384        failed_model: &str,
385    ) -> Option<ModelResolution> {
386        if let Some(custom) = &self.custom_default_chat_model {
387            if !custom.eq_ignore_ascii_case(failed_model) {
388                return Some(ModelResolution::new(
389                    custom.clone(),
390                    ModelResolutionSource::CustomDefault,
391                    self.model_resolver.doc_url(self.provider_id()),
392                ));
393            }
394        }
395
396        self.model_resolver
397            .fallback_after_invalid(self.provider_id(), failed_model)
398    }
399}
400
401fn should_use_auto_token(model: &str) -> bool {
402    let trimmed = model.trim();
403    trimmed.is_empty()
404        || trimmed.eq_ignore_ascii_case("auto")
405        || trimmed.eq_ignore_ascii_case("default")
406        || trimmed.eq_ignore_ascii_case("provider_default")
407}