1use crate::client::types::{cancel_pair, CancelHandle, ControlledStream};
2use crate::types::{events::StreamingEvent, message::Message};
3use crate::Result;
4use futures::{stream::Stream, TryStreamExt};
5use std::pin::Pin;
6
7use super::core::{AiClient, UnifiedResponse};
8
9#[derive(Debug, Clone)]
11pub struct ChatBatchRequest {
12 pub messages: Vec<Message>,
13 pub temperature: Option<f64>,
14 pub max_tokens: Option<u32>,
15 pub tools: Option<Vec<crate::types::tool::ToolDefinition>>,
16 pub tool_choice: Option<serde_json::Value>,
17}
18
19impl ChatBatchRequest {
20 pub fn new(messages: Vec<Message>) -> Self {
21 Self {
22 messages,
23 temperature: None,
24 max_tokens: None,
25 tools: None,
26 tool_choice: None,
27 }
28 }
29
30 pub fn temperature(mut self, temp: f64) -> Self {
31 self.temperature = Some(temp);
32 self
33 }
34
35 pub fn max_tokens(mut self, max: u32) -> Self {
36 self.max_tokens = Some(max);
37 self
38 }
39
40 pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
41 self.tools = Some(tools);
42 self
43 }
44
45 pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
46 self.tool_choice = Some(tool_choice);
47 self
48 }
49}
50
51pub struct ChatRequestBuilder<'a> {
53 pub(crate) client: &'a AiClient,
54 pub(crate) messages: Vec<Message>,
55 pub(crate) temperature: Option<f64>,
56 pub(crate) max_tokens: Option<u32>,
57 pub(crate) stream: bool,
58 pub(crate) tools: Option<Vec<crate::types::tool::ToolDefinition>>,
59 pub(crate) tool_choice: Option<serde_json::Value>,
60 pub(crate) model: Option<String>,
62 pub(crate) response_format: Option<crate::structured::JsonModeConfig>,
64}
65
66impl<'a> ChatRequestBuilder<'a> {
67 pub(crate) fn new(client: &'a AiClient) -> Self {
68 Self {
69 client,
70 messages: Vec::new(),
71 temperature: None,
72 max_tokens: None,
73 stream: false,
74 tools: None,
75 tool_choice: None,
76 model: None,
77 response_format: None,
78 }
79 }
80
81 pub fn messages(mut self, messages: Vec<Message>) -> Self {
83 self.messages = messages;
84 self
85 }
86
87 pub fn temperature(mut self, temp: f64) -> Self {
89 self.temperature = Some(temp);
90 self
91 }
92
93 pub fn max_tokens(mut self, max: u32) -> Self {
95 self.max_tokens = Some(max);
96 self
97 }
98
99 pub fn stream(mut self) -> Self {
101 self.stream = true;
102 self
103 }
104
105 pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
107 self.tools = Some(tools);
108 self
109 }
110
111 pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
113 self.tool_choice = Some(tool_choice);
114 self
115 }
116
117 pub fn tools_json(self, tools: Vec<serde_json::Value>) -> Self {
122 let defs: Vec<crate::types::tool::ToolDefinition> = tools
123 .into_iter()
124 .filter_map(|v| serde_json::from_value(v).ok())
125 .collect();
126 self.tools(defs)
127 }
128
129 pub fn model(mut self, model: impl Into<String>) -> Self {
145 self.model = Some(model.into());
146 self
147 }
148
149 pub fn response_format(mut self, cfg: crate::structured::JsonModeConfig) -> Self {
151 self.response_format = Some(cfg);
152 self
153 }
154
155 pub async fn execute_stream(
157 self,
158 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>> {
159 let (stream, _cancel) = self.execute_stream_with_cancel().await?;
160 Ok(stream)
161 }
162
163 pub async fn execute_stream_with_cancel_and_stats(
169 self,
170 ) -> Result<(
171 Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
172 CancelHandle,
173 crate::client::types::CallStats,
174 )> {
175 self.client.validate_request(&self)?;
177
178 self.client.record_request();
179
180 let base_client = self.client;
181 let unified_req = self.into_unified_request();
182
183 let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(base_client.fallbacks.len());
185 for model in &base_client.fallbacks {
186 if let Ok(c) = base_client.with_model(model).await {
187 fallback_clients.push(c);
188 }
189 }
190
191 let (cancel_handle, cancel_rx) = cancel_pair();
192
193 let mut last_err: Option<crate::Error> = None;
194
195 for (candidate_idx, client) in std::iter::once(base_client)
196 .chain(fallback_clients.iter())
197 .enumerate()
198 {
199 let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
200 let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
201 let mut attempt: u32 = 0;
202 let mut retry_count: u32 = 0;
203
204 loop {
205 let sig = client.signals().await;
207 if let Some(crate::client::policy::Decision::Fallback) =
208 policy.pre_decide(&sig, has_fallback)
209 {
210 last_err = Some(crate::Error::runtime_with_context(
211 "skipped candidate due to signals",
212 crate::ErrorContext::new().with_source("policy_engine"),
213 ));
214 break;
215 }
216
217 let mut req = unified_req.clone();
218 if candidate_idx > 0 {
219 req.model = client.model_id.clone();
220 }
221
222 match client.execute_stream_once(&req).await {
223 Ok((mut event_stream, permit, mut stats)) => {
224 use futures::StreamExt;
227 let next_fut = event_stream.next();
228 let first = if let Some(t) = client.attempt_timeout {
229 match tokio::time::timeout(t, next_fut).await {
230 Ok(v) => v,
231 Err(_) => Some(Err(crate::Error::runtime_with_context(
232 "attempt timeout",
233 crate::ErrorContext::new().with_source("timeout_policy"),
234 ))),
235 }
236 } else {
237 next_fut.await
238 };
239
240 match first {
241 None => {
242 stats.retry_count = retry_count;
243 stats.emitted_any = false;
244 base_client.record_success(&stats);
245 let wrapped = ControlledStream::new(
246 Box::pin(futures::stream::empty()),
247 Some(cancel_rx),
248 permit,
249 );
250 return Ok((Box::pin(wrapped), cancel_handle, stats));
251 }
252 Some(Ok(first_ev)) => {
253 let first_ms = stats.duration_ms;
254 let stream = futures::stream::once(async move { Ok(first_ev) })
255 .chain(event_stream);
256 let wrapped = ControlledStream::new(
257 Box::pin(stream.map_err(|e| {
258 e
261 })),
262 Some(cancel_rx),
263 permit,
264 );
265
266 stats.retry_count = retry_count;
267 stats.first_event_ms = Some(first_ms);
268 stats.emitted_any = true;
269
270 base_client.record_success(&stats);
271 return Ok((Box::pin(wrapped), cancel_handle, stats));
272 }
273 Some(Err(e)) => {
274 let decision = policy.decide(&e, attempt, has_fallback)?;
275 last_err = Some(e);
276 match decision {
277 crate::client::policy::Decision::Retry { delay } => {
278 retry_count = retry_count.saturating_add(1);
279 if delay.as_millis() > 0 {
280 tokio::time::sleep(delay).await;
281 }
282 attempt = attempt.saturating_add(1);
283 continue;
284 }
285 crate::client::policy::Decision::Fallback => break,
286 crate::client::policy::Decision::Fail => {
287 return Err(last_err.unwrap());
288 }
289 }
290 }
291 }
292 }
293 Err(e) => {
294 let decision = policy.decide(&e, attempt, has_fallback)?;
295 last_err = Some(e);
296 match decision {
297 crate::client::policy::Decision::Retry { delay } => {
298 retry_count = retry_count.saturating_add(1);
299 if delay.as_millis() > 0 {
300 tokio::time::sleep(delay).await;
301 }
302 attempt = attempt.saturating_add(1);
303 continue;
304 }
305 crate::client::policy::Decision::Fallback => break,
306 crate::client::policy::Decision::Fail => {
307 return Err(last_err.unwrap());
308 }
309 }
310 }
311 }
312 }
313 }
314
315 Err(last_err.unwrap_or_else(|| {
316 crate::Error::runtime_with_context(
317 "all streaming attempts failed",
318 crate::ErrorContext::new().with_source("retry_policy"),
319 )
320 }))
321 }
322
323 pub async fn execute_stream_with_cancel(
348 self,
349 ) -> Result<(
350 Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
351 CancelHandle,
352 )> {
353 let (s, c, _stats) = self.execute_stream_with_cancel_and_stats().await?;
354 Ok((s, c))
355 }
356
357 pub async fn execute(self) -> Result<UnifiedResponse> {
359 let stream_flag = self.stream;
360 let client = self.client;
361 let unified_req = self.into_unified_request();
362
363 if !stream_flag {
365 let (resp, _stats) = client.call_model_with_stats(unified_req).await?;
366 return Ok(resp);
367 }
368
369 let mut stream = {
372 let builder = ChatRequestBuilder {
373 client,
374 messages: unified_req.messages.clone(),
375 temperature: unified_req.temperature,
376 max_tokens: unified_req.max_tokens,
377 stream: true,
378 tools: unified_req.tools.clone(),
379 tool_choice: unified_req.tool_choice.clone(),
380 model: Some(unified_req.model.clone()),
381 response_format: unified_req.response_format.clone(),
382 };
383 builder.execute_stream().await?
384 };
385 let mut response = UnifiedResponse::default();
386 let mut tool_asm = crate::utils::tool_call_assembler::ToolCallAssembler::new();
387
388 use futures::StreamExt;
389 let mut event_count = 0;
390 while let Some(event) = stream.next().await {
391 event_count += 1;
392 match event? {
393 StreamingEvent::PartialContentDelta { content, .. } => {
394 response.content.push_str(&content);
395 }
396 StreamingEvent::ToolCallStarted {
397 tool_call_id,
398 tool_name,
399 ..
400 } => {
401 tool_asm.on_started(tool_call_id, tool_name);
402 }
403 StreamingEvent::PartialToolCall {
404 tool_call_id,
405 arguments,
406 ..
407 } => {
408 tool_asm.on_partial(&tool_call_id, &arguments);
409 }
410 StreamingEvent::Metadata { usage, .. } => {
411 response.usage = usage;
412 }
413 StreamingEvent::StreamEnd { .. } => {
414 break;
415 }
416 StreamingEvent::ThinkingDelta { .. } => {}
417 other => {
418 tracing::warn!("Unexpected event in execute(): {:?}", other);
420 }
421 }
422 }
423
424 if event_count == 0 {
425 tracing::warn!(
426 "No events received from stream. Possible causes: provider returned empty stream, \
427 network interruption, or event mapping configuration issue. Provider: {}, Model: {}",
428 client.manifest.id,
429 client.model_id
430 );
431 } else if response.content.is_empty() {
432 tracing::warn!(
433 "Received {} events but content is empty. This might indicate: (1) provider filtered \
434 content (safety/content policy), (2) non-streaming response format mismatch, \
435 (3) event mapping issue. Provider: {}, Model: {}",
436 event_count,
437 client.manifest.id,
438 client.model_id
439 );
440 }
441
442 response.tool_calls = tool_asm.finalize();
443
444 Ok(response)
445 }
446
447 fn into_unified_request(self) -> crate::protocol::UnifiedRequest {
448 let model = self.model.unwrap_or_else(|| self.client.model_id.clone());
449 crate::protocol::UnifiedRequest {
450 operation: "chat".to_string(),
451 model,
452 messages: self.messages,
453 temperature: self.temperature,
454 max_tokens: self.max_tokens,
455 stream: self.stream,
456 tools: self.tools,
457 tool_choice: self.tool_choice,
458 response_format: self.response_format,
459 }
460 }
461}