Skip to main content

chat_core/chat/
completion.rs

1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::{ChatOutcome, PauseReason, StructuredResponse};
5use crate::{
6    chat::state::{Structured, Unstructured},
7    error::{ChatError, ChatFailure},
8    traits::CompletionProvider,
9    types::{
10        callback::CallbackRetryContext,
11        messages::{Messages, content::Content, parts::PartEnum},
12        metadata::Metadata,
13        response::ChatResponse,
14    },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19    /// Run the chat loop until the model completes, max_steps is reached,
20    /// or a tool call strategy pauses execution (pending approval,
21    /// scheduled, etc). Callers handle `ChatOutcome::Paused` by mutating
22    /// pending tool statuses and invoking [`Chat::resume`].
23    pub async fn complete(
24        &mut self,
25        messages: &mut Messages,
26    ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
27        self.execute_with_retries(messages, |response| {
28            Ok(ChatResponse {
29                content: response.content.clone(),
30                metadata: response.metadata.clone(),
31            })
32        })
33        .await
34    }
35
36    /// Resume a loop that previously returned `ChatOutcome::Paused`. The
37    /// caller is expected to have resolved at least one pending tool
38    /// (typically by calling `tool.approve(...)` or `tool.reject(...)`
39    /// on each) before calling resume.
40    pub async fn resume(
41        &mut self,
42        messages: &mut Messages,
43    ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
44        self.resume_with(messages, |response| {
45            Ok(ChatResponse {
46                content: response.content.clone(),
47                metadata: response.metadata.clone(),
48            })
49        })
50        .await
51    }
52}
53
54impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
55where
56    T: DeserializeOwned + JsonSchema,
57{
58    pub async fn complete(
59        &mut self,
60        messages: &mut Messages,
61    ) -> Result<ChatOutcome<StructuredResponse<T>>, ChatFailure> {
62        self.execute_with_retries(messages, |response| {
63            let value = extract_structured_candidate(&response.content).ok_or_else(|| {
64                ChatError::InvalidResponse(
65                    "Response did not contain valid structured output".into(),
66                )
67            })?;
68            serde_json::from_value::<T>(value.clone())
69                .map(|content| StructuredResponse {
70                    content,
71                    metadata: response.metadata.clone(),
72                })
73                .map_err(|err| {
74                    ChatError::InvalidResponse(format!(
75                        "Failed to parse structured output: {}",
76                        err
77                    ))
78                })
79        })
80        .await
81    }
82}
83
84/// Internal loop result: either the model reached a terminal text/structured
85/// response, or a tool-call strategy paused us.
86enum LoopStep {
87    Complete(ChatResponse),
88    Paused(PauseReason, Option<Metadata>),
89}
90
91impl<CP: CompletionProvider, Output> Chat<CP, Output> {
92    async fn call_loop(&mut self, messages: &mut Messages) -> Result<LoopStep, ChatFailure> {
93        let mut last_metadata: Option<Metadata> = None;
94
95        // First attempt to resume any existing pending work on the last
96        // Content before calling the model. Lets `resume()` pick up where
97        // `complete()` paused without re-entering the provider.
98        if let Some(last) = messages.0.last_mut() {
99            let pre = self.tool_call(last).await.map_err(|err| ChatFailure {
100                err,
101                metadata: None,
102            })?;
103            if let Some(reason) = pre.pause {
104                return Ok(LoopStep::Paused(reason, last_metadata));
105            }
106            // If any tools just ran, fall through into the normal loop
107            // so the model sees the results.
108        }
109
110        for _ in 0..self.max_steps.unwrap_or(1) {
111            // Split the borrows manually: `decls` views only
112            // `scoped_collections`, leaving `self.model` free to borrow
113            // mutably for the `complete()` call.
114            let decls = crate::chat::tool_declarations_from(&self.scoped_collections);
115            let decls_dyn = decls
116                .as_ref()
117                .map(|d| d as &dyn crate::types::tools::ToolDeclarations);
118            let response = self
119                .model
120                .complete(
121                    messages,
122                    decls_dyn,
123                    self.model_options.as_ref(),
124                    self.output_shape.as_ref(),
125                )
126                .await?;
127
128            if let Some(metadata) = response.metadata.clone() {
129                match &mut last_metadata {
130                    Some(existing) => {
131                        existing.extend(&metadata);
132                    }
133                    None => {
134                        last_metadata = Some(metadata);
135                    }
136                }
137            }
138
139            messages.push(response.content.clone());
140
141            // Walk the just-pushed Content: run tools whose strategy
142            // says Execute, leave tools that need approval/deferral in
143            // a non-resolved state, and return Paused if any remain.
144            let pass = match messages.0.last_mut() {
145                Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
146                    err,
147                    metadata: last_metadata.clone(),
148                })?,
149                None => crate::chat::ToolCallPass::default(),
150            };
151
152            if let Some(reason) = pass.pause {
153                return Ok(LoopStep::Paused(reason, last_metadata));
154            }
155            if pass.executed {
156                continue;
157            }
158
159            match response.content.parts.last() {
160                Some(res) => match res {
161                    PartEnum::Text(_) | PartEnum::Structured(_) => {
162                        return Ok(LoopStep::Complete(ChatResponse {
163                            metadata: last_metadata,
164                            content: response.content,
165                        }));
166                    }
167                    PartEnum::Reasoning(_) => {
168                        continue;
169                    }
170                    _ => {}
171                },
172                None => {
173                    return Err(ChatFailure {
174                        err: ChatError::InvalidResponse(
175                            "Response did not generate any parts".to_string(),
176                        ),
177                        metadata: last_metadata,
178                    });
179                }
180            };
181        }
182
183        Err(ChatFailure {
184            err: ChatError::MaxStepsExceeded,
185            metadata: last_metadata,
186        })
187    }
188
189    async fn execute_with_retries<F, R>(
190        &mut self,
191        messages: &mut Messages,
192        mut processor: F,
193    ) -> Result<ChatOutcome<R>, ChatFailure>
194    where
195        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
196    {
197        let max_retries = self.max_retries.unwrap_or(1);
198        let mut last_err: Option<ChatError> = None;
199        let mut last_metadata: Option<Metadata> = None;
200
201        if let Some(strategy) = self.before_strategy.as_mut() {
202            strategy(messages, last_metadata.as_ref()).await;
203        }
204
205        for idx in 0..max_retries {
206            let original_len = messages.len();
207            match self.call_loop(messages).await {
208                Ok(LoopStep::Paused(reason, _metadata)) => {
209                    // Pause short-circuits the retry loop — metadata from
210                    // the paused step isn't currently surfaced to the
211                    // caller (ChatOutcome::Paused carries no metadata).
212                    // Add it to PauseReason's envelope if that changes.
213                    return Ok(ChatOutcome::Paused { reason });
214                }
215                Ok(LoopStep::Complete(response)) => {
216                    if let Some(metadata) = response.metadata.clone() {
217                        match &mut last_metadata {
218                            Some(existing) => {
219                                existing.extend(&metadata);
220                            }
221                            None => {
222                                last_metadata = Some(metadata);
223                            }
224                        }
225                    }
226
227                    match processor(&response) {
228                        Ok(parsed_result) => {
229                            if let Some(strategy) = self.after_strategy.as_mut() {
230                                strategy(messages, last_metadata.as_ref()).await;
231                            }
232                            return Ok(ChatOutcome::Complete(parsed_result));
233                        }
234                        Err(err) => {
235                            last_err = Some(err.clone());
236                            if idx + 1 < max_retries {
237                                let ctx = CallbackRetryContext {
238                                    idx,
239                                    failure: ChatFailure {
240                                        err,
241                                        metadata: last_metadata.clone(),
242                                    },
243                                };
244                                if let Some(strategy) = self.retry_strategy.as_mut() {
245                                    strategy(messages, last_metadata.as_ref(), ctx).await;
246                                }
247                            }
248                        }
249                    }
250                }
251                Err(failure) => {
252                    if let Some(metadata) = failure.metadata.clone() {
253                        match &mut last_metadata {
254                            Some(existing) => {
255                                existing.extend(&metadata);
256                            }
257                            None => {
258                                last_metadata = Some(metadata);
259                            }
260                        }
261                    }
262
263                    last_err = Some(failure.err.clone());
264
265                    if !failure.err.is_retryable() {
266                        break;
267                    }
268
269                    if idx + 1 < max_retries {
270                        let ctx = CallbackRetryContext { idx, failure };
271                        if let Some(strategy) = self.retry_strategy.as_mut() {
272                            strategy(messages, last_metadata.as_ref(), ctx).await;
273                        }
274                    }
275                }
276            }
277
278            messages.0.truncate(original_len);
279        }
280
281        Err(ChatFailure {
282            metadata: last_metadata,
283            err: last_err.unwrap_or(ChatError::RateLimited),
284        })
285    }
286
287    /// Resume helper for structured / unstructured variants. Does not
288    /// run retries — resume is always a continuation of a prior attempt,
289    /// not a new one.
290    async fn resume_with<F, R>(
291        &mut self,
292        messages: &mut Messages,
293        mut processor: F,
294    ) -> Result<ChatOutcome<R>, ChatFailure>
295    where
296        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
297    {
298        match self.call_loop(messages).await? {
299            LoopStep::Paused(reason, _) => Ok(ChatOutcome::Paused { reason }),
300            LoopStep::Complete(response) => match processor(&response) {
301                Ok(parsed) => Ok(ChatOutcome::Complete(parsed)),
302                Err(err) => Err(ChatFailure {
303                    err,
304                    metadata: response.metadata,
305                }),
306            },
307        }
308    }
309}
310
311fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
312    let last = content.parts.last()?;
313
314    match last {
315        PartEnum::Structured(v) => Some(v.clone()),
316        PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
317        _ => None,
318    }
319}