openai_rust2/
chat.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Serialize, Deserialize, Debug, Clone)]
4pub enum ResponseFormat {
5    JsonObject,
6    Text,
7}
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct ImageGeneration {
11    pub quality: Option<String>,       // e.g., "standard", "hd"
12    pub size: Option<String>,          // e.g., "1024x1024"
13    pub output_format: Option<String>, // e.g., "base64", "url"
14}
15
16#[derive(Serialize, Debug, Clone)]
17pub struct ChatArguments {
18    pub model: String,
19    pub messages: Vec<Message>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub temperature: Option<f32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub top_p: Option<f32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub n: Option<u32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub stop: Option<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub max_tokens: Option<u32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub presence_penalty: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub frequency_penalty: Option<f32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub user: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub response_format: Option<ResponseFormat>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub image_generation: Option<ImageGeneration>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub search_parameters: Option<SearchParameters>, // Grok-specific search parameter (for now)
44}
45
46impl ChatArguments {
47    pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
48        ChatArguments {
49            model: model.as_ref().to_owned(),
50            messages,
51            temperature: None,
52            top_p: None,
53            n: None,
54            stream: None,
55            stop: None,
56            max_tokens: None,
57            presence_penalty: None,
58            frequency_penalty: None,
59            user: None,
60            response_format: None,
61            image_generation: None,
62            search_parameters: None, // Grok-specific search parameter (for now)
63        }
64    }
65
66    pub fn with_search_parameters(mut self, params: SearchParameters) -> Self {
67        self.search_parameters = Some(params);
68        self
69    }
70}
71
72#[derive(Deserialize, Debug, Clone)]
73pub struct ChatCompletion {
74    #[serde(default)]
75    pub id: Option<String>,
76    pub created: u32,
77    #[serde(default)]
78    pub model: Option<String>,
79    #[serde(default)]
80    pub object: Option<String>,
81    pub choices: Vec<Choice>,
82    pub usage: Usage,
83}
84
85impl std::fmt::Display for ChatCompletion {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        write!(f, "{}", &self.choices[0].message.content)?;
88        Ok(())
89    }
90}
91
92pub mod stream {
93    use bytes::Bytes;
94    use futures_util::Stream;
95    use serde::Deserialize;
96    use std::pin::Pin;
97    use std::str;
98    use std::task::Poll;
99
100    #[derive(Deserialize, Debug, Clone)]
101    pub struct ChatCompletionChunk {
102        pub id: String,
103        pub created: u32,
104        pub model: String,
105        pub choices: Vec<Choice>,
106        pub system_fingerprint: Option<String>,
107    }
108
109    impl std::fmt::Display for ChatCompletionChunk {
110        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111            write!(
112                f,
113                "{}",
114                self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
115            )?;
116            Ok(())
117        }
118    }
119
120    #[derive(Deserialize, Debug, Clone)]
121    pub struct Choice {
122        pub delta: ChoiceDelta,
123        pub index: u32,
124        pub finish_reason: Option<String>,
125    }
126
127    #[derive(Deserialize, Debug, Clone)]
128    pub struct ChoiceDelta {
129        pub content: Option<String>,
130    }
131
132    pub struct ChatCompletionChunkStream {
133        byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
134        buf: String,
135    }
136
137    impl ChatCompletionChunkStream {
138        pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
139            Self {
140                byte_stream: stream,
141                buf: String::new(),
142            }
143        }
144
145        fn deserialize_buf(
146            self: Pin<&mut Self>,
147            cx: &mut std::task::Context<'_>,
148        ) -> Option<anyhow::Result<ChatCompletionChunk>> {
149            let bufclone = self.buf.clone();
150            let mut chunks = bufclone.split("\n\n").peekable();
151            let first = chunks.next();
152            let second = chunks.peek();
153
154            match first {
155                Some(first) => match first.strip_prefix("data: ") {
156                    Some(chunk) => {
157                        if !chunk.ends_with("}") {
158                            None
159                        } else {
160                            if let Some(second) = second {
161                                if second.ends_with("}") {
162                                    cx.waker().wake_by_ref();
163                                }
164                            }
165                            self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
166                            Some(
167                                serde_json::from_str::<ChatCompletionChunk>(&chunk)
168                                    .map_err(|e| anyhow::anyhow!(e)),
169                            )
170                        }
171                    }
172                    None => None,
173                },
174                None => None,
175            }
176        }
177    }
178
179    impl Stream for ChatCompletionChunkStream {
180        type Item = anyhow::Result<ChatCompletionChunk>;
181
182        fn poll_next(
183            mut self: Pin<&mut Self>,
184            cx: &mut std::task::Context<'_>,
185        ) -> Poll<Option<Self::Item>> {
186            match self.as_mut().deserialize_buf(cx) {
187                Some(chunk) => return Poll::Ready(Some(chunk)),
188                None => {}
189            };
190
191            match self.byte_stream.as_mut().poll_next(cx) {
192                Poll::Ready(bytes_option) => match bytes_option {
193                    Some(bytes_result) => match bytes_result {
194                        Ok(bytes) => {
195                            let data = str::from_utf8(&bytes)?.to_owned();
196                            self.buf = self.buf.clone() + &data;
197                            match self.deserialize_buf(cx) {
198                                Some(chunk) => Poll::Ready(Some(chunk)),
199                                None => {
200                                    cx.waker().wake_by_ref();
201                                    Poll::Pending
202                                }
203                            }
204                        }
205                        Err(e) => Poll::Ready(Some(Err(e.into()))),
206                    },
207                    None => Poll::Ready(None),
208                },
209                Poll::Pending => Poll::Pending,
210            }
211        }
212    }
213}
214
215#[derive(Deserialize, Debug, Clone)]
216pub struct Usage {
217    pub prompt_tokens: u32,
218    pub completion_tokens: u32,
219    pub total_tokens: u32,
220}
221
222#[derive(Deserialize, Debug, Clone)]
223pub struct Choice {
224    #[serde(default)]
225    pub index: Option<u32>,
226    pub message: Message,
227    pub finish_reason: String,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231pub struct Message {
232    pub role: String,
233    pub content: String,
234}
235
236pub enum Role {
237    System,
238    Assistant,
239    User,
240}
241
242// Grok-specific search parameters
243#[derive(Serialize, Debug, Clone)]
244pub struct SearchParameters {
245    pub mode: SearchMode, // "off", "on", "auto" (Live search is enabled but model decides when to use it)
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub return_citations: Option<bool>,
248    /// Inclusive yyyy-mm-dd
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub from_date: Option<String>,
251    /// Inclusive upper‐bound yyyy-mm-dd
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub to_date: Option<String>,
254}
255
256#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
257#[serde(rename_all = "lowercase")] // <-- "On" → "on", etc.
258pub enum SearchMode {
259    On,
260    Off,
261    Auto,
262}
263
264impl SearchParameters {
265    pub fn new(mode: SearchMode) -> Self {
266        Self {
267            mode,
268            return_citations: None,
269            from_date: None,
270            to_date: None,
271        }
272    }
273    pub fn with_citations(mut self, yes: bool) -> Self {
274        self.return_citations = Some(yes);
275        self
276    }
277    pub fn with_date_range_str(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
278        self.from_date = Some(from.into());
279        self.to_date = Some(to.into());
280        self
281    }
282}