openai/
chat.rs

1//! Given a chat conversation, the model will return a chat completion response.
2
3use super::{
4    openai_get, openai_get_with_query, openai_post, ApiResponseOrError, Credentials,
5    RequestPagination, Usage,
6};
7use crate::openai_request_stream;
8use derive_builder::Builder;
9use futures_util::StreamExt;
10use reqwest::Method;
11use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use tokio::sync::mpsc::{channel, Receiver, Sender};
16
17/// A full chat completion.
18pub type ChatCompletion = ChatCompletionGeneric<ChatCompletionChoice>;
19
20/// A delta chat completion, which is streamed token by token.
21pub type ChatCompletionDelta = ChatCompletionGeneric<ChatCompletionChoiceDelta>;
22
23#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
24pub struct ChatCompletionGeneric<C> {
25    #[serde(default)]
26    pub id: String,
27    #[serde(default)]
28    pub object: String,
29    #[serde(default)]
30    pub created: u64,
31    #[serde(default)]
32    pub model: String,
33    #[serde(default = "default_empty_vec")]
34    pub choices: Vec<C>,
35    pub usage: Option<Usage>,
36}
37
38#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
39pub struct ChatCompletionChoice {
40    pub index: u64,
41    pub finish_reason: String,
42    pub message: ChatCompletionMessage,
43}
44
45#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
46pub struct ChatCompletionChoiceDelta {
47    pub index: u64,
48    pub finish_reason: Option<String>,
49    pub delta: ChatCompletionMessageDelta,
50}
51
52fn is_none_or_empty_vec<T>(opt: &Option<Vec<T>>) -> bool {
53    opt.as_ref().map(|v| v.is_empty()).unwrap_or(true)
54}
55
56#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)]
57pub struct ChatCompletionMessage {
58    /// The role of the author of this message.
59    pub role: ChatCompletionMessageRole,
60    /// The contents of the message
61    ///
62    /// This is always required for all messages, except for when ChatGPT calls
63    /// a function.
64    pub content: Option<String>,
65    /// The name of the user in a multi-user chat
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub name: Option<String>,
68    /// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer
69    ///
70    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub function_call: Option<ChatCompletionFunctionCall>,
73    /// Tool call that this message is responding to.
74    /// Required if the role is `Tool`.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub tool_call_id: Option<String>,
77    /// Tool calls that the assistant is requesting to invoke.
78    /// Can only be populated if the role is `Assistant`,
79    /// otherwise it should be empty.
80    #[serde(skip_serializing_if = "is_none_or_empty_vec")]
81    pub tool_calls: Option<Vec<ToolCall>>,
82}
83
84/// Same as ChatCompletionMessage, but received during a response stream.
85#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
86pub struct ChatCompletionMessageDelta {
87    /// The role of the author of this message.
88    pub role: Option<ChatCompletionMessageRole>,
89    /// The contents of the message
90    pub content: Option<String>,
91    /// The name of the user in a multi-user chat
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub name: Option<String>,
94    /// The function that ChatGPT called
95    ///
96    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub function_call: Option<ChatCompletionFunctionCallDelta>,
99    /// Tool call that this message is responding to.
100    /// Required if the role is `Tool`.
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub tool_call_id: Option<String>,
103    /// Tool calls that the assistant is requesting to invoke.
104    /// Can only be populated if the role is `Assistant`,
105    /// otherwise it should be empty.
106    #[serde(skip_serializing_if = "is_none_or_empty_vec")]
107    pub tool_calls: Option<Vec<ToolCall>>,
108}
109
110#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
111pub struct ToolCall {
112    /// The ID of the tool call.
113    pub id: String,
114    /// The type of the tool. Currently, only `function` is supported.
115    pub r#type: String,
116    /// The function that the model called.
117    pub function: ToolCallFunction,
118}
119
120#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
121pub struct ToolCallFunction {
122    /// The name of the function to call.
123    pub name: String,
124    /// The arguments to call the function with, as generated by the model in
125    /// JSON format.
126    /// Note that the model does not always generate valid JSON, and may
127    /// hallucinate parameters not defined by your function schema.
128    /// Validate the arguments in your code before calling your function.
129    pub arguments: String,
130}
131
132#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
133pub struct ChatCompletionFunctionDefinition {
134    /// The name of the function
135    pub name: String,
136    /// The description of the function
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub description: Option<String>,
139    /// The parameters of the function formatted in JSON Schema
140    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-parameters)
141    /// [See more information about JSON Schema.](https://json-schema.org/understanding-json-schema/)
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub parameters: Option<Value>,
144}
145
146#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
147pub struct ChatCompletionFunctionCall {
148    /// The name of the function ChatGPT called
149    pub name: String,
150    /// The arguments that ChatGPT called (formatted in JSON)
151    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
152    pub arguments: String,
153}
154
155/// Same as ChatCompletionFunctionCall, but received during a response stream.
156#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
157pub struct ChatCompletionFunctionCallDelta {
158    /// The name of the function ChatGPT called
159    pub name: Option<String>,
160    /// The arguments that ChatGPT called (formatted in JSON)
161    /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
162    pub arguments: Option<String>,
163}
164
165#[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)]
166#[serde(rename_all = "lowercase")]
167pub enum ChatCompletionMessageRole {
168    System,
169    User,
170    Assistant,
171    Function,
172    Tool,
173    Developer,
174}
175
176#[derive(Serialize, Builder, Debug, Clone)]
177#[builder(derive(Clone, Debug, PartialEq))]
178#[builder(pattern = "owned")]
179#[builder(name = "ChatCompletionBuilder")]
180#[builder(setter(strip_option, into))]
181pub struct ChatCompletionRequest {
182    /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4`
183    /// are supported.
184    model: String,
185    /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
186    messages: Vec<ChatCompletionMessage>,
187    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
188    ///
189    /// We generally recommend altering this or `top_p` but not both.
190    #[builder(default)]
191    #[serde(skip_serializing_if = "Option::is_none")]
192    temperature: Option<f32>,
193    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
194    ///
195    /// We generally recommend altering this or `temperature` but not both.
196    #[builder(default)]
197    #[serde(skip_serializing_if = "Option::is_none")]
198    top_p: Option<f32>,
199    /// How many chat completion choices to generate for each input message.
200    #[builder(default)]
201    #[serde(skip_serializing_if = "Option::is_none")]
202    n: Option<u8>,
203    #[builder(default)]
204    #[serde(skip_serializing_if = "Option::is_none")]
205    stream: Option<bool>,
206    /// Up to 4 sequences where the API will stop generating further tokens.
207    #[builder(default)]
208    #[serde(skip_serializing_if = "Vec::is_empty")]
209    stop: Vec<String>,
210    /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.
211    #[builder(default)]
212    #[serde(skip_serializing_if = "Option::is_none")]
213    seed: Option<u64>,
214    /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens).
215    #[builder(default)]
216    #[serde(skip_serializing_if = "Option::is_none")]
217    max_tokens: Option<u64>,
218    /// The maximum number of tokens allowed for the generated answer.
219    /// For reasoning models such as o1 and o3-mini, this does not include reasoning tokens.
220    #[builder(default)]
221    #[serde(skip_serializing_if = "Option::is_none")]
222    max_completion_tokens: Option<u64>,
223    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
224    ///
225    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
226    #[builder(default)]
227    #[serde(skip_serializing_if = "Option::is_none")]
228    presence_penalty: Option<f32>,
229    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
230    ///
231    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
232    #[builder(default)]
233    #[serde(skip_serializing_if = "Option::is_none")]
234    frequency_penalty: Option<f32>,
235    /// Modify the likelihood of specified tokens appearing in the completion.
236    ///
237    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
238    #[builder(default)]
239    #[serde(skip_serializing_if = "Option::is_none")]
240    logit_bias: Option<HashMap<String, f32>>,
241    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
242    #[builder(default)]
243    #[serde(skip_serializing_if = "String::is_empty")]
244    user: String,
245    /// Describe functions that ChatGPT can call
246    /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt.
247    /// For example, you can define a function called "get_weather" that returns the weather in a given city
248    ///
249    /// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions)
250    /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling)
251    #[builder(default)]
252    #[serde(skip_serializing_if = "Vec::is_empty")]
253    functions: Vec<ChatCompletionFunctionDefinition>,
254    /// A string or object of the function to call
255    ///
256    /// Controls how the model responds to function calls
257    ///
258    /// - "none" means the model does not call a function, and responds to the end-user.
259    /// - "auto" means the model can pick between an end-user or calling a function.
260    /// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function.
261    ///
262    /// "none" is the default when no functions are present. "auto" is the default if functions are present.
263    #[builder(default)]
264    #[serde(skip_serializing_if = "Option::is_none")]
265    function_call: Option<Value>,
266    /// An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106.
267    /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
268    /// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.
269    #[builder(default)]
270    #[serde(skip_serializing_if = "Option::is_none")]
271    response_format: Option<ChatCompletionResponseFormat>,
272    /// The credentials to use for this request.
273    #[serde(skip_serializing)]
274    #[builder(default)]
275    credentials: Option<Credentials>,
276    /// Parameters unique to the Venice API.
277    /// https://docs.venice.ai/api-reference/api-spec
278    #[builder(default)]
279    #[serde(skip_serializing_if = "Option::is_none")]
280    venice_parameters: Option<VeniceParameters>,
281    /// Whether to store the completion for use in distillation or evals.
282    #[serde(skip_serializing_if = "Option::is_none")]
283    #[builder(default)]
284    pub store: Option<bool>,
285}
286
287#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
288pub struct VeniceParameters {
289    pub include_venice_system_prompt: bool,
290}
291
292#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
293pub struct ChatCompletionResponseFormat {
294    /// Must be one of text or json_object (defaults to text)
295    #[serde(rename = "type")]
296    typ: String,
297}
298
299impl ChatCompletionResponseFormat {
300    pub fn json_object() -> Self {
301        ChatCompletionResponseFormat {
302            typ: "json_object".to_string(),
303        }
304    }
305
306    pub fn text() -> Self {
307        ChatCompletionResponseFormat {
308            typ: "text".to_string(),
309        }
310    }
311}
312
313impl<C> ChatCompletionGeneric<C> {
314    pub fn builder(
315        model: &str,
316        messages: impl Into<Vec<ChatCompletionMessage>>,
317    ) -> ChatCompletionBuilder {
318        ChatCompletionBuilder::create_empty()
319            .model(model)
320            .messages(messages)
321    }
322}
323
324#[derive(Serialize, Builder, Debug, Clone, Default)]
325#[builder(derive(Clone, Debug, PartialEq))]
326#[builder(pattern = "owned")]
327#[builder(name = "ChatCompletionMessagesRequestBuilder")]
328#[builder(setter(strip_option, into))]
329pub struct ChatCompletionMessagesRequest {
330    #[serde(skip_serializing)]
331    pub completion_id: String,
332
333    #[builder(default)]
334    #[serde(skip_serializing)]
335    pub credentials: Option<Credentials>,
336
337    #[builder(default)]
338    #[serde(flatten)]
339    pub pagination: RequestPagination,
340}
341
342/// A list of messages for a chat completion.
343#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
344pub struct ChatCompletionMessages {
345    pub data: Vec<ChatCompletionMessage>,
346    pub object: String,
347    pub first_id: Option<String>,
348    pub last_id: Option<String>,
349    pub has_more: bool,
350}
351
352impl ChatCompletion {
353    pub async fn create(request: ChatCompletionRequest) -> ApiResponseOrError<Self> {
354        let credentials_opt = request.credentials.clone();
355        openai_post("chat/completions", &request, credentials_opt).await
356    }
357
358    /// Get a stored completion.
359    pub async fn get(id: &str, credentials: Credentials) -> ApiResponseOrError<Self> {
360        let route = format!("chat/completions/{}", id);
361        openai_get(route.as_str(), Some(credentials)).await
362    }
363}
364
365impl ChatCompletionDelta {
366    pub async fn create(
367        request: ChatCompletionRequest,
368    ) -> Result<Receiver<Self>, CannotCloneRequestError> {
369        let credentials_opt = request.credentials.clone();
370        let stream = openai_request_stream(
371            Method::POST,
372            "chat/completions",
373            |r| r.json(&request),
374            credentials_opt,
375        )
376        .await?;
377        let (tx, rx) = channel::<Self>(32);
378        tokio::spawn(forward_deserialized_chat_response_stream(stream, tx));
379        Ok(rx)
380    }
381
382    /// Merges the input delta completion into `self`.
383    pub fn merge(
384        &mut self,
385        other: ChatCompletionDelta,
386    ) -> Result<(), ChatCompletionDeltaMergeError> {
387        if other.id.ne(&self.id) {
388            return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds);
389        }
390        for other_choice in other.choices.iter() {
391            for choice in self.choices.iter_mut() {
392                if choice.index != other_choice.index {
393                    continue;
394                }
395                choice.merge(other_choice)?;
396            }
397        }
398        Ok(())
399    }
400}
401
402impl ChatCompletionChoiceDelta {
403    pub fn merge(
404        &mut self,
405        other: &ChatCompletionChoiceDelta,
406    ) -> Result<(), ChatCompletionDeltaMergeError> {
407        if self.index != other.index {
408            return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices);
409        }
410        if self.delta.role.is_none() {
411            if let Some(other_role) = other.delta.role {
412                // Set role to other_role.
413                self.delta.role = Some(other_role);
414            }
415        }
416        if self.delta.name.is_none() {
417            if let Some(other_name) = &other.delta.name {
418                // Set name to other_name.
419                self.delta.name = Some(other_name.clone());
420            }
421        }
422        // Merge contents.
423        match self.delta.content.as_mut() {
424            Some(content) => {
425                match &other.delta.content {
426                    Some(other_content) => {
427                        // Push other content into this one.
428                        content.push_str(other_content)
429                    }
430                    None => {}
431                }
432            }
433            None => {
434                match &other.delta.content {
435                    Some(other_content) => {
436                        // Set this content to other content.
437                        self.delta.content = Some(other_content.clone());
438                    }
439                    None => {}
440                }
441            }
442        };
443
444        // merge function calls
445        // function call names are concatenated
446        // arguments are merged by concatenating them
447        match self.delta.function_call.as_mut() {
448            Some(function_call) => {
449                match &other.delta.function_call {
450                    Some(other_function_call) => {
451                        // push the arguments string of the other function call into this one
452                        match (&mut function_call.arguments, &other_function_call.arguments) {
453                            (Some(function_call), Some(other_function_call)) => {
454                                function_call.push_str(&other_function_call);
455                            }
456                            (None, Some(other_function_call)) => {
457                                function_call.arguments = Some(other_function_call.clone());
458                            }
459                            _ => {}
460                        }
461                    }
462                    None => {}
463                }
464            }
465            None => {
466                match &other.delta.function_call {
467                    Some(other_function_call) => {
468                        // Set this content to other content.
469                        self.delta.function_call = Some(other_function_call.clone());
470                    }
471                    None => {}
472                }
473            }
474        };
475        Ok(())
476    }
477}
478
479impl From<ChatCompletionDelta> for ChatCompletion {
480    fn from(delta: ChatCompletionDelta) -> Self {
481        ChatCompletion {
482            id: delta.id,
483            object: delta.object,
484            created: delta.created,
485            model: delta.model,
486            usage: delta.usage,
487            choices: delta
488                .choices
489                .iter()
490                .map(|choice| ChatCompletionChoice {
491                    index: choice.index,
492                    finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason),
493                    message: ChatCompletionMessage {
494                        role: choice
495                            .delta
496                            .role
497                            .unwrap_or_else(|| ChatCompletionMessageRole::System),
498                        content: choice.delta.content.clone(),
499                        name: choice.delta.name.clone(),
500                        function_call: choice.delta.function_call.clone().map(|f| f.into()),
501                        tool_call_id: None,
502                        tool_calls: Some(Vec::new()),
503                    },
504                })
505                .collect(),
506        }
507    }
508}
509
510impl From<ChatCompletionFunctionCallDelta> for ChatCompletionFunctionCall {
511    fn from(delta: ChatCompletionFunctionCallDelta) -> Self {
512        ChatCompletionFunctionCall {
513            name: delta.name.unwrap_or("".to_string()),
514            arguments: delta.arguments.unwrap_or_default(),
515        }
516    }
517}
518
519impl ChatCompletionMessages {
520    /// Create a builder for fetching messages for a stored completion.
521    pub fn builder(completion_id: String) -> ChatCompletionMessagesRequestBuilder {
522        ChatCompletionMessagesRequestBuilder::create_empty()
523            .completion_id(completion_id.to_string())
524    }
525
526    /// Fetch messages for a stored completion.
527    pub async fn fetch(
528        request: ChatCompletionMessagesRequest,
529    ) -> ApiResponseOrError<ChatCompletionMessages> {
530        let route = format!("chat/completions/{}/messages", request.completion_id);
531        let credentials = request.credentials.clone();
532        openai_get_with_query(route.as_str(), &request, credentials).await
533    }
534}
535
536#[derive(Debug)]
537pub enum ChatCompletionDeltaMergeError {
538    DifferentCompletionIds,
539    DifferentCompletionChoiceIndices,
540    FunctionCallArgumentTypeMismatch,
541}
542
543impl std::fmt::Display for ChatCompletionDeltaMergeError {
544    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545        match self {
546            ChatCompletionDeltaMergeError::DifferentCompletionIds => {
547                f.write_str("Different completion IDs")
548            }
549            ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices => {
550                f.write_str("Different completion choice indices")
551            }
552            ChatCompletionDeltaMergeError::FunctionCallArgumentTypeMismatch => {
553                f.write_str("Function call argument type mismatch")
554            }
555        }
556    }
557}
558
559impl std::error::Error for ChatCompletionDeltaMergeError {}
560
561async fn forward_deserialized_chat_response_stream(
562    mut stream: EventSource,
563    tx: Sender<ChatCompletionDelta>,
564) -> anyhow::Result<()> {
565    while let Some(event) = stream.next().await {
566        let event = event?;
567        match event {
568            Event::Message(event) => {
569                let completion = serde_json::from_str::<ChatCompletionDelta>(&event.data)?;
570                tx.send(completion).await?;
571            }
572            _ => {}
573        }
574    }
575    Ok(())
576}
577
578impl ChatCompletionBuilder {
579    pub async fn create(self) -> ApiResponseOrError<ChatCompletion> {
580        ChatCompletion::create(self.build().unwrap()).await
581    }
582
583    pub async fn create_stream(
584        mut self,
585    ) -> Result<Receiver<ChatCompletionDelta>, CannotCloneRequestError> {
586        self.stream = Some(Some(true));
587        ChatCompletionDelta::create(self.build().unwrap()).await
588    }
589}
590
591impl ChatCompletionMessagesRequestBuilder {
592    /// Fetch messages for the specified completion.
593    pub async fn fetch(self) -> ApiResponseOrError<ChatCompletionMessages> {
594        ChatCompletionMessages::fetch(self.build().unwrap()).await
595    }
596}
597
598fn clone_default_unwrapped_option_string(string: &Option<String>) -> String {
599    match string {
600        Some(value) => value.clone(),
601        None => "".to_string(),
602    }
603}
604
605impl Default for ChatCompletionMessageRole {
606    fn default() -> Self {
607        Self::User
608    }
609}
610
611fn default_empty_vec<C>() -> Vec<C> {
612    Vec::new()
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use dotenvy::dotenv;
619    use std::time::Duration;
620    use tokio::time::sleep;
621
622    #[tokio::test]
623    async fn chat() {
624        dotenv().ok();
625        let credentials = Credentials::from_env();
626
627        let chat_completion = ChatCompletion::builder(
628            "gpt-3.5-turbo",
629            [ChatCompletionMessage {
630                role: ChatCompletionMessageRole::User,
631                content: Some("Hello!".to_string()),
632                name: None,
633                function_call: None,
634                tool_call_id: None,
635                tool_calls: Some(Vec::new()),
636            }],
637        )
638        .temperature(0.0)
639        .response_format(ChatCompletionResponseFormat::text())
640        .credentials(credentials)
641        .create()
642        .await
643        .unwrap();
644
645        assert_eq!(
646            chat_completion
647                .choices
648                .first()
649                .unwrap()
650                .message
651                .content
652                .as_ref()
653                .unwrap(),
654            "Hello! How can I assist you today?"
655        );
656    }
657
658    // Seeds are not deterministic so the only point of the test is to
659    // ensure that passing a seed still results in a valid response.
660    #[tokio::test]
661    async fn chat_seed() {
662        dotenv().ok();
663        let credentials = Credentials::from_env();
664
665        let chat_completion = ChatCompletion::builder(
666            "gpt-3.5-turbo",
667            [ChatCompletionMessage {
668                role: ChatCompletionMessageRole::User,
669                content: Some(
670                    "What type of seed does Mr. England sow in the song? Reply with 1 word."
671                        .to_string(),
672                ),
673                name: None,
674                function_call: None,
675                tool_call_id: None,
676                tool_calls: Some(Vec::new()),
677            }],
678        )
679        // Determinism currently comes from temperature 0, not seed.
680        .temperature(0.0)
681        .seed(1337u64)
682        .credentials(credentials)
683        .create()
684        .await
685        .unwrap();
686
687        assert_eq!(
688            chat_completion
689                .choices
690                .first()
691                .unwrap()
692                .message
693                .content
694                .as_ref()
695                .unwrap(),
696            "Love"
697        );
698    }
699
700    #[tokio::test]
701    async fn chat_stream() {
702        dotenv().ok();
703        let credentials = Credentials::from_env();
704
705        let chat_stream = ChatCompletion::builder(
706            "gpt-3.5-turbo",
707            [ChatCompletionMessage {
708                role: ChatCompletionMessageRole::User,
709                content: Some("Hello!".to_string()),
710                name: None,
711                function_call: None,
712                tool_call_id: None,
713                tool_calls: Some(Vec::new()),
714            }],
715        )
716        .temperature(0.0)
717        .credentials(credentials)
718        .create_stream()
719        .await
720        .unwrap();
721
722        let chat_completion = stream_to_completion(chat_stream).await;
723
724        assert_eq!(
725            chat_completion
726                .choices
727                .first()
728                .unwrap()
729                .message
730                .content
731                .as_ref()
732                .unwrap(),
733            "Hello! How can I assist you today?"
734        );
735    }
736
737    #[tokio::test]
738    async fn chat_function() {
739        dotenv().ok();
740        let credentials = Credentials::from_env();
741
742        let chat_stream = ChatCompletion::builder(
743            "gpt-4o",
744            [
745                ChatCompletionMessage {
746                    role: ChatCompletionMessageRole::User,
747                    content: Some("What is the weather in Boston?".to_string()),
748                    name: None,
749                    function_call: None,
750                    tool_call_id: None,
751                    tool_calls: Some(Vec::new()),
752                }
753            ]
754        ).functions([ChatCompletionFunctionDefinition {
755            description: Some("Get the current weather in a given location.".to_string()),
756            name: "get_current_weather".to_string(),
757            parameters: Some(serde_json::json!({
758                "type": "object",
759                "properties": {
760                    "location": {
761                        "type": "string",
762                        "description": "The city and state to get the weather for. (eg: San Francisco, CA)"
763                    }
764                },
765                "required": ["location"]
766            })),
767        }])
768        .temperature(0.2)
769        .credentials(credentials)
770        .create_stream()
771        .await
772        .unwrap();
773
774        let chat_completion = stream_to_completion(chat_stream).await;
775
776        assert_eq!(
777            chat_completion
778                .choices
779                .first()
780                .unwrap()
781                .message
782                .function_call
783                .as_ref()
784                .unwrap()
785                .name,
786            "get_current_weather".to_string(),
787        );
788
789        assert_eq!(
790            serde_json::from_str::<Value>(
791                &chat_completion
792                    .choices
793                    .first()
794                    .unwrap()
795                    .message
796                    .function_call
797                    .as_ref()
798                    .unwrap()
799                    .arguments
800            )
801            .unwrap(),
802            serde_json::json!({
803                "location": "Boston, MA"
804            }),
805        );
806    }
807
808    #[tokio::test]
809    async fn chat_response_format_json() {
810        dotenv().ok();
811        let credentials = Credentials::from_env();
812        let chat_completion = ChatCompletion::builder(
813            "gpt-3.5-turbo",
814            [ChatCompletionMessage {
815                role: ChatCompletionMessageRole::User,
816                content: Some("Write an example JSON for a JWT header using RS256".to_string()),
817                name: None,
818                function_call: None,
819                tool_call_id: None,
820                tool_calls: Some(Vec::new()),
821            }],
822        )
823        .temperature(0.0)
824        .seed(1337u64)
825        .response_format(ChatCompletionResponseFormat::json_object())
826        .credentials(credentials)
827        .create()
828        .await
829        .unwrap();
830        let response_string = chat_completion
831            .choices
832            .first()
833            .unwrap()
834            .message
835            .content
836            .as_ref()
837            .unwrap();
838        #[derive(Deserialize, Eq, PartialEq, Debug)]
839        struct Response {
840            alg: String,
841            typ: String,
842        }
843        let response = serde_json::from_str::<Response>(response_string).unwrap();
844        assert_eq!(
845            response,
846            Response {
847                alg: "RS256".to_owned(),
848                typ: "JWT".to_owned()
849            }
850        );
851    }
852
853    #[test]
854    fn builder_clone_and_eq() {
855        let builder_a = ChatCompletion::builder("gpt-4", [])
856            .temperature(0.0)
857            .seed(65u64);
858        let builder_b = builder_a.clone();
859        let builder_c = builder_b.clone().temperature(1.0);
860        let builder_d = ChatCompletionBuilder::default();
861        assert_eq!(builder_a, builder_b);
862        assert_ne!(builder_a, builder_c);
863        assert_ne!(builder_b, builder_c);
864        assert_ne!(builder_a, builder_d);
865        assert_ne!(builder_c, builder_d);
866    }
867
868    async fn stream_to_completion(
869        mut chat_stream: Receiver<ChatCompletionDelta>,
870    ) -> ChatCompletion {
871        let mut merged: Option<ChatCompletionDelta> = None;
872        while let Some(delta) = chat_stream.recv().await {
873            match merged.as_mut() {
874                Some(c) => {
875                    c.merge(delta).unwrap();
876                }
877                None => merged = Some(delta),
878            };
879        }
880        merged.unwrap().into()
881    }
882
883    #[tokio::test]
884    async fn chat_tool_response_completion() {
885        dotenv().ok();
886        let credentials = Credentials::from_env();
887
888        let chat_completion = ChatCompletion::builder(
889            "gpt-4o-mini",
890            [
891                ChatCompletionMessage {
892                    role: ChatCompletionMessageRole::User,
893                    content: Some(
894                        "What's 0.9102847*28456? \
895                        reply in plain text, \
896                        round the number to to 2 decimals \
897                        and reply with the result number only, \
898                        with no full stop at the end"
899                            .to_string(),
900                    ),
901                    name: None,
902                    function_call: None,
903                    tool_call_id: None,
904                    tool_calls: Some(Vec::new()),
905                },
906                ChatCompletionMessage {
907                    role: ChatCompletionMessageRole::Assistant,
908                    content: Some("Let me calculate that for you.".to_string()),
909                    name: None,
910                    function_call: None,
911                    tool_call_id: None,
912                    tool_calls: Some(vec![ToolCall {
913                        id: "the_tool_call".to_string(),
914                        r#type: "function".to_string(),
915                        function: ToolCallFunction {
916                            name: "mul".to_string(),
917                            arguments: "not_required_to_be_valid_here".to_string(),
918                        },
919                    }]),
920                },
921                ChatCompletionMessage {
922                    role: ChatCompletionMessageRole::Tool,
923                    content: Some("the result is 25903.061423199997".to_string()),
924                    name: None,
925                    function_call: None,
926                    tool_call_id: Some("the_tool_call".to_owned()),
927                    tool_calls: Some(Vec::new()),
928                },
929            ],
930        )
931        // Determinism currently comes from temperature 0, not seed.
932        .temperature(0.0)
933        .seed(1337u64)
934        .credentials(credentials)
935        .create()
936        .await
937        .unwrap();
938
939        assert_eq!(
940            chat_completion
941                .choices
942                .first()
943                .unwrap()
944                .message
945                .content
946                .as_ref()
947                .unwrap(),
948            "25903.06"
949        );
950    }
951
952    #[tokio::test]
953    async fn get_completion() {
954        dotenv().ok();
955        let credentials = Credentials::from_env();
956
957        let chat_completion = ChatCompletion::builder(
958            "gpt-3.5-turbo",
959            [ChatCompletionMessage {
960                role: ChatCompletionMessageRole::User,
961                content: Some("Hello!".to_string()),
962                ..Default::default()
963            }],
964        )
965        .credentials(credentials.clone())
966        .store(true)
967        .create()
968        .await
969        .unwrap();
970
971        // Unfortunatelly completions are not available immediately so we need to wait a bit
972        sleep(Duration::from_secs(7)).await;
973
974        let retrieved_completion = ChatCompletion::get(&chat_completion.id, credentials.clone())
975            .await
976            .unwrap();
977
978        assert_eq!(retrieved_completion, chat_completion);
979    }
980
981    #[tokio::test]
982    async fn get_completion_non_existent() {
983        dotenv().ok();
984        let credentials = Credentials::from_env();
985
986        match ChatCompletion::get("non_existent_id", credentials.clone()).await {
987            Ok(_) => panic!("Expected error"),
988            Err(e) => assert_eq!(e.code, Some("not_found".to_string())),
989        }
990    }
991
992    #[tokio::test]
993    async fn get_completion_messages() {
994        dotenv().ok();
995        let credentials = Credentials::from_env();
996
997        let user_message = ChatCompletionMessage {
998            role: ChatCompletionMessageRole::User,
999            content: Some("Tell me a short joke".to_string()),
1000            ..Default::default()
1001        };
1002
1003        let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", [user_message.clone()])
1004            .credentials(credentials.clone())
1005            .store(true)
1006            .create()
1007            .await
1008            .unwrap();
1009
1010        // Unfortunatelly completions are not available immediately so we need to wait a bit
1011        sleep(Duration::from_secs(7)).await;
1012
1013        let retrieved_messages = ChatCompletionMessages::builder(chat_completion.id)
1014            .credentials(credentials.clone())
1015            .fetch()
1016            .await
1017            .unwrap();
1018
1019        assert_eq!(retrieved_messages.data, vec![user_message]);
1020        assert_eq!(retrieved_messages.has_more, false);
1021    }
1022
1023    #[tokio::test]
1024    async fn get_completion_messages_with_pagination() {
1025        dotenv().ok();
1026        let credentials = Credentials::from_env();
1027
1028        let user_message = ChatCompletionMessage {
1029            role: ChatCompletionMessageRole::User,
1030            content: Some("Tell me a short joke".to_string()),
1031            ..Default::default()
1032        };
1033
1034        let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", [user_message.clone()])
1035            .credentials(credentials.clone())
1036            .store(true)
1037            .create()
1038            .await
1039            .unwrap();
1040
1041        dbg!(&chat_completion);
1042
1043        // Unfortunatelly completions are not available immediately so we need to wait a bit
1044        sleep(Duration::from_secs(7)).await;
1045
1046        // Fetch the first page
1047        let retrieved_messages1 = ChatCompletionMessages::builder(chat_completion.id.clone())
1048            .credentials(credentials.clone())
1049            .pagination(RequestPagination {
1050                limit: Some(1),
1051                ..Default::default()
1052            })
1053            .fetch()
1054            .await
1055            .unwrap();
1056
1057        assert_eq!(retrieved_messages1.data, vec![user_message]);
1058        assert_eq!(retrieved_messages1.has_more, false);
1059        assert!(retrieved_messages1.first_id.is_some());
1060        assert!(retrieved_messages1.last_id.is_some());
1061
1062        // Fetch the second page, which should be empty
1063        let retrieved_messages2 = ChatCompletionMessages::builder(chat_completion.id.clone())
1064            .credentials(credentials.clone())
1065            .pagination(RequestPagination {
1066                limit: Some(1),
1067                after: Some(retrieved_messages1.first_id.unwrap()),
1068                ..Default::default()
1069            })
1070            .fetch()
1071            .await
1072            .unwrap();
1073
1074        assert_eq!(retrieved_messages2.data, vec![]);
1075        assert_eq!(retrieved_messages2.has_more, false);
1076        assert!(retrieved_messages2.first_id.is_none());
1077        assert!(retrieved_messages2.last_id.is_none());
1078    }
1079}