1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::{ChatOutcome, PauseReason, StructuredResponse};
5use crate::{
6 chat::state::{Structured, Unstructured},
7 error::{ChatError, ChatFailure},
8 traits::CompletionProvider,
9 types::{
10 callback::CallbackRetryContext,
11 messages::{Messages, content::Content, parts::PartEnum},
12 metadata::Metadata,
13 response::ChatResponse,
14 },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19 pub async fn complete(
24 &mut self,
25 messages: &mut Messages,
26 ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
27 self.execute_with_retries(messages, |response| {
28 Ok(ChatResponse {
29 content: response.content.clone(),
30 metadata: response.metadata.clone(),
31 })
32 })
33 .await
34 }
35
36 pub async fn resume(
41 &mut self,
42 messages: &mut Messages,
43 ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
44 self.resume_with(messages, |response| {
45 Ok(ChatResponse {
46 content: response.content.clone(),
47 metadata: response.metadata.clone(),
48 })
49 })
50 .await
51 }
52}
53
54impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
55where
56 T: DeserializeOwned + JsonSchema,
57{
58 pub async fn complete(
59 &mut self,
60 messages: &mut Messages,
61 ) -> Result<ChatOutcome<StructuredResponse<T>>, ChatFailure> {
62 self.execute_with_retries(messages, |response| {
63 let value = extract_structured_candidate(&response.content).ok_or_else(|| {
64 ChatError::InvalidResponse(
65 "Response did not contain valid structured output".into(),
66 )
67 })?;
68 serde_json::from_value::<T>(value.clone())
69 .map(|content| StructuredResponse {
70 content,
71 metadata: response.metadata.clone(),
72 })
73 .map_err(|err| {
74 ChatError::InvalidResponse(format!(
75 "Failed to parse structured output: {}",
76 err
77 ))
78 })
79 })
80 .await
81 }
82}
83
84enum LoopStep {
87 Complete(ChatResponse),
88 Paused(PauseReason, Option<Metadata>),
89}
90
91impl<CP: CompletionProvider, Output> Chat<CP, Output> {
92 async fn call_loop(&mut self, messages: &mut Messages) -> Result<LoopStep, ChatFailure> {
93 let mut last_metadata: Option<Metadata> = None;
94
95 if let Some(last) = messages.0.last_mut() {
99 let pre = self.tool_call(last).await.map_err(|err| ChatFailure {
100 err,
101 metadata: None,
102 })?;
103 if let Some(reason) = pre.pause {
104 return Ok(LoopStep::Paused(reason, last_metadata));
105 }
106 }
109
110 for _ in 0..self.max_steps.unwrap_or(1) {
111 let decls =
115 crate::chat::tool_declarations_from(&self.scoped_collections);
116 let decls_dyn = decls
117 .as_ref()
118 .map(|d| d as &dyn crate::types::tools::ToolDeclarations);
119 let response = self
120 .model
121 .complete(
122 messages,
123 decls_dyn,
124 self.model_options.as_ref(),
125 self.output_shape.as_ref(),
126 )
127 .await?;
128
129 if let Some(metadata) = response.metadata.clone() {
130 match &mut last_metadata {
131 Some(existing) => {
132 existing.extend(&metadata);
133 }
134 None => {
135 last_metadata = Some(metadata);
136 }
137 }
138 }
139
140 messages.push(response.content.clone());
141
142 let pass = match messages.0.last_mut() {
146 Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
147 err,
148 metadata: last_metadata.clone(),
149 })?,
150 None => crate::chat::ToolCallPass::default(),
151 };
152
153 if let Some(reason) = pass.pause {
154 return Ok(LoopStep::Paused(reason, last_metadata));
155 }
156 if pass.executed {
157 continue;
158 }
159
160 match response.content.parts.last() {
161 Some(res) => match res {
162 PartEnum::Text(_) | PartEnum::Structured(_) => {
163 return Ok(LoopStep::Complete(ChatResponse {
164 metadata: last_metadata,
165 content: response.content,
166 }));
167 }
168 PartEnum::Reasoning(_) => {
169 continue;
170 }
171 _ => {}
172 },
173 None => {
174 return Err(ChatFailure {
175 err: ChatError::InvalidResponse(
176 "Response did not generate any parts".to_string(),
177 ),
178 metadata: last_metadata,
179 });
180 }
181 };
182 }
183
184 Err(ChatFailure {
185 err: ChatError::MaxStepsExceeded,
186 metadata: last_metadata,
187 })
188 }
189
190 async fn execute_with_retries<F, R>(
191 &mut self,
192 messages: &mut Messages,
193 mut processor: F,
194 ) -> Result<ChatOutcome<R>, ChatFailure>
195 where
196 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
197 {
198 let max_retries = self.max_retries.unwrap_or(1);
199 let mut last_err: Option<ChatError> = None;
200 let mut last_metadata: Option<Metadata> = None;
201
202 if let Some(strategy) = self.before_strategy.as_mut() {
203 strategy(messages, last_metadata.as_ref()).await;
204 }
205
206 for idx in 0..max_retries {
207 let original_len = messages.len();
208 match self.call_loop(messages).await {
209 Ok(LoopStep::Paused(reason, _metadata)) => {
210 return Ok(ChatOutcome::Paused { reason });
215 }
216 Ok(LoopStep::Complete(response)) => {
217 if let Some(metadata) = response.metadata.clone() {
218 match &mut last_metadata {
219 Some(existing) => {
220 existing.extend(&metadata);
221 }
222 None => {
223 last_metadata = Some(metadata);
224 }
225 }
226 }
227
228 match processor(&response) {
229 Ok(parsed_result) => {
230 if let Some(strategy) = self.after_strategy.as_mut() {
231 strategy(messages, last_metadata.as_ref()).await;
232 }
233 return Ok(ChatOutcome::Complete(parsed_result));
234 }
235 Err(err) => {
236 last_err = Some(err.clone());
237 if idx + 1 < max_retries {
238 let ctx = CallbackRetryContext {
239 idx,
240 failure: ChatFailure {
241 err,
242 metadata: last_metadata.clone(),
243 },
244 };
245 if let Some(strategy) = self.retry_strategy.as_mut() {
246 strategy(messages, last_metadata.as_ref(), ctx).await;
247 }
248 }
249 }
250 }
251 }
252 Err(failure) => {
253 if let Some(metadata) = failure.metadata.clone() {
254 match &mut last_metadata {
255 Some(existing) => {
256 existing.extend(&metadata);
257 }
258 None => {
259 last_metadata = Some(metadata);
260 }
261 }
262 }
263
264 last_err = Some(failure.err.clone());
265
266 if !failure.err.is_retryable() {
267 break;
268 }
269
270 if idx + 1 < max_retries {
271 let ctx = CallbackRetryContext { idx, failure };
272 if let Some(strategy) = self.retry_strategy.as_mut() {
273 strategy(messages, last_metadata.as_ref(), ctx).await;
274 }
275 }
276 }
277 }
278
279 messages.0.truncate(original_len);
280 }
281
282 Err(ChatFailure {
283 metadata: last_metadata,
284 err: last_err.unwrap_or(ChatError::RateLimited),
285 })
286 }
287
288 async fn resume_with<F, R>(
292 &mut self,
293 messages: &mut Messages,
294 mut processor: F,
295 ) -> Result<ChatOutcome<R>, ChatFailure>
296 where
297 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
298 {
299 match self.call_loop(messages).await? {
300 LoopStep::Paused(reason, _) => Ok(ChatOutcome::Paused { reason }),
301 LoopStep::Complete(response) => {
302 match processor(&response) {
303 Ok(parsed) => Ok(ChatOutcome::Complete(parsed)),
304 Err(err) => Err(ChatFailure {
305 err,
306 metadata: response.metadata,
307 }),
308 }
309 }
310 }
311 }
312}
313
314fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
315 let last = content.parts.last()?;
316
317 match last {
318 PartEnum::Structured(v) => Some(v.clone()),
319 PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
320 _ => None,
321 }
322}