chat_core/chat/
completion.rs1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::StructuredResponse;
5use crate::{
6 chat::state::{Structured, Unstructured},
7 error::{ChatError, ChatFailure},
8 traits::CompletionProvider,
9 types::{
10 callback::CallbackRetryContext,
11 messages::{Messages, content::Content, parts::PartEnum},
12 metadata::Metadata,
13 response::ChatResponse,
14 },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19 pub async fn complete(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
20 self.execute_with_retries(messages, |response| {
21 Ok(ChatResponse {
22 content: response.content.clone(),
23 metadata: response.metadata.clone(),
24 })
25 })
26 .await
27 }
28}
29
30impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
31where
32 T: DeserializeOwned + JsonSchema,
33{
34 pub async fn complete(
35 &mut self,
36 messages: &mut Messages,
37 ) -> Result<StructuredResponse<T>, ChatFailure> {
38 self.execute_with_retries(messages, |response| {
39 let value = extract_structured_candidate(&response.content).ok_or_else(|| {
40 ChatError::InvalidResponse(
41 "Response did not contain valid structured output".into(),
42 )
43 })?;
44 serde_json::from_value::<T>(value.clone())
45 .map(|content| StructuredResponse {
46 content,
47 metadata: response.metadata.clone(),
48 })
49 .map_err(|err| {
50 ChatError::InvalidResponse(format!(
51 "Failed to parse structured output: {}",
52 err
53 ))
54 })
55 })
56 .await
57 }
58}
59
60impl<CP: CompletionProvider, Output> Chat<CP, Output> {
61 async fn call_loop(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
62 let mut last_metadata: Option<Metadata> = None;
63
64 for _ in 0..self.max_steps.unwrap_or(1) {
65 let response = self
66 .model
67 .complete(
68 messages,
69 self.tools.as_ref(),
70 self.model_options.as_ref(),
71 self.output_shape.as_ref(),
72 )
73 .await?;
74
75 if let Some(metadata) = response.metadata.clone() {
76 match &mut last_metadata {
77 Some(existing) => {
78 existing.extend(&metadata);
79 }
80 None => {
81 last_metadata = Some(metadata);
82 }
83 }
84 }
85
86 messages.push(response.content.clone());
87
88 if let Ok(frs) = self.tool_call(&response.content).await
89 && !frs.is_empty()
90 {
91 let mut tool_message = Content::default();
92 tool_message.parts.extend(frs);
93 messages.push(tool_message);
94 continue;
95 }
96
97 match response.content.parts.last() {
98 Some(res) => match res {
99 PartEnum::Text(_) | PartEnum::Structured(_) => {
100 return Ok(ChatResponse {
101 metadata: last_metadata,
102 content: response.content,
103 });
104 }
105 PartEnum::Reasoning(_) => {
106 continue;
107 }
108 _ => {}
109 },
110 None => {
111 return Err(ChatFailure {
112 err: ChatError::InvalidResponse(
113 "Response did not generate any parts".to_string(),
114 ),
115 metadata: last_metadata,
116 });
117 }
118 };
119 }
120
121 Err(ChatFailure {
122 err: ChatError::MaxStepsExceeded,
123 metadata: last_metadata,
124 })
125 }
126
127 async fn execute_with_retries<F, R>(
128 &mut self,
129 messages: &mut Messages,
130 mut processor: F,
131 ) -> Result<R, ChatFailure>
132 where
133 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
134 {
135 let max_retries = self.max_retries.unwrap_or(1);
136 let mut last_err: Option<ChatError> = None;
137 let mut last_metadata: Option<Metadata> = None;
138
139 if let Some(strategy) = self.before_strategy.as_mut() {
140 strategy(messages, last_metadata.as_ref()).await;
141 }
142
143 for idx in 0..max_retries {
144 let original_len = messages.len();
145 match self.call_loop(messages).await {
146 Ok(response) => {
147 if let Some(metadata) = response.metadata.clone() {
148 match &mut last_metadata {
149 Some(existing) => {
150 existing.extend(&metadata);
151 }
152 None => {
153 last_metadata = Some(metadata);
154 }
155 }
156 }
157
158 match processor(&response) {
159 Ok(parsed_result) => {
160 if let Some(strategy) = self.after_strategy.as_mut() {
161 strategy(messages, last_metadata.as_ref()).await;
162 }
163 return Ok(parsed_result);
164 }
165 Err(err) => {
166 last_err = Some(err.clone());
167 if idx + 1 < max_retries {
168 let ctx = CallbackRetryContext {
169 idx,
170 failure: ChatFailure {
171 err,
172 metadata: last_metadata.clone(),
173 },
174 };
175 if let Some(strategy) = self.retry_strategy.as_mut() {
176 strategy(messages, last_metadata.as_ref(), ctx).await;
177 }
178 }
179 }
180 }
181 }
182 Err(failure) => {
183 if let Some(metadata) = failure.metadata.clone() {
184 match &mut last_metadata {
185 Some(existing) => {
186 existing.extend(&metadata);
187 }
188 None => {
189 last_metadata = Some(metadata);
190 }
191 }
192 }
193
194 last_err = Some(failure.err.clone());
195
196 if !failure.err.is_retryable() {
197 break;
198 }
199
200 if idx + 1 < max_retries {
201 let ctx = CallbackRetryContext { idx, failure };
202 if let Some(strategy) = self.retry_strategy.as_mut() {
203 strategy(messages, last_metadata.as_ref(), ctx).await;
204 }
205 }
206 }
207 }
208
209 messages.0.truncate(original_len);
210 }
211
212 Err(ChatFailure {
213 metadata: last_metadata,
214 err: last_err.unwrap_or(ChatError::RateLimited),
215 })
216 }
217}
218
219fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
220 let last = content.parts.last()?;
221
222 match last {
223 PartEnum::Structured(v) => Some(v.clone()),
224 PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
225 _ => None,
226 }
227}