1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::{ChatOutcome, PauseReason, StructuredResponse};
5use crate::{
6 chat::state::{Structured, Unstructured},
7 error::{ChatError, ChatFailure},
8 traits::CompletionProvider,
9 types::{
10 callback::CallbackRetryContext,
11 messages::{Messages, content::Content, parts::PartEnum},
12 metadata::Metadata,
13 response::ChatResponse,
14 },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19 pub async fn complete(
24 &mut self,
25 messages: &mut Messages,
26 ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
27 self.execute_with_retries(messages, |response| {
28 Ok(ChatResponse {
29 content: response.content.clone(),
30 metadata: response.metadata.clone(),
31 })
32 })
33 .await
34 }
35
36 pub async fn resume(
41 &mut self,
42 messages: &mut Messages,
43 ) -> Result<ChatOutcome<ChatResponse>, ChatFailure> {
44 self.resume_with(messages, |response| {
45 Ok(ChatResponse {
46 content: response.content.clone(),
47 metadata: response.metadata.clone(),
48 })
49 })
50 .await
51 }
52}
53
54impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
55where
56 T: DeserializeOwned + JsonSchema,
57{
58 pub async fn complete(
59 &mut self,
60 messages: &mut Messages,
61 ) -> Result<ChatOutcome<StructuredResponse<T>>, ChatFailure> {
62 self.execute_with_retries(messages, |response| {
63 let value = extract_structured_candidate(&response.content).ok_or_else(|| {
64 ChatError::InvalidResponse(
65 "Response did not contain valid structured output".into(),
66 )
67 })?;
68 serde_json::from_value::<T>(value.clone())
69 .map(|content| StructuredResponse {
70 content,
71 metadata: response.metadata.clone(),
72 })
73 .map_err(|err| {
74 ChatError::InvalidResponse(format!(
75 "Failed to parse structured output: {}",
76 err
77 ))
78 })
79 })
80 .await
81 }
82}
83
84enum LoopStep {
87 Complete(ChatResponse),
88 Paused(PauseReason, Option<Metadata>),
89}
90
91impl<CP: CompletionProvider, Output> Chat<CP, Output> {
92 async fn call_loop(&mut self, messages: &mut Messages) -> Result<LoopStep, ChatFailure> {
93 let mut last_metadata: Option<Metadata> = None;
94
95 if let Some(last) = messages.0.last_mut() {
96 let pre = self.tool_call(last).await.map_err(|err| ChatFailure {
97 err,
98 metadata: None,
99 })?;
100 if let Some(reason) = pre.pause {
101 return Ok(LoopStep::Paused(reason, last_metadata));
102 }
103 }
104
105 for _ in 0..self.max_steps.unwrap_or(1) {
106 let decls = crate::chat::tool_declarations_from(&self.scoped_collections);
107 let decls_dyn = decls
108 .as_ref()
109 .map(|d| d as &dyn crate::types::tools::ToolDeclarations);
110 let response = self
111 .model
112 .complete(
113 messages,
114 decls_dyn,
115 self.model_options.as_ref(),
116 self.output_shape.as_ref(),
117 )
118 .await?;
119
120 if let Some(metadata) = response.metadata.clone() {
121 match &mut last_metadata {
122 Some(existing) => {
123 existing.extend(&metadata);
124 }
125 None => {
126 last_metadata = Some(metadata);
127 }
128 }
129 }
130
131 messages.push(response.content.clone());
132
133 let pass = match messages.0.last_mut() {
134 Some(last) => self.tool_call(last).await.map_err(|err| ChatFailure {
135 err,
136 metadata: last_metadata.clone(),
137 })?,
138 None => crate::chat::ToolCallPass::default(),
139 };
140
141 if let Some(reason) = pass.pause {
142 return Ok(LoopStep::Paused(reason, last_metadata));
143 }
144 if pass.executed {
145 continue;
146 }
147
148 match response.content.parts.last() {
149 Some(res) => match res {
150 PartEnum::Text(_) | PartEnum::Structured(_) => {
151 return Ok(LoopStep::Complete(ChatResponse {
152 metadata: last_metadata,
153 content: response.content,
154 }));
155 }
156 PartEnum::Reasoning(_) => {
157 continue;
158 }
159 _ => {}
160 },
161 None => {
162 return Err(ChatFailure {
163 err: ChatError::InvalidResponse(
164 "Response did not generate any parts".to_string(),
165 ),
166 metadata: last_metadata,
167 });
168 }
169 };
170 }
171
172 Err(ChatFailure {
173 err: ChatError::MaxStepsExceeded,
174 metadata: last_metadata,
175 })
176 }
177
178 async fn execute_with_retries<F, R>(
179 &mut self,
180 messages: &mut Messages,
181 mut processor: F,
182 ) -> Result<ChatOutcome<R>, ChatFailure>
183 where
184 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
185 {
186 let max_retries = self.max_retries.unwrap_or(1);
187 let mut last_err: Option<ChatError> = None;
188 let mut last_metadata: Option<Metadata> = None;
189
190 if let Some(strategy) = self.before_strategy.as_mut() {
191 strategy(messages, last_metadata.as_ref()).await;
192 }
193
194 for idx in 0..max_retries {
195 let original_len = messages.len();
196 match self.call_loop(messages).await {
197 Ok(LoopStep::Paused(reason, _metadata)) => {
198 return Ok(ChatOutcome::Paused { reason });
199 }
200 Ok(LoopStep::Complete(response)) => {
201 if let Some(metadata) = response.metadata.clone() {
202 match &mut last_metadata {
203 Some(existing) => {
204 existing.extend(&metadata);
205 }
206 None => {
207 last_metadata = Some(metadata);
208 }
209 }
210 }
211
212 match processor(&response) {
213 Ok(parsed_result) => {
214 if let Some(strategy) = self.after_strategy.as_mut() {
215 strategy(messages, last_metadata.as_ref()).await;
216 }
217 return Ok(ChatOutcome::Complete(parsed_result));
218 }
219 Err(err) => {
220 last_err = Some(err.clone());
221 if idx + 1 < max_retries {
222 let ctx = CallbackRetryContext {
223 idx,
224 failure: ChatFailure {
225 err,
226 metadata: last_metadata.clone(),
227 },
228 };
229 if let Some(strategy) = self.retry_strategy.as_mut() {
230 strategy(messages, last_metadata.as_ref(), ctx).await;
231 }
232 }
233 }
234 }
235 }
236 Err(failure) => {
237 if let Some(metadata) = failure.metadata.clone() {
238 match &mut last_metadata {
239 Some(existing) => {
240 existing.extend(&metadata);
241 }
242 None => {
243 last_metadata = Some(metadata);
244 }
245 }
246 }
247
248 last_err = Some(failure.err.clone());
249
250 if !failure.err.is_retryable() {
251 break;
252 }
253
254 if idx + 1 < max_retries {
255 let ctx = CallbackRetryContext { idx, failure };
256 if let Some(strategy) = self.retry_strategy.as_mut() {
257 strategy(messages, last_metadata.as_ref(), ctx).await;
258 }
259 }
260 }
261 }
262
263 messages.0.truncate(original_len);
264 }
265
266 Err(ChatFailure {
267 metadata: last_metadata,
268 err: last_err.unwrap_or(ChatError::RateLimited),
269 })
270 }
271
272 async fn resume_with<F, R>(
276 &mut self,
277 messages: &mut Messages,
278 mut processor: F,
279 ) -> Result<ChatOutcome<R>, ChatFailure>
280 where
281 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
282 {
283 match self.call_loop(messages).await? {
284 LoopStep::Paused(reason, _) => Ok(ChatOutcome::Paused { reason }),
285 LoopStep::Complete(response) => match processor(&response) {
286 Ok(parsed) => Ok(ChatOutcome::Complete(parsed)),
287 Err(err) => Err(ChatFailure {
288 err,
289 metadata: response.metadata,
290 }),
291 },
292 }
293 }
294}
295
296fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
297 let last = content.parts.last()?;
298
299 match last {
300 PartEnum::Structured(v) => Some(v.clone()),
301 PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
302 _ => None,
303 }
304}