Skip to main content

ai_lib_rust/client/
chat.rs

1use crate::client::types::{cancel_pair, CancelHandle, ControlledStream};
2use crate::types::{events::StreamingEvent, message::Message};
3use crate::Result;
4use futures::{stream::Stream, TryStreamExt};
5use std::pin::Pin;
6
7use super::core::{AiClient, UnifiedResponse};
8
9/// Batch chat request parameters (developer-friendly, small surface).
10#[derive(Debug, Clone)]
11pub struct ChatBatchRequest {
12    pub messages: Vec<Message>,
13    pub temperature: Option<f64>,
14    pub max_tokens: Option<u32>,
15    pub tools: Option<Vec<crate::types::tool::ToolDefinition>>,
16    pub tool_choice: Option<serde_json::Value>,
17}
18
19impl ChatBatchRequest {
20    pub fn new(messages: Vec<Message>) -> Self {
21        Self {
22            messages,
23            temperature: None,
24            max_tokens: None,
25            tools: None,
26            tool_choice: None,
27        }
28    }
29
30    pub fn temperature(mut self, temp: f64) -> Self {
31        self.temperature = Some(temp);
32        self
33    }
34
35    pub fn max_tokens(mut self, max: u32) -> Self {
36        self.max_tokens = Some(max);
37        self
38    }
39
40    pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
41        self.tools = Some(tools);
42        self
43    }
44
45    pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
46        self.tool_choice = Some(tool_choice);
47        self
48    }
49}
50
51/// Builder for chat requests.
52pub struct ChatRequestBuilder<'a> {
53    pub(crate) client: &'a AiClient,
54    pub(crate) messages: Vec<Message>,
55    pub(crate) temperature: Option<f64>,
56    pub(crate) max_tokens: Option<u32>,
57    pub(crate) stream: bool,
58    pub(crate) tools: Option<Vec<crate::types::tool::ToolDefinition>>,
59    pub(crate) tool_choice: Option<serde_json::Value>,
60}
61
62impl<'a> ChatRequestBuilder<'a> {
63    pub(crate) fn new(client: &'a AiClient) -> Self {
64        Self {
65            client,
66            messages: Vec::new(),
67            temperature: None,
68            max_tokens: None,
69            stream: false,
70            tools: None,
71            tool_choice: None,
72        }
73    }
74
75    /// Add messages to the conversation.
76    pub fn messages(mut self, messages: Vec<Message>) -> Self {
77        self.messages = messages;
78        self
79    }
80
81    /// Set temperature.
82    pub fn temperature(mut self, temp: f64) -> Self {
83        self.temperature = Some(temp);
84        self
85    }
86
87    /// Set max tokens.
88    pub fn max_tokens(mut self, max: u32) -> Self {
89        self.max_tokens = Some(max);
90        self
91    }
92
93    /// Enable streaming.
94    pub fn stream(mut self) -> Self {
95        self.stream = true;
96        self
97    }
98
99    /// Set tools for function calling.
100    pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
101        self.tools = Some(tools);
102        self
103    }
104
105    /// Set tool_choice (OpenAI-style).
106    pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
107        self.tool_choice = Some(tool_choice);
108        self
109    }
110
111    /// Execute the request and return a stream of events.
112    pub async fn execute_stream(
113        self,
114    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>> {
115        let (stream, _cancel) = self.execute_stream_with_cancel().await?;
116        Ok(stream)
117    }
118
119    /// Execute the request and return a cancellable stream of events plus per-call stats.
120    ///
121    /// Streaming semantics:
122    /// - retry/fallback may happen only before any event is emitted to the caller
123    /// - once an event is emitted, we will not retry automatically to avoid duplicate output
124    pub async fn execute_stream_with_cancel_and_stats(
125        self,
126    ) -> Result<(
127        Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
128        CancelHandle,
129        crate::client::types::CallStats,
130    )> {
131        // Validate request against protocol capabilities
132        self.client.validate_request(&self)?;
133
134        let base_client = self.client;
135        let unified_req = self.into_unified_request();
136
137        // Pre-build fallback clients (async), then run unified policy loops.
138        let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(base_client.fallbacks.len());
139        for model in &base_client.fallbacks {
140            if let Ok(c) = base_client.with_model(model).await {
141                fallback_clients.push(c);
142            }
143        }
144
145        let (cancel_handle, cancel_rx) = cancel_pair();
146
147        let mut last_err: Option<crate::Error> = None;
148
149        for (candidate_idx, client) in std::iter::once(base_client)
150            .chain(fallback_clients.iter())
151            .enumerate()
152        {
153            let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
154            let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
155            let mut attempt: u32 = 0;
156            let mut retry_count: u32 = 0;
157
158            loop {
159                // Pre-decision based on signals (skip known-bad candidates, e.g. breaker open).
160                let sig = client.signals().await;
161                if let Some(crate::client::policy::Decision::Fallback) =
162                    policy.pre_decide(&sig, has_fallback)
163                {
164                    last_err = Some(crate::Error::runtime_with_context(
165                        "skipped candidate due to signals",
166                        crate::ErrorContext::new().with_source("policy_engine"),
167                    ));
168                    break;
169                }
170
171                let mut req = unified_req.clone();
172                req.model = client.model_id.clone();
173
174                match client.execute_stream_once(&req).await {
175                    Ok((mut event_stream, permit, mut stats)) => {
176                        // Peek the first item. If it errors BEFORE emitting anything, allow retry/fallback.
177                        // If it yields an event, we commit to this stream (no more retry/fallback).
178                        use futures::StreamExt;
179                        let next_fut = event_stream.next();
180                        let first = if let Some(t) = client.attempt_timeout {
181                            match tokio::time::timeout(t, next_fut).await {
182                                Ok(v) => v,
183                                Err(_) => Some(Err(crate::Error::runtime_with_context(
184                                    "attempt timeout",
185                                    crate::ErrorContext::new().with_source("timeout_policy"),
186                                ))),
187                            }
188                        } else {
189                            next_fut.await
190                        };
191
192                        match first {
193                            None => {
194                                stats.retry_count = retry_count;
195                                stats.emitted_any = false;
196                                let wrapped = ControlledStream::new(
197                                    Box::pin(futures::stream::empty()),
198                                    Some(cancel_rx),
199                                    permit,
200                                );
201                                return Ok((Box::pin(wrapped), cancel_handle, stats));
202                            }
203                            Some(Ok(first_ev)) => {
204                                let first_ms = stats.duration_ms;
205                                let stream = futures::stream::once(async move { Ok(first_ev) })
206                                    .chain(event_stream);
207                                let wrapped = ControlledStream::new(
208                                    Box::pin(stream.map_err(|e| {
209                                        // If it's already a crate::Error (like Transport error), preserve it.
210                                        // Otherwise, it's likely a downstream pipeline error, wrap it.
211                                        e
212                                    })),
213                                    Some(cancel_rx),
214                                    permit,
215                                );
216
217                                stats.retry_count = retry_count;
218                                stats.first_event_ms = Some(first_ms);
219                                stats.emitted_any = true;
220
221                                return Ok((Box::pin(wrapped), cancel_handle, stats));
222                            }
223                            Some(Err(e)) => {
224                                let decision = policy.decide(&e, attempt, has_fallback)?;
225                                last_err = Some(e);
226                                match decision {
227                                    crate::client::policy::Decision::Retry { delay } => {
228                                        retry_count = retry_count.saturating_add(1);
229                                        if delay.as_millis() > 0 {
230                                            tokio::time::sleep(delay).await;
231                                        }
232                                        attempt = attempt.saturating_add(1);
233                                        continue;
234                                    }
235                                    crate::client::policy::Decision::Fallback => break,
236                                    crate::client::policy::Decision::Fail => {
237                                        return Err(last_err.unwrap());
238                                    }
239                                }
240                            }
241                        }
242                    }
243                    Err(e) => {
244                        let decision = policy.decide(&e, attempt, has_fallback)?;
245                        last_err = Some(e);
246                        match decision {
247                            crate::client::policy::Decision::Retry { delay } => {
248                                retry_count = retry_count.saturating_add(1);
249                                if delay.as_millis() > 0 {
250                                    tokio::time::sleep(delay).await;
251                                }
252                                attempt = attempt.saturating_add(1);
253                                continue;
254                            }
255                            crate::client::policy::Decision::Fallback => break,
256                            crate::client::policy::Decision::Fail => {
257                                return Err(last_err.unwrap());
258                            }
259                        }
260                    }
261                }
262            }
263        }
264
265        Err(last_err.unwrap_or_else(|| {
266            crate::Error::runtime_with_context(
267                "all streaming attempts failed",
268                crate::ErrorContext::new().with_source("retry_policy"),
269            )
270        }))
271    }
272
273    /// Execute the request and return a cancellable stream of events.
274    pub async fn execute_stream_with_cancel(
275        self,
276    ) -> Result<(
277        Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
278        CancelHandle,
279    )> {
280        let (s, c, _stats) = self.execute_stream_with_cancel_and_stats().await?;
281        Ok((s, c))
282    }
283
284    /// Execute the request and return the complete response.
285    pub async fn execute(self) -> Result<UnifiedResponse> {
286        let stream_flag = self.stream;
287        let client = self.client;
288        let unified_req = self.into_unified_request();
289
290        // If streaming is not explicitly enabled, use non-streaming execution
291        if !stream_flag {
292            let (resp, _stats) = client.call_model_with_stats(unified_req).await?;
293            return Ok(resp);
294        }
295
296        // For streaming requests, collect all events
297        // Rebuild builder for streaming execution
298        let mut stream = {
299            let builder = ChatRequestBuilder {
300                client,
301                messages: unified_req.messages.clone(),
302                temperature: unified_req.temperature,
303                max_tokens: unified_req.max_tokens,
304                stream: true,
305                tools: unified_req.tools.clone(),
306                tool_choice: unified_req.tool_choice.clone(),
307            };
308            builder.execute_stream().await?
309        };
310        let mut response = UnifiedResponse::default();
311        let mut tool_asm = crate::utils::tool_call_assembler::ToolCallAssembler::new();
312
313        use futures::StreamExt;
314        let mut event_count = 0;
315        while let Some(event) = stream.next().await {
316            event_count += 1;
317            match event? {
318                StreamingEvent::PartialContentDelta { content, .. } => {
319                    response.content.push_str(&content);
320                }
321                StreamingEvent::ToolCallStarted {
322                    tool_call_id,
323                    tool_name,
324                    ..
325                } => {
326                    tool_asm.on_started(tool_call_id, tool_name);
327                }
328                StreamingEvent::PartialToolCall {
329                    tool_call_id,
330                    arguments,
331                    ..
332                } => {
333                    tool_asm.on_partial(&tool_call_id, &arguments);
334                }
335                StreamingEvent::Metadata { usage, .. } => {
336                    response.usage = usage;
337                }
338                StreamingEvent::StreamEnd { .. } => {
339                    break;
340                }
341                other => {
342                    // Log unexpected events for debugging
343                    tracing::warn!("Unexpected event in execute(): {:?}", other);
344                }
345            }
346        }
347
348        if event_count == 0 {
349            tracing::warn!(
350                "No events received from stream. Possible causes: provider returned empty stream, \
351                 network interruption, or event mapping configuration issue. Provider: {}, Model: {}",
352                client.manifest.id,
353                client.model_id
354            );
355        } else if response.content.is_empty() {
356            tracing::warn!(
357                "Received {} events but content is empty. This might indicate: (1) provider filtered \
358                 content (safety/content policy), (2) non-streaming response format mismatch, \
359                 (3) event mapping issue. Provider: {}, Model: {}",
360                event_count,
361                client.manifest.id,
362                client.model_id
363            );
364        }
365
366        response.tool_calls = tool_asm.finalize();
367
368        Ok(response)
369    }
370
371    fn into_unified_request(self) -> crate::protocol::UnifiedRequest {
372        crate::protocol::UnifiedRequest {
373            operation: "chat".to_string(),
374            model: self.client.model_id.clone(),
375            messages: self.messages,
376            temperature: self.temperature,
377            max_tokens: self.max_tokens,
378            stream: self.stream,
379            tools: self.tools,
380            tool_choice: self.tool_choice,
381        }
382    }
383}