openai_fork/
chat.rs

1//! Given a chat conversation, the model will return a chat completion response.
2
3use super::{openai_post, ApiResponseOrError, Usage};
4use crate::openai_request_stream;
5use derive_builder::Builder;
6use futures_util::StreamExt;
7use reqwest::Method;
8use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use tokio::sync::mpsc::{channel, Receiver, Sender};
13
14/// A full chat completion.
15pub type ChatCompletion = ChatCompletionGeneric<ChatCompletionChoice>;
16
17/// A delta chat completion, which is streamed token by token.
18pub type ChatCompletionDelta = ChatCompletionGeneric<ChatCompletionChoiceDelta>;
19
20#[derive(Deserialize, Clone, Debug)]
21pub struct ChatCompletionGeneric<C> {
22    pub id: String,
23    pub object: String,
24    pub created: u64,
25    pub model: String,
26    pub choices: Vec<C>,
27    pub usage: Option<Usage>,
28}
29
30#[derive(Deserialize, Clone, Debug)]
31pub struct ChatCompletionChoice {
32    pub index: u64,
33    pub finish_reason: String,
34    pub message: ChatCompletionMessage,
35}
36
37#[derive(Deserialize, Clone, Debug)]
38pub struct ChatCompletionChoiceDelta {
39    pub index: u64,
40    pub finish_reason: Option<String>,
41    pub delta: ChatCompletionMessageDelta,
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
45pub struct ChatCompletionMessage {
46    /// The role of the author of this message.
47    pub role: ChatCompletionMessageRole,
48    /// The contents of the message
49    ///
50    /// This is always required for all messages, except for when ChatGPT calls
51    /// a function.
52    pub content: Option<String>,
53    /// The name of the user in a multi-user chat
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub name: Option<String>,
56    /// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer
57    ///
58    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub function_call: Option<ChatCompletionFunctionCall>,
61}
62
63/// Same as ChatCompletionMessage, but received during a response stream.
64#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
65pub struct ChatCompletionMessageDelta {
66    /// The role of the author of this message.
67    pub role: Option<ChatCompletionMessageRole>,
68    /// The contents of the message
69    pub content: Option<String>,
70    /// The name of the user in a multi-user chat
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub name: Option<String>,
73    /// The function that ChatGPT called
74    ///
75    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub function_call: Option<ChatCompletionFunctionCallDelta>,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone)]
81pub struct ChatCompletionFunctionDefinition {
82    /// The name of the function
83    pub name: String,
84    /// The description of the function
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub description: Option<String>,
87    /// The parameters of the function formatted in JSON Schema
88    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-parameters)
89    /// [See more information about JSON Schema.](https://json-schema.org/understanding-json-schema/)
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub parameters: Option<Value>,
92}
93
94#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
95pub struct ChatCompletionFunctionCall {
96    /// The name of the function ChatGPT called
97    pub name: String,
98    /// The arguments that ChatGPT called (formatted in JSON)
99    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
100    pub arguments: String,
101}
102
103/// Same as ChatCompletionFunctionCall, but received during a response stream.
104#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
105pub struct ChatCompletionFunctionCallDelta {
106    /// The name of the function ChatGPT called
107    pub name: Option<String>,
108    /// The arguments that ChatGPT called (formatted in JSON)
109    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
110    pub arguments: Option<String>,
111}
112
113#[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)]
114#[serde(rename_all = "lowercase")]
115pub enum ChatCompletionMessageRole {
116    System,
117    User,
118    Assistant,
119    Function,
120}
121
122#[derive(Serialize, Builder, Debug, Clone)]
123#[builder(pattern = "owned")]
124#[builder(name = "ChatCompletionBuilder")]
125#[builder(setter(strip_option, into))]
126pub struct ChatCompletionRequest {
127    /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4`
128    /// are supported.
129    model: String,
130    /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
131    messages: Vec<ChatCompletionMessage>,
132    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
133    ///
134    /// We generally recommend altering this or `top_p` but not both.
135    #[builder(default)]
136    #[serde(skip_serializing_if = "Option::is_none")]
137    temperature: Option<f32>,
138    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
139    ///
140    /// We generally recommend altering this or `temperature` but not both.
141    #[builder(default)]
142    #[serde(skip_serializing_if = "Option::is_none")]
143    top_p: Option<f32>,
144    /// How many chat completion choices to generate for each input message.
145    #[builder(default)]
146    #[serde(skip_serializing_if = "Option::is_none")]
147    n: Option<u8>,
148    #[builder(default)]
149    #[serde(skip_serializing_if = "Option::is_none")]
150    stream: Option<bool>,
151    /// Up to 4 sequences where the API will stop generating further tokens.
152    #[builder(default)]
153    #[serde(skip_serializing_if = "Vec::is_empty")]
154    stop: Vec<String>,
155    /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.
156    #[builder(default)]
157    #[serde(skip_serializing_if = "Option::is_none")]
158    seed: Option<u64>,
159    /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens).
160    #[builder(default)]
161    #[serde(skip_serializing_if = "Option::is_none")]
162    max_tokens: Option<u64>,
163    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
164    ///
165    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
166    #[builder(default)]
167    #[serde(skip_serializing_if = "Option::is_none")]
168    presence_penalty: Option<f32>,
169    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
170    ///
171    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
172    #[builder(default)]
173    #[serde(skip_serializing_if = "Option::is_none")]
174    frequency_penalty: Option<f32>,
175    /// Modify the likelihood of specified tokens appearing in the completion.
176    ///
177    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
178    #[builder(default)]
179    #[serde(skip_serializing_if = "Option::is_none")]
180    logit_bias: Option<HashMap<String, f32>>,
181    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
182    #[builder(default)]
183    #[serde(skip_serializing_if = "String::is_empty")]
184    user: String,
185    /// Describe functions that ChatGPT can call
186    /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt.
187    /// For example, you can define a function called "get_weather" that returns the weather in a given city
188    ///
189    /// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions)
190    /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling)
191    #[builder(default)]
192    #[serde(skip_serializing_if = "Vec::is_empty")]
193    functions: Vec<ChatCompletionFunctionDefinition>,
194    /// A string or object of the function to call
195    ///
196    /// Controls how the model responds to function calls
197    ///
198    /// - "none" means the model does not call a function, and responds to the end-user.
199    /// - "auto" means the model can pick between an end-user or calling a function.
200    /// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function.
201    ///
202    /// "none" is the default when no functions are present. "auto" is the default if functions are present.
203    #[builder(default)]
204    #[serde(skip_serializing_if = "Option::is_none")]
205    function_call: Option<Value>,
206}
207
208impl<C> ChatCompletionGeneric<C> {
209    pub fn builder(
210        model: &str,
211        messages: impl Into<Vec<ChatCompletionMessage>>,
212    ) -> ChatCompletionBuilder {
213        ChatCompletionBuilder::create_empty()
214            .model(model)
215            .messages(messages)
216    }
217}
218
219impl ChatCompletion {
220    pub async fn create(request: &ChatCompletionRequest) -> ApiResponseOrError<Self> {
221        openai_post("chat/completions", request).await
222    }
223}
224
225impl ChatCompletionDelta {
226    pub async fn create(
227        request: &ChatCompletionRequest,
228    ) -> Result<Receiver<Self>, CannotCloneRequestError> {
229        let stream =
230            openai_request_stream(Method::POST, "chat/completions", |r| r.json(request)).await?;
231        let (tx, rx) = channel::<Self>(32);
232        tokio::spawn(forward_deserialized_chat_response_stream(stream, tx));
233        Ok(rx)
234    }
235
236    /// Merges the input delta completion into `self`.
237    pub fn merge(
238        &mut self,
239        other: ChatCompletionDelta,
240    ) -> Result<(), ChatCompletionDeltaMergeError> {
241        if other.id.ne(&self.id) {
242            return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds);
243        }
244        for other_choice in other.choices.iter() {
245            for choice in self.choices.iter_mut() {
246                if choice.index != other_choice.index {
247                    continue;
248                }
249                choice.merge(other_choice)?;
250            }
251        }
252        Ok(())
253    }
254}
255
256impl ChatCompletionChoiceDelta {
257    pub fn merge(
258        &mut self,
259        other: &ChatCompletionChoiceDelta,
260    ) -> Result<(), ChatCompletionDeltaMergeError> {
261        if self.index != other.index {
262            return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices);
263        }
264        if self.delta.role.is_none() {
265            if let Some(other_role) = other.delta.role {
266                // Set role to other_role.
267                self.delta.role = Some(other_role);
268            }
269        }
270        if self.delta.name.is_none() {
271            if let Some(other_name) = &other.delta.name {
272                // Set name to other_name.
273                self.delta.name = Some(other_name.clone());
274            }
275        }
276        // Merge contents.
277        match self.delta.content.as_mut() {
278            Some(content) => {
279                match &other.delta.content {
280                    Some(other_content) => {
281                        // Push other content into this one.
282                        content.push_str(other_content)
283                    }
284                    None => {}
285                }
286            }
287            None => {
288                match &other.delta.content {
289                    Some(other_content) => {
290                        // Set this content to other content.
291                        self.delta.content = Some(other_content.clone());
292                    }
293                    None => {}
294                }
295            }
296        };
297
298        // merge function calls
299        // function call names are concatenated
300        // arguments are merged by concatenating them
301        match self.delta.function_call.as_mut() {
302            Some(function_call) => {
303                match &other.delta.function_call {
304                    Some(other_function_call) => {
305                        // push the arguments string of the other function call into this one
306                        match (&mut function_call.arguments, &other_function_call.arguments) {
307                            (Some(function_call), Some(other_function_call)) => {
308                                function_call.push_str(&other_function_call);
309                            }
310                            (None, Some(other_function_call)) => {
311                                function_call.arguments = Some(other_function_call.clone());
312                            }
313                            _ => {}
314                        }
315                    }
316                    None => {}
317                }
318            }
319            None => {
320                match &other.delta.function_call {
321                    Some(other_function_call) => {
322                        // Set this content to other content.
323                        self.delta.function_call = Some(other_function_call.clone());
324                    }
325                    None => {}
326                }
327            }
328        };
329        Ok(())
330    }
331}
332
333impl From<ChatCompletionDelta> for ChatCompletion {
334    fn from(delta: ChatCompletionDelta) -> Self {
335        ChatCompletion {
336            id: delta.id,
337            object: delta.object,
338            created: delta.created,
339            model: delta.model,
340            usage: delta.usage,
341            choices: delta
342                .choices
343                .iter()
344                .map(|choice| ChatCompletionChoice {
345                    index: choice.index,
346                    finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason),
347                    message: ChatCompletionMessage {
348                        role: choice
349                            .delta
350                            .role
351                            .unwrap_or_else(|| ChatCompletionMessageRole::System),
352                        content: choice.delta.content.clone(),
353                        name: choice.delta.name.clone(),
354                        function_call: choice.delta.function_call.clone().map(|f| f.into()),
355                    },
356                })
357                .collect(),
358        }
359    }
360}
361
362impl From<ChatCompletionFunctionCallDelta> for ChatCompletionFunctionCall {
363    fn from(delta: ChatCompletionFunctionCallDelta) -> Self {
364        ChatCompletionFunctionCall {
365            name: delta.name.unwrap_or("".to_string()),
366            arguments: delta.arguments.unwrap_or_default(),
367        }
368    }
369}
370
371#[derive(Debug)]
372pub enum ChatCompletionDeltaMergeError {
373    DifferentCompletionIds,
374    DifferentCompletionChoiceIndices,
375    FunctionCallArgumentTypeMismatch,
376}
377
378async fn forward_deserialized_chat_response_stream(
379    mut stream: EventSource,
380    tx: Sender<ChatCompletionDelta>,
381) -> anyhow::Result<()> {
382    while let Some(event) = stream.next().await {
383        let event = event?;
384        match event {
385            Event::Message(event) => {
386                let completion = serde_json::from_str::<ChatCompletionDelta>(&event.data)?;
387                tx.send(completion).await?;
388            }
389            _ => {}
390        }
391    }
392    Ok(())
393}
394
395impl ChatCompletionBuilder {
396    pub async fn create(self) -> ApiResponseOrError<ChatCompletion> {
397        ChatCompletion::create(&self.build().unwrap()).await
398    }
399
400    pub async fn create_stream(
401        mut self,
402    ) -> Result<Receiver<ChatCompletionDelta>, CannotCloneRequestError> {
403        self.stream = Some(Some(true));
404        ChatCompletionDelta::create(&self.build().unwrap()).await
405    }
406}
407
408fn clone_default_unwrapped_option_string(string: &Option<String>) -> String {
409    match string {
410        Some(value) => value.clone(),
411        None => "".to_string(),
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::set_key;
419    use dotenvy::dotenv;
420    use std::env;
421
422    #[tokio::test]
423    async fn chat() {
424        dotenv().ok();
425        set_key(env::var("OPENAI_KEY").unwrap());
426
427        let chat_completion = ChatCompletion::builder(
428            "gpt-3.5-turbo",
429            [ChatCompletionMessage {
430                role: ChatCompletionMessageRole::User,
431                content: Some("Hello!".to_string()),
432                name: None,
433                function_call: None,
434            }],
435        )
436        .temperature(0.0)
437        .create()
438        .await
439        .unwrap();
440
441        assert_eq!(
442            chat_completion
443                .choices
444                .first()
445                .unwrap()
446                .message
447                .content
448                .as_ref()
449                .unwrap(),
450            "Hello! How can I assist you today?"
451        );
452    }
453
454    // Seeds are not deterministic so the only point of the test is to
455    // ensure that passing a seed still results in a valid response.
456    #[tokio::test]
457    async fn chat_seed() {
458        dotenv().ok();
459        set_key(env::var("OPENAI_KEY").unwrap());
460
461        let chat_completion = ChatCompletion::builder(
462            "gpt-3.5-turbo",
463            [ChatCompletionMessage {
464                role: ChatCompletionMessageRole::User,
465                content: Some(
466                    "What type of seed does Mr. England sow in the song? Reply with 1 word."
467                        .to_string(),
468                ),
469                name: None,
470                function_call: None,
471            }],
472        )
473        // Determinism currently comes from temperature 0, not seed.
474        .temperature(0.0)
475        .seed(1337u64)
476        .create()
477        .await
478        .unwrap();
479
480        assert_eq!(
481            chat_completion
482                .choices
483                .first()
484                .unwrap()
485                .message
486                .content
487                .as_ref()
488                .unwrap(),
489            "Love"
490        );
491    }
492
493    #[tokio::test]
494    async fn chat_stream() {
495        dotenv().ok();
496        set_key(env::var("OPENAI_KEY").unwrap());
497
498        let chat_stream = ChatCompletion::builder(
499            "gpt-3.5-turbo",
500            [ChatCompletionMessage {
501                role: ChatCompletionMessageRole::User,
502                content: Some("Hello!".to_string()),
503                name: None,
504                function_call: None,
505            }],
506        )
507        .temperature(0.0)
508        .create_stream()
509        .await
510        .unwrap();
511
512        let chat_completion = stream_to_completion(chat_stream).await;
513
514        assert_eq!(
515            chat_completion
516                .choices
517                .first()
518                .unwrap()
519                .message
520                .content
521                .as_ref()
522                .unwrap(),
523            "Hello! How can I assist you today?"
524        );
525    }
526
527    #[tokio::test]
528    async fn chat_function() {
529        dotenv().ok();
530        set_key(env::var("OPENAI_KEY").unwrap());
531
532        let chat_stream = ChatCompletion::builder(
533            "gpt-3.5-turbo-0613",
534            [
535                ChatCompletionMessage {
536                    role: ChatCompletionMessageRole::User,
537                    content: Some("What is the weather in Boston?".to_string()),
538                    name: None,
539                    function_call: None,
540                }
541            ]
542        ).functions([ChatCompletionFunctionDefinition {
543            description: Some("Get the current weather in a given location.".to_string()),
544            name: "get_current_weather".to_string(),
545            parameters: Some(serde_json::json!({
546                "type": "object",
547                "properties": {
548                    "location": {
549                        "type": "string",
550                        "description": "The city and state to get the weather for. (eg: San Francisco, CA)"
551                    }
552                },
553                "required": ["location"]
554            })),
555        }])
556        .temperature(0.2)
557        .create_stream()
558        .await
559        .unwrap();
560
561        let chat_completion = stream_to_completion(chat_stream).await;
562
563        assert_eq!(
564            chat_completion
565                .choices
566                .first()
567                .unwrap()
568                .message
569                .function_call
570                .as_ref()
571                .unwrap()
572                .name,
573            "get_current_weather".to_string(),
574        );
575
576        assert_eq!(
577            serde_json::from_str::<Value>(
578                &chat_completion
579                    .choices
580                    .first()
581                    .unwrap()
582                    .message
583                    .function_call
584                    .as_ref()
585                    .unwrap()
586                    .arguments
587            )
588            .unwrap(),
589            serde_json::json!({
590                "location": "Boston, MA"
591            }),
592        );
593    }
594
595    async fn stream_to_completion(
596        mut chat_stream: Receiver<ChatCompletionDelta>,
597    ) -> ChatCompletion {
598        let mut merged: Option<ChatCompletionDelta> = None;
599        while let Some(delta) = chat_stream.recv().await {
600            match merged.as_mut() {
601                Some(c) => {
602                    c.merge(delta).unwrap();
603                }
604                None => merged = Some(delta),
605            };
606        }
607        merged.unwrap().into()
608    }
609}