1use crate::response::AssistantMessage;
2use anyhow::{anyhow, Ok, Result};
3use schemars::schema::SchemaObject;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct FrequencyPenalty(pub f32);
9
10impl FrequencyPenalty {
11 pub fn new(v: f32) -> Result<Self> {
21 if !(-2.0..=2.0).contains(&v) {
22 return Err(anyhow!(
23 "Frequency penalty value must be between -2 and 2.".to_string()
24 ));
25 }
26 Ok(FrequencyPenalty(v))
27 }
28}
29
30impl Default for FrequencyPenalty {
31 fn default() -> Self {
33 FrequencyPenalty(0.0)
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub struct PresencePenalty(pub f32);
40
41impl PresencePenalty {
42 pub fn new(v: f32) -> Result<Self> {
52 if !(-2.0..=2.0).contains(&v) {
53 return Err(anyhow!(
54 "Presence penalty value must be between -2 and 2.".to_string()
55 ));
56 }
57 Ok(PresencePenalty(v))
58 }
59}
60
61impl Default for PresencePenalty {
62 fn default() -> Self {
64 PresencePenalty(0.0)
65 }
66}
67
68#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
70pub enum ResponseType {
71 #[serde(rename = "json_object")]
72 Json,
73 #[serde(rename = "text")]
74 Text,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub struct ResponseFormat {
80 #[serde(rename = "type")]
81 pub resp_type: ResponseType,
82}
83
84impl ResponseFormat {
85 pub fn new(rt: ResponseType) -> Self {
91 ResponseFormat { resp_type: rt }
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub struct MaxToken(pub u32);
98
99impl MaxToken {
100 pub fn new(v: u32) -> Result<Self> {
110 if !(1..=8192).contains(&v) {
111 return Err(anyhow!("Max token must be between 1 and 8192.".to_string()));
112 }
113 Ok(MaxToken(v))
114 }
115}
116
117impl Default for MaxToken {
118 fn default() -> Self {
120 MaxToken(4096)
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
126pub enum Stop {
127 Single(String),
128 Multiple(Vec<String>),
129}
130
131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct StreamOptions {
134 pub include_usage: bool,
135}
136
137impl StreamOptions {
138 pub fn new(include_usage: bool) -> Self {
144 StreamOptions { include_usage }
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub struct Temperature(pub f32);
151
152impl Temperature {
153 pub fn new(v: f32) -> Result<Self> {
163 if !(0.0..=2.0).contains(&v) {
164 return Err(anyhow!("Temperature must be between 0 and 2.".to_string()));
165 }
166 Ok(Temperature(v))
167 }
168}
169
170impl Default for Temperature {
171 fn default() -> Self {
173 Temperature(1.0)
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub struct TopP(pub f32);
180
181impl TopP {
182 pub fn new(v: f32) -> Result<Self> {
192 if !(0.0..=1.0).contains(&v) {
193 return Err(anyhow!("TopP value must be between 0and 2.".to_string()));
194 }
195 Ok(TopP(v))
196 }
197}
198
199impl Default for TopP {
200 fn default() -> Self {
202 TopP(1.0)
203 }
204}
205
206#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208pub enum ToolType {
209 #[serde(rename = "function")]
210 Function,
211}
212
213#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215pub struct Function {
216 pub description: String,
217 pub name: String,
218 pub parameters: SchemaObject,
219}
220
221#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
223pub struct ToolObject {
224 #[serde(rename = "type")]
225 pub tool_type: ToolType,
226 pub function: Function,
227}
228
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231pub enum ChatCompletionToolChoice {
232 #[serde(rename = "none")]
233 None,
234 #[serde(rename = "auto")]
235 Auto,
236 #[serde(rename = "required")]
237 Required,
238}
239
240#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub struct FunctionChoice {
243 pub name: String,
244}
245
246#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
248pub struct ChatCompletionNamedToolChoice {
249 #[serde(rename = "type")]
250 pub tool_type: ToolType,
251 pub function: FunctionChoice,
252}
253
254#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
256pub enum ToolChoice {
257 ChatCompletion(ChatCompletionToolChoice),
258 ChatCompletionNamed(ChatCompletionNamedToolChoice),
259}
260
261#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
263pub struct TopLogprobs(pub u32);
264
265impl TopLogprobs {
266 pub fn new(v: u32) -> Result<Self> {
276 if v > 20 {
277 return Err(anyhow!(
278 "Top log probs must be between 0 and 20.".to_string()
279 ));
280 }
281 Ok(TopLogprobs(v))
282 }
283}
284
285impl Default for TopLogprobs {
286 fn default() -> Self {
288 TopLogprobs(0)
289 }
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294#[serde(tag = "role")]
295pub enum MessageRequest {
296 #[serde(rename = "system")]
297 System(SystemMessageRequest),
298 #[serde(rename = "user")]
299 User(UserMessageRequest),
300 #[serde(rename = "assistant")]
301 Assistant(AssistantMessage),
302 #[serde(rename = "tool")]
303 Tool(ToolMessageRequest),
304}
305
306impl MessageRequest {
307 pub fn user(content: &str) -> Self {
314 MessageRequest::User(UserMessageRequest {
315 content: content.to_string(),
316 name: None,
317 })
318 }
319
320 pub fn sys(content: &str) -> Self {
326 MessageRequest::System(SystemMessageRequest {
327 content: content.to_string(),
328 name: None,
329 })
330 }
331 pub fn get_content(&self) -> &str {
332 match self {
333 MessageRequest::System(req) => req.content.as_str(),
334 MessageRequest::User(req) => req.content.as_str(),
335 MessageRequest::Assistant(req) => req.content.as_str(),
336 MessageRequest::Tool(req) => req.content.as_str(),
337 }
338 }
339}
340
341#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
343pub struct SystemMessageRequest {
344 pub content: String,
345 pub name: Option<String>,
346}
347
348impl SystemMessageRequest {
349 pub fn new(msg: &str) -> Self {
355 SystemMessageRequest {
356 content: msg.to_string(),
357 name: None,
358 }
359 }
360
361 pub fn new_with_name(name: &str, msg: &str) -> Self {
368 SystemMessageRequest {
369 content: msg.to_string(),
370 name: Some(name.to_string()),
371 }
372 }
373}
374
375#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
377pub struct UserMessageRequest {
378 pub content: String,
379 pub name: Option<String>,
380}
381
382impl UserMessageRequest {
383 pub fn new(msg: &str) -> Self {
389 UserMessageRequest {
390 content: msg.to_string(),
391 name: None,
392 }
393 }
394
395 pub fn new_with_name(name: &str, msg: &str) -> Self {
402 UserMessageRequest {
403 content: msg.to_string(),
404 name: Some(name.to_string()),
405 }
406 }
407}
408
409#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
411pub struct ToolMessageRequest {
412 pub content: String,
413 pub tool_call_id: String,
414}
415
416impl ToolMessageRequest {
417 pub fn new(msg: &str, tool_call_id: &str) -> Self {
424 ToolMessageRequest {
425 content: msg.to_string(),
426 tool_call_id: tool_call_id.to_string(),
427 }
428 }
429}