Skip to main content

ai_lib_rust/client/
core.rs

1//! Core AI client implementation
2
3use crate::client::types::CallStats;
4use crate::protocol::ProtocolLoader;
5use crate::protocol::ProtocolManifest;
6use crate::{Error, ErrorContext, Result};
7use std::sync::Arc;
8
9use crate::pipeline::Pipeline;
10use crate::transport::HttpTransport;
11
12// Import submodules
13use crate::client::validation;
14
15/// Unified AI client that works with any provider through protocol configuration.
16pub struct AiClient {
17    pub manifest: ProtocolManifest,
18    pub transport: Arc<HttpTransport>,
19    pub pipeline: Arc<Pipeline>,
20    pub loader: Arc<ProtocolLoader>,
21    pub(crate) fallbacks: Vec<String>,
22    pub(crate) model_id: String,
23    pub(crate) strict_streaming: bool,
24    pub(crate) feedback: Arc<dyn crate::telemetry::FeedbackSink>,
25    pub(crate) inflight: Option<Arc<tokio::sync::Semaphore>>,
26    pub(crate) max_inflight: Option<usize>,
27    pub(crate) attempt_timeout: Option<std::time::Duration>,
28    pub(crate) breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
29    pub(crate) rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
30}
31
32/// Unified response format.
33#[derive(Debug, Default)]
34pub struct UnifiedResponse {
35    pub content: String,
36    pub tool_calls: Vec<crate::types::tool::ToolCall>,
37    pub usage: Option<serde_json::Value>,
38}
39
40impl AiClient {
41    /// Snapshot current runtime signals (facts only) for application-layer orchestration.
42    pub async fn signals(&self) -> crate::client::signals::SignalsSnapshot {
43        let inflight = self.inflight.as_ref().and_then(|sem| {
44            let max = self.max_inflight?;
45            let available = sem.available_permits();
46            let in_use = max.saturating_sub(available);
47            Some(crate::client::signals::InflightSnapshot {
48                max,
49                available,
50                in_use,
51            })
52        });
53
54        let rate_limiter = match &self.rate_limiter {
55            Some(rl) => Some(rl.snapshot().await),
56            None => None,
57        };
58
59        let circuit_breaker = match &self.breaker {
60            Some(cb) => Some(cb.snapshot()),
61            None => None,
62        };
63
64        crate::client::signals::SignalsSnapshot {
65            inflight,
66            rate_limiter,
67            circuit_breaker,
68        }
69    }
70
71    /// Create a new client for a specific model.
72    ///
73    /// The model identifier should be in the format "provider/model-name"
74    /// (e.g., "anthropic/claude-3-5-sonnet")
75    pub async fn new(model: &str) -> Result<Self> {
76        crate::client::builder::AiClientBuilder::new()
77            .build(model)
78            .await
79    }
80
81    /// Create a new client instance for another model, reusing loader + shared runtime knobs
82    /// (feedback, inflight, breaker, rate limiter) for consistent behavior.
83    pub(crate) async fn with_model(&self, model: &str) -> Result<Self> {
84        // model is in form "provider/model-id"
85        let parts: Vec<&str> = model.split('/').collect();
86        let model_id = parts
87            .get(1)
88            .map(|s| s.to_string())
89            .unwrap_or_else(|| model.to_string());
90
91        let manifest = self.loader.load_model(model).await?;
92        validation::validate_manifest(&manifest, self.strict_streaming)?;
93
94        let transport = Arc::new(crate::transport::HttpTransport::new(&manifest, &model_id)?);
95        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
96
97        Ok(AiClient {
98            manifest,
99            transport,
100            pipeline,
101            loader: self.loader.clone(),
102            fallbacks: Vec::new(),
103            model_id,
104            strict_streaming: self.strict_streaming,
105            feedback: self.feedback.clone(),
106            inflight: self.inflight.clone(),
107            max_inflight: self.max_inflight,
108            attempt_timeout: self.attempt_timeout,
109            breaker: self.breaker.clone(),
110            rate_limiter: self.rate_limiter.clone(),
111        })
112    }
113
114    /// Create a chat request builder.
115    pub fn chat(&self) -> crate::client::chat::ChatRequestBuilder<'_> {
116        crate::client::chat::ChatRequestBuilder::new(self)
117    }
118
119    /// Execute multiple chat requests concurrently with an optional concurrency limit.
120    ///
121    /// Notes:
122    /// - Results preserve input order.
123    /// - Internally uses the same "streaming → UnifiedResponse" path for consistency.
124    pub async fn chat_batch(
125        &self,
126        requests: Vec<crate::client::chat::ChatBatchRequest>,
127        concurrency_limit: Option<usize>,
128    ) -> Vec<Result<UnifiedResponse>> {
129        use futures::StreamExt;
130
131        let n = requests.len();
132        if n == 0 {
133            return Vec::new();
134        }
135
136        let limit = concurrency_limit.unwrap_or(10).max(1);
137        let mut out: Vec<Option<Result<UnifiedResponse>>> = (0..n).map(|_| None).collect();
138
139        let results: Vec<(usize, Result<UnifiedResponse>)> =
140            futures::stream::iter(requests.into_iter().enumerate())
141                .map(|(idx, req)| async move {
142                    let mut b = self.chat().messages(req.messages).stream();
143                    if let Some(t) = req.temperature {
144                        b = b.temperature(t);
145                    }
146                    if let Some(m) = req.max_tokens {
147                        b = b.max_tokens(m);
148                    }
149                    if let Some(tools) = req.tools {
150                        b = b.tools(tools);
151                    }
152                    if let Some(tc) = req.tool_choice {
153                        b = b.tool_choice(tc);
154                    }
155                    let r = b.execute().await;
156                    (idx, r)
157                })
158                .buffer_unordered(limit)
159                .collect()
160                .await;
161
162        for (idx, r) in results {
163            out[idx] = Some(r);
164        }
165
166        out.into_iter()
167            .map(|o| {
168                o.unwrap_or_else(|| {
169                    Err(Error::runtime_with_context(
170                        "batch result missing",
171                        ErrorContext::new().with_source("batch_executor"),
172                    ))
173                })
174            })
175            .collect()
176    }
177
178    /// Smart batch execution with a conservative, developer-friendly default heuristic.
179    ///
180    /// - For very small batches, run sequentially to reduce overhead.
181    /// - For larger batches, run with a bounded concurrency.
182    ///
183    /// You can override the chosen concurrency via env:
184    /// - `AI_LIB_BATCH_CONCURRENCY`
185    pub async fn chat_batch_smart(
186        &self,
187        requests: Vec<crate::client::chat::ChatBatchRequest>,
188    ) -> Vec<Result<UnifiedResponse>> {
189        let n = requests.len();
190        if n == 0 {
191            return Vec::new();
192        }
193
194        let env_override = std::env::var("AI_LIB_BATCH_CONCURRENCY")
195            .ok()
196            .and_then(|s| s.parse::<usize>().ok())
197            .filter(|v| *v > 0);
198
199        let chosen = env_override.unwrap_or_else(|| {
200            if n <= 3 {
201                1
202            } else if n <= 10 {
203                5
204            } else {
205                10
206            }
207        });
208
209        self.chat_batch(requests, Some(chosen)).await
210    }
211
212    /// Report user feedback (optional). This delegates to the injected `FeedbackSink`.
213    pub async fn report_feedback(&self, event: crate::telemetry::FeedbackEvent) -> Result<()> {
214        self.feedback.report(event).await
215    }
216
217    /// Update rate limiter state from response headers using protocol-mapped names.
218    ///
219    /// This method is public for testing purposes.
220    pub async fn update_rate_limits(&self, headers: &reqwest::header::HeaderMap) {
221        use crate::client::preflight::PreflightExt;
222        PreflightExt::update_rate_limits(self, headers).await;
223    }
224
225    /// Unified entry point for calling a model.
226    /// Handles text, streaming, and error fallback automatically.
227    pub async fn call_model(
228        &self,
229        request: crate::protocol::UnifiedRequest,
230    ) -> Result<UnifiedResponse> {
231        Ok(self.call_model_with_stats(request).await?.0)
232    }
233
234    /// Call a model and also return per-call stats (latency, retries, request ids, endpoint, usage, etc.).
235    ///
236    /// This is intended for higher-level model selection and observability.
237    /// Call a model and also return per-call stats (latency, retries, request ids, endpoint, usage, etc.).
238    ///
239    /// This is intended for higher-level model selection and observability.
240    pub async fn call_model_with_stats(
241        &self,
242        request: crate::protocol::UnifiedRequest,
243    ) -> Result<(UnifiedResponse, CallStats)> {
244        // v0.5.0: The resilience logic is now delegated to the "Resilience Layer" (Pipeline Operators).
245        // This core loop is now significantly simpler: it just tries the primary client.
246        // If advanced resilience (multi-candidate fallback, complex retries) is needed,
247        // it should be configured via the `Pipeline` or `PolicyEngine` which now acts as an operator.
248
249        // Note: For v0.5.0 migration, we preserve the basic fallback iteration here
250        // until the `Pipeline` fully absorbs "Client Switching" logic.
251        // However, the explicit *retry* loop inside each candidate is now conceptually
252        // part of `execute_once_with_stats` (which will eventually use RetryOperator).
253
254        let mut last_err: Option<Error> = None;
255
256        // Build fallback clients first (async)
257        // In v0.6.0+, this will be replaced by `FallbackOperator` inside the pipeline
258        let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(self.fallbacks.len());
259        for model in &self.fallbacks {
260            if let Ok(c) = self.with_model(model).await {
261                fallback_clients.push(c);
262            }
263        }
264
265        // Iterate candidates: primary first, then fallbacks.
266        for (candidate_idx, client) in std::iter::once(self)
267            .chain(fallback_clients.iter())
268            .enumerate()
269        {
270            let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
271            let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
272
273            // 1. Validation check
274            if let Err(e) = policy.validate_capabilities(&request) {
275                if has_fallback {
276                    last_err = Some(e);
277                    continue; // Fallback to next candidate
278                } else {
279                    return Err(e); // No more fallbacks, fail fast
280                }
281            }
282
283            // 2. Pre-decision based on signals
284            let sig = client.signals().await;
285            if let Some(crate::client::policy::Decision::Fallback) =
286                policy.pre_decide(&sig, has_fallback)
287            {
288                last_err = Some(Error::runtime_with_context(
289                    "skipped candidate due to signals",
290                    ErrorContext::new().with_source("policy_engine"),
291                ));
292                continue;
293            }
294
295            let mut req = request.clone();
296            req.model = client.model_id.clone();
297
298            // 3. Execution with Retry Policy
299            // The `execute_with_retry` helper now encapsulates the retry loop,
300            // paving the way for `RetryOperator` migration.
301            match client.execute_with_retry(&req, &policy, has_fallback).await {
302                Ok(res) => return Ok(res),
303                Err(e) => {
304                    // If we are here, retries were exhausted or policy said Fallback/Fail.
305                    last_err = Some(e);
306                    // If policy said Fallback, continue loop.
307                    // If policy said Fail, strictly we should stop, but current logic implies
308                    // the loop itself is the "Fallback mechanism".
309                    if !has_fallback {
310                        return Err(last_err.unwrap());
311                    }
312                }
313            }
314        }
315
316        Err(last_err.unwrap_or_else(|| {
317            Error::runtime_with_context(
318                "all attempts failed",
319                ErrorContext::new().with_source("retry_policy"),
320            )
321        }))
322    }
323
324    /// Internal helper to execute with retry policy.
325    /// In future versions, this Logic moves entirely into `RetryOperator`.
326    async fn execute_with_retry(
327        &self,
328        request: &crate::protocol::UnifiedRequest,
329        policy: &crate::client::policy::PolicyEngine,
330        has_fallback: bool,
331    ) -> Result<(UnifiedResponse, CallStats)> {
332        let mut attempt: u32 = 0;
333        let mut retry_count: u32 = 0;
334
335        loop {
336            let attempt_fut = self.execute_once_with_stats(request);
337            let attempt_res = if let Some(t) = self.attempt_timeout {
338                match tokio::time::timeout(t, attempt_fut).await {
339                    Ok(r) => r,
340                    Err(_) => Err(Error::runtime_with_context(
341                        "attempt timeout",
342                        ErrorContext::new().with_source("timeout_policy"),
343                    )),
344                }
345            } else {
346                attempt_fut.await
347            };
348
349            match attempt_res {
350                Ok((resp, mut stats)) => {
351                    stats.retry_count = retry_count;
352                    return Ok((resp, stats));
353                }
354                Err(e) => {
355                    let decision = policy.decide(&e, attempt, has_fallback)?;
356
357                    match decision {
358                        crate::client::policy::Decision::Retry { delay } => {
359                            retry_count = retry_count.saturating_add(1);
360                            if delay.as_millis() > 0 {
361                                tokio::time::sleep(delay).await;
362                            }
363                            attempt = attempt.saturating_add(1);
364                            continue;
365                        }
366                        crate::client::policy::Decision::Fallback => return Err(e),
367                        crate::client::policy::Decision::Fail => return Err(e),
368                    }
369                }
370            }
371        }
372    }
373
374    /// Validate request capabilities.
375    pub fn validate_request(
376        &self,
377        request: &crate::client::chat::ChatRequestBuilder<'_>,
378    ) -> Result<()> {
379        // Build a minimal UnifiedRequest to check capabilities via PolicyEngine
380        let mut mock_req = crate::protocol::UnifiedRequest::default();
381        mock_req.stream = request.stream;
382        mock_req.tools = request.tools.clone();
383        mock_req.messages = request.messages.clone();
384
385        let policy = crate::client::policy::PolicyEngine::new(&self.manifest);
386        policy.validate_capabilities(&mock_req)
387    }
388}