Skip to main content

ai_lib_core/client/
core.rs

1//! 核心客户端实现:管理协议加载、传输、流水线及弹性策略。
2//!
3//! Core AI client implementation.
4
5use crate::client::types::{CallStats, ClientMetrics};
6use crate::protocol::ProtocolLoader;
7use crate::protocol::ProtocolManifest;
8use crate::{Error, ErrorContext, Result};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use crate::pipeline::Pipeline;
13use crate::transport::HttpTransport;
14
15// Import submodules
16use crate::client::validation;
17
18/// Unified AI client that works with any provider through protocol configuration.
19pub struct AiClient {
20    pub manifest: ProtocolManifest,
21    pub transport: Arc<HttpTransport>,
22    pub pipeline: Arc<Pipeline>,
23    pub loader: Arc<ProtocolLoader>,
24    pub(crate) fallbacks: Vec<String>,
25    pub(crate) model_id: String,
26    pub(crate) strict_streaming: bool,
27    pub(crate) feedback: Arc<dyn crate::feedback::FeedbackSink>,
28    pub(crate) inflight: Option<Arc<tokio::sync::Semaphore>>,
29    pub(crate) max_inflight: Option<usize>,
30    pub(crate) attempt_timeout: Option<std::time::Duration>,
31    pub(crate) total_requests: AtomicU64,
32    pub(crate) successful_requests: AtomicU64,
33    pub(crate) total_tokens: AtomicU64,
34}
35
36/// Unified response format.
37#[derive(Debug, Default)]
38pub struct UnifiedResponse {
39    pub content: String,
40    pub tool_calls: Vec<crate::types::tool::ToolCall>,
41    pub usage: Option<serde_json::Value>,
42}
43
44impl AiClient {
45    /// Returns a snapshot of cumulative client metrics.
46    ///
47    /// Useful for monitoring, routing decisions, and observability.
48    pub fn metrics(&self) -> ClientMetrics {
49        ClientMetrics {
50            total_requests: self.total_requests.load(Ordering::Relaxed),
51            successful_requests: self.successful_requests.load(Ordering::Relaxed),
52            total_tokens: self.total_tokens.load(Ordering::Relaxed),
53        }
54    }
55
56    pub(crate) fn record_success(&self, stats: &CallStats) {
57        self.successful_requests.fetch_add(1, Ordering::Relaxed);
58        if let Some(tokens) = Self::extract_total_tokens(&stats.usage) {
59            self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
60        }
61    }
62
63    pub(crate) fn record_request(&self) {
64        self.total_requests.fetch_add(1, Ordering::Relaxed);
65    }
66
67    fn extract_total_tokens(usage: &Option<serde_json::Value>) -> Option<u64> {
68        let u = usage.as_ref()?;
69        u.get("total_tokens").and_then(|v| v.as_u64()).or_else(|| {
70            u.get("usage")
71                .and_then(|nested| nested.get("total_tokens"))
72                .and_then(|v| v.as_u64())
73        })
74    }
75
76    /// Snapshot current runtime signals (facts only) for application-layer orchestration.
77    pub async fn signals(&self) -> crate::client::signals::SignalsSnapshot {
78        let inflight = self.inflight.as_ref().and_then(|sem| {
79            let max = self.max_inflight?;
80            let available = sem.available_permits();
81            let in_use = max.saturating_sub(available);
82            Some(crate::client::signals::InflightSnapshot {
83                max,
84                available,
85                in_use,
86            })
87        });
88
89        crate::client::signals::SignalsSnapshot { inflight }
90    }
91
92    /// Create a new client for a specific model.
93    ///
94    /// The model identifier should be in the format "provider/model-name"
95    /// (e.g., "anthropic/claude-3-5-sonnet")
96    pub async fn new(model: &str) -> Result<Self> {
97        crate::client::builder::AiClientBuilder::new()
98            .build(model)
99            .await
100    }
101
102    /// Create a new client instance for another model, reusing loader + shared runtime knobs
103    /// (feedback, inflight) for consistent behavior.
104    pub(crate) async fn with_model(&self, model: &str) -> Result<Self> {
105        // model is in form "provider/model-id"
106        let parts: Vec<&str> = model.split('/').collect();
107        let model_id = parts
108            .get(1)
109            .map(|s| s.to_string())
110            .unwrap_or_else(|| model.to_string());
111
112        let manifest = self.loader.load_model(model).await?;
113        validation::validate_manifest(&manifest, self.strict_streaming)?;
114
115        let transport = Arc::new(crate::transport::HttpTransport::new(&manifest, &model_id)?);
116        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
117
118        Ok(AiClient {
119            manifest,
120            transport,
121            pipeline,
122            loader: self.loader.clone(),
123            fallbacks: Vec::new(),
124            model_id,
125            strict_streaming: self.strict_streaming,
126            feedback: self.feedback.clone(),
127            inflight: self.inflight.clone(),
128            max_inflight: self.max_inflight,
129            attempt_timeout: self.attempt_timeout,
130            total_requests: AtomicU64::new(0),
131            successful_requests: AtomicU64::new(0),
132            total_tokens: AtomicU64::new(0),
133        })
134    }
135
136    /// Create a chat request builder.
137    pub fn chat(&self) -> crate::client::chat::ChatRequestBuilder<'_> {
138        crate::client::chat::ChatRequestBuilder::new(self)
139    }
140
141    /// Execute multiple chat requests concurrently with an optional concurrency limit.
142    ///
143    /// Notes:
144    /// - Results preserve input order.
145    /// - Internally uses the same "streaming → UnifiedResponse" path for consistency.
146    pub async fn chat_batch(
147        &self,
148        requests: Vec<crate::client::chat::ChatBatchRequest>,
149        concurrency_limit: Option<usize>,
150    ) -> Vec<Result<UnifiedResponse>> {
151        use futures::StreamExt;
152
153        let n = requests.len();
154        if n == 0 {
155            return Vec::new();
156        }
157
158        let limit = concurrency_limit.unwrap_or(10).max(1);
159        let mut out: Vec<Option<Result<UnifiedResponse>>> = (0..n).map(|_| None).collect();
160
161        let results: Vec<(usize, Result<UnifiedResponse>)> =
162            futures::stream::iter(requests.into_iter().enumerate())
163                .map(|(idx, req)| async move {
164                    let mut b = self.chat().messages(req.messages).stream();
165                    if let Some(t) = req.temperature {
166                        b = b.temperature(t);
167                    }
168                    if let Some(m) = req.max_tokens {
169                        b = b.max_tokens(m);
170                    }
171                    if let Some(tools) = req.tools {
172                        b = b.tools(tools);
173                    }
174                    if let Some(tc) = req.tool_choice {
175                        b = b.tool_choice(tc);
176                    }
177                    let r = b.execute().await;
178                    (idx, r)
179                })
180                .buffer_unordered(limit)
181                .collect()
182                .await;
183
184        for (idx, r) in results {
185            out[idx] = Some(r);
186        }
187
188        out.into_iter()
189            .map(|o| {
190                o.unwrap_or_else(|| {
191                    Err(Error::runtime_with_context(
192                        "batch result missing",
193                        ErrorContext::new().with_source("batch_executor"),
194                    ))
195                })
196            })
197            .collect()
198    }
199
200    /// Smart batch execution with a conservative, developer-friendly default heuristic.
201    ///
202    /// - For very small batches, run sequentially to reduce overhead.
203    /// - For larger batches, run with a bounded concurrency.
204    ///
205    /// You can override the chosen concurrency via env:
206    /// - `AI_LIB_BATCH_CONCURRENCY`
207    pub async fn chat_batch_smart(
208        &self,
209        requests: Vec<crate::client::chat::ChatBatchRequest>,
210    ) -> Vec<Result<UnifiedResponse>> {
211        let n = requests.len();
212        if n == 0 {
213            return Vec::new();
214        }
215
216        let env_override = std::env::var("AI_LIB_BATCH_CONCURRENCY")
217            .ok()
218            .and_then(|s| s.parse::<usize>().ok())
219            .filter(|v| *v > 0);
220
221        let chosen = env_override.unwrap_or({
222            if n <= 3 {
223                1
224            } else if n <= 10 {
225                5
226            } else {
227                10
228            }
229        });
230
231        self.chat_batch(requests, Some(chosen)).await
232    }
233
234    /// Report user feedback (optional). This delegates to the injected `FeedbackSink`.
235    pub async fn report_feedback(&self, event: crate::feedback::FeedbackEvent) -> Result<()> {
236        self.feedback.report(event).await
237    }
238
239    /// Unified entry point for calling a model.
240    /// Handles text, streaming, and error fallback automatically.
241    pub async fn call_model(
242        &self,
243        request: crate::protocol::UnifiedRequest,
244    ) -> Result<UnifiedResponse> {
245        Ok(self.call_model_with_stats(request).await?.0)
246    }
247
248    /// Call a model and also return per-call stats (latency, retries, request ids, endpoint, usage, etc.).
249    ///
250    /// This is intended for higher-level model selection and observability.
251    pub async fn call_model_with_stats(
252        &self,
253        request: crate::protocol::UnifiedRequest,
254    ) -> Result<(UnifiedResponse, CallStats)> {
255        self.record_request();
256
257        // v0.5.0: The resilience logic is now delegated to the "Resilience Layer" (Pipeline Operators).
258        // This core loop is now significantly simpler: it just tries the primary client.
259        // If advanced resilience (multi-candidate fallback, complex retries) is needed,
260        // it should be configured via the `Pipeline` or `PolicyEngine` which now acts as an operator.
261
262        // Note: For v0.5.0 migration, we preserve the basic fallback iteration here
263        // until the `Pipeline` fully absorbs "Client Switching" logic.
264        // However, the explicit *retry* loop inside each candidate is now conceptually
265        // part of `execute_once_with_stats` (which will eventually use RetryOperator).
266
267        let mut last_err: Option<Error> = None;
268
269        // Build fallback clients first (async)
270        // In v0.6.0+, this will be replaced by `FallbackOperator` inside the pipeline
271        let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(self.fallbacks.len());
272        for model in &self.fallbacks {
273            if let Ok(c) = self.with_model(model).await {
274                fallback_clients.push(c);
275            }
276        }
277
278        // Iterate candidates: primary first, then fallbacks.
279        for (candidate_idx, client) in std::iter::once(self)
280            .chain(fallback_clients.iter())
281            .enumerate()
282        {
283            let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
284            let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
285
286            // 1. Validation check
287            if let Err(e) = policy.validate_capabilities(&request) {
288                if has_fallback {
289                    last_err = Some(e);
290                    continue; // Fallback to next candidate
291                } else {
292                    return Err(e); // No more fallbacks, fail fast
293                }
294            }
295
296            // 2. Pre-decision based on signals
297            let sig = client.signals().await;
298            if let Some(crate::client::policy::Decision::Fallback) =
299                policy.pre_decide(&sig, has_fallback)
300            {
301                last_err = Some(Error::runtime_with_context(
302                    "skipped candidate due to signals",
303                    ErrorContext::new().with_source("policy_engine"),
304                ));
305                continue;
306            }
307
308            let mut req = request.clone();
309            if candidate_idx > 0 {
310                req.model = client.model_id.clone();
311            }
312
313            // 3. Execution with Retry Policy
314            // The `execute_with_retry` helper now encapsulates the retry loop,
315            // paving the way for `RetryOperator` migration.
316            match client.execute_with_retry(&req, &policy, has_fallback).await {
317                Ok((resp, stats)) => {
318                    client.record_success(&stats);
319                    return Ok((resp, stats));
320                }
321                Err(e) => {
322                    // If we are here, retries were exhausted or policy said Fallback/Fail.
323                    last_err = Some(e);
324                    // If policy said Fallback, continue loop.
325                    // If policy said Fail, strictly we should stop, but current logic implies
326                    // the loop itself is the "Fallback mechanism".
327                    if !has_fallback {
328                        return Err(last_err.unwrap());
329                    }
330                }
331            }
332        }
333
334        Err(last_err.unwrap_or_else(|| {
335            Error::runtime_with_context(
336                "all attempts failed",
337                ErrorContext::new().with_source("retry_policy"),
338            )
339        }))
340    }
341
342    /// Internal helper to execute with retry policy.
343    /// In future versions, this Logic moves entirely into `RetryOperator`.
344    async fn execute_with_retry(
345        &self,
346        request: &crate::protocol::UnifiedRequest,
347        policy: &crate::client::policy::PolicyEngine,
348        has_fallback: bool,
349    ) -> Result<(UnifiedResponse, CallStats)> {
350        let mut attempt: u32 = 0;
351        let mut retry_count: u32 = 0;
352
353        loop {
354            let attempt_fut = self.execute_once_with_stats(request);
355            let attempt_res = if let Some(t) = self.attempt_timeout {
356                match tokio::time::timeout(t, attempt_fut).await {
357                    Ok(r) => r,
358                    Err(_) => Err(Error::runtime_with_context(
359                        "attempt timeout",
360                        ErrorContext::new().with_source("timeout_policy"),
361                    )),
362                }
363            } else {
364                attempt_fut.await
365            };
366
367            match attempt_res {
368                Ok((resp, mut stats)) => {
369                    stats.retry_count = retry_count;
370                    return Ok((resp, stats));
371                }
372                Err(e) => {
373                    let decision = policy.decide(&e, attempt, has_fallback)?;
374
375                    match decision {
376                        crate::client::policy::Decision::Retry { delay } => {
377                            retry_count = retry_count.saturating_add(1);
378                            if delay.as_millis() > 0 {
379                                tokio::time::sleep(delay).await;
380                            }
381                            attempt = attempt.saturating_add(1);
382                            continue;
383                        }
384                        crate::client::policy::Decision::Fallback => return Err(e),
385                        crate::client::policy::Decision::Fail => return Err(e),
386                    }
387                }
388            }
389        }
390    }
391
392    /// Validate request capabilities.
393    pub fn validate_request(
394        &self,
395        request: &crate::client::chat::ChatRequestBuilder<'_>,
396    ) -> Result<()> {
397        // Build a minimal UnifiedRequest to check capabilities via PolicyEngine
398        let mock_req = crate::protocol::UnifiedRequest {
399            stream: request.stream,
400            tools: request.tools.clone(),
401            messages: request.messages.clone(),
402            response_format: request.response_format.clone(),
403            ..Default::default()
404        };
405
406        let policy = crate::client::policy::PolicyEngine::new(&self.manifest);
407        policy.validate_capabilities(&mock_req)
408    }
409}