1use crate::client::types::{CallStats, ClientMetrics};
6use crate::protocol::ProtocolLoader;
7use crate::protocol::ProtocolManifest;
8use crate::{Error, ErrorContext, Result};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use crate::pipeline::Pipeline;
13use crate::transport::HttpTransport;
14
15use crate::client::validation;
17
18pub struct AiClient {
20 pub manifest: ProtocolManifest,
21 pub transport: Arc<HttpTransport>,
22 pub pipeline: Arc<Pipeline>,
23 pub loader: Arc<ProtocolLoader>,
24 pub(crate) fallbacks: Vec<String>,
25 pub(crate) model_id: String,
26 pub(crate) strict_streaming: bool,
27 pub(crate) feedback: Arc<dyn crate::feedback::FeedbackSink>,
28 pub(crate) inflight: Option<Arc<tokio::sync::Semaphore>>,
29 pub(crate) max_inflight: Option<usize>,
30 pub(crate) attempt_timeout: Option<std::time::Duration>,
31 pub(crate) total_requests: AtomicU64,
32 pub(crate) successful_requests: AtomicU64,
33 pub(crate) total_tokens: AtomicU64,
34}
35
36#[derive(Debug, Default)]
38pub struct UnifiedResponse {
39 pub content: String,
40 pub tool_calls: Vec<crate::types::tool::ToolCall>,
41 pub usage: Option<serde_json::Value>,
42}
43
44impl AiClient {
45 pub fn metrics(&self) -> ClientMetrics {
49 ClientMetrics {
50 total_requests: self.total_requests.load(Ordering::Relaxed),
51 successful_requests: self.successful_requests.load(Ordering::Relaxed),
52 total_tokens: self.total_tokens.load(Ordering::Relaxed),
53 }
54 }
55
56 pub(crate) fn record_success(&self, stats: &CallStats) {
57 self.successful_requests.fetch_add(1, Ordering::Relaxed);
58 if let Some(tokens) = Self::extract_total_tokens(&stats.usage) {
59 self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
60 }
61 }
62
63 pub(crate) fn record_request(&self) {
64 self.total_requests.fetch_add(1, Ordering::Relaxed);
65 }
66
67 fn extract_total_tokens(usage: &Option<serde_json::Value>) -> Option<u64> {
68 let u = usage.as_ref()?;
69 u.get("total_tokens").and_then(|v| v.as_u64()).or_else(|| {
70 u.get("usage")
71 .and_then(|nested| nested.get("total_tokens"))
72 .and_then(|v| v.as_u64())
73 })
74 }
75
76 pub async fn signals(&self) -> crate::client::signals::SignalsSnapshot {
78 let inflight = self.inflight.as_ref().and_then(|sem| {
79 let max = self.max_inflight?;
80 let available = sem.available_permits();
81 let in_use = max.saturating_sub(available);
82 Some(crate::client::signals::InflightSnapshot {
83 max,
84 available,
85 in_use,
86 })
87 });
88
89 crate::client::signals::SignalsSnapshot { inflight }
90 }
91
92 pub async fn new(model: &str) -> Result<Self> {
97 crate::client::builder::AiClientBuilder::new()
98 .build(model)
99 .await
100 }
101
102 pub(crate) async fn with_model(&self, model: &str) -> Result<Self> {
105 let parts: Vec<&str> = model.split('/').collect();
107 let model_id = parts
108 .get(1)
109 .map(|s| s.to_string())
110 .unwrap_or_else(|| model.to_string());
111
112 let manifest = self.loader.load_model(model).await?;
113 validation::validate_manifest(&manifest, self.strict_streaming)?;
114
115 let transport = Arc::new(crate::transport::HttpTransport::new(&manifest, &model_id)?);
116 let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
117
118 Ok(AiClient {
119 manifest,
120 transport,
121 pipeline,
122 loader: self.loader.clone(),
123 fallbacks: Vec::new(),
124 model_id,
125 strict_streaming: self.strict_streaming,
126 feedback: self.feedback.clone(),
127 inflight: self.inflight.clone(),
128 max_inflight: self.max_inflight,
129 attempt_timeout: self.attempt_timeout,
130 total_requests: AtomicU64::new(0),
131 successful_requests: AtomicU64::new(0),
132 total_tokens: AtomicU64::new(0),
133 })
134 }
135
136 pub fn chat(&self) -> crate::client::chat::ChatRequestBuilder<'_> {
138 crate::client::chat::ChatRequestBuilder::new(self)
139 }
140
141 pub async fn chat_batch(
147 &self,
148 requests: Vec<crate::client::chat::ChatBatchRequest>,
149 concurrency_limit: Option<usize>,
150 ) -> Vec<Result<UnifiedResponse>> {
151 use futures::StreamExt;
152
153 let n = requests.len();
154 if n == 0 {
155 return Vec::new();
156 }
157
158 let limit = concurrency_limit.unwrap_or(10).max(1);
159 let mut out: Vec<Option<Result<UnifiedResponse>>> = (0..n).map(|_| None).collect();
160
161 let results: Vec<(usize, Result<UnifiedResponse>)> =
162 futures::stream::iter(requests.into_iter().enumerate())
163 .map(|(idx, req)| async move {
164 let mut b = self.chat().messages(req.messages).stream();
165 if let Some(t) = req.temperature {
166 b = b.temperature(t);
167 }
168 if let Some(m) = req.max_tokens {
169 b = b.max_tokens(m);
170 }
171 if let Some(tools) = req.tools {
172 b = b.tools(tools);
173 }
174 if let Some(tc) = req.tool_choice {
175 b = b.tool_choice(tc);
176 }
177 let r = b.execute().await;
178 (idx, r)
179 })
180 .buffer_unordered(limit)
181 .collect()
182 .await;
183
184 for (idx, r) in results {
185 out[idx] = Some(r);
186 }
187
188 out.into_iter()
189 .map(|o| {
190 o.unwrap_or_else(|| {
191 Err(Error::runtime_with_context(
192 "batch result missing",
193 ErrorContext::new().with_source("batch_executor"),
194 ))
195 })
196 })
197 .collect()
198 }
199
200 pub async fn chat_batch_smart(
208 &self,
209 requests: Vec<crate::client::chat::ChatBatchRequest>,
210 ) -> Vec<Result<UnifiedResponse>> {
211 let n = requests.len();
212 if n == 0 {
213 return Vec::new();
214 }
215
216 let env_override = std::env::var("AI_LIB_BATCH_CONCURRENCY")
217 .ok()
218 .and_then(|s| s.parse::<usize>().ok())
219 .filter(|v| *v > 0);
220
221 let chosen = env_override.unwrap_or({
222 if n <= 3 {
223 1
224 } else if n <= 10 {
225 5
226 } else {
227 10
228 }
229 });
230
231 self.chat_batch(requests, Some(chosen)).await
232 }
233
234 pub async fn report_feedback(&self, event: crate::feedback::FeedbackEvent) -> Result<()> {
236 self.feedback.report(event).await
237 }
238
239 pub async fn call_model(
242 &self,
243 request: crate::protocol::UnifiedRequest,
244 ) -> Result<UnifiedResponse> {
245 Ok(self.call_model_with_stats(request).await?.0)
246 }
247
248 pub async fn call_model_with_stats(
252 &self,
253 request: crate::protocol::UnifiedRequest,
254 ) -> Result<(UnifiedResponse, CallStats)> {
255 self.record_request();
256
257 let mut last_err: Option<Error> = None;
268
269 let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(self.fallbacks.len());
272 for model in &self.fallbacks {
273 if let Ok(c) = self.with_model(model).await {
274 fallback_clients.push(c);
275 }
276 }
277
278 for (candidate_idx, client) in std::iter::once(self)
280 .chain(fallback_clients.iter())
281 .enumerate()
282 {
283 let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
284 let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
285
286 if let Err(e) = policy.validate_capabilities(&request) {
288 if has_fallback {
289 last_err = Some(e);
290 continue; } else {
292 return Err(e); }
294 }
295
296 let sig = client.signals().await;
298 if let Some(crate::client::policy::Decision::Fallback) =
299 policy.pre_decide(&sig, has_fallback)
300 {
301 last_err = Some(Error::runtime_with_context(
302 "skipped candidate due to signals",
303 ErrorContext::new().with_source("policy_engine"),
304 ));
305 continue;
306 }
307
308 let mut req = request.clone();
309 if candidate_idx > 0 {
310 req.model = client.model_id.clone();
311 }
312
313 match client.execute_with_retry(&req, &policy, has_fallback).await {
317 Ok((resp, stats)) => {
318 client.record_success(&stats);
319 return Ok((resp, stats));
320 }
321 Err(e) => {
322 last_err = Some(e);
324 if !has_fallback {
328 return Err(last_err.unwrap());
329 }
330 }
331 }
332 }
333
334 Err(last_err.unwrap_or_else(|| {
335 Error::runtime_with_context(
336 "all attempts failed",
337 ErrorContext::new().with_source("retry_policy"),
338 )
339 }))
340 }
341
342 async fn execute_with_retry(
345 &self,
346 request: &crate::protocol::UnifiedRequest,
347 policy: &crate::client::policy::PolicyEngine,
348 has_fallback: bool,
349 ) -> Result<(UnifiedResponse, CallStats)> {
350 let mut attempt: u32 = 0;
351 let mut retry_count: u32 = 0;
352
353 loop {
354 let attempt_fut = self.execute_once_with_stats(request);
355 let attempt_res = if let Some(t) = self.attempt_timeout {
356 match tokio::time::timeout(t, attempt_fut).await {
357 Ok(r) => r,
358 Err(_) => Err(Error::runtime_with_context(
359 "attempt timeout",
360 ErrorContext::new().with_source("timeout_policy"),
361 )),
362 }
363 } else {
364 attempt_fut.await
365 };
366
367 match attempt_res {
368 Ok((resp, mut stats)) => {
369 stats.retry_count = retry_count;
370 return Ok((resp, stats));
371 }
372 Err(e) => {
373 let decision = policy.decide(&e, attempt, has_fallback)?;
374
375 match decision {
376 crate::client::policy::Decision::Retry { delay } => {
377 retry_count = retry_count.saturating_add(1);
378 if delay.as_millis() > 0 {
379 tokio::time::sleep(delay).await;
380 }
381 attempt = attempt.saturating_add(1);
382 continue;
383 }
384 crate::client::policy::Decision::Fallback => return Err(e),
385 crate::client::policy::Decision::Fail => return Err(e),
386 }
387 }
388 }
389 }
390 }
391
392 pub fn validate_request(
394 &self,
395 request: &crate::client::chat::ChatRequestBuilder<'_>,
396 ) -> Result<()> {
397 let mock_req = crate::protocol::UnifiedRequest {
399 stream: request.stream,
400 tools: request.tools.clone(),
401 messages: request.messages.clone(),
402 response_format: request.response_format.clone(),
403 ..Default::default()
404 };
405
406 let policy = crate::client::policy::PolicyEngine::new(&self.manifest);
407 policy.validate_capabilities(&mock_req)
408 }
409}