Skip to main content

chat_core/chat/
completion.rs

1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::{ChatOutcome, PauseReason, StructuredResponse};
5use crate::{
6    chat::state::{Structured, Unstructured},
7    error::{ChatError, ChatFailure},
8    traits::CompletionProvider,
9    types::{
10        callback::CallbackRetryContext,
11        messages::{Messages, content::Content, parts::PartEnum},
12        metadata::Metadata,
13        response::ChatResponse,
14    },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19    /// Run the chat loop until the model completes, max_steps is reached,
20    /// or a tool call strategy pauses execution (pending approval,
21    /// scheduled, etc). Callers handle `ChatOutcome::Paused` by mutating
22    /// pending tool statuses and invoking [`Chat::resume`].
23    pub async fn complete(
24        &mut self,
25        messages: &mut Messages,
26    ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
27        self.execute_with_retries(messages, |response| {
28            Ok(ChatResponse {
29                content: response.content.clone(),
30                metadata: response.metadata.clone(),
31            })
32        })
33        .await
34    }
35
36    /// Resume a loop that previously returned `ChatOutcome::Paused`. The
37    /// caller is expected to have resolved at least one pending tool
38    /// (typically by calling `tool.approve(...)` or `tool.reject(...)`
39    /// on each) before calling resume.
40    pub async fn resume(
41        &mut self,
42        messages: &mut Messages,
43    ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
44        self.resume_with(messages, |response| {
45            Ok(ChatResponse {
46                content: response.content.clone(),
47                metadata: response.metadata.clone(),
48            })
49        })
50        .await
51    }
52}
53
54impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
55where
56    T: DeserializeOwned + JsonSchema,
57{
58    pub async fn complete(
59        &mut self,
60        messages: &mut Messages,
61    ) -> Result<ChatOutcome<StructuredResponse<T>>, ChatFailure> {
62        self.execute_with_retries(messages, |response| {
63            let value = extract_structured_candidate(&response.content).ok_or_else(|| {
64                ChatError::InvalidResponse(
65                    "Response did not contain valid structured output".into(),
66                )
67            })?;
68            serde_json::from_value::<T>(value.clone())
69                .map(|content| StructuredResponse {
70                    content,
71                    metadata: response.metadata.clone(),
72                })
73                .map_err(|err| {
74                    ChatError::InvalidResponse(format!(
75                        "Failed to parse structured output: {}",
76                        err
77                    ))
78                })
79        })
80        .await
81    }
82}
83
84/// Internal loop result: either the model reached a terminal text/structured
85/// response, or a tool-call strategy paused us.
86enum LoopStep {
87    Complete(ChatResponse),
88    Paused(PauseReason, Option<Metadata>),
89}
90
91impl<CP: CompletionProvider, Output> Chat<CP, Output> {
92    async fn call_loop(&mut self, messages: &mut Messages) -> Result<LoopStep, ChatFailure> {
93        let mut last_metadata: Option<Metadata> = None;
94
95        // First attempt to resume any existing pending work on the last
96        // Content before calling the model. Lets `resume()` pick up where
97        // `complete()` paused without re-entering the provider.
98        if let Some(last) = messages.0.last_mut() {
99            let pre = self.tool_call(last).await.map_err(|err| ChatFailure {
100                err,
101                metadata: None,
102            })?;
103            if let Some(reason) = pre.pause {
104                return Ok(LoopStep::Paused(reason, last_metadata));
105            }
106            // If any tools just ran, fall through into the normal loop
107            // so the model sees the results.
108        }
109
110        for _ in 0..self.max_steps.unwrap_or(1) {
111            // Split the borrows manually: `decls` views only
112            // `scoped_collections`, leaving `self.model` free to borrow
113            // mutably for the `complete()` call.
114            let decls =
115                crate::chat::tool_declarations_from(&self.scoped_collections);
116            let decls_dyn = decls
117                .as_ref()
118                .map(|d| d as &dyn crate::types::tools::ToolDeclarations);
119            let response = self
120                .model
121                .complete(
122                    messages,
123                    decls_dyn,
124                    self.model_options.as_ref(),
125                    self.output_shape.as_ref(),
126                )
127                .await?;
128
129            if let Some(metadata) = response.metadata.clone() {
130                match &mut last_metadata {
131                    Some(existing) => {
132                        existing.extend(&metadata);
133                    }
134                    None => {
135                        last_metadata = Some(metadata);
136                    }
137                }
138            }
139
140            messages.push(response.content.clone());
141
142            // Walk the just-pushed Content: run tools whose strategy
143            // says Execute, leave tools that need approval/deferral in
144            // a non-resolved state, and return Paused if any remain.
145            let pass = match messages.0.last_mut() {
146                Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
147                    err,
148                    metadata: last_metadata.clone(),
149                })?,
150                None => crate::chat::ToolCallPass::default(),
151            };
152
153            if let Some(reason) = pass.pause {
154                return Ok(LoopStep::Paused(reason, last_metadata));
155            }
156            if pass.executed {
157                continue;
158            }
159
160            match response.content.parts.last() {
161                Some(res) => match res {
162                    PartEnum::Text(_) | PartEnum::Structured(_) => {
163                        return Ok(LoopStep::Complete(ChatResponse {
164                            metadata: last_metadata,
165                            content: response.content,
166                        }));
167                    }
168                    PartEnum::Reasoning(_) => {
169                        continue;
170                    }
171                    _ => {}
172                },
173                None => {
174                    return Err(ChatFailure {
175                        err: ChatError::InvalidResponse(
176                            "Response did not generate any parts".to_string(),
177                        ),
178                        metadata: last_metadata,
179                    });
180                }
181            };
182        }
183
184        Err(ChatFailure {
185            err: ChatError::MaxStepsExceeded,
186            metadata: last_metadata,
187        })
188    }
189
190    async fn execute_with_retries<F, R>(
191        &mut self,
192        messages: &mut Messages,
193        mut processor: F,
194    ) -> Result<ChatOutcome<R>, ChatFailure>
195    where
196        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
197    {
198        let max_retries = self.max_retries.unwrap_or(1);
199        let mut last_err: Option<ChatError> = None;
200        let mut last_metadata: Option<Metadata> = None;
201
202        if let Some(strategy) = self.before_strategy.as_mut() {
203            strategy(messages, last_metadata.as_ref()).await;
204        }
205
206        for idx in 0..max_retries {
207            let original_len = messages.len();
208            match self.call_loop(messages).await {
209                Ok(LoopStep::Paused(reason, _metadata)) => {
210                    // Pause short-circuits the retry loop — metadata from
211                    // the paused step isn't currently surfaced to the
212                    // caller (ChatOutcome::Paused carries no metadata).
213                    // Add it to PauseReason's envelope if that changes.
214                    return Ok(ChatOutcome::Paused { reason });
215                }
216                Ok(LoopStep::Complete(response)) => {
217                    if let Some(metadata) = response.metadata.clone() {
218                        match &mut last_metadata {
219                            Some(existing) => {
220                                existing.extend(&metadata);
221                            }
222                            None => {
223                                last_metadata = Some(metadata);
224                            }
225                        }
226                    }
227
228                    match processor(&response) {
229                        Ok(parsed_result) => {
230                            if let Some(strategy) = self.after_strategy.as_mut() {
231                                strategy(messages, last_metadata.as_ref()).await;
232                            }
233                            return Ok(ChatOutcome::Complete(parsed_result));
234                        }
235                        Err(err) => {
236                            last_err = Some(err.clone());
237                            if idx + 1 < max_retries {
238                                let ctx = CallbackRetryContext {
239                                    idx,
240                                    failure: ChatFailure {
241                                        err,
242                                        metadata: last_metadata.clone(),
243                                    },
244                                };
245                                if let Some(strategy) = self.retry_strategy.as_mut() {
246                                    strategy(messages, last_metadata.as_ref(), ctx).await;
247                                }
248                            }
249                        }
250                    }
251                }
252                Err(failure) => {
253                    if let Some(metadata) = failure.metadata.clone() {
254                        match &mut last_metadata {
255                            Some(existing) => {
256                                existing.extend(&metadata);
257                            }
258                            None => {
259                                last_metadata = Some(metadata);
260                            }
261                        }
262                    }
263
264                    last_err = Some(failure.err.clone());
265
266                    if !failure.err.is_retryable() {
267                        break;
268                    }
269
270                    if idx + 1 < max_retries {
271                        let ctx = CallbackRetryContext { idx, failure };
272                        if let Some(strategy) = self.retry_strategy.as_mut() {
273                            strategy(messages, last_metadata.as_ref(), ctx).await;
274                        }
275                    }
276                }
277            }
278
279            messages.0.truncate(original_len);
280        }
281
282        Err(ChatFailure {
283            metadata: last_metadata,
284            err: last_err.unwrap_or(ChatError::RateLimited),
285        })
286    }
287
288    /// Resume helper for structured / unstructured variants. Does not
289    /// run retries — resume is always a continuation of a prior attempt,
290    /// not a new one.
291    async fn resume_with<F, R>(
292        &mut self,
293        messages: &mut Messages,
294        mut processor: F,
295    ) -> Result<ChatOutcome<R>, ChatFailure>
296    where
297        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
298    {
299        match self.call_loop(messages).await? {
300            LoopStep::Paused(reason, _) => Ok(ChatOutcome::Paused { reason }),
301            LoopStep::Complete(response) => {
302                match processor(&response) {
303                    Ok(parsed) => Ok(ChatOutcome::Complete(parsed)),
304                    Err(err) => Err(ChatFailure {
305                        err,
306                        metadata: response.metadata,
307                    }),
308                }
309            }
310        }
311    }
312}
313
314fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
315    let last = content.parts.last()?;
316
317    match last {
318        PartEnum::Structured(v) => Some(v.clone()),
319        PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
320        _ => None,
321    }
322}