1use crate::client::types::{cancel_pair, CancelHandle, ControlledStream};
2use crate::types::{events::StreamingEvent, message::Message};
3use crate::Result;
4use futures::{stream::Stream, TryStreamExt};
5use std::pin::Pin;
6
7use super::core::{AiClient, UnifiedResponse};
8
9#[derive(Debug, Clone)]
11pub struct ChatBatchRequest {
12 pub messages: Vec<Message>,
13 pub temperature: Option<f64>,
14 pub max_tokens: Option<u32>,
15 pub tools: Option<Vec<crate::types::tool::ToolDefinition>>,
16 pub tool_choice: Option<serde_json::Value>,
17}
18
19impl ChatBatchRequest {
20 pub fn new(messages: Vec<Message>) -> Self {
21 Self {
22 messages,
23 temperature: None,
24 max_tokens: None,
25 tools: None,
26 tool_choice: None,
27 }
28 }
29
30 pub fn temperature(mut self, temp: f64) -> Self {
31 self.temperature = Some(temp);
32 self
33 }
34
35 pub fn max_tokens(mut self, max: u32) -> Self {
36 self.max_tokens = Some(max);
37 self
38 }
39
40 pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
41 self.tools = Some(tools);
42 self
43 }
44
45 pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
46 self.tool_choice = Some(tool_choice);
47 self
48 }
49}
50
51pub struct ChatRequestBuilder<'a> {
53 pub(crate) client: &'a AiClient,
54 pub(crate) messages: Vec<Message>,
55 pub(crate) temperature: Option<f64>,
56 pub(crate) max_tokens: Option<u32>,
57 pub(crate) stream: bool,
58 pub(crate) tools: Option<Vec<crate::types::tool::ToolDefinition>>,
59 pub(crate) tool_choice: Option<serde_json::Value>,
60}
61
62impl<'a> ChatRequestBuilder<'a> {
63 pub(crate) fn new(client: &'a AiClient) -> Self {
64 Self {
65 client,
66 messages: Vec::new(),
67 temperature: None,
68 max_tokens: None,
69 stream: false,
70 tools: None,
71 tool_choice: None,
72 }
73 }
74
75 pub fn messages(mut self, messages: Vec<Message>) -> Self {
77 self.messages = messages;
78 self
79 }
80
81 pub fn temperature(mut self, temp: f64) -> Self {
83 self.temperature = Some(temp);
84 self
85 }
86
87 pub fn max_tokens(mut self, max: u32) -> Self {
89 self.max_tokens = Some(max);
90 self
91 }
92
93 pub fn stream(mut self) -> Self {
95 self.stream = true;
96 self
97 }
98
99 pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
101 self.tools = Some(tools);
102 self
103 }
104
105 pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
107 self.tool_choice = Some(tool_choice);
108 self
109 }
110
111 pub async fn execute_stream(
113 self,
114 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>> {
115 let (stream, _cancel) = self.execute_stream_with_cancel().await?;
116 Ok(stream)
117 }
118
119 pub async fn execute_stream_with_cancel_and_stats(
125 self,
126 ) -> Result<(
127 Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
128 CancelHandle,
129 crate::client::types::CallStats,
130 )> {
131 self.client.validate_request(&self)?;
133
134 let base_client = self.client;
135 let unified_req = self.into_unified_request();
136
137 let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(base_client.fallbacks.len());
139 for model in &base_client.fallbacks {
140 if let Ok(c) = base_client.with_model(model).await {
141 fallback_clients.push(c);
142 }
143 }
144
145 let (cancel_handle, cancel_rx) = cancel_pair();
146
147 let mut last_err: Option<crate::Error> = None;
148
149 for (candidate_idx, client) in std::iter::once(base_client)
150 .chain(fallback_clients.iter())
151 .enumerate()
152 {
153 let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
154 let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
155 let mut attempt: u32 = 0;
156 let mut retry_count: u32 = 0;
157
158 loop {
159 let sig = client.signals().await;
161 if let Some(crate::client::policy::Decision::Fallback) =
162 policy.pre_decide(&sig, has_fallback)
163 {
164 last_err = Some(crate::Error::runtime_with_context(
165 "skipped candidate due to signals",
166 crate::ErrorContext::new().with_source("policy_engine"),
167 ));
168 break;
169 }
170
171 let mut req = unified_req.clone();
172 req.model = client.model_id.clone();
173
174 match client.execute_stream_once(&req).await {
175 Ok((mut event_stream, permit, mut stats)) => {
176 use futures::StreamExt;
179 let next_fut = event_stream.next();
180 let first = if let Some(t) = client.attempt_timeout {
181 match tokio::time::timeout(t, next_fut).await {
182 Ok(v) => v,
183 Err(_) => Some(Err(crate::Error::runtime_with_context(
184 "attempt timeout",
185 crate::ErrorContext::new().with_source("timeout_policy"),
186 ))),
187 }
188 } else {
189 next_fut.await
190 };
191
192 match first {
193 None => {
194 stats.retry_count = retry_count;
195 stats.emitted_any = false;
196 let wrapped = ControlledStream::new(
197 Box::pin(futures::stream::empty()),
198 Some(cancel_rx),
199 permit,
200 );
201 return Ok((Box::pin(wrapped), cancel_handle, stats));
202 }
203 Some(Ok(first_ev)) => {
204 let first_ms = stats.duration_ms;
205 let stream = futures::stream::once(async move { Ok(first_ev) })
206 .chain(event_stream);
207 let wrapped = ControlledStream::new(
208 Box::pin(stream.map_err(|e| {
209 e
212 })),
213 Some(cancel_rx),
214 permit,
215 );
216
217 stats.retry_count = retry_count;
218 stats.first_event_ms = Some(first_ms);
219 stats.emitted_any = true;
220
221 return Ok((Box::pin(wrapped), cancel_handle, stats));
222 }
223 Some(Err(e)) => {
224 let decision = policy.decide(&e, attempt, has_fallback)?;
225 last_err = Some(e);
226 match decision {
227 crate::client::policy::Decision::Retry { delay } => {
228 retry_count = retry_count.saturating_add(1);
229 if delay.as_millis() > 0 {
230 tokio::time::sleep(delay).await;
231 }
232 attempt = attempt.saturating_add(1);
233 continue;
234 }
235 crate::client::policy::Decision::Fallback => break,
236 crate::client::policy::Decision::Fail => {
237 return Err(last_err.unwrap());
238 }
239 }
240 }
241 }
242 }
243 Err(e) => {
244 let decision = policy.decide(&e, attempt, has_fallback)?;
245 last_err = Some(e);
246 match decision {
247 crate::client::policy::Decision::Retry { delay } => {
248 retry_count = retry_count.saturating_add(1);
249 if delay.as_millis() > 0 {
250 tokio::time::sleep(delay).await;
251 }
252 attempt = attempt.saturating_add(1);
253 continue;
254 }
255 crate::client::policy::Decision::Fallback => break,
256 crate::client::policy::Decision::Fail => {
257 return Err(last_err.unwrap());
258 }
259 }
260 }
261 }
262 }
263 }
264
265 Err(last_err.unwrap_or_else(|| {
266 crate::Error::runtime_with_context(
267 "all streaming attempts failed",
268 crate::ErrorContext::new().with_source("retry_policy"),
269 )
270 }))
271 }
272
273 pub async fn execute_stream_with_cancel(
275 self,
276 ) -> Result<(
277 Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
278 CancelHandle,
279 )> {
280 let (s, c, _stats) = self.execute_stream_with_cancel_and_stats().await?;
281 Ok((s, c))
282 }
283
284 pub async fn execute(self) -> Result<UnifiedResponse> {
286 let stream_flag = self.stream;
287 let client = self.client;
288 let unified_req = self.into_unified_request();
289
290 if !stream_flag {
292 let (resp, _stats) = client.call_model_with_stats(unified_req).await?;
293 return Ok(resp);
294 }
295
296 let mut stream = {
299 let builder = ChatRequestBuilder {
300 client,
301 messages: unified_req.messages.clone(),
302 temperature: unified_req.temperature,
303 max_tokens: unified_req.max_tokens,
304 stream: true,
305 tools: unified_req.tools.clone(),
306 tool_choice: unified_req.tool_choice.clone(),
307 };
308 builder.execute_stream().await?
309 };
310 let mut response = UnifiedResponse::default();
311 let mut tool_asm = crate::utils::tool_call_assembler::ToolCallAssembler::new();
312
313 use futures::StreamExt;
314 let mut event_count = 0;
315 while let Some(event) = stream.next().await {
316 event_count += 1;
317 match event? {
318 StreamingEvent::PartialContentDelta { content, .. } => {
319 response.content.push_str(&content);
320 }
321 StreamingEvent::ToolCallStarted {
322 tool_call_id,
323 tool_name,
324 ..
325 } => {
326 tool_asm.on_started(tool_call_id, tool_name);
327 }
328 StreamingEvent::PartialToolCall {
329 tool_call_id,
330 arguments,
331 ..
332 } => {
333 tool_asm.on_partial(&tool_call_id, &arguments);
334 }
335 StreamingEvent::Metadata { usage, .. } => {
336 response.usage = usage;
337 }
338 StreamingEvent::StreamEnd { .. } => {
339 break;
340 }
341 other => {
342 tracing::warn!("Unexpected event in execute(): {:?}", other);
344 }
345 }
346 }
347
348 if event_count == 0 {
349 tracing::warn!(
350 "No events received from stream. Possible causes: provider returned empty stream, \
351 network interruption, or event mapping configuration issue. Provider: {}, Model: {}",
352 client.manifest.id,
353 client.model_id
354 );
355 } else if response.content.is_empty() {
356 tracing::warn!(
357 "Received {} events but content is empty. This might indicate: (1) provider filtered \
358 content (safety/content policy), (2) non-streaming response format mismatch, \
359 (3) event mapping issue. Provider: {}, Model: {}",
360 event_count,
361 client.manifest.id,
362 client.model_id
363 );
364 }
365
366 response.tool_calls = tool_asm.finalize();
367
368 Ok(response)
369 }
370
371 fn into_unified_request(self) -> crate::protocol::UnifiedRequest {
372 crate::protocol::UnifiedRequest {
373 operation: "chat".to_string(),
374 model: self.client.model_id.clone(),
375 messages: self.messages,
376 temperature: self.temperature,
377 max_tokens: self.max_tokens,
378 stream: self.stream,
379 tools: self.tools,
380 tool_choice: self.tool_choice,
381 }
382 }
383}