1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::{ChatOutcome, PauseReason, StructuredResponse};
5use crate::{
6 chat::state::{Structured, Unstructured},
7 error::{ChatError, ChatFailure},
8 traits::CompletionProvider,
9 types::{
10 callback::CallbackRetryContext,
11 messages::{Messages, content::Content, parts::PartEnum},
12 metadata::Metadata,
13 response::ChatResponse,
14 },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19 pub async fn complete(
24 &mut self,
25 messages: &mut Messages,
26 ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
27 self.execute_with_retries(messages, |response| {
28 Ok(ChatResponse {
29 content: response.content.clone(),
30 metadata: response.metadata.clone(),
31 })
32 })
33 .await
34 }
35
36 pub async fn resume(
41 &mut self,
42 messages: &mut Messages,
43 ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
44 self.resume_with(messages, |response| {
45 Ok(ChatResponse {
46 content: response.content.clone(),
47 metadata: response.metadata.clone(),
48 })
49 })
50 .await
51 }
52}
53
54impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
55where
56 T: DeserializeOwned + JsonSchema,
57{
58 pub async fn complete(
59 &mut self,
60 messages: &mut Messages,
61 ) -> Result<ChatOutcome<StructuredResponse<T>>, ChatFailure> {
62 self.execute_with_retries(messages, |response| {
63 let value = extract_structured_candidate(&response.content).ok_or_else(|| {
64 ChatError::InvalidResponse(
65 "Response did not contain valid structured output".into(),
66 )
67 })?;
68 serde_json::from_value::<T>(value.clone())
69 .map(|content| StructuredResponse {
70 content,
71 metadata: response.metadata.clone(),
72 })
73 .map_err(|err| {
74 ChatError::InvalidResponse(format!(
75 "Failed to parse structured output: {}",
76 err
77 ))
78 })
79 })
80 .await
81 }
82}
83
84enum LoopStep {
87 Complete(ChatResponse),
88 Paused(PauseReason, Option<Metadata>),
89}
90
91impl<CP: CompletionProvider, Output> Chat<CP, Output> {
92 async fn call_loop(&mut self, messages: &mut Messages) -> Result<LoopStep, ChatFailure> {
93 let mut last_metadata: Option<Metadata> = None;
94
95 if let Some(last) = messages.0.last_mut() {
99 let pre = self.tool_call(last).await.map_err(|err| ChatFailure {
100 err,
101 metadata: None,
102 })?;
103 if let Some(reason) = pre.pause {
104 return Ok(LoopStep::Paused(reason, last_metadata));
105 }
106 }
109
110 for _ in 0..self.max_steps.unwrap_or(1) {
111 let decls = crate::chat::tool_declarations_from(&self.scoped_collections);
115 let decls_dyn = decls
116 .as_ref()
117 .map(|d| d as &dyn crate::types::tools::ToolDeclarations);
118 let response = self
119 .model
120 .complete(
121 messages,
122 decls_dyn,
123 self.model_options.as_ref(),
124 self.output_shape.as_ref(),
125 )
126 .await?;
127
128 if let Some(metadata) = response.metadata.clone() {
129 match &mut last_metadata {
130 Some(existing) => {
131 existing.extend(&metadata);
132 }
133 None => {
134 last_metadata = Some(metadata);
135 }
136 }
137 }
138
139 messages.push(response.content.clone());
140
141 let pass = match messages.0.last_mut() {
145 Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
146 err,
147 metadata: last_metadata.clone(),
148 })?,
149 None => crate::chat::ToolCallPass::default(),
150 };
151
152 if let Some(reason) = pass.pause {
153 return Ok(LoopStep::Paused(reason, last_metadata));
154 }
155 if pass.executed {
156 continue;
157 }
158
159 match response.content.parts.last() {
160 Some(res) => match res {
161 PartEnum::Text(_) | PartEnum::Structured(_) => {
162 return Ok(LoopStep::Complete(ChatResponse {
163 metadata: last_metadata,
164 content: response.content,
165 }));
166 }
167 PartEnum::Reasoning(_) => {
168 continue;
169 }
170 _ => {}
171 },
172 None => {
173 return Err(ChatFailure {
174 err: ChatError::InvalidResponse(
175 "Response did not generate any parts".to_string(),
176 ),
177 metadata: last_metadata,
178 });
179 }
180 };
181 }
182
183 Err(ChatFailure {
184 err: ChatError::MaxStepsExceeded,
185 metadata: last_metadata,
186 })
187 }
188
189 async fn execute_with_retries<F, R>(
190 &mut self,
191 messages: &mut Messages,
192 mut processor: F,
193 ) -> Result<ChatOutcome<R>, ChatFailure>
194 where
195 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
196 {
197 let max_retries = self.max_retries.unwrap_or(1);
198 let mut last_err: Option<ChatError> = None;
199 let mut last_metadata: Option<Metadata> = None;
200
201 if let Some(strategy) = self.before_strategy.as_mut() {
202 strategy(messages, last_metadata.as_ref()).await;
203 }
204
205 for idx in 0..max_retries {
206 let original_len = messages.len();
207 match self.call_loop(messages).await {
208 Ok(LoopStep::Paused(reason, _metadata)) => {
209 return Ok(ChatOutcome::Paused { reason });
214 }
215 Ok(LoopStep::Complete(response)) => {
216 if let Some(metadata) = response.metadata.clone() {
217 match &mut last_metadata {
218 Some(existing) => {
219 existing.extend(&metadata);
220 }
221 None => {
222 last_metadata = Some(metadata);
223 }
224 }
225 }
226
227 match processor(&response) {
228 Ok(parsed_result) => {
229 if let Some(strategy) = self.after_strategy.as_mut() {
230 strategy(messages, last_metadata.as_ref()).await;
231 }
232 return Ok(ChatOutcome::Complete(parsed_result));
233 }
234 Err(err) => {
235 last_err = Some(err.clone());
236 if idx + 1 < max_retries {
237 let ctx = CallbackRetryContext {
238 idx,
239 failure: ChatFailure {
240 err,
241 metadata: last_metadata.clone(),
242 },
243 };
244 if let Some(strategy) = self.retry_strategy.as_mut() {
245 strategy(messages, last_metadata.as_ref(), ctx).await;
246 }
247 }
248 }
249 }
250 }
251 Err(failure) => {
252 if let Some(metadata) = failure.metadata.clone() {
253 match &mut last_metadata {
254 Some(existing) => {
255 existing.extend(&metadata);
256 }
257 None => {
258 last_metadata = Some(metadata);
259 }
260 }
261 }
262
263 last_err = Some(failure.err.clone());
264
265 if !failure.err.is_retryable() {
266 break;
267 }
268
269 if idx + 1 < max_retries {
270 let ctx = CallbackRetryContext { idx, failure };
271 if let Some(strategy) = self.retry_strategy.as_mut() {
272 strategy(messages, last_metadata.as_ref(), ctx).await;
273 }
274 }
275 }
276 }
277
278 messages.0.truncate(original_len);
279 }
280
281 Err(ChatFailure {
282 metadata: last_metadata,
283 err: last_err.unwrap_or(ChatError::RateLimited),
284 })
285 }
286
287 async fn resume_with<F, R>(
291 &mut self,
292 messages: &mut Messages,
293 mut processor: F,
294 ) -> Result<ChatOutcome<R>, ChatFailure>
295 where
296 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
297 {
298 match self.call_loop(messages).await? {
299 LoopStep::Paused(reason, _) => Ok(ChatOutcome::Paused { reason }),
300 LoopStep::Complete(response) => match processor(&response) {
301 Ok(parsed) => Ok(ChatOutcome::Complete(parsed)),
302 Err(err) => Err(ChatFailure {
303 err,
304 metadata: response.metadata,
305 }),
306 },
307 }
308 }
309}
310
311fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
312 let last = content.parts.last()?;
313
314 match last {
315 PartEnum::Structured(v) => Some(v.clone()),
316 PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
317 _ => None,
318 }
319}