1use crate::client::types::CallStats;
6use crate::protocol::ProtocolLoader;
7use crate::protocol::ProtocolManifest;
8use crate::{Error, ErrorContext, Result};
9use std::sync::Arc;
10
11use crate::pipeline::Pipeline;
12use crate::transport::HttpTransport;
13
14use crate::client::validation;
16
17pub struct AiClient {
19 pub manifest: ProtocolManifest,
20 pub transport: Arc<HttpTransport>,
21 pub pipeline: Arc<Pipeline>,
22 pub loader: Arc<ProtocolLoader>,
23 pub(crate) fallbacks: Vec<String>,
24 pub(crate) model_id: String,
25 pub(crate) strict_streaming: bool,
26 pub(crate) feedback: Arc<dyn crate::feedback::FeedbackSink>,
27 pub(crate) inflight: Option<Arc<tokio::sync::Semaphore>>,
28 pub(crate) max_inflight: Option<usize>,
29 pub(crate) attempt_timeout: Option<std::time::Duration>,
30 pub(crate) breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
31 pub(crate) rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
32}
33
34#[derive(Debug, Default)]
36pub struct UnifiedResponse {
37 pub content: String,
38 pub tool_calls: Vec<crate::types::tool::ToolCall>,
39 pub usage: Option<serde_json::Value>,
40}
41
42impl AiClient {
43 pub async fn signals(&self) -> crate::client::signals::SignalsSnapshot {
45 let inflight = self.inflight.as_ref().and_then(|sem| {
46 let max = self.max_inflight?;
47 let available = sem.available_permits();
48 let in_use = max.saturating_sub(available);
49 Some(crate::client::signals::InflightSnapshot {
50 max,
51 available,
52 in_use,
53 })
54 });
55
56 let rate_limiter = match &self.rate_limiter {
57 Some(rl) => Some(rl.snapshot().await),
58 None => None,
59 };
60
61 let circuit_breaker = match &self.breaker {
62 Some(cb) => Some(cb.snapshot()),
63 None => None,
64 };
65
66 crate::client::signals::SignalsSnapshot {
67 inflight,
68 rate_limiter,
69 circuit_breaker,
70 }
71 }
72
73 pub async fn new(model: &str) -> Result<Self> {
78 crate::client::builder::AiClientBuilder::new()
79 .build(model)
80 .await
81 }
82
83 pub(crate) async fn with_model(&self, model: &str) -> Result<Self> {
86 let parts: Vec<&str> = model.split('/').collect();
88 let model_id = parts
89 .get(1)
90 .map(|s| s.to_string())
91 .unwrap_or_else(|| model.to_string());
92
93 let manifest = self.loader.load_model(model).await?;
94 validation::validate_manifest(&manifest, self.strict_streaming)?;
95
96 let transport = Arc::new(crate::transport::HttpTransport::new(&manifest, &model_id)?);
97 let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
98
99 Ok(AiClient {
100 manifest,
101 transport,
102 pipeline,
103 loader: self.loader.clone(),
104 fallbacks: Vec::new(),
105 model_id,
106 strict_streaming: self.strict_streaming,
107 feedback: self.feedback.clone(),
108 inflight: self.inflight.clone(),
109 max_inflight: self.max_inflight,
110 attempt_timeout: self.attempt_timeout,
111 breaker: self.breaker.clone(),
112 rate_limiter: self.rate_limiter.clone(),
113 })
114 }
115
116 pub fn chat(&self) -> crate::client::chat::ChatRequestBuilder<'_> {
118 crate::client::chat::ChatRequestBuilder::new(self)
119 }
120
121 pub async fn chat_batch(
127 &self,
128 requests: Vec<crate::client::chat::ChatBatchRequest>,
129 concurrency_limit: Option<usize>,
130 ) -> Vec<Result<UnifiedResponse>> {
131 use futures::StreamExt;
132
133 let n = requests.len();
134 if n == 0 {
135 return Vec::new();
136 }
137
138 let limit = concurrency_limit.unwrap_or(10).max(1);
139 let mut out: Vec<Option<Result<UnifiedResponse>>> = (0..n).map(|_| None).collect();
140
141 let results: Vec<(usize, Result<UnifiedResponse>)> =
142 futures::stream::iter(requests.into_iter().enumerate())
143 .map(|(idx, req)| async move {
144 let mut b = self.chat().messages(req.messages).stream();
145 if let Some(t) = req.temperature {
146 b = b.temperature(t);
147 }
148 if let Some(m) = req.max_tokens {
149 b = b.max_tokens(m);
150 }
151 if let Some(tools) = req.tools {
152 b = b.tools(tools);
153 }
154 if let Some(tc) = req.tool_choice {
155 b = b.tool_choice(tc);
156 }
157 let r = b.execute().await;
158 (idx, r)
159 })
160 .buffer_unordered(limit)
161 .collect()
162 .await;
163
164 for (idx, r) in results {
165 out[idx] = Some(r);
166 }
167
168 out.into_iter()
169 .map(|o| {
170 o.unwrap_or_else(|| {
171 Err(Error::runtime_with_context(
172 "batch result missing",
173 ErrorContext::new().with_source("batch_executor"),
174 ))
175 })
176 })
177 .collect()
178 }
179
180 pub async fn chat_batch_smart(
188 &self,
189 requests: Vec<crate::client::chat::ChatBatchRequest>,
190 ) -> Vec<Result<UnifiedResponse>> {
191 let n = requests.len();
192 if n == 0 {
193 return Vec::new();
194 }
195
196 let env_override = std::env::var("AI_LIB_BATCH_CONCURRENCY")
197 .ok()
198 .and_then(|s| s.parse::<usize>().ok())
199 .filter(|v| *v > 0);
200
201 let chosen = env_override.unwrap_or_else(|| {
202 if n <= 3 {
203 1
204 } else if n <= 10 {
205 5
206 } else {
207 10
208 }
209 });
210
211 self.chat_batch(requests, Some(chosen)).await
212 }
213
214 pub async fn report_feedback(&self, event: crate::feedback::FeedbackEvent) -> Result<()> {
216 self.feedback.report(event).await
217 }
218
219 pub async fn update_rate_limits(&self, headers: &reqwest::header::HeaderMap) {
223 use crate::client::preflight::PreflightExt;
224 PreflightExt::update_rate_limits(self, headers).await;
225 }
226
227 pub async fn call_model(
230 &self,
231 request: crate::protocol::UnifiedRequest,
232 ) -> Result<UnifiedResponse> {
233 Ok(self.call_model_with_stats(request).await?.0)
234 }
235
236 pub async fn call_model_with_stats(
240 &self,
241 request: crate::protocol::UnifiedRequest,
242 ) -> Result<(UnifiedResponse, CallStats)> {
243 let mut last_err: Option<Error> = None;
254
255 let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(self.fallbacks.len());
258 for model in &self.fallbacks {
259 if let Ok(c) = self.with_model(model).await {
260 fallback_clients.push(c);
261 }
262 }
263
264 for (candidate_idx, client) in std::iter::once(self)
266 .chain(fallback_clients.iter())
267 .enumerate()
268 {
269 let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
270 let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
271
272 if let Err(e) = policy.validate_capabilities(&request) {
274 if has_fallback {
275 last_err = Some(e);
276 continue; } else {
278 return Err(e); }
280 }
281
282 let sig = client.signals().await;
284 if let Some(crate::client::policy::Decision::Fallback) =
285 policy.pre_decide(&sig, has_fallback)
286 {
287 last_err = Some(Error::runtime_with_context(
288 "skipped candidate due to signals",
289 ErrorContext::new().with_source("policy_engine"),
290 ));
291 continue;
292 }
293
294 let mut req = request.clone();
295 req.model = client.model_id.clone();
296
297 match client.execute_with_retry(&req, &policy, has_fallback).await {
301 Ok(res) => return Ok(res),
302 Err(e) => {
303 last_err = Some(e);
305 if !has_fallback {
309 return Err(last_err.unwrap());
310 }
311 }
312 }
313 }
314
315 Err(last_err.unwrap_or_else(|| {
316 Error::runtime_with_context(
317 "all attempts failed",
318 ErrorContext::new().with_source("retry_policy"),
319 )
320 }))
321 }
322
323 async fn execute_with_retry(
326 &self,
327 request: &crate::protocol::UnifiedRequest,
328 policy: &crate::client::policy::PolicyEngine,
329 has_fallback: bool,
330 ) -> Result<(UnifiedResponse, CallStats)> {
331 let mut attempt: u32 = 0;
332 let mut retry_count: u32 = 0;
333
334 loop {
335 let attempt_fut = self.execute_once_with_stats(request);
336 let attempt_res = if let Some(t) = self.attempt_timeout {
337 match tokio::time::timeout(t, attempt_fut).await {
338 Ok(r) => r,
339 Err(_) => Err(Error::runtime_with_context(
340 "attempt timeout",
341 ErrorContext::new().with_source("timeout_policy"),
342 )),
343 }
344 } else {
345 attempt_fut.await
346 };
347
348 match attempt_res {
349 Ok((resp, mut stats)) => {
350 stats.retry_count = retry_count;
351 return Ok((resp, stats));
352 }
353 Err(e) => {
354 let decision = policy.decide(&e, attempt, has_fallback)?;
355
356 match decision {
357 crate::client::policy::Decision::Retry { delay } => {
358 retry_count = retry_count.saturating_add(1);
359 if delay.as_millis() > 0 {
360 tokio::time::sleep(delay).await;
361 }
362 attempt = attempt.saturating_add(1);
363 continue;
364 }
365 crate::client::policy::Decision::Fallback => return Err(e),
366 crate::client::policy::Decision::Fail => return Err(e),
367 }
368 }
369 }
370 }
371 }
372
373 pub fn validate_request(
375 &self,
376 request: &crate::client::chat::ChatRequestBuilder<'_>,
377 ) -> Result<()> {
378 let mut mock_req = crate::protocol::UnifiedRequest::default();
380 mock_req.stream = request.stream;
381 mock_req.tools = request.tools.clone();
382 mock_req.messages = request.messages.clone();
383
384 let policy = crate::client::policy::PolicyEngine::new(&self.manifest);
385 policy.validate_capabilities(&mock_req)
386 }
387}