rs_openai/interfaces/chat.rs
1use crate::shared::response_wrapper::OpenAIError;
2use crate::shared::types::Stop;
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Serialize, Deserialize, Clone, Default, strum::Display)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10 #[strum(serialize = "system")]
11 System,
12 #[default]
13 #[strum(serialize = "user")]
14 User,
15 #[strum(serialize = "assistant")]
16 Assistant,
17}
18
19#[derive(Builder, Default, Debug, Clone, Deserialize, Serialize)]
20#[builder(name = "ChatCompletionMessageRequestBuilder")]
21#[builder(pattern = "mutable")]
22#[builder(setter(into, strip_option), default)]
23#[builder(derive(Debug))]
24#[builder(build_fn(error = "OpenAIError"))]
25pub struct ChatCompletionMessage {
26 /// The role of the author of this message. One of `system`, `user`, or `assistant`.
27 pub role: Role,
28
29 /// The contents of the message.
30 pub content: String,
31
32 /// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
33 pub name: Option<String>,
34}
35
36#[derive(Builder, Clone, Debug, Default, Serialize)]
37#[builder(name = "CreateChatRequestBuilder")]
38#[builder(pattern = "mutable")]
39#[builder(setter(into, strip_option), default)]
40#[builder(derive(Debug))]
41#[builder(build_fn(error = "OpenAIError"))]
42pub struct CreateChatRequest {
43 /// ID of the model to use.
44 /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
45 pub model: String,
46
47 /// A list of messages describing the conversation so far.
48 pub messages: Vec<ChatCompletionMessage>,
49
50 /// What sampling temperature to use, between 0 and 2.
51 /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
52 ///
53 /// We generally recommend altering this or `top_p` but not both.
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub temperature: Option<f32>, // min: 0, max: 2, default: 1
56
57 /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
58 /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
59 ///
60 /// We generally recommend altering this or `temperature` but not both.
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub top_p: Option<f32>, // default: 1
63
64 /// How many chat completion choices to generate for each input message.
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub n: Option<u8>, // default: 1
67
68 /// If set, partial message deltas will be sent, like in ChatGPT.
69 /// Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message.
70 /// See the OpenAI Cookbook for [example code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
71 ///
72 /// For streamed progress, use [`create_with_stream`](Chat::create_with_stream).
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub stream: Option<bool>, // default: false
75
76 /// Up to 4 sequences where the API will stop generating further tokens.
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub stop: Option<Stop>, // default: null
79
80 /// The maximum number of tokens to generate in the chat completion.
81 ///
82 /// The total length of input tokens and generated tokens is limited by the model's context length.
83 #[serde(skip_serializing_if = "Option::is_none")]
84 pub max_tokens: Option<u32>,
85
86 /// Number between -2.0 and 2.0.
87 /// Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
88 ///
89 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub presence_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
92
93 /// Number between -2.0 and 2.0.
94 /// Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
95 ///
96 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub frequency_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
99
100 /// Modify the likelihood of specified tokens appearing in the completion.
101 ///
102 /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
103 /// Mathematically, the bias is added to the logits generated by the model prior to sampling.
104 /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection;
105 /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
108
109 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub user: Option<String>,
112}
113
114#[derive(Debug, Deserialize, Clone, Serialize)]
115pub struct Message {
116 pub role: String,
117 pub content: String,
118}
119
120#[derive(Debug, Deserialize, Clone, Serialize)]
121pub struct ChatUsage {
122 pub prompt_tokens: u32,
123 pub completion_tokens: u32,
124 pub total_tokens: u32,
125}
126
127#[derive(Debug, Deserialize, Clone, Serialize)]
128pub struct ChatChoice {
129 pub message: ChatCompletionMessage,
130 pub finish_reason: String,
131 pub index: u32,
132}
133
134#[derive(Debug, Deserialize, Clone, Serialize)]
135pub struct ChatResponse {
136 pub id: String,
137 pub object: String,
138 pub created: u32,
139 pub choices: Vec<ChatChoice>,
140 pub usage: ChatUsage,
141}
142
143#[derive(Debug, Deserialize, Clone, Serialize)]
144pub struct Delta {
145 pub content: Option<String>,
146}
147
148#[derive(Debug, Deserialize, Clone, Serialize)]
149pub struct ChatChoiceStream {
150 pub delta: Delta,
151 pub finish_reason: Option<String>,
152 pub index: u32,
153}
154
155#[derive(Debug, Deserialize, Clone, Serialize)]
156pub struct ChatStreamResponse {
157 pub id: String,
158 pub object: String,
159 pub model: String,
160 pub created: u32,
161 pub choices: Vec<ChatChoiceStream>,
162}