aionic/openai/chat.rs
1use crate::openai::misc::Usage;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5/// Represents the response from a chat model API call to `OpenAI`.
6///
7/// Contains fields that provide information about the model used, the choices made by the model,
8/// the unique ID for the API call, and usage data regarding the number of tokens processed.
9#[derive(Deserialize, Debug, Clone)]
10pub struct Response {
11 /// Unique ID for the API call.
12 pub id: Option<String>,
13
14 /// Type of the API object. For a chat model, this should be 'chat.completion'.
15 pub object: Option<String>,
16
17 /// UNIX timestamp indicating when the chat model was created.
18 pub created: Option<u64>,
19
20 /// The model that was used for the chat session.
21 pub model: Option<String>,
22
23 /// Choices made by the chat model during the conversation.
24 pub choices: Option<Vec<Choice>>,
25
26 /// Information on the number of tokens processed in the request.
27 pub usage: Option<Usage>,
28}
29
30/// Represents a choice made by the model in a chat API call.
31#[derive(Deserialize, Debug, Clone)]
32pub struct Choice {
33 /// The message that corresponds to the choice made.
34 pub message: Message,
35
36 /// Reason for finishing the generation.
37 pub finish_reason: String,
38
39 /// Index of the choice in the list of choices.
40 pub index: u64,
41}
42
43/// Represents the response from a streaming chat model API call to `OpenAI`.
44#[derive(Serialize, Deserialize, Debug, Clone)]
45pub struct StreamedReponse {
46 /// Unique ID for the API call.
47 pub id: String,
48
49 /// Type of the API object. For a chat model, this should be 'chat.completion'.
50 pub object: String,
51
52 /// UNIX timestamp indicating when the chat model was created.
53 pub created: u64,
54
55 /// The model that was used for the chat session.
56 pub model: String,
57
58 /// Choices made by the chat model during the conversation.
59 pub choices: Vec<StreamedChoices>,
60}
61
62/// Represents a choice made by the model in a streaming chat API call.
63#[derive(Serialize, Deserialize, Debug, Clone)]
64pub struct StreamedChoices {
65 /// Index of the choice in the list of choices.
66 pub index: u64,
67
68 /// Information about the change made by the model.
69 pub delta: Delta,
70
71 /// Reason for finishing the generation.
72 pub finish_reason: Option<String>,
73}
74
75/// Represents a change made by the model in a streaming chat API call.
76#[derive(Serialize, Deserialize, Debug, Clone)]
77pub struct Delta {
78 /// Role of the author making the change.
79 pub role: Option<String>,
80
81 /// Content of the change made.
82 pub content: Option<String>,
83}
84
85/// Enumeration of roles for authors of messages in a chat API call.
86#[derive(Clone, Debug, Copy)]
87pub enum MessageRole {
88 User,
89 Assistant,
90 System,
91 Function,
92}
93
94impl ToString for MessageRole {
95 fn to_string(&self) -> String {
96 match self {
97 Self::User => "user".to_string(),
98 Self::Assistant => "assistant".to_string(),
99 Self::System => "system".to_string(),
100 Self::Function => "function".to_string(),
101 }
102 }
103}
104
105impl<T: Into<String>> From<T> for MessageRole {
106 fn from(s: T) -> Self {
107 match s.into().as_str() {
108 "assistant" => Self::Assistant,
109 "system" => Self::System,
110 "function" => Self::Function,
111 _ => Self::User,
112 }
113 }
114}
115
116/// Represents a single Message exchanged with the `OpenAI` API during a conversational model session.
117///
118/// `Message` struct is used to encapsulate the details of an individual message in the conversation. This includes the role of the author,
119/// the content of the message, the name of the author if the role is 'function', and information about any function that should be called.
120///
121/// Each message sent or received in a conversational model session with `OpenAI` API will be represented by an instance of this struct.
122#[derive(Serialize, Deserialize, Clone, Debug)]
123pub struct Message {
124 /// The role of the messages author. One of system, user, assistant, or function.
125 pub role: String,
126
127 /// The contents of the message. content is required for all messages, and may be null for
128 /// assistant messages with function calls.
129 pub content: String,
130
131 /// The name of the author of this message. name is required if role is function, and it should
132 /// be the name of the function whose response is in the content. May contain a-z, A-Z, 0-9,
133 /// and underscores, with a maximum length of 64 characters.
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub name: Option<String>,
136
137 /// The name and arguments of a function that should be called, as generated by the model.
138 #[serde(skip_serializing_if = "Option::is_none")]
139 pub function_call: Option<FunctionCall>,
140}
141
142impl Message {
143 /// Constructs a new `Message` instance.
144 ///
145 /// This function is responsible for creating a new message object that will be sent to or received from the `OpenAI` API.
146 ///
147 /// # Arguments
148 ///
149 /// * `role`: The role that corresponds to the author of the message. It should be either "user", "assistant", or "system".
150 /// * `content`: The text content of the message. This should be the input provided by the user or the generated response.
151 ///
152 /// # Examples
153 ///
154 /// ```
155 /// use aionic::openai::chat::{MessageRole, Message};
156 ///
157 /// let user_message = Message::new(&MessageRole::User, "Hello, assistant!");
158 /// ```
159 pub fn new<S: Into<String>>(role: &MessageRole, content: S) -> Self {
160 Self {
161 role: role.to_string(),
162 content: content.into(),
163 name: None,
164 function_call: None,
165 }
166 }
167}
168
169impl<T: Into<String>> From<T> for Message {
170 fn from(s: T) -> Self {
171 Self {
172 role: MessageRole::User.to_string(),
173 content: s.into(),
174 name: None,
175 function_call: None,
176 }
177 }
178}
179
180impl std::fmt::Display for Message {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 let role = match self.role.as_str() {
183 "assistant" => MessageRole::Assistant,
184 "system" => MessageRole::System,
185 "function" => MessageRole::Function,
186 _ => MessageRole::User,
187 };
188 write!(f, "{}: {}", role.to_string(), self.content)
189 }
190}
191
192#[derive(Serialize, Deserialize, Clone, Debug)]
193pub struct FunctionCall {
194 /// The name of the function to call.
195 pub name: String,
196
197 /// The arguments to call the function with, as generated by the model in JSON format.
198 ///
199 /// Note that the model does not always generate valid JSON, and may hallucinate parameters
200 /// not defined by your function schema. Validate the arguments in your code before calling
201 /// your function.
202 pub arguments: String,
203}
204
205#[derive(Serialize, Deserialize, Clone, Debug)]
206#[serde(untagged)]
207pub enum Stop {
208 String(String),
209 Array(Vec<String>),
210}
211
212/// This struct is used for chat completions with `OpenAI`'s models.
213/// It contains all the parameters that can be set for an API request.
214///
215/// All fields with an `Option` type can be omitted from the JSON payload,
216/// thanks to the `skip_serializing_if` attribute.
217///
218/// For more information check the official [openAI API documentation](https://platform.openai.com/docs/api-reference/completions/create)
219///
220/// # Example
221///
222/// ```rust
223/// use aionic::openai::Chat;
224/// use aionic::openai::OpenAIConfig;
225///
226/// let chat = Chat::default();
227/// ```
228#[derive(Serialize, Deserialize, Clone, Debug)]
229pub struct Chat {
230 /// ID of the model to use. You can use the List models API to see all of your available models
231 pub model: String,
232
233 /// A list of messages comprising the conversation so far
234 pub messages: Vec<Message>,
235
236 /// A list of functions the model may generate JSON inputs for.
237 #[serde(skip_serializing_if = "Option::is_none")]
238 pub functions: Option<Vec<Function>>,
239
240 /// Controls how the model responds to function calls. "none" means the model does not call a function,
241 /// and responds to the end-user. "auto" means the model can pick between an end-user or calling a function.
242 /// Specifying a particular function via {"name":\ "my_function"} forces the model to call that function.
243 /// "none" is the default when no functions are present. "auto" is the default if functions are present.
244 #[serde(skip_serializing_if = "Option::is_none")]
245 pub function_call: Option<String>,
246
247 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random,
248 /// while lower values like 0.2 will make it more focused and deterministic.
249 /// It's generally recommended to either alter this or `top_p` but not both.
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub temperature: Option<f64>,
252
253 /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results
254 /// of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability
255 /// mass are considered.
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub top_p: Option<f64>,
258
259 /// How many chat completion choices to generate for each input message.
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub n: Option<i64>,
262
263 /// If set, partial message deltas will be sent, like in ChatGPT.
264 #[serde(skip_serializing_if = "Option::is_none")]
265 pub stream: Option<bool>,
266
267 /// Up to 4 sequences where the API will stop generating further tokens.
268 #[serde(skip_serializing_if = "Option::is_none")]
269 pub stop: Option<Stop>,
270
271 /// The maximum number of tokens to generate in the chat completion.
272 /// The total length of input tokens and generated tokens is limited by the model's context length.
273 #[serde(skip_serializing_if = "Option::is_none")]
274 pub max_tokens: Option<u64>,
275
276 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text
277 /// so far, increasing the model's likelihood to talk about new topics.
278 #[serde(skip_serializing_if = "Option::is_none")]
279 pub presence_penalty: Option<f32>,
280
281 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in
282 /// the text so far, decreasing the model's likelihood to repeat the same line verbatim.
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub frequency_penalty: Option<f32>,
285
286 /// Modify the likelihood of specified tokens appearing in the completion.
287 ///
288 /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer)
289 /// to an associated bias value from -100 to 100. Mathematically, the bias is added to the
290 /// logits generated by the model prior to sampling. The exact effect will vary per model,
291 /// but values between -1 and 1 should decrease or increase likelihood of selection; values
292 /// like -100 or 100 should result in a ban or exclusive selection of the relevant token.
293 #[serde(skip_serializing_if = "Option::is_none")]
294 pub logit_bias: Option<HashMap<String, f32>>,
295
296 /// A unique identifier representing your end-user, which can help `OpenAI` to monitor and detect abuse.
297 #[serde(skip_serializing_if = "Option::is_none")]
298 pub user: Option<String>,
299}
300
301impl Chat {
302 const DEFAULT_TEMPERATURE: f64 = 1.0;
303 const DEFAULT_MAX_TOKENS: u64 = 2048;
304 const DEFAULT_STREAM_RESPONSE: bool = true;
305 const DEFAULT_MODEL: &str = "gpt-3.5-turbo";
306 /// Returns the default temperature for this AI system.
307 ///
308 /// # Returns
309 ///
310 /// This function returns a `f64` value which represents the default temperature.
311 pub fn get_default_temperature() -> f64 {
312 Self::DEFAULT_TEMPERATURE
313 }
314
315 /// Returns the default maximum token limit for this AI system.
316 ///
317 /// # Returns
318 ///
319 /// This function returns a `u64` value which represents the default maximum number of tokens that can be used in a single AI system action.
320 pub fn get_default_max_tokens() -> u64 {
321 Self::DEFAULT_MAX_TOKENS
322 }
323
324 /// Returns the default streaming behavior for this AI system.
325 ///
326 /// # Returns
327 ///
328 /// This function returns a `bool` value which represents the default behavior of the AI system when handling streaming data.
329 /// If it returns `true`, the system will stream the data by default. If `false`, it will not.
330 pub fn get_default_stream() -> bool {
331 Self::DEFAULT_STREAM_RESPONSE
332 }
333
334 /// Returns the default model to be used by this AI system.
335 ///
336 /// # Returns
337 ///
338 /// This function returns a static string slice (`&'static str`) which represents the identifier of the default model used by the AI system.
339 pub fn get_default_model() -> &'static str {
340 Self::DEFAULT_MODEL
341 }
342}
343
344/// This struct is used to describe a single function the model may generate JSON inputs for.
345/// It's part of the `Chat` structure.
346#[derive(Serialize, Deserialize, Clone, Debug)]
347pub struct Function {
348 /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
349 pub name: String,
350
351 /// A description of what the function does, used by the model to choose when and how to call the function.
352 pub description: Option<String>,
353
354 /// The parameters the functions accepts, described as a JSON Schema object. See the guide for examples, and the JSON Schema
355 /// reference for documentation about the format.
356 ///
357 /// To describe a function that accepts no parameters, provide the value {"type": "object", "properties": {}}.
358 // FIXME:
359 pub parameters: String,
360}