1use crate::client::types::CallStats;
4use crate::protocol::ProtocolLoader;
5use crate::protocol::ProtocolManifest;
6use crate::{Error, ErrorContext, Result};
7use std::sync::Arc;
8
9use crate::pipeline::Pipeline;
10use crate::transport::HttpTransport;
11
12use crate::client::validation;
14
15pub struct AiClient {
17 pub manifest: ProtocolManifest,
18 pub transport: Arc<HttpTransport>,
19 pub pipeline: Arc<Pipeline>,
20 pub loader: Arc<ProtocolLoader>,
21 pub(crate) fallbacks: Vec<String>,
22 pub(crate) model_id: String,
23 pub(crate) strict_streaming: bool,
24 pub(crate) feedback: Arc<dyn crate::telemetry::FeedbackSink>,
25 pub(crate) inflight: Option<Arc<tokio::sync::Semaphore>>,
26 pub(crate) max_inflight: Option<usize>,
27 pub(crate) attempt_timeout: Option<std::time::Duration>,
28 pub(crate) breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
29 pub(crate) rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
30}
31
32#[derive(Debug, Default)]
34pub struct UnifiedResponse {
35 pub content: String,
36 pub tool_calls: Vec<crate::types::tool::ToolCall>,
37 pub usage: Option<serde_json::Value>,
38}
39
40impl AiClient {
41 pub async fn signals(&self) -> crate::client::signals::SignalsSnapshot {
43 let inflight = self.inflight.as_ref().and_then(|sem| {
44 let max = self.max_inflight?;
45 let available = sem.available_permits();
46 let in_use = max.saturating_sub(available);
47 Some(crate::client::signals::InflightSnapshot {
48 max,
49 available,
50 in_use,
51 })
52 });
53
54 let rate_limiter = match &self.rate_limiter {
55 Some(rl) => Some(rl.snapshot().await),
56 None => None,
57 };
58
59 let circuit_breaker = match &self.breaker {
60 Some(cb) => Some(cb.snapshot()),
61 None => None,
62 };
63
64 crate::client::signals::SignalsSnapshot {
65 inflight,
66 rate_limiter,
67 circuit_breaker,
68 }
69 }
70
71 pub async fn new(model: &str) -> Result<Self> {
76 crate::client::builder::AiClientBuilder::new()
77 .build(model)
78 .await
79 }
80
81 pub(crate) async fn with_model(&self, model: &str) -> Result<Self> {
84 let parts: Vec<&str> = model.split('/').collect();
86 let model_id = parts
87 .get(1)
88 .map(|s| s.to_string())
89 .unwrap_or_else(|| model.to_string());
90
91 let manifest = self.loader.load_model(model).await?;
92 validation::validate_manifest(&manifest, self.strict_streaming)?;
93
94 let transport = Arc::new(crate::transport::HttpTransport::new(&manifest, &model_id)?);
95 let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
96
97 Ok(AiClient {
98 manifest,
99 transport,
100 pipeline,
101 loader: self.loader.clone(),
102 fallbacks: Vec::new(),
103 model_id,
104 strict_streaming: self.strict_streaming,
105 feedback: self.feedback.clone(),
106 inflight: self.inflight.clone(),
107 max_inflight: self.max_inflight,
108 attempt_timeout: self.attempt_timeout,
109 breaker: self.breaker.clone(),
110 rate_limiter: self.rate_limiter.clone(),
111 })
112 }
113
114 pub fn chat(&self) -> crate::client::chat::ChatRequestBuilder<'_> {
116 crate::client::chat::ChatRequestBuilder::new(self)
117 }
118
119 pub async fn chat_batch(
125 &self,
126 requests: Vec<crate::client::chat::ChatBatchRequest>,
127 concurrency_limit: Option<usize>,
128 ) -> Vec<Result<UnifiedResponse>> {
129 use futures::StreamExt;
130
131 let n = requests.len();
132 if n == 0 {
133 return Vec::new();
134 }
135
136 let limit = concurrency_limit.unwrap_or(10).max(1);
137 let mut out: Vec<Option<Result<UnifiedResponse>>> = (0..n).map(|_| None).collect();
138
139 let results: Vec<(usize, Result<UnifiedResponse>)> =
140 futures::stream::iter(requests.into_iter().enumerate())
141 .map(|(idx, req)| async move {
142 let mut b = self.chat().messages(req.messages).stream();
143 if let Some(t) = req.temperature {
144 b = b.temperature(t);
145 }
146 if let Some(m) = req.max_tokens {
147 b = b.max_tokens(m);
148 }
149 if let Some(tools) = req.tools {
150 b = b.tools(tools);
151 }
152 if let Some(tc) = req.tool_choice {
153 b = b.tool_choice(tc);
154 }
155 let r = b.execute().await;
156 (idx, r)
157 })
158 .buffer_unordered(limit)
159 .collect()
160 .await;
161
162 for (idx, r) in results {
163 out[idx] = Some(r);
164 }
165
166 out.into_iter()
167 .map(|o| {
168 o.unwrap_or_else(|| {
169 Err(Error::runtime_with_context(
170 "batch result missing",
171 ErrorContext::new().with_source("batch_executor"),
172 ))
173 })
174 })
175 .collect()
176 }
177
178 pub async fn chat_batch_smart(
186 &self,
187 requests: Vec<crate::client::chat::ChatBatchRequest>,
188 ) -> Vec<Result<UnifiedResponse>> {
189 let n = requests.len();
190 if n == 0 {
191 return Vec::new();
192 }
193
194 let env_override = std::env::var("AI_LIB_BATCH_CONCURRENCY")
195 .ok()
196 .and_then(|s| s.parse::<usize>().ok())
197 .filter(|v| *v > 0);
198
199 let chosen = env_override.unwrap_or_else(|| {
200 if n <= 3 {
201 1
202 } else if n <= 10 {
203 5
204 } else {
205 10
206 }
207 });
208
209 self.chat_batch(requests, Some(chosen)).await
210 }
211
212 pub async fn report_feedback(&self, event: crate::telemetry::FeedbackEvent) -> Result<()> {
214 self.feedback.report(event).await
215 }
216
217 pub async fn update_rate_limits(&self, headers: &reqwest::header::HeaderMap) {
221 use crate::client::preflight::PreflightExt;
222 PreflightExt::update_rate_limits(self, headers).await;
223 }
224
225 pub async fn call_model(
228 &self,
229 request: crate::protocol::UnifiedRequest,
230 ) -> Result<UnifiedResponse> {
231 Ok(self.call_model_with_stats(request).await?.0)
232 }
233
234 pub async fn call_model_with_stats(
241 &self,
242 request: crate::protocol::UnifiedRequest,
243 ) -> Result<(UnifiedResponse, CallStats)> {
244 let mut last_err: Option<Error> = None;
255
256 let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(self.fallbacks.len());
259 for model in &self.fallbacks {
260 if let Ok(c) = self.with_model(model).await {
261 fallback_clients.push(c);
262 }
263 }
264
265 for (candidate_idx, client) in std::iter::once(self)
267 .chain(fallback_clients.iter())
268 .enumerate()
269 {
270 let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
271 let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
272
273 if let Err(e) = policy.validate_capabilities(&request) {
275 if has_fallback {
276 last_err = Some(e);
277 continue; } else {
279 return Err(e); }
281 }
282
283 let sig = client.signals().await;
285 if let Some(crate::client::policy::Decision::Fallback) =
286 policy.pre_decide(&sig, has_fallback)
287 {
288 last_err = Some(Error::runtime_with_context(
289 "skipped candidate due to signals",
290 ErrorContext::new().with_source("policy_engine"),
291 ));
292 continue;
293 }
294
295 let mut req = request.clone();
296 req.model = client.model_id.clone();
297
298 match client.execute_with_retry(&req, &policy, has_fallback).await {
302 Ok(res) => return Ok(res),
303 Err(e) => {
304 last_err = Some(e);
306 if !has_fallback {
310 return Err(last_err.unwrap());
311 }
312 }
313 }
314 }
315
316 Err(last_err.unwrap_or_else(|| {
317 Error::runtime_with_context(
318 "all attempts failed",
319 ErrorContext::new().with_source("retry_policy"),
320 )
321 }))
322 }
323
324 async fn execute_with_retry(
327 &self,
328 request: &crate::protocol::UnifiedRequest,
329 policy: &crate::client::policy::PolicyEngine,
330 has_fallback: bool,
331 ) -> Result<(UnifiedResponse, CallStats)> {
332 let mut attempt: u32 = 0;
333 let mut retry_count: u32 = 0;
334
335 loop {
336 let attempt_fut = self.execute_once_with_stats(request);
337 let attempt_res = if let Some(t) = self.attempt_timeout {
338 match tokio::time::timeout(t, attempt_fut).await {
339 Ok(r) => r,
340 Err(_) => Err(Error::runtime_with_context(
341 "attempt timeout",
342 ErrorContext::new().with_source("timeout_policy"),
343 )),
344 }
345 } else {
346 attempt_fut.await
347 };
348
349 match attempt_res {
350 Ok((resp, mut stats)) => {
351 stats.retry_count = retry_count;
352 return Ok((resp, stats));
353 }
354 Err(e) => {
355 let decision = policy.decide(&e, attempt, has_fallback)?;
356
357 match decision {
358 crate::client::policy::Decision::Retry { delay } => {
359 retry_count = retry_count.saturating_add(1);
360 if delay.as_millis() > 0 {
361 tokio::time::sleep(delay).await;
362 }
363 attempt = attempt.saturating_add(1);
364 continue;
365 }
366 crate::client::policy::Decision::Fallback => return Err(e),
367 crate::client::policy::Decision::Fail => return Err(e),
368 }
369 }
370 }
371 }
372 }
373
374 pub fn validate_request(
376 &self,
377 request: &crate::client::chat::ChatRequestBuilder<'_>,
378 ) -> Result<()> {
379 let mut mock_req = crate::protocol::UnifiedRequest::default();
381 mock_req.stream = request.stream;
382 mock_req.tools = request.tools.clone();
383 mock_req.messages = request.messages.clone();
384
385 let policy = crate::client::policy::PolicyEngine::new(&self.manifest);
386 policy.validate_capabilities(&mock_req)
387 }
388}