openai_rust2/chat.rs
1//! See <https://platform.openai.com/docs/api-reference/chat>.
2//! Use with [Client::create_chat](crate::Client::create_chat) or [Client::create_chat_stream](crate::Client::create_chat_stream).
3
4use serde::{Deserialize, Serialize};
5
6/// Request arguments for chat completion.
7///
8/// See <https://platform.openai.com/docs/api-reference/chat/create>.
9///
10/// ```
11/// let args = openai_rust2::chat::ChatArguments::new("gpt-3.5-turbo", vec![
12/// openai_rust2::chat::Message {
13/// role: "user".to_owned(),
14/// content: "Hello GPT!".to_owned(),
15/// }
16/// ]);
17/// ```
18///
19/// To use streaming, use [crate::Client::create_chat_stream].
20///
21#[derive(Serialize, Debug, Clone)]
22pub struct ChatArguments {
23 /// ID of the model to use. See the model [endpoint compatibility table](https://platform.openai.com/docs/models/model-endpoint-compatibility) for details on which models work with the Chat API.
24 pub model: String,
25
26 /// The [Message]s to generate chat completions for
27 pub messages: Vec<Message>,
28 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
29 ///
30 /// We generally recommend altering this or `top_p` but not both.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub temperature: Option<f32>,
33
34 /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
35 ///
36 /// We generally recommend altering this or `temperature `but not both.
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub top_p: Option<f32>,
39
40 /// How many chat completion choices to generate for each input message.
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub n: Option<u32>,
43
44 /// Whether to stream back partial progress.
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub(crate) stream: Option<bool>,
47
48 /// Up to 4 sequences where the API will stop generating further tokens.
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub stop: Option<String>,
51
52 /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
53 ///
54 /// The total length of input tokens and generated tokens is limited by the model's context length.
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub max_tokens: Option<u32>,
57
58 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
59 /// increasing the model's likelihood to talk about new topics.
60 ///
61 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub presence_penalty: Option<f32>,
64
65 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
66 /// decreasing the model's likelihood to repeat the same line verbatim.
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub frequency_penalty: Option<f32>,
69
70 // logit_bias
71 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
72 /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub user: Option<String>,
75}
76
77impl ChatArguments {
78 pub fn new(model: impl AsRef<str>, messages: Vec<Message>) -> ChatArguments {
79 ChatArguments {
80 model: model.as_ref().to_owned(),
81 messages,
82 temperature: None,
83 top_p: None,
84 n: None,
85 stream: None,
86 stop: None,
87 max_tokens: None,
88 presence_penalty: None,
89 frequency_penalty: None,
90 user: None,
91 }
92 }
93}
94
95/// This is the response of a chat.
96///
97/// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
98/// ```
99/// # use serde_json;
100/// # let json = "{
101/// # \"id\": \"chatcmpl-123\",
102/// # \"object\": \"chat.completion\",
103/// # \"created\": 1677652288,
104/// # \"choices\": [{
105/// # \"index\": 0,
106/// # \"message\": {
107/// # \"role\": \"assistant\",
108/// # \"content\": \"\\n\\nHello there, how may I assist you today?\"
109/// # },
110/// # \"finish_reason\": \"stop\"
111/// # }],
112/// # \"usage\": {
113/// # \"prompt_tokens\": 9,
114/// # \"completion_tokens\": 12,
115/// # \"total_tokens\": 21
116/// # }
117/// # }";
118/// # let res = serde_json::from_str::<openai_rust2::chat::ChatCompletion>(json).unwrap();
119/// let msg = &res.choices[0].message.content;
120/// // or
121/// let msg = res.to_string();
122/// ```
123#[derive(Deserialize, Debug, Clone)]
124pub struct ChatCompletion {
125 // Make id optional
126 #[serde(default)]
127 pub id: Option<String>,
128 pub created: u32,
129 #[serde(default)]
130 pub model: Option<String>,
131 #[serde(default)]
132 pub object: Option<String>,
133 pub choices: Vec<Choice>,
134 pub usage: Usage,
135}
136
137impl std::fmt::Display for ChatCompletion {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 write!(f, "{}", &self.choices[0].message.content)?;
140 Ok(())
141 }
142}
143
144/// Structs and deserialization method for the responses
145/// when using streaming chat responses.
146pub mod stream {
147 use bytes::Bytes;
148 use futures_util::Stream;
149 use serde::Deserialize;
150 use std::pin::Pin;
151 use std::str;
152 use std::task::Poll;
153
154 /// This is the partial chat result received when streaming.
155 ///
156 /// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
157 /// ```
158 /// # use serde_json;
159 /// # let json = "{
160 /// # \"id\": \"chatcmpl-6yX67cSCIAm4nrNLQUPOtJu9JUoLG\",
161 /// # \"object\": \"chat.completion.chunk\",
162 /// # \"created\": 1679884927,
163 /// # \"model\": \"gpt-3.5-turbo-0301\",
164 /// # \"choices\": [
165 /// # {
166 /// # \"delta\": {
167 /// # \"content\": \" today\"
168 /// # },
169 /// # \"index\": 0,
170 /// # \"finish_reason\": null
171 /// # }
172 /// # ]
173 /// # }";
174 /// # let res = serde_json::from_str::<openai_rust2::chat::stream::ChatCompletionChunk>(json).unwrap();
175 /// let msg = &res.choices[0].delta.content;
176 /// // or
177 /// let msg = res.to_string();
178 /// ```
179 #[derive(Deserialize, Debug, Clone)]
180 pub struct ChatCompletionChunk {
181 pub id: String,
182 pub created: u32,
183 pub model: String,
184 pub choices: Vec<Choice>,
185 pub system_fingerprint: Option<String>,
186 }
187
188 impl std::fmt::Display for ChatCompletionChunk {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(
191 f,
192 "{}",
193 self.choices[0].delta.content.as_ref().unwrap_or(&"".into())
194 )?;
195 Ok(())
196 }
197 }
198
199 /// Choices for [super::ChatCompletion].
200 #[derive(Deserialize, Debug, Clone)]
201 pub struct Choice {
202 pub delta: ChoiceDelta,
203 pub index: u32,
204 pub finish_reason: Option<String>,
205 }
206
207 /// Additional data from [Choice].
208 #[derive(Deserialize, Debug, Clone)]
209 pub struct ChoiceDelta {
210 pub content: Option<String>,
211 }
212
213 pub struct ChatCompletionChunkStream {
214 byte_stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>,
215 // internal buffer of incomplete completionchunks
216 buf: String,
217 }
218
219 impl ChatCompletionChunkStream {
220 pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>>>>) -> Self {
221 Self {
222 byte_stream: stream,
223 buf: String::new(),
224 }
225 }
226
227 /// If possible, returns a the first deserialized chunk
228 /// from the buffer.
229 fn deserialize_buf(
230 self: Pin<&mut Self>,
231 cx: &mut std::task::Context<'_>,
232 ) -> Option<anyhow::Result<ChatCompletionChunk>> {
233 // let's take the first chunk
234 let bufclone = self.buf.clone();
235 let mut chunks = bufclone.split("\n\n").peekable();
236 let first = chunks.next();
237 let second = chunks.peek();
238
239 match first {
240 Some(first) => {
241 match first.strip_prefix("data: ") {
242 Some(chunk) => {
243 if !chunk.ends_with("}") {
244 // This guard happens on partial chunks or the
245 // [DONE] marker
246 None
247 } else {
248 // If there's a second chunk, wake
249 if let Some(second) = second {
250 if second.ends_with("}") {
251 cx.waker().wake_by_ref();
252 }
253 }
254
255 // Save the remainder
256 self.get_mut().buf = chunks.collect::<Vec<_>>().join("\n\n");
257 //self.get_mut().buf = chunks.remainder().unwrap_or("").to_owned();
258
259 Some(
260 serde_json::from_str::<ChatCompletionChunk>(&chunk)
261 .map_err(|e| anyhow::anyhow!(e)),
262 )
263 }
264 }
265 None => None,
266 }
267 }
268 None => None,
269 }
270 }
271 }
272
273 impl Stream for ChatCompletionChunkStream {
274 type Item = anyhow::Result<ChatCompletionChunk>;
275
276 fn poll_next(
277 mut self: Pin<&mut Self>,
278 cx: &mut std::task::Context<'_>,
279 ) -> Poll<Option<Self::Item>> {
280 // Possibly fetch a chunk from the buffer
281 match self.as_mut().deserialize_buf(cx) {
282 Some(chunk) => return Poll::Ready(Some(chunk)),
283 None => {}
284 };
285
286 match self.byte_stream.as_mut().poll_next(cx) {
287 Poll::Ready(bytes_option) => match bytes_option {
288 Some(bytes_result) => match bytes_result {
289 Ok(bytes) => {
290 // Finally actually get some bytes
291 let data = str::from_utf8(&bytes)?.to_owned();
292 self.buf = self.buf.clone() + &data;
293 match self.deserialize_buf(cx) {
294 Some(chunk) => Poll::Ready(Some(chunk)),
295 // Partial
296 None => {
297 // On a partial, I think the best we can do is just to wake the
298 // task again. If we don't this task will get stuck.
299 cx.waker().wake_by_ref();
300 Poll::Pending
301 }
302 }
303 }
304 Err(e) => Poll::Ready(Some(Err(e.into()))),
305 },
306 // Stream terminated
307 None => Poll::Ready(None),
308 },
309 Poll::Pending => Poll::Pending,
310 }
311 }
312 }
313}
314
315/// Information about the tokens used by [ChatCompletion].
316#[derive(Deserialize, Debug, Clone)]
317pub struct Usage {
318 pub prompt_tokens: u32,
319 pub completion_tokens: u32,
320 pub total_tokens: u32,
321}
322
323/// Completion choices from [ChatCompletion].
324#[derive(Deserialize, Debug, Clone)]
325pub struct Choice {
326 #[serde(default)]
327 pub index: Option<u32>,
328 pub message: Message,
329 pub finish_reason: String,
330}
331
332/// A message.
333#[derive(Serialize, Deserialize, Debug, Clone)]
334pub struct Message {
335 pub role: String,
336 pub content: String,
337}
338
339/// Role of a [Message].
340pub enum Role {
341 System,
342 Assistant,
343 User,
344}