openai_req/chat/
mod.rs

1use async_trait::async_trait;
2use crate::{JsonRequest, Usage};
3use std::collections::HashMap;
4use serde::{Serialize,Deserialize};
5
6#[derive(Clone, Debug,Serialize,Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    System,
10    User,
11    Assistant,
12}
13
14#[derive(Clone, Serialize, Deserialize, Debug)]
15pub struct Message{
16    pub role:Role,
17    pub content:String
18}
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21#[serde(untagged)]
22pub enum StopSeq{
23    String(String),
24    Vec(Vec<String>)
25}
26
27/// request providing chat completion. Detailed parameter description can be found at
28/// https://platform.openai.com/docs/api-reference/chat
29/// # Usage example
30/// ```
31/// use openai_req::chat::{ChatRequest, Message, Role};
32/// use openai_req::JsonRequest;
33///
34/// let messages  = vec!(Message{
35///      role: Role::User,
36///      content: "hello!".to_string(),
37///    });
38///    let chat_request = ChatRequest::new(messages);
39///    let response = chat_request.run(&client).await?;
40/// ```
41#[derive(Clone, Serialize, Deserialize, Debug)]
42pub struct ChatRequest {
43    model:String,
44    messages:Vec<Message>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    temperature: Option<f64>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    top_p: Option<f64>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    n: Option<u16>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    stream: Option<bool>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    stop:Option<StopSeq>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    max_tokens: Option<u64>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    presence_penalty: Option<f64>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    frequency_penalty:Option<f64>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    logit_bias: Option<HashMap<String,f32>>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    user: Option<String>
65}
66
67
68#[async_trait(?Send)]
69impl JsonRequest<ChatSuccess> for ChatRequest {
70    const ENDPOINT: &'static str = "/chat/completions";
71}
72
73impl ChatRequest {
74
75    pub fn new(messages : Vec<Message>) -> Self {
76        Self {
77            model: "gpt-3.5-turbo".to_string(),
78            messages,
79            temperature: None,
80            top_p: None,
81            n: None,
82            stream: None,
83            stop: None,
84            max_tokens: None,
85            presence_penalty: None,
86            frequency_penalty: None,
87            logit_bias: None,
88            user: None,
89        }
90    }
91
92    pub fn with_model_and_messages(model: &str, messages : Vec<Message>) -> Self {
93        Self {
94            model: model.to_string(),
95            messages,
96            temperature: None,
97            top_p: None,
98            n: None,
99            stream: None,
100            stop: None,
101            max_tokens: None,
102            presence_penalty: None,
103            frequency_penalty: None,
104            logit_bias: None,
105            user: None,
106        }
107    }
108
109    pub fn add_message(mut self, message:Message) ->Self{
110        self.messages.push(message);
111        self
112    }
113
114    pub fn model(mut self, model: String) -> Self {
115        self.model = model;
116        self
117    }
118
119    pub fn temperature(mut self, temperature: f64) -> Self {
120        if self.top_p.is_some() {
121            self.top_p = None;
122        }
123        self.temperature = Some(temperature.clamp(0f64,2f64));
124        self
125    }
126
127    pub fn top_p(mut self, top_p: f64) -> Self {
128        if self.temperature.is_some() {
129            self.temperature = None;
130        }
131        self.top_p = Some(top_p.clamp(0f64,1f64));
132        self
133    }
134
135    pub fn n(mut self, n: u16) -> Self {
136        self.n = Some(n);
137        self
138    }
139
140    pub fn stream(mut self, stream: bool) -> Self {
141        self.stream = Some(stream);
142        self
143    }
144
145    pub fn stop(mut self, stop: StopSeq) -> Self {
146        self.stop = Some(stop);
147        self
148    }
149
150    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
151        self.max_tokens = Some(max_tokens);
152        self
153    }
154
155    pub fn presence_penalty(mut self, presence_penalty: f64) -> Self{
156        self.presence_penalty= Some(presence_penalty.clamp(-2f64,2f64));
157        self
158    }
159
160    pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self {
161        self.frequency_penalty = Some(frequency_penalty.clamp(-2f64,2f64));
162        self
163    }
164
165    pub fn logit_bias(mut self, logit_bias: HashMap<String, f32>) -> Self {
166        self.logit_bias = Some(logit_bias);
167        self
168    }
169
170    pub fn user(mut self, user: String) -> Self {
171        self.user = Some(user);
172        self
173    }
174
175}
176
177#[derive(Clone, Serialize, Deserialize, Debug)]
178pub struct ChatChoice {
179    pub index: u16,
180    pub message: Message,
181    pub finish_reason: String
182}
183
184
185#[derive(Clone, Serialize, Deserialize, Debug)]
186pub struct ChatSuccess {
187    pub id: String,
188    pub object: String,
189    pub created: u64,
190    pub choices: Vec<ChatChoice>,
191    pub usage:Usage
192}
193
194