Skip to main content

chat_core/chat/
completion.rs

1use schemars::JsonSchema;
2
3use crate::chat::Chat;
4use crate::types::response::StructuredResponse;
5use crate::{
6    chat::state::{Structured, Unstructured},
7    error::{ChatError, ChatFailure},
8    traits::CompletionProvider,
9    types::{
10        callback::CallbackRetryContext,
11        messages::{Messages, content::Content, parts::PartEnum},
12        metadata::Metadata,
13        response::ChatResponse,
14    },
15};
16use serde::de::DeserializeOwned;
17
18impl<CP: CompletionProvider> Chat<CP, Unstructured> {
19    pub async fn complete(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
20        self.execute_with_retries(messages, |response| {
21            Ok(ChatResponse {
22                content: response.content.clone(),
23                metadata: response.metadata.clone(),
24            })
25        })
26        .await
27    }
28}
29
30impl<CP: CompletionProvider, T> Chat<CP, Structured<T>>
31where
32    T: DeserializeOwned + JsonSchema,
33{
34    pub async fn complete(
35        &mut self,
36        messages: &mut Messages,
37    ) -> Result<StructuredResponse<T>, ChatFailure> {
38        self.execute_with_retries(messages, |response| {
39            let value = extract_structured_candidate(&response.content).ok_or_else(|| {
40                ChatError::InvalidResponse(
41                    "Response did not contain valid structured output".into(),
42                )
43            })?;
44
45            serde_json::from_value::<T>(value.clone())
46                .map(|content| StructuredResponse {
47                    content,
48                    metadata: None,
49                })
50                .map_err(|err| {
51                    ChatError::InvalidResponse(format!(
52                        "Failed to parse structured output: {}",
53                        err
54                    ))
55                })
56        })
57        .await
58    }
59}
60
61impl<CP: CompletionProvider, Output> Chat<CP, Output> {
62    async fn call_loop(&mut self, messages: &mut Messages) -> Result<ChatResponse, ChatFailure> {
63        let mut last_metadata: Option<Metadata> = None;
64
65        for _ in 0..self.max_steps.unwrap_or(1) {
66            let response = self
67                .model
68                .complete(
69                    messages,
70                    self.tools.as_ref(),
71                    self.model_options.as_ref(),
72                    self.output_shape.as_ref(),
73                )
74                .await?;
75
76            if let Some(metadata) = response.metadata.clone() {
77                match &mut last_metadata {
78                    Some(existing) => {
79                        existing.extend(&metadata);
80                    }
81                    None => {
82                        last_metadata = Some(metadata);
83                    }
84                }
85            }
86
87            messages.push(response.content.clone());
88
89            if let Ok(frs) = self.tool_call(&response.content).await
90                && !frs.is_empty()
91            {
92                let mut tool_message = Content::default();
93                tool_message.parts.extend(frs);
94                messages.push(tool_message);
95                continue;
96            }
97
98            match response.content.parts.last() {
99                Some(res) => match res {
100                    PartEnum::Text(_) | PartEnum::Structured(_) => {
101                        return Ok(ChatResponse {
102                            metadata: last_metadata,
103                            content: response.content,
104                        });
105                    }
106                    PartEnum::Reasoning(_) => {
107                        continue;
108                    }
109                    _ => {}
110                },
111                None => {
112                    return Err(ChatFailure {
113                        err: ChatError::InvalidResponse(
114                            "Response did not generate any parts".to_string(),
115                        ),
116                        metadata: last_metadata,
117                    });
118                }
119            };
120        }
121
122        Err(ChatFailure {
123            err: ChatError::RateLimited,
124            metadata: last_metadata,
125        })
126    }
127
128    async fn execute_with_retries<F, R>(
129        &mut self,
130        messages: &mut Messages,
131        mut processor: F,
132    ) -> Result<R, ChatFailure>
133    where
134        F: FnMut(&ChatResponse) -> Result<R, ChatError>,
135    {
136        let max_retries = self.max_retries.unwrap_or(1);
137        let mut last_err: Option<ChatError> = None;
138        let mut last_metadata: Option<Metadata> = None;
139
140        if let Some(strategy) = self.before_strategy.as_mut() {
141            strategy(messages, last_metadata.as_ref()).await;
142        }
143
144        for idx in 0..max_retries {
145            let original_len = messages.len();
146            match self.call_loop(messages).await {
147                Ok(response) => {
148                    if let Some(metadata) = response.metadata.clone() {
149                        match &mut last_metadata {
150                            Some(existing) => {
151                                existing.extend(&metadata);
152                            }
153                            None => {
154                                last_metadata = Some(metadata);
155                            }
156                        }
157                    }
158
159                    match processor(&response) {
160                        Ok(parsed_result) => {
161                            if let Some(strategy) = self.after_strategy.as_mut() {
162                                strategy(messages, last_metadata.as_ref()).await;
163                            }
164                            return Ok(parsed_result);
165                        }
166                        Err(err) => {
167                            last_err = Some(err.clone());
168                            if idx + 1 < max_retries {
169                                let ctx = CallbackRetryContext {
170                                    idx,
171                                    failure: ChatFailure {
172                                        err,
173                                        metadata: last_metadata.clone(),
174                                    },
175                                };
176                                if let Some(strategy) = self.retry_strategy.as_mut() {
177                                    strategy(messages, last_metadata.as_ref(), ctx).await;
178                                }
179                            }
180                        }
181                    }
182                }
183                Err(failure) => {
184                    if let Some(metadata) = failure.metadata.clone() {
185                        match &mut last_metadata {
186                            Some(existing) => {
187                                existing.extend(&metadata);
188                            }
189                            None => {
190                                last_metadata = Some(metadata);
191                            }
192                        }
193                    }
194
195                    last_err = Some(failure.err.clone());
196
197                    if idx + 1 < max_retries {
198                        let ctx = CallbackRetryContext { idx, failure };
199                        if let Some(strategy) = self.retry_strategy.as_mut() {
200                            strategy(messages, last_metadata.as_ref(), ctx).await;
201                        }
202                    }
203                }
204            }
205
206            messages.0.truncate(original_len);
207        }
208
209        Err(ChatFailure {
210            metadata: last_metadata,
211            err: last_err.unwrap_or(ChatError::RateLimited),
212        })
213    }
214}
215
216fn extract_structured_candidate(content: &Content) -> Option<serde_json::Value> {
217    let last = content.parts.last()?;
218
219    match last {
220        PartEnum::Structured(v) => Some(v.clone()),
221        PartEnum::Text(t) => serde_json::from_str::<serde_json::Value>(t.as_str()).ok(),
222        _ => None,
223    }
224}