openai_rust/
chat.rs

1//! See <https://platform.openai.com/docs/api-reference/chat>.
2//! Use with [Client::create_chat](crate::Client::create_chat) or [Client::create_chat_stream](crate::Client::create_chat_stream).
3
4use serde::{Deserialize, Serialize};
5
6/// Request arguments for chat completion.
7///
8/// See <https://platform.openai.com/docs/api-reference/chat/create>.
9///
10/// ```
11/// let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
12///     openai_rust::chat::Message {
13///         role: "user".to_owned(),
14///         content: "Hello GPT!".to_owned(),
15///     }
16/// ]);
17/// ```
18///
19/// To use streaming, use [crate::Client::create_chat_stream].
20///
21#[derive(Serialize, Debug, Clone)]
22pub struct ChatArguments {
23    /// ID of the model to use. See the model [endpoint compatibility table](https://platform.openai.com/docs/models/model-endpoint-compatibility) for details on which models work with the Chat API.
24    pub model: String,
25
26    /// The [Message]s to generate chat completions for
27    pub messages: Vec<Message>,
28    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
29    ///
30    /// We generally recommend altering this or `top_p` but not both.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub temperature: Option<f32>,
33
34    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
35    ///
36    /// We generally recommend altering this or `temperature `but not both.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub top_p: Option<f32>,
39
40    /// How many chat completion choices to generate for each input message.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub n: Option<u32>,
43
44    /// Whether to stream back partial progress.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub(crate) stream: Option<bool>,
47
48    /// Up to 4 sequences where the API will stop generating further tokens.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub stop: Option<String>,
51
52    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
53    ///
54    /// The total length of input tokens and generated tokens is limited by the model's context length.
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub max_tokens: Option<u32>,
57
58    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
59    /// increasing the model's likelihood to talk about new topics.
60    ///
61    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub presence_penalty: Option<f32>,
64
65    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
66    /// decreasing the model's likelihood to repeat the same line verbatim.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub frequency_penalty: Option<f32>,
69
70    // logit_bias
71    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
72    /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub user: Option<String>,
75}
76
77impl ChatArguments {
78    pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
79        ChatArguments {
80            model: model.as_ref().to_owned(),
81            messages,
82            temperature: None,
83            top_p: None,
84            n: None,
85            stream: None,
86            stop: None,
87            max_tokens: None,
88            presence_penalty: None,
89            frequency_penalty: None,
90            user: None,
91        }
92    }
93}
94
95/// This is the response of a chat.
96///
97/// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
98/// ```
99/// # use serde_json;
100/// # let json = "{
101/// #  \"id\": \"chatcmpl-123\",
102/// #  \"object\": \"chat.completion\",
103/// #  \"created\": 1677652288,
104/// #  \"choices\": [{
105/// #    \"index\": 0,
106/// #    \"message\": {
107/// #     \"role\": \"assistant\",
108/// #     \"content\": \"\\n\\nHello there, how may I assist you today?\"
109/// #    },
110/// #    \"finish_reason\": \"stop\"
111/// #  }],
112/// #  \"usage\": {
113/// #   \"prompt_tokens\": 9,
114/// #   \"completion_tokens\": 12,
115/// #   \"total_tokens\": 21
116/// #  }
117/// # }";
118/// # let res = serde_json::from_str::<openai_rust::chat::ChatCompletion>(json).unwrap();
119/// let msg = &res.choices[0].message.content;
120/// // or
121/// let msg = res.to_string();
122/// ```
123#[derive(Deserialize, Debug, Clone)]
124pub struct ChatCompletion {
125    pub id: String,
126    pub created: u32,
127    pub choices: Vec<Choice>,
128    pub usage: Usage,
129}
130
131impl std::fmt::Display for ChatCompletion {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(f, "{}", &self.choices[0].message.content)?;
134        Ok(())
135    }
136}
137
138/// Structs and deserialization method for the responses
139/// when using streaming chat responses.
140pub mod stream {
141    use bytes::Bytes;
142    use futures_util::Stream;
143    use serde::Deserialize;
144    use std::pin::Pin;
145    use std::task::Poll;
146    use std::str;
147
148    /// This is the partial chat result received when streaming.
149    ///
150    /// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
151    /// ```
152    /// # use serde_json;
153    /// # let json = "{
154    /// # \"id\": \"chatcmpl-6yX67cSCIAm4nrNLQUPOtJu9JUoLG\",
155    /// # \"object\": \"chat.completion.chunk\",
156    /// # \"created\": 1679884927,
157    /// # \"model\": \"gpt-3.5-turbo-0301\",
158    /// # \"choices\": [
159    /// #   {
160    /// #     \"delta\": {
161    /// #       \"content\": \" today\"
162    /// #     },
163    /// #     \"index\": 0,
164    /// #     \"finish_reason\": null
165    /// #   }
166    /// # ]
167    /// # }";
168    /// # let res = serde_json::from_str::<openai_rust::chat::stream::ChatCompletionChunk>(json).unwrap();
169    /// let msg = &res.choices[0].delta.content;
170    /// // or
171    /// let msg = res.to_string();
172    /// ```
173    #[derive(Deserialize, Debug, Clone)]
174    pub struct ChatCompletionChunk {
175        pub id: String,
176        pub created: u32,
177        pub model: String,
178        pub choices: Vec<Choice>,
179        pub system_fingerprint: Option<String>,
180    }
181
182    impl std::fmt::Display for ChatCompletionChunk {
183        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184            write!(
185                f,
186                "{}",
187                self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
188            )?;
189            Ok(())
190        }
191    }
192
193    /// Choices for [super::ChatCompletion].
194    #[derive(Deserialize, Debug, Clone)]
195    pub struct Choice {
196        pub delta: ChoiceDelta,
197        pub index: u32,
198        pub finish_reason: Option<String>,
199    }
200
201    /// Additional data from [Choice].
202    #[derive(Deserialize, Debug, Clone)]
203    pub struct ChoiceDelta {
204        pub content: Option<String>,
205    }
206
207    pub struct ChatCompletionChunkStream {
208        byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
209        // internal buffer of incomplete completionchunks
210        buf: String,
211    }
212
213    impl ChatCompletionChunkStream {
214
215        pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
216            Self {
217                byte_stream: stream,
218                buf: String::new(),
219            }
220        }
221
222        /// If possible, returns a the first deserialized chunk
223        /// from the buffer.
224        fn deserialize_buf(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Option<anyhow::Result<ChatCompletionChunk>> {
225            // let's take the first chunk
226            let bufclone = self.buf.clone();
227            let mut chunks = bufclone.split("\n\n").peekable();
228            let first = chunks.next();
229            let second = chunks.peek();
230
231            match first {
232                Some(first) => {
233                    match first.strip_prefix("data: ") {
234                        Some(chunk) => {
235                            if !chunk.ends_with("}") {
236                                // This guard happens on partial chunks or the
237                                // [DONE] marker
238                                None
239                            } else {
240                                // If there's a second chunk, wake
241                                if let Some(second) = second {
242                                    if second.ends_with("}") {
243                                        cx.waker().wake_by_ref();
244                                    }
245                                }
246
247                                // Save the remainder
248                                self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
249                                //self.get_mut().buf = chunks.remainder().unwrap_or("").to_owned();
250
251                                Some(
252                                    serde_json::from_str::<ChatCompletionChunk>(&chunk)
253                                    .map_err(|e| anyhow::anyhow!(e))
254                                )
255                            }
256                        },
257                        None => None,
258                    }
259                },
260                None => None,
261            }
262        }
263    }
264
265    impl Stream for ChatCompletionChunkStream {
266        type Item = anyhow::Result<ChatCompletionChunk>;
267
268        fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
269
270            // Possibly fetch a chunk from the buffer
271            match self.as_mut().deserialize_buf(cx) {
272                Some(chunk) => return Poll::Ready(Some(chunk)),
273                None => {},
274            };
275
276            match self.byte_stream.as_mut().poll_next(cx) {
277                Poll::Ready(bytes_option) => match bytes_option {
278                    Some(bytes_result) => match bytes_result {
279                        Ok(bytes) => {
280                            // Finally actually get some bytes
281                            let data = str::from_utf8(&bytes)?.to_owned();
282                            self.buf = self.buf.clone() + &data;
283                            match self.deserialize_buf(cx) {
284                                Some(chunk) => Poll::Ready(Some(chunk)),
285                                // Partial
286                                None => {
287                                    // On a partial, I think the best we can do is just to wake the
288                                    // task again. If we don't this task will get stuck.
289                                    cx.waker().wake_by_ref();
290                                    Poll::Pending
291                                },
292                            }
293                        },
294                        Err(e) => Poll::Ready(Some(Err(e.into()))),
295                    },
296                    // Stream terminated
297                    None => Poll::Ready(None),
298                },
299                Poll::Pending => Poll::Pending,
300            }
301        }
302    }
303}
304
305/// Infomration about the tokens used by [ChatCompletion].
306#[derive(Deserialize, Debug, Clone)]
307pub struct Usage {
308    pub prompt_tokens: u32,
309    pub completion_tokens: u32,
310    pub total_tokens: u32,
311}
312
313/// Completion choices from [ChatCompletion].
314#[derive(Deserialize, Debug, Clone)]
315pub struct Choice {
316    pub index: u32,
317    pub message: Message,
318    pub finish_reason: String,
319}
320
321/// A message.
322#[derive(Serialize, Deserialize, Debug, Clone)]
323pub struct Message {
324    pub role: String,
325    pub content: String,
326}
327
328/// Role of a [Message].
329pub enum Role {
330    System,
331    Assistant,
332    User,
333}