1use async_trait::async_trait;
2use crate::{JsonRequest, Usage};
3use std::collections::HashMap;
4use serde::{Serialize,Deserialize};
5
6#[derive(Clone, Debug,Serialize,Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9 System,
10 User,
11 Assistant,
12}
13
14#[derive(Clone, Serialize, Deserialize, Debug)]
15pub struct Message{
16 pub role:Role,
17 pub content:String
18}
19
20#[derive(Clone, Serialize, Deserialize, Debug)]
21#[serde(untagged)]
22pub enum StopSeq{
23 String(String),
24 Vec(Vec<String>)
25}
26
27#[derive(Clone, Serialize, Deserialize, Debug)]
42pub struct ChatRequest {
43 model:String,
44 messages:Vec<Message>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 temperature: Option<f64>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 top_p: Option<f64>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 n: Option<u16>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 stream: Option<bool>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 stop:Option<StopSeq>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 max_tokens: Option<u64>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 presence_penalty: Option<f64>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 frequency_penalty:Option<f64>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 logit_bias: Option<HashMap<String,f32>>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 user: Option<String>
65}
66
67
68#[async_trait(?Send)]
69impl JsonRequest<ChatSuccess> for ChatRequest {
70 const ENDPOINT: &'static str = "/chat/completions";
71}
72
73impl ChatRequest {
74
75 pub fn new(messages : Vec<Message>) -> Self {
76 Self {
77 model: "gpt-3.5-turbo".to_string(),
78 messages,
79 temperature: None,
80 top_p: None,
81 n: None,
82 stream: None,
83 stop: None,
84 max_tokens: None,
85 presence_penalty: None,
86 frequency_penalty: None,
87 logit_bias: None,
88 user: None,
89 }
90 }
91
92 pub fn with_model_and_messages(model: &str, messages : Vec<Message>) -> Self {
93 Self {
94 model: model.to_string(),
95 messages,
96 temperature: None,
97 top_p: None,
98 n: None,
99 stream: None,
100 stop: None,
101 max_tokens: None,
102 presence_penalty: None,
103 frequency_penalty: None,
104 logit_bias: None,
105 user: None,
106 }
107 }
108
109 pub fn add_message(mut self, message:Message) ->Self{
110 self.messages.push(message);
111 self
112 }
113
114 pub fn model(mut self, model: String) -> Self {
115 self.model = model;
116 self
117 }
118
119 pub fn temperature(mut self, temperature: f64) -> Self {
120 if self.top_p.is_some() {
121 self.top_p = None;
122 }
123 self.temperature = Some(temperature.clamp(0f64,2f64));
124 self
125 }
126
127 pub fn top_p(mut self, top_p: f64) -> Self {
128 if self.temperature.is_some() {
129 self.temperature = None;
130 }
131 self.top_p = Some(top_p.clamp(0f64,1f64));
132 self
133 }
134
135 pub fn n(mut self, n: u16) -> Self {
136 self.n = Some(n);
137 self
138 }
139
140 pub fn stream(mut self, stream: bool) -> Self {
141 self.stream = Some(stream);
142 self
143 }
144
145 pub fn stop(mut self, stop: StopSeq) -> Self {
146 self.stop = Some(stop);
147 self
148 }
149
150 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
151 self.max_tokens = Some(max_tokens);
152 self
153 }
154
155 pub fn presence_penalty(mut self, presence_penalty: f64) -> Self{
156 self.presence_penalty= Some(presence_penalty.clamp(-2f64,2f64));
157 self
158 }
159
160 pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self {
161 self.frequency_penalty = Some(frequency_penalty.clamp(-2f64,2f64));
162 self
163 }
164
165 pub fn logit_bias(mut self, logit_bias: HashMap<String, f32>) -> Self {
166 self.logit_bias = Some(logit_bias);
167 self
168 }
169
170 pub fn user(mut self, user: String) -> Self {
171 self.user = Some(user);
172 self
173 }
174
175}
176
177#[derive(Clone, Serialize, Deserialize, Debug)]
178pub struct ChatChoice {
179 pub index: u16,
180 pub message: Message,
181 pub finish_reason: String
182}
183
184
185#[derive(Clone, Serialize, Deserialize, Debug)]
186pub struct ChatSuccess {
187 pub id: String,
188 pub object: String,
189 pub created: u64,
190 pub choices: Vec<ChatChoice>,
191 pub usage:Usage
192}
193
194