chat_core/chat/
completion.rs1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::StructuredResponse;
5use crate::{
6 chat::state::{Structured, Unstructured},
7 error::{ChatError, ChatFailure},
8 traits::CompletionProvider,
9 types::{
10 callback::CallbackRetryContext,
11 messages::{Messages, content::Content, parts::PartEnum},
12 metadata::Metadata,
13 response::ChatResponse,
14 },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19 pub async fn complete(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
20 self.execute_with_retries(messages, |response| {
21 Ok(ChatResponse {
22 content: response.content.clone(),
23 metadata: response.metadata.clone(),
24 })
25 })
26 .await
27 }
28}
29
30impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
31where
32 T: DeserializeOwned + JsonSchema,
33{
34 pub async fn complete(
35 &mut self,
36 messages: &mut Messages,
37 ) -> Result<StructuredResponse<T>, ChatFailure> {
38 self.execute_with_retries(messages, |response| {
39 let value = extract_structured_candidate(&response.content).ok_or_else(|| {
40 ChatError::InvalidResponse(
41 "Response did not contain valid structured output".into(),
42 )
43 })?;
44
45 serde_json::from_value::<T>(value.clone())
46 .map(|content| StructuredResponse {
47 content,
48 metadata: None,
49 })
50 .map_err(|err| {
51 ChatError::InvalidResponse(format!(
52 "Failed to parse structured output: {}",
53 err
54 ))
55 })
56 })
57 .await
58 }
59}
60
61impl<CP: CompletionProvider, Output> Chat<CP, Output> {
62 async fn call_loop(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
63 let mut last_metadata: Option<Metadata> = None;
64
65 for _ in 0..self.max_steps.unwrap_or(1) {
66 let response = self
67 .model
68 .complete(
69 messages,
70 self.tools.as_ref(),
71 self.model_options.as_ref(),
72 self.output_shape.as_ref(),
73 )
74 .await?;
75
76 if let Some(metadata) = response.metadata.clone() {
77 match &mut last_metadata {
78 Some(existing) => {
79 existing.extend(&metadata);
80 }
81 None => {
82 last_metadata = Some(metadata);
83 }
84 }
85 }
86
87 messages.push(response.content.clone());
88
89 if let Ok(frs) = self.tool_call(&response.content).await
90 && !frs.is_empty()
91 {
92 let mut tool_message = Content::default();
93 tool_message.parts.extend(frs);
94 messages.push(tool_message);
95 continue;
96 }
97
98 match response.content.parts.last() {
99 Some(res) => match res {
100 PartEnum::Text(_) | PartEnum::Structured(_) => {
101 return Ok(ChatResponse {
102 metadata: last_metadata,
103 content: response.content,
104 });
105 }
106 PartEnum::Reasoning(_) => {
107 continue;
108 }
109 _ => {}
110 },
111 None => {
112 return Err(ChatFailure {
113 err: ChatError::InvalidResponse(
114 "Response did not generate any parts".to_string(),
115 ),
116 metadata: last_metadata,
117 });
118 }
119 };
120 }
121
122 Err(ChatFailure {
123 err: ChatError::RateLimited,
124 metadata: last_metadata,
125 })
126 }
127
128 async fn execute_with_retries<F, R>(
129 &mut self,
130 messages: &mut Messages,
131 mut processor: F,
132 ) -> Result<R, ChatFailure>
133 where
134 F: FnMut(&ChatResponse) -> Result<R, ChatError>,
135 {
136 let max_retries = self.max_retries.unwrap_or(1);
137 let mut last_err: Option<ChatError> = None;
138 let mut last_metadata: Option<Metadata> = None;
139
140 if let Some(strategy) = self.before_strategy.as_mut() {
141 strategy(messages, last_metadata.as_ref()).await;
142 }
143
144 for idx in 0..max_retries {
145 let original_len = messages.len();
146 match self.call_loop(messages).await {
147 Ok(response) => {
148 if let Some(metadata) = response.metadata.clone() {
149 match &mut last_metadata {
150 Some(existing) => {
151 existing.extend(&metadata);
152 }
153 None => {
154 last_metadata = Some(metadata);
155 }
156 }
157 }
158
159 match processor(&response) {
160 Ok(parsed_result) => {
161 if let Some(strategy) = self.after_strategy.as_mut() {
162 strategy(messages, last_metadata.as_ref()).await;
163 }
164 return Ok(parsed_result);
165 }
166 Err(err) => {
167 last_err = Some(err.clone());
168 if idx + 1 < max_retries {
169 let ctx = CallbackRetryContext {
170 idx,
171 failure: ChatFailure {
172 err,
173 metadata: last_metadata.clone(),
174 },
175 };
176 if let Some(strategy) = self.retry_strategy.as_mut() {
177 strategy(messages, last_metadata.as_ref(), ctx).await;
178 }
179 }
180 }
181 }
182 }
183 Err(failure) => {
184 if let Some(metadata) = failure.metadata.clone() {
185 match &mut last_metadata {
186 Some(existing) => {
187 existing.extend(&metadata);
188 }
189 None => {
190 last_metadata = Some(metadata);
191 }
192 }
193 }
194
195 last_err = Some(failure.err.clone());
196
197 if idx + 1 < max_retries {
198 let ctx = CallbackRetryContext { idx, failure };
199 if let Some(strategy) = self.retry_strategy.as_mut() {
200 strategy(messages, last_metadata.as_ref(), ctx).await;
201 }
202 }
203 }
204 }
205
206 messages.0.truncate(original_len);
207 }
208
209 Err(ChatFailure {
210 metadata: last_metadata,
211 err: last_err.unwrap_or(ChatError::RateLimited),
212 })
213 }
214}
215
216fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
217 let last = content.parts.last()?;
218
219 match last {
220 PartEnum::Structured(v) => Some(v.clone()),
221 PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
222 _ => None,
223 }
224}