Skip to main content

chat_core/chat/
completion.rs

1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::{ChatOutcome, PauseReason, StructuredResponse};
5use crate::{
6    chat::state::{Structured, Unstructured},
7    error::{ChatError, ChatFailure},
8    traits::CompletionProvider,
9    types::{
10        callback::CallbackRetryContext,
11        messages::{Messages, content::Content, parts::PartEnum},
12        metadata::Metadata,
13        response::ChatResponse,
14    },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19    /// Run the chat loop until the model completes, max_steps is reached,
20    /// or a tool call strategy pauses execution (pending approval,
21    /// scheduled, etc). Callers handle `ChatOutcome::Paused` by mutating
22    /// pending tool statuses and invoking [`Chat::resume`].
23    pub async fn complete(
24        &mut self,
25        messages: &mut Messages,
26    ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
27        self.execute_with_retries(messages, |response| {
28            Ok(ChatResponse {
29                content: response.content.clone(),
30                metadata: response.metadata.clone(),
31            })
32        })
33        .await
34    }
35
36    /// Resume a loop that previously returned `ChatOutcome::Paused`. The
37    /// caller is expected to have resolved at least one pending tool
38    /// (typically by calling `tool.approve(...)` or `tool.reject(...)`
39    /// on each) before calling resume.
40    pub async fn resume(
41        &mut self,
42        messages: &mut Messages,
43    ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
44        self.resume_with(messages, |response| {
45            Ok(ChatResponse {
46                content: response.content.clone(),
47                metadata: response.metadata.clone(),
48            })
49        })
50        .await
51    }
52}
53
54impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
55where
56    T: DeserializeOwned + JsonSchema,
57{
58    pub async fn complete(
59        &mut self,
60        messages: &mut Messages,
61    ) -> Result<ChatOutcome<StructuredResponse<T>>, ChatFailure> {
62        self.execute_with_retries(messages, |response| {
63            let value = extract_structured_candidate(&response.content).ok_or_else(|| {
64                ChatError::InvalidResponse(
65                    "Response did not contain valid structured output".into(),
66                )
67            })?;
68            serde_json::from_value::<T>(value.clone())
69                .map(|content| StructuredResponse {
70                    content,
71                    metadata: response.metadata.clone(),
72                })
73                .map_err(|err| {
74                    ChatError::InvalidResponse(format!(
75                        "Failed to parse structured output: {}",
76                        err
77                    ))
78                })
79        })
80        .await
81    }
82}
83
84/// Internal loop result: either the model reached a terminal text/structured
85/// response, or a tool-call strategy paused us.
86enum LoopStep {
87    Complete(ChatResponse),
88    Paused(PauseReason, Option<Metadata>),
89}
90
91impl<CP: CompletionProvider, Output> Chat<CP, Output> {
92    async fn call_loop(&mut self, messages: &mut Messages) -> Result<LoopStep, ChatFailure> {
93        let mut last_metadata: Option<Metadata> = None;
94
95        if let Some(last) = messages.0.last_mut() {
96            let pre = self.tool_call(last).await.map_err(|err| ChatFailure {
97                err,
98                metadata: None,
99            })?;
100            if let Some(reason) = pre.pause {
101                return Ok(LoopStep::Paused(reason, last_metadata));
102            }
103        }
104
105        for _ in 0..self.max_steps.unwrap_or(1) {
106            let decls = crate::chat::tool_declarations_from(&self.scoped_collections);
107            let decls_dyn = decls
108                .as_ref()
109                .map(|d| d as &dyn crate::types::tools::ToolDeclarations);
110            let response = self
111                .model
112                .complete(
113                    messages,
114                    decls_dyn,
115                    self.model_options.as_ref(),
116                    self.output_shape.as_ref(),
117                )
118                .await?;
119
120            if let Some(metadata) = response.metadata.clone() {
121                match &mut last_metadata {
122                    Some(existing) => {
123                        existing.extend(&metadata);
124                    }
125                    None => {
126                        last_metadata = Some(metadata);
127                    }
128                }
129            }
130
131            messages.push(response.content.clone());
132
133            let pass = match messages.0.last_mut() {
134                Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
135                    err,
136                    metadata: last_metadata.clone(),
137                })?,
138                None => crate::chat::ToolCallPass::default(),
139            };
140
141            if let Some(reason) = pass.pause {
142                return Ok(LoopStep::Paused(reason, last_metadata));
143            }
144            if pass.executed {
145                continue;
146            }
147
148            match response.content.parts.last() {
149                Some(res) => match res {
150                    PartEnum::Text(_) | PartEnum::Structured(_) => {
151                        return Ok(LoopStep::Complete(ChatResponse {
152                            metadata: last_metadata,
153                            content: response.content,
154                        }));
155                    }
156                    PartEnum::Reasoning(_) => {
157                        continue;
158                    }
159                    _ => {}
160                },
161                None => {
162                    return Err(ChatFailure {
163                        err: ChatError::InvalidResponse(
164                            "Response did not generate any parts".to_string(),
165                        ),
166                        metadata: last_metadata,
167                    });
168                }
169            };
170        }
171
172        Err(ChatFailure {
173            err: ChatError::MaxStepsExceeded,
174            metadata: last_metadata,
175        })
176    }
177
178    async fn execute_with_retries<F, R>(
179        &mut self,
180        messages: &mut Messages,
181        mut processor: F,
182    ) -> Result<ChatOutcome<R>, ChatFailure>
183    where
184        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
185    {
186        let max_retries = self.max_retries.unwrap_or(1);
187        let mut last_err: Option<ChatError> = None;
188        let mut last_metadata: Option<Metadata> = None;
189
190        if let Some(strategy) = self.before_strategy.as_mut() {
191            strategy(messages, last_metadata.as_ref()).await;
192        }
193
194        for idx in 0..max_retries {
195            let original_len = messages.len();
196            match self.call_loop(messages).await {
197                Ok(LoopStep::Paused(reason, _metadata)) => {
198                    return Ok(ChatOutcome::Paused { reason });
199                }
200                Ok(LoopStep::Complete(response)) => {
201                    if let Some(metadata) = response.metadata.clone() {
202                        match &mut last_metadata {
203                            Some(existing) => {
204                                existing.extend(&metadata);
205                            }
206                            None => {
207                                last_metadata = Some(metadata);
208                            }
209                        }
210                    }
211
212                    match processor(&response) {
213                        Ok(parsed_result) => {
214                            if let Some(strategy) = self.after_strategy.as_mut() {
215                                strategy(messages, last_metadata.as_ref()).await;
216                            }
217                            return Ok(ChatOutcome::Complete(parsed_result));
218                        }
219                        Err(err) => {
220                            last_err = Some(err.clone());
221                            if idx + 1 < max_retries {
222                                let ctx = CallbackRetryContext {
223                                    idx,
224                                    failure: ChatFailure {
225                                        err,
226                                        metadata: last_metadata.clone(),
227                                    },
228                                };
229                                if let Some(strategy) = self.retry_strategy.as_mut() {
230                                    strategy(messages, last_metadata.as_ref(), ctx).await;
231                                }
232                            }
233                        }
234                    }
235                }
236                Err(failure) => {
237                    if let Some(metadata) = failure.metadata.clone() {
238                        match &mut last_metadata {
239                            Some(existing) => {
240                                existing.extend(&metadata);
241                            }
242                            None => {
243                                last_metadata = Some(metadata);
244                            }
245                        }
246                    }
247
248                    last_err = Some(failure.err.clone());
249
250                    if !failure.err.is_retryable() {
251                        break;
252                    }
253
254                    if idx + 1 < max_retries {
255                        let ctx = CallbackRetryContext { idx, failure };
256                        if let Some(strategy) = self.retry_strategy.as_mut() {
257                            strategy(messages, last_metadata.as_ref(), ctx).await;
258                        }
259                    }
260                }
261            }
262
263            messages.0.truncate(original_len);
264        }
265
266        Err(ChatFailure {
267            metadata: last_metadata,
268            err: last_err.unwrap_or(ChatError::RateLimited),
269        })
270    }
271
272    /// Resume helper for structured / unstructured variants. Does not
273    /// run retries — resume is always a continuation of a prior attempt,
274    /// not a new one.
275    async fn resume_with<F, R>(
276        &mut self,
277        messages: &mut Messages,
278        mut processor: F,
279    ) -> Result<ChatOutcome<R>, ChatFailure>
280    where
281        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
282    {
283        match self.call_loop(messages).await? {
284            LoopStep::Paused(reason, _) => Ok(ChatOutcome::Paused { reason }),
285            LoopStep::Complete(response) => match processor(&response) {
286                Ok(parsed) => Ok(ChatOutcome::Complete(parsed)),
287                Err(err) => Err(ChatFailure {
288                    err,
289                    metadata: response.metadata,
290                }),
291            },
292        }
293    }
294}
295
296fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
297    let last = content.parts.last()?;
298
299    match last {
300        PartEnum::Structured(v) => Some(v.clone()),
301        PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
302        _ => None,
303    }
304}