rs_openai/interfaces/
chat.rs

1use crate::shared::response_wrapper::OpenAIError;
2use crate::shared::types::Stop;
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Serialize, Deserialize, Clone, Default, strum::Display)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10    #[strum(serialize = "system")]
11    System,
12    #[default]
13    #[strum(serialize = "user")]
14    User,
15    #[strum(serialize = "assistant")]
16    Assistant,
17}
18
19#[derive(Builder, Default, Debug, Clone, Deserialize, Serialize)]
20#[builder(name = "ChatCompletionMessageRequestBuilder")]
21#[builder(pattern = "mutable")]
22#[builder(setter(into, strip_option), default)]
23#[builder(derive(Debug))]
24#[builder(build_fn(error = "OpenAIError"))]
25pub struct ChatCompletionMessage {
26    /// The role of the author of this message. One of `system`, `user`, or `assistant`.
27    pub role: Role,
28
29    /// The contents of the message.
30    pub content: String,
31
32    /// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
33    pub name: Option<String>,
34}
35
36#[derive(Builder, Clone, Debug, Default, Serialize)]
37#[builder(name = "CreateChatRequestBuilder")]
38#[builder(pattern = "mutable")]
39#[builder(setter(into, strip_option), default)]
40#[builder(derive(Debug))]
41#[builder(build_fn(error = "OpenAIError"))]
42pub struct CreateChatRequest {
43    /// ID of the model to use.
44    /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
45    pub model: String,
46
47    /// A list of messages describing the conversation so far.
48    pub messages: Vec<ChatCompletionMessage>,
49
50    /// What sampling temperature to use, between 0 and 2.
51    /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
52    ///
53    /// We generally recommend altering this or `top_p` but not both.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub temperature: Option<f32>, // min: 0, max: 2, default: 1
56
57    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
58    /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
59    ///
60    /// We generally recommend altering this or `temperature` but not both.
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub top_p: Option<f32>, //  default: 1
63
64    /// How many chat completion choices to generate for each input message.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub n: Option<u8>, // default: 1
67
68    /// If set, partial message deltas will be sent, like in ChatGPT.
69    /// Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message.
70    /// See the OpenAI Cookbook for [example code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
71    ///
72    /// For streamed progress, use [`create_with_stream`](Chat::create_with_stream).
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub stream: Option<bool>, // default: false
75
76    /// Up to 4 sequences where the API will stop generating further tokens.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub stop: Option<Stop>, // default: null
79
80    /// The maximum number of tokens to generate in the chat completion.
81    ///
82    /// The total length of input tokens and generated tokens is limited by the model's context length.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub max_tokens: Option<u32>,
85
86    /// Number between -2.0 and 2.0.
87    /// Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
88    ///
89    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub presence_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
92
93    /// Number between -2.0 and 2.0.
94    /// Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
95    ///
96    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub frequency_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
99
100    /// Modify the likelihood of specified tokens appearing in the completion.
101    ///
102    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
103    /// Mathematically, the bias is added to the logits generated by the model prior to sampling.
104    /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection;
105    /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
108
109    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub user: Option<String>,
112}
113
114#[derive(Debug, Deserialize, Clone, Serialize)]
115pub struct Message {
116    pub role: String,
117    pub content: String,
118}
119
120#[derive(Debug, Deserialize, Clone, Serialize)]
121pub struct ChatUsage {
122    pub prompt_tokens: u32,
123    pub completion_tokens: u32,
124    pub total_tokens: u32,
125}
126
127#[derive(Debug, Deserialize, Clone, Serialize)]
128pub struct ChatChoice {
129    pub message: ChatCompletionMessage,
130    pub finish_reason: String,
131    pub index: u32,
132}
133
134#[derive(Debug, Deserialize, Clone, Serialize)]
135pub struct ChatResponse {
136    pub id: String,
137    pub object: String,
138    pub created: u32,
139    pub choices: Vec<ChatChoice>,
140    pub usage: ChatUsage,
141}
142
143#[derive(Debug, Deserialize, Clone, Serialize)]
144pub struct Delta {
145    pub content: Option<String>,
146}
147
148#[derive(Debug, Deserialize, Clone, Serialize)]
149pub struct ChatChoiceStream {
150    pub delta: Delta,
151    pub finish_reason: Option<String>,
152    pub index: u32,
153}
154
155#[derive(Debug, Deserialize, Clone, Serialize)]
156pub struct ChatStreamResponse {
157    pub id: String,
158    pub object: String,
159    pub model: String,
160    pub created: u32,
161    pub choices: Vec<ChatChoiceStream>,
162}