Skip to main content

ai_lib_core/client/
chat.rs

1use crate::client::types::{cancel_pair, CancelHandle, ControlledStream};
2use crate::types::{events::StreamingEvent, message::Message};
3use crate::Result;
4use futures::{stream::Stream, TryStreamExt};
5use std::pin::Pin;
6
7use super::core::{AiClient, UnifiedResponse};
8
9/// Batch chat request parameters (developer-friendly, small surface).
10#[derive(Debug, Clone)]
11pub struct ChatBatchRequest {
12    pub messages: Vec<Message>,
13    pub temperature: Option<f64>,
14    pub max_tokens: Option<u32>,
15    pub tools: Option<Vec<crate::types::tool::ToolDefinition>>,
16    pub tool_choice: Option<serde_json::Value>,
17}
18
19impl ChatBatchRequest {
20    pub fn new(messages: Vec<Message>) -> Self {
21        Self {
22            messages,
23            temperature: None,
24            max_tokens: None,
25            tools: None,
26            tool_choice: None,
27        }
28    }
29
30    pub fn temperature(mut self, temp: f64) -> Self {
31        self.temperature = Some(temp);
32        self
33    }
34
35    pub fn max_tokens(mut self, max: u32) -> Self {
36        self.max_tokens = Some(max);
37        self
38    }
39
40    pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
41        self.tools = Some(tools);
42        self
43    }
44
45    pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
46        self.tool_choice = Some(tool_choice);
47        self
48    }
49}
50
51/// Builder for chat requests.
52pub struct ChatRequestBuilder<'a> {
53    pub(crate) client: &'a AiClient,
54    pub(crate) messages: Vec<Message>,
55    pub(crate) temperature: Option<f64>,
56    pub(crate) max_tokens: Option<u32>,
57    pub(crate) stream: bool,
58    pub(crate) tools: Option<Vec<crate::types::tool::ToolDefinition>>,
59    pub(crate) tool_choice: Option<serde_json::Value>,
60    /// Optional model override; when set, overrides the client's default model for this request.
61    pub(crate) model: Option<String>,
62    /// JSON / structured output (`response_format` in provider request body).
63    pub(crate) response_format: Option<crate::structured::JsonModeConfig>,
64}
65
66impl<'a> ChatRequestBuilder<'a> {
67    pub(crate) fn new(client: &'a AiClient) -> Self {
68        Self {
69            client,
70            messages: Vec::new(),
71            temperature: None,
72            max_tokens: None,
73            stream: false,
74            tools: None,
75            tool_choice: None,
76            model: None,
77            response_format: None,
78        }
79    }
80
81    /// Add messages to the conversation.
82    pub fn messages(mut self, messages: Vec<Message>) -> Self {
83        self.messages = messages;
84        self
85    }
86
87    /// Set temperature.
88    pub fn temperature(mut self, temp: f64) -> Self {
89        self.temperature = Some(temp);
90        self
91    }
92
93    /// Set max tokens.
94    pub fn max_tokens(mut self, max: u32) -> Self {
95        self.max_tokens = Some(max);
96        self
97    }
98
99    /// Enable streaming.
100    pub fn stream(mut self) -> Self {
101        self.stream = true;
102        self
103    }
104
105    /// Set tools for function calling.
106    pub fn tools(mut self, tools: Vec<crate::types::tool::ToolDefinition>) -> Self {
107        self.tools = Some(tools);
108        self
109    }
110
111    /// Set tool_choice (OpenAI-style).
112    pub fn tool_choice(mut self, tool_choice: serde_json::Value) -> Self {
113        self.tool_choice = Some(tool_choice);
114        self
115    }
116
117    /// Set tools from raw JSON values (e.g., from existing JSON Schema tool definitions).
118    ///
119    /// Convenience method for integrating with tool systems that produce `serde_json::Value`.
120    /// Values that fail to deserialize into `ToolDefinition` are skipped.
121    pub fn tools_json(self, tools: Vec<serde_json::Value>) -> Self {
122        let defs: Vec<crate::types::tool::ToolDefinition> = tools
123            .into_iter()
124            .filter_map(|v| serde_json::from_value(v).ok())
125            .collect();
126        self.tools(defs)
127    }
128
129    /// Override the model for this request.
130    ///
131    /// When set, this overrides the client's default model. Useful for single-client
132    /// multi-model usage (e.g., same API key with different models).
133    ///
134    /// # Example
135    ///
136    /// ```ignore
137    /// let client = AiClient::new("openai/gpt-4o").await?;
138    /// let resp = client.chat()
139    ///     .messages(msgs)
140    ///     .model("gpt-4o-mini")  // Use different model for this request
141    ///     .execute()
142    ///     .await?;
143    /// ```
144    pub fn model(mut self, model: impl Into<String>) -> Self {
145        self.model = Some(model.into());
146        self
147    }
148
149    /// Enable structured output using JSON mode configuration (OpenAI-style `response_format`).
150    pub fn response_format(mut self, cfg: crate::structured::JsonModeConfig) -> Self {
151        self.response_format = Some(cfg);
152        self
153    }
154
155    /// Execute the request and return a stream of events.
156    pub async fn execute_stream(
157        self,
158    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>> {
159        let (stream, _cancel) = self.execute_stream_with_cancel().await?;
160        Ok(stream)
161    }
162
163    /// Execute the request and return a cancellable stream of events plus per-call stats.
164    ///
165    /// Streaming semantics:
166    /// - retry/fallback may happen only before any event is emitted to the caller
167    /// - once an event is emitted, we will not retry automatically to avoid duplicate output
168    pub async fn execute_stream_with_cancel_and_stats(
169        self,
170    ) -> Result<(
171        Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
172        CancelHandle,
173        crate::client::types::CallStats,
174    )> {
175        // Validate request against protocol capabilities
176        self.client.validate_request(&self)?;
177
178        self.client.record_request();
179
180        let base_client = self.client;
181        let unified_req = self.into_unified_request();
182
183        // Pre-build fallback clients (async), then run unified policy loops.
184        let mut fallback_clients: Vec<AiClient> = Vec::with_capacity(base_client.fallbacks.len());
185        for model in &base_client.fallbacks {
186            if let Ok(c) = base_client.with_model(model).await {
187                fallback_clients.push(c);
188            }
189        }
190
191        let (cancel_handle, cancel_rx) = cancel_pair();
192
193        let mut last_err: Option<crate::Error> = None;
194
195        for (candidate_idx, client) in std::iter::once(base_client)
196            .chain(fallback_clients.iter())
197            .enumerate()
198        {
199            let has_fallback = candidate_idx + 1 < (1 + fallback_clients.len());
200            let policy = crate::client::policy::PolicyEngine::new(&client.manifest);
201            let mut attempt: u32 = 0;
202            let mut retry_count: u32 = 0;
203
204            loop {
205                // Pre-decision based on signals (skip known-bad candidates, e.g. breaker open).
206                let sig = client.signals().await;
207                if let Some(crate::client::policy::Decision::Fallback) =
208                    policy.pre_decide(&sig, has_fallback)
209                {
210                    last_err = Some(crate::Error::runtime_with_context(
211                        "skipped candidate due to signals",
212                        crate::ErrorContext::new().with_source("policy_engine"),
213                    ));
214                    break;
215                }
216
217                let mut req = unified_req.clone();
218                if candidate_idx > 0 {
219                    req.model = client.model_id.clone();
220                }
221
222                match client.execute_stream_once(&req).await {
223                    Ok((mut event_stream, permit, mut stats)) => {
224                        // Peek the first item. If it errors BEFORE emitting anything, allow retry/fallback.
225                        // If it yields an event, we commit to this stream (no more retry/fallback).
226                        use futures::StreamExt;
227                        let next_fut = event_stream.next();
228                        let first = if let Some(t) = client.attempt_timeout {
229                            match tokio::time::timeout(t, next_fut).await {
230                                Ok(v) => v,
231                                Err(_) => Some(Err(crate::Error::runtime_with_context(
232                                    "attempt timeout",
233                                    crate::ErrorContext::new().with_source("timeout_policy"),
234                                ))),
235                            }
236                        } else {
237                            next_fut.await
238                        };
239
240                        match first {
241                            None => {
242                                stats.retry_count = retry_count;
243                                stats.emitted_any = false;
244                                base_client.record_success(&stats);
245                                let wrapped = ControlledStream::new(
246                                    Box::pin(futures::stream::empty()),
247                                    Some(cancel_rx),
248                                    permit,
249                                );
250                                return Ok((Box::pin(wrapped), cancel_handle, stats));
251                            }
252                            Some(Ok(first_ev)) => {
253                                let first_ms = stats.duration_ms;
254                                let stream = futures::stream::once(async move { Ok(first_ev) })
255                                    .chain(event_stream);
256                                let wrapped = ControlledStream::new(
257                                    Box::pin(stream.map_err(|e| {
258                                        // If it's already a crate::Error (like Transport error), preserve it.
259                                        // Otherwise, it's likely a downstream pipeline error, wrap it.
260                                        e
261                                    })),
262                                    Some(cancel_rx),
263                                    permit,
264                                );
265
266                                stats.retry_count = retry_count;
267                                stats.first_event_ms = Some(first_ms);
268                                stats.emitted_any = true;
269
270                                base_client.record_success(&stats);
271                                return Ok((Box::pin(wrapped), cancel_handle, stats));
272                            }
273                            Some(Err(e)) => {
274                                let decision = policy.decide(&e, attempt, has_fallback)?;
275                                last_err = Some(e);
276                                match decision {
277                                    crate::client::policy::Decision::Retry { delay } => {
278                                        retry_count = retry_count.saturating_add(1);
279                                        if delay.as_millis() > 0 {
280                                            tokio::time::sleep(delay).await;
281                                        }
282                                        attempt = attempt.saturating_add(1);
283                                        continue;
284                                    }
285                                    crate::client::policy::Decision::Fallback => break,
286                                    crate::client::policy::Decision::Fail => {
287                                        return Err(last_err.unwrap());
288                                    }
289                                }
290                            }
291                        }
292                    }
293                    Err(e) => {
294                        let decision = policy.decide(&e, attempt, has_fallback)?;
295                        last_err = Some(e);
296                        match decision {
297                            crate::client::policy::Decision::Retry { delay } => {
298                                retry_count = retry_count.saturating_add(1);
299                                if delay.as_millis() > 0 {
300                                    tokio::time::sleep(delay).await;
301                                }
302                                attempt = attempt.saturating_add(1);
303                                continue;
304                            }
305                            crate::client::policy::Decision::Fallback => break,
306                            crate::client::policy::Decision::Fail => {
307                                return Err(last_err.unwrap());
308                            }
309                        }
310                    }
311                }
312            }
313        }
314
315        Err(last_err.unwrap_or_else(|| {
316            crate::Error::runtime_with_context(
317                "all streaming attempts failed",
318                crate::ErrorContext::new().with_source("retry_policy"),
319            )
320        }))
321    }
322
323    /// Execute the request and return a cancellable stream of events.
324    ///
325    /// Returns a stream and a [`CancelHandle`]. Call `cancel_handle.cancel()` to stop
326    /// the stream early (e.g., when the user abandons the request).
327    ///
328    /// # Example
329    ///
330    /// ```ignore
331    /// let (mut stream, cancel_handle) = client.chat()
332    ///     .messages(msgs)
333    ///     .stream()
334    ///     .execute_stream_with_cancel()
335    ///     .await?;
336    ///
337    /// // In another task or on user cancel:
338    /// cancel_handle.cancel();
339    ///
340    /// while let Some(event) = stream.next().await {
341    ///     match event? {
342    ///         StreamingEvent::StreamEnd { .. } => break,
343    ///         ev => process(ev),
344    ///     }
345    /// }
346    /// ```
347    pub async fn execute_stream_with_cancel(
348        self,
349    ) -> Result<(
350        Pin<Box<dyn Stream<Item = Result<StreamingEvent>> + Send + 'static>>,
351        CancelHandle,
352    )> {
353        let (s, c, _stats) = self.execute_stream_with_cancel_and_stats().await?;
354        Ok((s, c))
355    }
356
357    /// Execute the request and return the complete response.
358    pub async fn execute(self) -> Result<UnifiedResponse> {
359        let stream_flag = self.stream;
360        let client = self.client;
361        let unified_req = self.into_unified_request();
362
363        // If streaming is not explicitly enabled, use non-streaming execution
364        if !stream_flag {
365            let (resp, _stats) = client.call_model_with_stats(unified_req).await?;
366            return Ok(resp);
367        }
368
369        // For streaming requests, collect all events
370        // Rebuild builder for streaming execution
371        let mut stream = {
372            let builder = ChatRequestBuilder {
373                client,
374                messages: unified_req.messages.clone(),
375                temperature: unified_req.temperature,
376                max_tokens: unified_req.max_tokens,
377                stream: true,
378                tools: unified_req.tools.clone(),
379                tool_choice: unified_req.tool_choice.clone(),
380                model: Some(unified_req.model.clone()),
381                response_format: unified_req.response_format.clone(),
382            };
383            builder.execute_stream().await?
384        };
385        let mut response = UnifiedResponse::default();
386        let mut tool_asm = crate::utils::tool_call_assembler::ToolCallAssembler::new();
387
388        use futures::StreamExt;
389        let mut event_count = 0;
390        while let Some(event) = stream.next().await {
391            event_count += 1;
392            match event? {
393                StreamingEvent::PartialContentDelta { content, .. } => {
394                    response.content.push_str(&content);
395                }
396                StreamingEvent::ToolCallStarted {
397                    tool_call_id,
398                    tool_name,
399                    ..
400                } => {
401                    tool_asm.on_started(tool_call_id, tool_name);
402                }
403                StreamingEvent::PartialToolCall {
404                    tool_call_id,
405                    arguments,
406                    ..
407                } => {
408                    tool_asm.on_partial(&tool_call_id, &arguments);
409                }
410                StreamingEvent::Metadata { usage, .. } => {
411                    response.usage = usage;
412                }
413                StreamingEvent::StreamEnd { .. } => {
414                    break;
415                }
416                StreamingEvent::ThinkingDelta { .. } => {}
417                other => {
418                    // Log unexpected events for debugging
419                    tracing::warn!("Unexpected event in execute(): {:?}", other);
420                }
421            }
422        }
423
424        if event_count == 0 {
425            tracing::warn!(
426                "No events received from stream. Possible causes: provider returned empty stream, \
427                 network interruption, or event mapping configuration issue. Provider: {}, Model: {}",
428                client.manifest.id,
429                client.model_id
430            );
431        } else if response.content.is_empty() {
432            tracing::warn!(
433                "Received {} events but content is empty. This might indicate: (1) provider filtered \
434                 content (safety/content policy), (2) non-streaming response format mismatch, \
435                 (3) event mapping issue. Provider: {}, Model: {}",
436                event_count,
437                client.manifest.id,
438                client.model_id
439            );
440        }
441
442        response.tool_calls = tool_asm.finalize();
443
444        Ok(response)
445    }
446
447    fn into_unified_request(self) -> crate::protocol::UnifiedRequest {
448        let model = self.model.unwrap_or_else(|| self.client.model_id.clone());
449        crate::protocol::UnifiedRequest {
450            operation: "chat".to_string(),
451            model,
452            messages: self.messages,
453            temperature: self.temperature,
454            max_tokens: self.max_tokens,
455            stream: self.stream,
456            tools: self.tools,
457            tool_choice: self.tool_choice,
458            response_format: self.response_format,
459        }
460    }
461}