llm_sdk/api/
chat_completion.rs

1use crate::{IntoRequest, ToSchema};
2use derive_builder::Builder;
3use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
4use serde::{Deserialize, Serialize};
5use strum::{Display, EnumIter, EnumMessage, EnumString, EnumVariantNames};
6
7#[derive(Debug, Clone, Serialize, Builder)]
8pub struct ChatCompletionRequest {
9    /// A list of messages comprising the conversation so far.
10    #[builder(setter(into))]
11    messages: Vec<ChatCompletionMessage>,
12    /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
13    #[builder(default)]
14    model: ChatCompleteModel,
15    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
16    #[builder(default, setter(strip_option))]
17    #[serde(skip_serializing_if = "Option::is_none")]
18    frequency_penalty: Option<f32>,
19
20    // Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
21    // #[builder(default, setter(strip_option))]
22    // #[serde(skip_serializing_if = "Option::is_none")]
23    // logit_bias: Option<f32>,
24    /// The maximum number of tokens to generate in the chat completion.
25    #[builder(default, setter(strip_option))]
26    #[serde(skip_serializing_if = "Option::is_none")]
27    max_tokens: Option<usize>,
28    /// How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
29    #[builder(default, setter(strip_option))]
30    #[serde(skip_serializing_if = "Option::is_none")]
31    n: Option<usize>,
32    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
33    #[builder(default, setter(strip_option))]
34    #[serde(skip_serializing_if = "Option::is_none")]
35    presence_penalty: Option<f32>,
36    /// An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
37    #[builder(default, setter(strip_option))]
38    #[serde(skip_serializing_if = "Option::is_none")]
39    response_format: Option<ChatResponseFormatObject>,
40    /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.
41    #[builder(default, setter(strip_option))]
42    #[serde(skip_serializing_if = "Option::is_none")]
43    seed: Option<usize>,
44    /// Up to 4 sequences where the API will stop generating further tokens.
45    // TODO: make this as an enum
46    #[builder(default, setter(strip_option))]
47    #[serde(skip_serializing_if = "Option::is_none")]
48    stop: Option<String>,
49    /// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
50    #[builder(default, setter(strip_option))]
51    #[serde(skip_serializing_if = "Option::is_none")]
52    stream: Option<bool>,
53    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.
54    #[builder(default, setter(strip_option))]
55    #[serde(skip_serializing_if = "Option::is_none")]
56    temperature: Option<f32>,
57    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.
58    #[builder(default, setter(strip_option))]
59    #[serde(skip_serializing_if = "Option::is_none")]
60    top_p: Option<f32>,
61    /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for.
62    #[builder(default, setter(into))]
63    #[serde(skip_serializing_if = "Vec::is_empty")]
64    tools: Vec<Tool>,
65    /// Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present.
66    #[builder(default, setter(strip_option))]
67    #[serde(skip_serializing_if = "Option::is_none")]
68    tool_choice: Option<ToolChoice>,
69    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
70    #[builder(default, setter(strip_option, into))]
71    #[serde(skip_serializing_if = "Option::is_none")]
72    user: Option<String>,
73}
74
75#[derive(
76    Debug, Clone, Default, PartialEq, Eq, Serialize, EnumString, Display, EnumVariantNames,
77)]
78#[serde(rename_all = "snake_case")]
79pub enum ToolChoice {
80    #[default]
81    None,
82    Auto,
83    // TODO: we need something like this: #[serde(tag = "type", content = "function")]
84    Function {
85        name: String,
86    },
87}
88
89#[derive(Debug, Clone, Serialize)]
90pub struct Tool {
91    /// The schema of the tool. Currently, only functions are supported.
92    r#type: ToolType,
93    /// The schema of the tool. Currently, only functions are supported.
94    function: FunctionInfo,
95}
96
97#[derive(Debug, Clone, Serialize)]
98pub struct FunctionInfo {
99    /// A description of what the function does, used by the model to choose when and how to call the function.
100    description: String,
101    /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
102    name: String,
103    /// The parameters the functions accepts, described as a JSON Schema object.
104    parameters: serde_json::Value,
105}
106
107#[derive(Debug, Clone, Serialize)]
108pub struct ChatResponseFormatObject {
109    r#type: ChatResponseFormat,
110}
111
112#[derive(
113    Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, EnumString, Display, EnumVariantNames,
114)]
115#[serde(rename_all = "snake_case")]
116pub enum ChatResponseFormat {
117    Text,
118    #[default]
119    Json,
120}
121
122#[derive(Debug, Clone, Serialize, Display, EnumVariantNames, EnumMessage)]
123#[serde(rename_all = "snake_case", tag = "role")]
124pub enum ChatCompletionMessage {
125    /// A message from a system.
126    System(SystemMessage),
127    /// A message from a human.
128    User(UserMessage),
129    /// A message from the assistant.
130    Assistant(AssistantMessage),
131    /// A message from a tool.
132    Tool(ToolMessage),
133}
134
135#[derive(
136    Debug,
137    Clone,
138    Copy,
139    Default,
140    PartialEq,
141    Eq,
142    Serialize,
143    Deserialize,
144    EnumString,
145    EnumIter,
146    Display,
147    EnumVariantNames,
148    EnumMessage,
149)]
150
151pub enum ChatCompleteModel {
152    /// The default model. Currently, this is the gpt-3.5-turbo-1106 model.
153    #[default]
154    #[serde(rename = "gpt-3.5-turbo-1106")]
155    #[strum(serialize = "gpt-3.5-turbo")]
156    Gpt3Turbo,
157    /// GPT-3.5 turbo model with instruct capability.
158    #[serde(rename = "gpt-3.5-turbo-instruct")]
159    #[strum(serialize = "gpt-3.5-turbo-instruct")]
160    Gpt3TurboInstruct,
161    /// The latest GPT4 model. Currently, this is the gpt-4-1106-preview model.
162    #[serde(rename = "gpt-4-1106-preview")]
163    #[strum(serialize = "gpt-4-turbo")]
164    Gpt4Turbo,
165    /// The latest GPT4 model with vision capability. Currently, this is the gpt-4-1106-vision-preview model.
166    #[serde(rename = "gpt-4-1106-vision-preview")]
167    #[strum(serialize = "gpt-4-turbo-vision")]
168    Gpt4TurboVision,
169}
170
171#[derive(Debug, Clone, Serialize)]
172pub struct SystemMessage {
173    /// The contents of the system message.
174    content: String,
175    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
176    #[serde(skip_serializing_if = "Option::is_none")]
177    name: Option<String>,
178}
179
180#[derive(Debug, Clone, Serialize)]
181pub struct UserMessage {
182    /// The contents of the user message.
183    content: String,
184    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
185    #[serde(skip_serializing_if = "Option::is_none")]
186    name: Option<String>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct AssistantMessage {
191    /// The contents of the system message.
192    #[serde(default)]
193    pub content: Option<String>,
194    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
195    #[serde(skip_serializing_if = "Option::is_none", default)]
196    pub name: Option<String>,
197    /// The tool calls generated by the model, such as function calls.
198    #[serde(skip_serializing_if = "Vec::is_empty", default)]
199    pub tool_calls: Vec<ToolCall>,
200}
201
202#[derive(Debug, Clone, Serialize)]
203pub struct ToolMessage {
204    /// The contents of the tool message.
205    content: String,
206    /// Tool call that this message is responding to.
207    tool_call_id: String,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ToolCall {
212    /// The ID of the tool call.
213    pub id: String,
214    /// The type of the tool. Currently, only function is supported.
215    pub r#type: ToolType,
216    /// The function that the model called.
217    pub function: FunctionCall,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct FunctionCall {
222    /// The name of the function to call.
223    pub name: String,
224    /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
225    pub arguments: String,
226}
227
228#[derive(
229    Debug,
230    Clone,
231    Copy,
232    PartialEq,
233    Eq,
234    Default,
235    Serialize,
236    Deserialize,
237    EnumString,
238    Display,
239    EnumVariantNames,
240)]
241#[serde(rename_all = "snake_case")]
242pub enum ToolType {
243    #[default]
244    Function,
245}
246
247#[derive(Debug, Clone, Deserialize)]
248pub struct ChatCompletionResponse {
249    /// A unique identifier for the chat completion.
250    pub id: String,
251    /// A list of chat completion choices. Can be more than one if n is greater than 1.
252    pub choices: Vec<ChatCompletionChoice>,
253    /// The Unix timestamp (in seconds) of when the chat completion was created.
254    pub created: usize,
255    /// The model used for the chat completion.
256    pub model: ChatCompleteModel,
257    /// This fingerprint represents the backend configuration that the model runs with. Can be used in conjunction with the seed request parameter to understand when backend changes have been made that might impact determinism.
258    pub system_fingerprint: String,
259    /// The object type, which is always chat.completion.
260    pub object: String,
261    /// Usage statistics for the completion request.
262    pub usage: ChatCompleteUsage,
263}
264
265#[derive(Debug, Clone, Deserialize)]
266pub struct ChatCompletionChoice {
267    /// The reason the model stopped generating tokens. This will be stop if the model hit a natural stop point or a provided stop sequence, length if the maximum number of tokens specified in the request was reached, content_filter if content was omitted due to a flag from our content filters, tool_calls if the model called a tool, or function_call (deprecated) if the model called a function.
268    pub finish_reason: FinishReason,
269    /// The index of the choice in the list of choices.
270    pub index: usize,
271    /// A chat completion message generated by the model.
272    pub message: AssistantMessage,
273}
274
275#[derive(Debug, Clone, Deserialize)]
276pub struct ChatCompleteUsage {
277    /// Number of tokens in the generated completion.
278    pub completion_tokens: usize,
279    /// Number of tokens in the prompt.
280    pub prompt_tokens: usize,
281    /// Total number of tokens used in the request (prompt + completion).
282    pub total_tokens: usize,
283}
284
285#[derive(
286    Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, EnumString, Display, EnumVariantNames,
287)]
288#[serde(rename_all = "snake_case")]
289pub enum FinishReason {
290    #[default]
291    Stop,
292    Length,
293    ContentFilter,
294    ToolCalls,
295}
296
297impl IntoRequest for ChatCompletionRequest {
298    fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
299        let url = format!("{}/chat/completions", base_url);
300        client.post(url).json(&self)
301    }
302}
303
304impl ChatCompletionRequest {
305    pub fn new(model: ChatCompleteModel, messages: impl Into<Vec<ChatCompletionMessage>>) -> Self {
306        ChatCompletionRequestBuilder::default()
307            .model(model)
308            .messages(messages)
309            .build()
310            .unwrap()
311    }
312
313    pub fn new_with_tools(
314        model: ChatCompleteModel,
315        messages: impl Into<Vec<ChatCompletionMessage>>,
316        tools: impl Into<Vec<Tool>>,
317    ) -> Self {
318        ChatCompletionRequestBuilder::default()
319            .model(model)
320            .messages(messages)
321            .tools(tools)
322            .build()
323            .unwrap()
324    }
325}
326
327impl ChatCompletionMessage {
328    pub fn new_system(content: impl Into<String>, name: &str) -> ChatCompletionMessage {
329        ChatCompletionMessage::System(SystemMessage {
330            content: content.into(),
331            name: Self::get_name(name),
332        })
333    }
334
335    pub fn new_user(content: impl Into<String>, name: &str) -> ChatCompletionMessage {
336        ChatCompletionMessage::User(UserMessage {
337            content: content.into(),
338            name: Self::get_name(name),
339        })
340    }
341
342    fn get_name(name: &str) -> Option<String> {
343        if name.is_empty() {
344            None
345        } else {
346            Some(name.into())
347        }
348    }
349}
350
351impl Tool {
352    pub fn new_function<T: ToSchema>(
353        name: impl Into<String>,
354        description: impl Into<String>,
355    ) -> Self {
356        let parameters = T::to_schema();
357        Self {
358            r#type: ToolType::Function,
359            function: FunctionInfo {
360                name: name.into(),
361                description: description.into(),
362                parameters,
363            },
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::{ToSchema, SDK};
372    use anyhow::Result;
373    use schemars::JsonSchema;
374
375    #[allow(dead_code)]
376    #[derive(Debug, Clone, Deserialize, JsonSchema)]
377    struct GetWeatherArgs {
378        /// The city to get the weather for.
379        pub city: String,
380        /// the unit
381        pub unit: TemperatureUnit,
382    }
383
384    #[allow(dead_code)]
385    #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, JsonSchema)]
386    enum TemperatureUnit {
387        /// Celsius
388        #[default]
389        Celsius,
390        /// Fahrenheit
391        Fahrenheit,
392    }
393
394    #[derive(Debug, Clone)]
395    struct GetWeatherResponse {
396        temperature: f32,
397        unit: TemperatureUnit,
398    }
399
400    #[allow(dead_code)]
401    #[derive(Debug, Deserialize, JsonSchema)]
402    struct ExplainMoodArgs {
403        /// The mood to explain.
404        pub name: String,
405    }
406
407    fn get_weather_forecast(args: GetWeatherArgs) -> GetWeatherResponse {
408        match args.unit {
409            TemperatureUnit::Celsius => GetWeatherResponse {
410                temperature: 22.2,
411                unit: TemperatureUnit::Celsius,
412            },
413            TemperatureUnit::Fahrenheit => GetWeatherResponse {
414                temperature: 72.0,
415                unit: TemperatureUnit::Fahrenheit,
416            },
417        }
418    }
419
420    #[test]
421    #[ignore]
422    fn chat_completion_request_tool_choice_function_serialize_should_work() {
423        let req = ChatCompletionRequestBuilder::default()
424            .tool_choice(ToolChoice::Function {
425                name: "my_function".to_string(),
426            })
427            .messages(vec![])
428            .build()
429            .unwrap();
430        let json = serde_json::to_value(req).unwrap();
431        assert_eq!(
432            json,
433            serde_json::json!({
434              "tool_choice": {
435                "type": "function",
436                "function": {
437                  "name": "my_function"
438                }
439              },
440              "messages": []
441            })
442        );
443    }
444
445    #[test]
446    fn chat_completion_request_serialize_should_work() {
447        let mut req = get_simple_completion_request();
448        req.tool_choice = Some(ToolChoice::Auto);
449        let json = serde_json::to_value(req).unwrap();
450        assert_eq!(
451            json,
452            serde_json::json!({
453              "tool_choice": "auto",
454              "model": "gpt-3.5-turbo-1106",
455              "messages": [{
456                "role": "system",
457                "content": "I can answer any question you ask me."
458              }, {
459                "role": "user",
460                "content": "What is human life expectancy in the world?",
461                "name": "user1"
462              }]
463            })
464        );
465    }
466
467    #[test]
468    fn chat_completion_request_with_tools_serialize_should_work() {
469        let req = get_tool_completion_request();
470        let json = serde_json::to_value(req).unwrap();
471        assert_eq!(
472            json,
473            serde_json::json!({
474              "model": "gpt-3.5-turbo-1106",
475              "messages": [{
476                "role": "system",
477                "content": "I can choose the right function for you."
478              }, {
479                "role": "user",
480                "content": "What is the weather like in Boston?",
481                "name": "user1"
482              }],
483              "tools": [
484                {
485                  "type": "function",
486                  "function": {
487                    "description": "Get the weather forecast for a city.",
488                    "name": "get_weather_forecast",
489                    "parameters": GetWeatherArgs::to_schema()
490                  }
491                },
492                {
493                  "type": "function",
494                  "function": {
495                    "description": "Explain the meaning of the given mood.",
496                    "name": "explain_mood",
497                    "parameters": ExplainMoodArgs::to_schema()
498                  }
499                }
500              ]
501            })
502        );
503    }
504
505    #[tokio::test]
506    async fn simple_chat_completion_should_work() -> Result<()> {
507        let req = get_simple_completion_request();
508        let res = SDK.chat_completion(req).await?;
509        assert_eq!(res.model, ChatCompleteModel::Gpt3Turbo);
510        assert_eq!(res.object, "chat.completion");
511        assert_eq!(res.choices.len(), 1);
512        let choice = &res.choices[0];
513        assert_eq!(choice.finish_reason, FinishReason::Stop);
514        assert_eq!(choice.index, 0);
515        assert_eq!(choice.message.tool_calls.len(), 0);
516        Ok(())
517    }
518
519    #[tokio::test]
520    async fn chat_completion_with_tools_should_work() -> Result<()> {
521        let req = get_tool_completion_request();
522        let res = SDK.chat_completion(req).await?;
523        assert_eq!(res.model, ChatCompleteModel::Gpt3Turbo);
524        assert_eq!(res.object, "chat.completion");
525        assert_eq!(res.choices.len(), 1);
526        let choice = &res.choices[0];
527        assert_eq!(choice.finish_reason, FinishReason::ToolCalls);
528        assert_eq!(choice.index, 0);
529        assert_eq!(choice.message.content, None);
530        assert_eq!(choice.message.tool_calls.len(), 1);
531        let tool_call = &choice.message.tool_calls[0];
532        assert_eq!(tool_call.function.name, "get_weather_forecast");
533        let ret = get_weather_forecast(serde_json::from_str(&tool_call.function.arguments)?);
534        assert_eq!(ret.unit, TemperatureUnit::Celsius);
535        assert_eq!(ret.temperature, 22.2);
536        Ok(())
537    }
538
539    fn get_simple_completion_request() -> ChatCompletionRequest {
540        let messages = vec![
541            ChatCompletionMessage::new_system("I can answer any question you ask me.", ""),
542            ChatCompletionMessage::new_user("What is human life expectancy in the world?", "user1"),
543        ];
544        ChatCompletionRequest::new(ChatCompleteModel::Gpt3Turbo, messages)
545    }
546
547    fn get_tool_completion_request() -> ChatCompletionRequest {
548        let messages = vec![
549            ChatCompletionMessage::new_system("I can choose the right function for you.", ""),
550            ChatCompletionMessage::new_user("What is the weather like in Boston?", "user1"),
551        ];
552        let tools = vec![
553            Tool::new_function::<GetWeatherArgs>(
554                "get_weather_forecast",
555                "Get the weather forecast for a city.",
556            ),
557            Tool::new_function::<ExplainMoodArgs>(
558                "explain_mood",
559                "Explain the meaning of the given mood.",
560            ),
561        ];
562        ChatCompletionRequest::new_with_tools(ChatCompleteModel::Gpt3Turbo, messages, tools)
563    }
564}