openai_rust/chat.rs
1//! See <https://platform.openai.com/docs/api-reference/chat>.
2//! Use with [Client::create_chat](crate::Client::create_chat) or [Client::create_chat_stream](crate::Client::create_chat_stream).
3
4use serde::{Deserialize, Serialize};
5
6/// Request arguments for chat completion.
7///
8/// See <https://platform.openai.com/docs/api-reference/chat/create>.
9///
10/// ```
11/// let args = openai_rust::chat::ChatArguments::new("gpt-3.5-turbo", vec![
12/// openai_rust::chat::Message {
13/// role: "user".to_owned(),
14/// content: "Hello GPT!".to_owned(),
15/// }
16/// ]);
17/// ```
18///
19/// To use streaming, use [crate::Client::create_chat_stream].
20///
21#[derive(Serialize, Debug, Clone)]
22pub struct ChatArguments {
23 /// ID of the model to use. See the model [endpoint compatibility table](https://platform.openai.com/docs/models/model-endpoint-compatibility) for details on which models work with the Chat API.
24 pub model: String,
25
26 /// The [Message]s to generate chat completions for
27 pub messages: Vec<Message>,
28 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
29 ///
30 /// We generally recommend altering this or `top_p` but not both.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub temperature: Option<f32>,
33
34 /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
35 ///
36 /// We generally recommend altering this or `temperature `but not both.
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub top_p: Option<f32>,
39
40 /// How many chat completion choices to generate for each input message.
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub n: Option<u32>,
43
44 /// Whether to stream back partial progress.
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub(crate) stream: Option<bool>,
47
48 /// Up to 4 sequences where the API will stop generating further tokens.
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub stop: Option<String>,
51
52 /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
53 ///
54 /// The total length of input tokens and generated tokens is limited by the model's context length.
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub max_tokens: Option<u32>,
57
58 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
59 /// increasing the model's likelihood to talk about new topics.
60 ///
61 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub presence_penalty: Option<f32>,
64
65 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
66 /// decreasing the model's likelihood to repeat the same line verbatim.
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub frequency_penalty: Option<f32>,
69
70 // logit_bias
71 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
72 /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub user: Option<String>,
75}
76
77impl ChatArguments {
78 pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
79 ChatArguments {
80 model: model.as_ref().to_owned(),
81 messages,
82 temperature: None,
83 top_p: None,
84 n: None,
85 stream: None,
86 stop: None,
87 max_tokens: None,
88 presence_penalty: None,
89 frequency_penalty: None,
90 user: None,
91 }
92 }
93}
94
95/// This is the response of a chat.
96///
97/// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
98/// ```
99/// # use serde_json;
100/// # let json = "{
101/// # \"id\": \"chatcmpl-123\",
102/// # \"object\": \"chat.completion\",
103/// # \"created\": 1677652288,
104/// # \"choices\": [{
105/// # \"index\": 0,
106/// # \"message\": {
107/// # \"role\": \"assistant\",
108/// # \"content\": \"\\n\\nHello there, how may I assist you today?\"
109/// # },
110/// # \"finish_reason\": \"stop\"
111/// # }],
112/// # \"usage\": {
113/// # \"prompt_tokens\": 9,
114/// # \"completion_tokens\": 12,
115/// # \"total_tokens\": 21
116/// # }
117/// # }";
118/// # let res = serde_json::from_str::<openai_rust::chat::ChatCompletion>(json).unwrap();
119/// let msg = &res.choices[0].message.content;
120/// // or
121/// let msg = res.to_string();
122/// ```
123#[derive(Deserialize, Debug, Clone)]
124pub struct ChatCompletion {
125 pub id: String,
126 pub created: u32,
127 pub choices: Vec<Choice>,
128 pub usage: Usage,
129}
130
131impl std::fmt::Display for ChatCompletion {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "{}", &self.choices[0].message.content)?;
134 Ok(())
135 }
136}
137
138/// Structs and deserialization method for the responses
139/// when using streaming chat responses.
140pub mod stream {
141 use bytes::Bytes;
142 use futures_util::Stream;
143 use serde::Deserialize;
144 use std::pin::Pin;
145 use std::task::Poll;
146 use std::str;
147
148 /// This is the partial chat result received when streaming.
149 ///
150 /// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
151 /// ```
152 /// # use serde_json;
153 /// # let json = "{
154 /// # \"id\": \"chatcmpl-6yX67cSCIAm4nrNLQUPOtJu9JUoLG\",
155 /// # \"object\": \"chat.completion.chunk\",
156 /// # \"created\": 1679884927,
157 /// # \"model\": \"gpt-3.5-turbo-0301\",
158 /// # \"choices\": [
159 /// # {
160 /// # \"delta\": {
161 /// # \"content\": \" today\"
162 /// # },
163 /// # \"index\": 0,
164 /// # \"finish_reason\": null
165 /// # }
166 /// # ]
167 /// # }";
168 /// # let res = serde_json::from_str::<openai_rust::chat::stream::ChatCompletionChunk>(json).unwrap();
169 /// let msg = &res.choices[0].delta.content;
170 /// // or
171 /// let msg = res.to_string();
172 /// ```
173 #[derive(Deserialize, Debug, Clone)]
174 pub struct ChatCompletionChunk {
175 pub id: String,
176 pub created: u32,
177 pub model: String,
178 pub choices: Vec<Choice>,
179 pub system_fingerprint: Option<String>,
180 }
181
182 impl std::fmt::Display for ChatCompletionChunk {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 write!(
185 f,
186 "{}",
187 self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
188 )?;
189 Ok(())
190 }
191 }
192
193 /// Choices for [super::ChatCompletion].
194 #[derive(Deserialize, Debug, Clone)]
195 pub struct Choice {
196 pub delta: ChoiceDelta,
197 pub index: u32,
198 pub finish_reason: Option<String>,
199 }
200
201 /// Additional data from [Choice].
202 #[derive(Deserialize, Debug, Clone)]
203 pub struct ChoiceDelta {
204 pub content: Option<String>,
205 }
206
207 pub struct ChatCompletionChunkStream {
208 byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
209 // internal buffer of incomplete completionchunks
210 buf: String,
211 }
212
213 impl ChatCompletionChunkStream {
214
215 pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
216 Self {
217 byte_stream: stream,
218 buf: String::new(),
219 }
220 }
221
222 /// If possible, returns a the first deserialized chunk
223 /// from the buffer.
224 fn deserialize_buf(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Option<anyhow::Result<ChatCompletionChunk>> {
225 // let's take the first chunk
226 let bufclone = self.buf.clone();
227 let mut chunks = bufclone.split("\n\n").peekable();
228 let first = chunks.next();
229 let second = chunks.peek();
230
231 match first {
232 Some(first) => {
233 match first.strip_prefix("data: ") {
234 Some(chunk) => {
235 if !chunk.ends_with("}") {
236 // This guard happens on partial chunks or the
237 // [DONE] marker
238 None
239 } else {
240 // If there's a second chunk, wake
241 if let Some(second) = second {
242 if second.ends_with("}") {
243 cx.waker().wake_by_ref();
244 }
245 }
246
247 // Save the remainder
248 self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
249 //self.get_mut().buf = chunks.remainder().unwrap_or("").to_owned();
250
251 Some(
252 serde_json::from_str::<ChatCompletionChunk>(&chunk)
253 .map_err(|e| anyhow::anyhow!(e))
254 )
255 }
256 },
257 None => None,
258 }
259 },
260 None => None,
261 }
262 }
263 }
264
265 impl Stream for ChatCompletionChunkStream {
266 type Item = anyhow::Result<ChatCompletionChunk>;
267
268 fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
269
270 // Possibly fetch a chunk from the buffer
271 match self.as_mut().deserialize_buf(cx) {
272 Some(chunk) => return Poll::Ready(Some(chunk)),
273 None => {},
274 };
275
276 match self.byte_stream.as_mut().poll_next(cx) {
277 Poll::Ready(bytes_option) => match bytes_option {
278 Some(bytes_result) => match bytes_result {
279 Ok(bytes) => {
280 // Finally actually get some bytes
281 let data = str::from_utf8(&bytes)?.to_owned();
282 self.buf = self.buf.clone() + &data;
283 match self.deserialize_buf(cx) {
284 Some(chunk) => Poll::Ready(Some(chunk)),
285 // Partial
286 None => {
287 // On a partial, I think the best we can do is just to wake the
288 // task again. If we don't this task will get stuck.
289 cx.waker().wake_by_ref();
290 Poll::Pending
291 },
292 }
293 },
294 Err(e) => Poll::Ready(Some(Err(e.into()))),
295 },
296 // Stream terminated
297 None => Poll::Ready(None),
298 },
299 Poll::Pending => Poll::Pending,
300 }
301 }
302 }
303}
304
305/// Infomration about the tokens used by [ChatCompletion].
306#[derive(Deserialize, Debug, Clone)]
307pub struct Usage {
308 pub prompt_tokens: u32,
309 pub completion_tokens: u32,
310 pub total_tokens: u32,
311}
312
313/// Completion choices from [ChatCompletion].
314#[derive(Deserialize, Debug, Clone)]
315pub struct Choice {
316 pub index: u32,
317 pub message: Message,
318 pub finish_reason: String,
319}
320
321/// A message.
322#[derive(Serialize, Deserialize, Debug, Clone)]
323pub struct Message {
324 pub role: String,
325 pub content: String,
326}
327
328/// Role of a [Message].
329pub enum Role {
330 System,
331 Assistant,
332 User,
333}