Skip to main content

chat_core/chat/
completion.rs

1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::StructuredResponse;
5use crate::{
6    chat::state::{Structured, Unstructured},
7    error::{ChatError, ChatFailure},
8    traits::CompletionProvider,
9    types::{
10        callback::CallbackRetryContext,
11        messages::{Messages, content::Content, parts::PartEnum},
12        metadata::Metadata,
13        response::ChatResponse,
14    },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19    pub async fn complete(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
20        self.execute_with_retries(messages, |response| {
21            Ok(ChatResponse {
22                content: response.content.clone(),
23                metadata: response.metadata.clone(),
24            })
25        })
26        .await
27    }
28}
29
30impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
31where
32    T: DeserializeOwned + JsonSchema,
33{
34    pub async fn complete(
35        &mut self,
36        messages: &mut Messages,
37    ) -> Result<StructuredResponse<T>, ChatFailure> {
38        self.execute_with_retries(messages, |response| {
39            let value = extract_structured_candidate(&response.content).ok_or_else(|| {
40                ChatError::InvalidResponse(
41                    "Response did not contain valid structured output".into(),
42                )
43            })?;
44            serde_json::from_value::<T>(value.clone())
45                .map(|content| StructuredResponse {
46                    content,
47                    metadata: response.metadata.clone(),
48                })
49                .map_err(|err| {
50                    ChatError::InvalidResponse(format!(
51                        "Failed to parse structured output: {}",
52                        err
53                    ))
54                })
55        })
56        .await
57    }
58}
59
60impl<CP: CompletionProvider, Output> Chat<CP, Output> {
61    async fn call_loop(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
62        let mut last_metadata: Option<Metadata> = None;
63
64        for _ in 0..self.max_steps.unwrap_or(1) {
65            let response = self
66                .model
67                .complete(
68                    messages,
69                    self.tools.as_ref(),
70                    self.model_options.as_ref(),
71                    self.output_shape.as_ref(),
72                )
73                .await?;
74
75            if let Some(metadata) = response.metadata.clone() {
76                match &mut last_metadata {
77                    Some(existing) => {
78                        existing.extend(&metadata);
79                    }
80                    None => {
81                        last_metadata = Some(metadata);
82                    }
83                }
84            }
85
86            messages.push(response.content.clone());
87
88            if let Ok(frs) = self.tool_call(&response.content).await
89                && !frs.is_empty()
90            {
91                let mut tool_message = Content::default();
92                tool_message.parts.extend(frs);
93                messages.push(tool_message);
94                continue;
95            }
96
97            match response.content.parts.last() {
98                Some(res) => match res {
99                    PartEnum::Text(_) | PartEnum::Structured(_) => {
100                        return Ok(ChatResponse {
101                            metadata: last_metadata,
102                            content: response.content,
103                        });
104                    }
105                    PartEnum::Reasoning(_) => {
106                        continue;
107                    }
108                    _ => {}
109                },
110                None => {
111                    return Err(ChatFailure {
112                        err: ChatError::InvalidResponse(
113                            "Response did not generate any parts".to_string(),
114                        ),
115                        metadata: last_metadata,
116                    });
117                }
118            };
119        }
120
121        Err(ChatFailure {
122            err: ChatError::MaxStepsExceeded,
123            metadata: last_metadata,
124        })
125    }
126
127    async fn execute_with_retries<F, R>(
128        &mut self,
129        messages: &mut Messages,
130        mut processor: F,
131    ) -> Result<R, ChatFailure>
132    where
133        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
134    {
135        let max_retries = self.max_retries.unwrap_or(1);
136        let mut last_err: Option<ChatError> = None;
137        let mut last_metadata: Option<Metadata> = None;
138
139        if let Some(strategy) = self.before_strategy.as_mut() {
140            strategy(messages, last_metadata.as_ref()).await;
141        }
142
143        for idx in 0..max_retries {
144            let original_len = messages.len();
145            match self.call_loop(messages).await {
146                Ok(response) => {
147                    if let Some(metadata) = response.metadata.clone() {
148                        match &mut last_metadata {
149                            Some(existing) => {
150                                existing.extend(&metadata);
151                            }
152                            None => {
153                                last_metadata = Some(metadata);
154                            }
155                        }
156                    }
157
158                    match processor(&response) {
159                        Ok(parsed_result) => {
160                            if let Some(strategy) = self.after_strategy.as_mut() {
161                                strategy(messages, last_metadata.as_ref()).await;
162                            }
163                            return Ok(parsed_result);
164                        }
165                        Err(err) => {
166                            last_err = Some(err.clone());
167                            if idx + 1 < max_retries {
168                                let ctx = CallbackRetryContext {
169                                    idx,
170                                    failure: ChatFailure {
171                                        err,
172                                        metadata: last_metadata.clone(),
173                                    },
174                                };
175                                if let Some(strategy) = self.retry_strategy.as_mut() {
176                                    strategy(messages, last_metadata.as_ref(), ctx).await;
177                                }
178                            }
179                        }
180                    }
181                }
182                Err(failure) => {
183                    if let Some(metadata) = failure.metadata.clone() {
184                        match &mut last_metadata {
185                            Some(existing) => {
186                                existing.extend(&metadata);
187                            }
188                            None => {
189                                last_metadata = Some(metadata);
190                            }
191                        }
192                    }
193
194                    last_err = Some(failure.err.clone());
195
196                    if !failure.err.is_retryable() {
197                        break;
198                    }
199
200                    if idx + 1 < max_retries {
201                        let ctx = CallbackRetryContext { idx, failure };
202                        if let Some(strategy) = self.retry_strategy.as_mut() {
203                            strategy(messages, last_metadata.as_ref(), ctx).await;
204                        }
205                    }
206                }
207            }
208
209            messages.0.truncate(original_len);
210        }
211
212        Err(ChatFailure {
213            metadata: last_metadata,
214            err: last_err.unwrap_or(ChatError::RateLimited),
215        })
216    }
217}
218
219fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
220    let last = content.parts.last()?;
221
222    match last {
223        PartEnum::Structured(v) => Some(v.clone()),
224        PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
225        _ => None,
226    }
227}