Skip to main content

ai_lib_rust/client/
core.rs

1//! 核心客户端实现:管理协议加载、传输、流水线及弹性策略。
2//!
3//! Core AI client implementation.
4
5use crate::client::types::CallStats;
6use crate::protocol::ProtocolLoader;
7use crate::protocol::ProtocolManifest;
8use crate::{Error, ErrorContext, Result};
9use std::sync::Arc;
10
11use crate::pipeline::Pipeline;
12use crate::transport::HttpTransport;
13
14// Import submodules
15use crate::client::validation;
16
17/// Unified AI client that works with any provider through protocol configuration.
18pub struct AiClient {
19    pub manifest: ProtocolManifest,
20    pub transport: Arc<HttpTransport>,
21    pub pipeline: Arc<Pipeline>,
22    pub loader: Arc<ProtocolLoader>,
23    pub(crate) fallbacks: Vec<String>,
24    pub(crate) model_id: String,
25    pub(crate) strict_streaming: bool,
26    pub(crate) feedback: Arc<dyn crate::feedback::FeedbackSink>,
27    pub(crate) inflight: Option<Arc<tokio::sync::Semaphore>>,
28    pub(crate) max_inflight: Option<usize>,
29    pub(crate) attempt_timeout: Option<std::time::Duration>,
30    pub(crate) breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
31    pub(crate) rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
32}
33
34/// Unified response format.
35#[derive(Debug, Default)]
36pub struct UnifiedResponse {
37    pub content: String,
38    pub tool_calls: Vec<crate::types::tool::ToolCall>,
39    pub usage: Option<serde_json::Value>,
40}
41
42impl AiClient {
43    /// Snapshot current runtime signals (facts only) for application-layer orchestration.
44    pub async fn signals(&self) -> crate::client::signals::SignalsSnapshot {
45        let inflight = self.inflight.as_ref().and_then(|sem| {
46            let max = self.max_inflight?;
47            let available = sem.available_permits();
48            let in_use = max.saturating_sub(available);
49            Some(crate::client::signals::InflightSnapshot {
50                max,
51                available,
52                in_use,
53            })
54        });
55
56        let rate_limiter = match &self.rate_limiter {
57            Some(rl) => Some(rl.snapshot().await),
58            None => None,
59        };
60
61        let circuit_breaker = match &self.breaker {
62            Some(cb) => Some(cb.snapshot()),
63            None => None,
64        };
65
66        crate::client::signals::SignalsSnapshot {
67            inflight,
68            rate_limiter,
69            circuit_breaker,
70        }
71    }
72
73    /// Create a new client for a specific model.
74    ///
75    /// The model identifier should be in the format "provider/model-name"
76    /// (e.g., "anthropic/claude-3-5-sonnet")
77    pub async fn new(model: &str) -> Result<Self> {
78        crate::client::builder::AiClientBuilder::new()
79            .build(model)
80            .await
81    }
82
83    /// Create a new client instance for another model, reusing loader + shared runtime knobs
84    /// (feedback, inflight, breaker, rate limiter) for consistent behavior.
85    pub(crate) async fn with_model(&self, model: &str) -> Result<Self> {
86        // model is in form "provider/model-id"
87        let parts: Vec<&str> = model.split('/').collect();
88        let model_id = parts
89            .get(1)
90            .map(|s| s.to_string())
91            .unwrap_or_else(|| model.to_string());
92
93        let manifest = self.loader.load_model(model).await?;
94        validation::validate_manifest(&manifest, self.strict_streaming)?;
95
96        let transport = Arc::new(crate::transport::HttpTransport::new(&manifest, &model_id)?);
97        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
98
99        Ok(AiClient {
100            manifest,
101            transport,
102            pipeline,
103            loader: self.loader.clone(),
104            fallbacks: Vec::new(),
105            model_id,
106            strict_streaming: self.strict_streaming,
107            feedback: self.feedback.clone(),
108            inflight: self.inflight.clone(),
109            max_inflight: self.max_inflight,
110            attempt_timeout: self.attempt_timeout,
111            breaker: self.breaker.clone(),
112            rate_limiter: self.rate_limiter.clone(),
113        })
114    }
115
116    /// Create a chat request builder.
117    pub fn chat(&self) -> crate::client::chat::ChatRequestBuilder<'_> {
118        crate::client::chat::ChatRequestBuilder::new(self)
119    }
120
121    /// Execute multiple chat requests concurrently with an optional concurrency limit.
122    ///
123    /// Notes:
124    /// - Results preserve input order.
125    /// - Internally uses the same "streaming → UnifiedResponse" path for consistency.
126    pub async fn chat_batch(
127        &self,
128        requests: Vec<crate::client::chat::ChatBatchRequest>,
129        concurrency_limit: Option<usize>,
130    ) -> Vec<Result<UnifiedResponse>> {
131        use futures::StreamExt;
132
133        let n = requests.len();
134        if n == 0 {
135            return Vec::new();
136        }
137
138        let limit = concurrency_limit.unwrap_or(10).max(1);
139        let mut out: Vec<Option<Result<UnifiedResponse>>> = (0..n).map(|_| None).collect();
140
141        let results: Vec<(usize, Result<UnifiedResponse>)> =
142            futures::stream::iter(requests.into_iter().enumerate())
143                .map(|(idx, req)| async move {
144                    let mut b = self.chat().messages(req.messages).stream();
145                    if let Some(t) = req.temperature {
146                        b = b.temperature(t);
147                    }
148                    if let Some(m) = req.max_tokens {
149                        b = b.max_tokens(m);
150                    }
151                    if let Some(tools) = req.tools {
152                        b = b.tools(tools);
153                    }
154                    if let Some(tc) = req.tool_choice {
155                        b = b.tool_choice(tc);
156                    }
157                    let r = b.execute().await;
158                    (idx, r)
159                })
160                .buffer_unordered(limit)
161                .collect()
162                .await;
163
164        for (idx, r) in results {
165            out[idx] = Some(r);
166        }
167
168        out.into_iter()
169            .map(|o| {
170                o.unwrap_or_else(|| {
171                    Err(Error::runtime_with_context(
172                        "batch result missing",
173                        ErrorContext::new().with_source("batch_executor"),
174                    ))
175                })
176            })
177            .collect()
178    }
179
180    /// Smart batch execution with a conservative, developer-friendly default heuristic.
181    ///
182    /// - For very small batches, run sequentially to reduce overhead.
183    /// - For larger batches, run with a bounded concurrency.
184    ///
185    /// You can override the chosen concurrency via env:
186    /// - `AI_LIB_BATCH_CONCURRENCY`
187    pub async fn chat_batch_smart(
188        &self,
189        requests: Vec<crate::client::chat::ChatBatchRequest>,
190    ) -> Vec<Result<UnifiedResponse>> {
191        let n = requests.len();
192        if n == 0 {
193            return Vec::new();
194        }
195
196        let env_override = std::env::var("AI_LIB_BATCH_CONCURRENCY")
197            .ok()
198            .and_then(|s| s.parse::<usize>().ok())
199            .filter(|v| *v > 0);
200
201        let chosen = env_override.unwrap_or_else(|| {
202            if n <= 3 {
203                1
204            } else if n <= 10 {
205                5
206            } else {
207                10
208            }
209        });
210
211        self.chat_batch(requests, Some(chosen)).await
212    }
213
214    /// Report user feedback (optional). This delegates to the injected `FeedbackSink`.
215    pub async fn report_feedback(&self, event: crate::feedback::FeedbackEvent) -> Result<()> {
216        self.feedback.report(event).await
217    }
218
219    /// Update rate limiter state from response headers using protocol-mapped names.
220    ///
221    /// This method is public for testing purposes.
222    pub async fn update_rate_limits(&self, headers: &reqwest::header::HeaderMap) {
223        use crate::client::preflight::PreflightExt;
224        PreflightExt::update_rate_limits(self, headers).await;
225    }
226
227    /// Unified entry point for calling a model.
228    /// Handles text, streaming, and error fallback automatically.
229    pub async fn call_model(
230        &self,
231        request: crate::protocol::UnifiedRequest,
232    ) -> Result<UnifiedResponse> {
233        Ok(self.call_model_with_stats(request).await?.0)
234    }
235
236    /// Call a model and also return per-call stats (latency, retries, request ids, endpoint, usage, etc.).
237    ///
238    /// This is intended for higher-level model selection and observability.
239    pub async fn call_model_with_stats(
240        &self,
241        request: crate::protocol::UnifiedRequest,
242    ) -> Result<(UnifiedResponse, CallStats)> {
243        // v0.5.0: The resilience logic is now delegated to the "Resilience Layer" (Pipeline Operators).
244        // This core loop is now significantly simpler: it just tries the primary client.
245        // If advanced resilience (multi-candidate fallback, complex retries) is needed,
246        // it should be configured via the `Pipeline` or `PolicyEngine` which now acts as an operator.
247
248        // Note: For v0.5.0 migration, we preserve the basic fallback iteration here
249        // until the `Pipeline` fully absorbs "Client Switching" logic.
250        // However, the explicit *retry* loop inside each candidate is now conceptually
251        // part of `execute_once_with_stats` (which will eventually use RetryOperator).
252
253        let mut last_err: Option<Error> = None;
254
255        // Build fallback clients first (async)
256        // In v0.6.0+, this will be replaced by `FallbackOperator` inside the pipeline
257        let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(self.fallbacks.len());
258        for model in &self.fallbacks {
259            if let Ok(c) = self.with_model(model).await {
260                fallback_clients.push(c);
261            }
262        }
263
264        // Iterate candidates: primary first, then fallbacks.
265        for (candidate_idx, client) in std::iter::once(self)
266            .chain(fallback_clients.iter())
267            .enumerate()
268        {
269            let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
270            let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
271
272            // 1. Validation check
273            if let Err(e) = policy.validate_capabilities(&request) {
274                if has_fallback {
275                    last_err = Some(e);
276                    continue; // Fallback to next candidate
277                } else {
278                    return Err(e); // No more fallbacks, fail fast
279                }
280            }
281
282            // 2. Pre-decision based on signals
283            let sig = client.signals().await;
284            if let Some(crate::client::policy::Decision::Fallback) =
285                policy.pre_decide(&sig, has_fallback)
286            {
287                last_err = Some(Error::runtime_with_context(
288                    "skipped candidate due to signals",
289                    ErrorContext::new().with_source("policy_engine"),
290                ));
291                continue;
292            }
293
294            let mut req = request.clone();
295            req.model = client.model_id.clone();
296
297            // 3. Execution with Retry Policy
298            // The `execute_with_retry` helper now encapsulates the retry loop,
299            // paving the way for `RetryOperator` migration.
300            match client.execute_with_retry(&req, &policy, has_fallback).await {
301                Ok(res) => return Ok(res),
302                Err(e) => {
303                    // If we are here, retries were exhausted or policy said Fallback/Fail.
304                    last_err = Some(e);
305                    // If policy said Fallback, continue loop.
306                    // If policy said Fail, strictly we should stop, but current logic implies
307                    // the loop itself is the "Fallback mechanism".
308                    if !has_fallback {
309                        return Err(last_err.unwrap());
310                    }
311                }
312            }
313        }
314
315        Err(last_err.unwrap_or_else(|| {
316            Error::runtime_with_context(
317                "all attempts failed",
318                ErrorContext::new().with_source("retry_policy"),
319            )
320        }))
321    }
322
323    /// Internal helper to execute with retry policy.
324    /// In future versions, this Logic moves entirely into `RetryOperator`.
325    async fn execute_with_retry(
326        &self,
327        request: &crate::protocol::UnifiedRequest,
328        policy: &crate::client::policy::PolicyEngine,
329        has_fallback: bool,
330    ) -> Result<(UnifiedResponse, CallStats)> {
331        let mut attempt: u32 = 0;
332        let mut retry_count: u32 = 0;
333
334        loop {
335            let attempt_fut = self.execute_once_with_stats(request);
336            let attempt_res = if let Some(t) = self.attempt_timeout {
337                match tokio::time::timeout(t, attempt_fut).await {
338                    Ok(r) => r,
339                    Err(_) => Err(Error::runtime_with_context(
340                        "attempt timeout",
341                        ErrorContext::new().with_source("timeout_policy"),
342                    )),
343                }
344            } else {
345                attempt_fut.await
346            };
347
348            match attempt_res {
349                Ok((resp, mut stats)) => {
350                    stats.retry_count = retry_count;
351                    return Ok((resp, stats));
352                }
353                Err(e) => {
354                    let decision = policy.decide(&e, attempt, has_fallback)?;
355
356                    match decision {
357                        crate::client::policy::Decision::Retry { delay } => {
358                            retry_count = retry_count.saturating_add(1);
359                            if delay.as_millis() > 0 {
360                                tokio::time::sleep(delay).await;
361                            }
362                            attempt = attempt.saturating_add(1);
363                            continue;
364                        }
365                        crate::client::policy::Decision::Fallback => return Err(e),
366                        crate::client::policy::Decision::Fail => return Err(e),
367                    }
368                }
369            }
370        }
371    }
372
373    /// Validate request capabilities.
374    pub fn validate_request(
375        &self,
376        request: &crate::client::chat::ChatRequestBuilder<'_>,
377    ) -> Result<()> {
378        // Build a minimal UnifiedRequest to check capabilities via PolicyEngine
379        let mut mock_req = crate::protocol::UnifiedRequest::default();
380        mock_req.stream = request.stream;
381        mock_req.tools = request.tools.clone();
382        mock_req.messages = request.messages.clone();
383
384        let policy = crate::client::policy::PolicyEngine::new(&self.manifest);
385        policy.validate_capabilities(&mock_req)
386    }
387}