deepseek_api/
request.rs

1use crate::response::AssistantMessage;
2use anyhow::{anyhow, Ok, Result};
3use schemars::schema::SchemaObject;
4use serde::{Deserialize, Serialize};
5
6/// Represents a frequency penalty with a value between -2 and 2.
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct FrequencyPenalty(pub f32);
9
10impl FrequencyPenalty {
11    /// Creates a new `FrequencyPenalty` instance.
12    ///
13    /// # Arguments
14    ///
15    /// * `v` - A float value representing the frequency penalty.
16    ///
17    /// # Errors
18    ///
19    /// Returns an error if the value is not between -2 and 2.
20    pub fn new(v: f32) -> Result<Self> {
21        if !(-2.0..=2.0).contains(&v) {
22            return Err(anyhow!(
23                "Frequency penalty value must be between -2 and 2.".to_string()
24            ));
25        }
26        Ok(FrequencyPenalty(v))
27    }
28}
29
30impl Default for FrequencyPenalty {
31    /// Returns the default value for `FrequencyPenalty`, which is 0.0.
32    fn default() -> Self {
33        FrequencyPenalty(0.0)
34    }
35}
36
37/// Represents a presence penalty with a value between -2 and 2.
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub struct PresencePenalty(pub f32);
40
41impl PresencePenalty {
42    /// Creates a new `PresencePenalty` instance.
43    ///
44    /// # Arguments
45    ///
46    /// * `v` - A float value representing the presence penalty.
47    ///
48    /// # Errors
49    ///
50    /// Returns an error if the value is not between -2 and 2.
51    pub fn new(v: f32) -> Result<Self> {
52        if !(-2.0..=2.0).contains(&v) {
53            return Err(anyhow!(
54                "Presence penalty value must be between -2 and 2.".to_string()
55            ));
56        }
57        Ok(PresencePenalty(v))
58    }
59}
60
61impl Default for PresencePenalty {
62    /// Returns the default value for `PresencePenalty`, which is 0.0.
63    fn default() -> Self {
64        PresencePenalty(0.0)
65    }
66}
67
68/// Represents the type of response.
69#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
70pub enum ResponseType {
71    #[serde(rename = "json_object")]
72    Json,
73    #[serde(rename = "text")]
74    Text,
75}
76
77/// Represents the format of the response.
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub struct ResponseFormat {
80    #[serde(rename = "type")]
81    pub resp_type: ResponseType,
82}
83
84impl ResponseFormat {
85    /// Creates a new `ResponseFormat` instance.
86    ///
87    /// # Arguments
88    ///
89    /// * `rt` - The type of response.
90    pub fn new(rt: ResponseType) -> Self {
91        ResponseFormat { resp_type: rt }
92    }
93}
94
95/// Represents the maximum number of tokens with a value between 1 and 8192.
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub struct MaxToken(pub u32);
98
99impl MaxToken {
100    /// Creates a new `MaxToken` instance.
101    ///
102    /// # Arguments
103    ///
104    /// * `v` - An unsigned integer representing the maximum number of tokens.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the value is not between 1 and 8192.
109    pub fn new(v: u32) -> Result<Self> {
110        if !(1..=8192).contains(&v) {
111            return Err(anyhow!("Max token must be between 1 and 8192.".to_string()));
112        }
113        Ok(MaxToken(v))
114    }
115}
116
117impl Default for MaxToken {
118    /// Returns the default value for `MaxToken`, which is 4096.
119    fn default() -> Self {
120        MaxToken(4096)
121    }
122}
123
124/// Represents the stopping criteria for the completion.
125#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
126pub enum Stop {
127    Single(String),
128    Multiple(Vec<String>),
129}
130
131/// Represents the options for streaming responses.
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct StreamOptions {
134    pub include_usage: bool,
135}
136
137impl StreamOptions {
138    /// Creates a new `StreamOptions` instance.
139    ///
140    /// # Arguments
141    ///
142    /// * `include_usage` - A boolean indicating whether to include usage information.
143    pub fn new(include_usage: bool) -> Self {
144        StreamOptions { include_usage }
145    }
146}
147
148/// Represents the temperature with a value between 0 and 2.
149#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub struct Temperature(pub f32);
151
152impl Temperature {
153    /// Creates a new `Temperature` instance.
154    ///
155    /// # Arguments
156    ///
157    /// * `v` - An unsigned integer representing the temperature.
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if the value is not between 0 and 2.
162    pub fn new(v: f32) -> Result<Self> {
163        if !(0.0..=2.0).contains(&v) {
164            return Err(anyhow!("Temperature must be between 0 and 2.".to_string()));
165        }
166        Ok(Temperature(v))
167    }
168}
169
170impl Default for Temperature {
171    /// Returns the default value for `Temperature`, which is 1.
172    fn default() -> Self {
173        Temperature(1.0)
174    }
175}
176
177/// Represents the top-p value with a value between 0.0 and 1.0.
178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub struct TopP(pub f32);
180
181impl TopP {
182    /// Creates a new `TopP` instance.
183    ///
184    /// # Arguments
185    ///
186    /// * `v` - A float value representing the top-p value.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if the value is not between 0.0 and 1.0.
191    pub fn new(v: f32) -> Result<Self> {
192        if !(0.0..=1.0).contains(&v) {
193            return Err(anyhow!("TopP value must be between 0and 2.".to_string()));
194        }
195        Ok(TopP(v))
196    }
197}
198
199impl Default for TopP {
200    /// Returns the default value for `TopP`, which is 1.0.
201    fn default() -> Self {
202        TopP(1.0)
203    }
204}
205
206/// Represents the type of tool.
207#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208pub enum ToolType {
209    #[serde(rename = "function")]
210    Function,
211}
212
213/// Represents a function with a description, name, and parameters.
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215pub struct Function {
216    pub description: String,
217    pub name: String,
218    pub parameters: SchemaObject,
219}
220
221/// Represents a tool object with a type and function.
222#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
223pub struct ToolObject {
224    #[serde(rename = "type")]
225    pub tool_type: ToolType,
226    pub function: Function,
227}
228
229/// Represents the choice of chat completion tool.
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231pub enum ChatCompletionToolChoice {
232    #[serde(rename = "none")]
233    None,
234    #[serde(rename = "auto")]
235    Auto,
236    #[serde(rename = "required")]
237    Required,
238}
239
240/// Represents a function choice with a name.
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub struct FunctionChoice {
243    pub name: String,
244}
245
246/// Represents the choice of named chat completion tool.
247#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
248pub struct ChatCompletionNamedToolChoice {
249    #[serde(rename = "type")]
250    pub tool_type: ToolType,
251    pub function: FunctionChoice,
252}
253
254/// Represents the choice of tool.
255#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
256pub enum ToolChoice {
257    ChatCompletion(ChatCompletionToolChoice),
258    ChatCompletionNamed(ChatCompletionNamedToolChoice),
259}
260
261/// Represents the top log probabilities with a value between 0 and 20.
262#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
263pub struct TopLogprobs(pub u32);
264
265impl TopLogprobs {
266    /// Creates a new `TopLogprobs` instance.
267    ///
268    /// # Arguments
269    ///
270    /// * `v` - An unsigned integer representing the top log probabilities.
271    ///
272    /// # Errors
273    ///
274    /// Returns an error if the value is not between 0 and 20.
275    pub fn new(v: u32) -> Result<Self> {
276        if v > 20 {
277            return Err(anyhow!(
278                "Top log probs must be between 0 and 20.".to_string()
279            ));
280        }
281        Ok(TopLogprobs(v))
282    }
283}
284
285impl Default for TopLogprobs {
286    /// Returns the default value for `TopLogprobs`, which is 0.
287    fn default() -> Self {
288        TopLogprobs(0)
289    }
290}
291
292/// Represents a message request with different roles.
293#[derive(Debug, Clone, Serialize, Deserialize)]
294#[serde(tag = "role")]
295pub enum MessageRequest {
296    #[serde(rename = "system")]
297    System(SystemMessageRequest),
298    #[serde(rename = "user")]
299    User(UserMessageRequest),
300    #[serde(rename = "assistant")]
301    Assistant(AssistantMessage),
302    #[serde(rename = "tool")]
303    Tool(ToolMessageRequest),
304}
305
306impl MessageRequest {
307    /// Creates a new `MessageRequest` instance for a user message.
308    ///
309    /// # Arguments
310    ///
311    /// * `content` - The content of the user message.
312    /// * `name` - An optional name for the user message.
313    pub fn user(content: &str) -> Self {
314        MessageRequest::User(UserMessageRequest {
315            content: content.to_string(),
316            name: None,
317        })
318    }
319
320    /// Creates a new `MessageRequest` instance for a system message.
321    ///
322    /// # Arguments
323    ///
324    /// * `content` - The content of the system message.
325    pub fn sys(content: &str) -> Self {
326        MessageRequest::System(SystemMessageRequest {
327            content: content.to_string(),
328            name: None,
329        })
330    }
331    pub fn get_content(&self) -> &str {
332        match self {
333            MessageRequest::System(req) => req.content.as_str(),
334            MessageRequest::User(req) => req.content.as_str(),
335            MessageRequest::Assistant(req) => req.content.as_str(),
336            MessageRequest::Tool(req) => req.content.as_str(),
337        }
338    }
339}
340
341/// Represents a system message request.
342#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
343pub struct SystemMessageRequest {
344    pub content: String,
345    pub name: Option<String>,
346}
347
348impl SystemMessageRequest {
349    /// Creates a new `SystemMessageRequest` instance.
350    ///
351    /// # Arguments
352    ///
353    /// * `msg` - A string slice representing the message content.
354    pub fn new(msg: &str) -> Self {
355        SystemMessageRequest {
356            content: msg.to_string(),
357            name: None,
358        }
359    }
360
361    /// Creates a new `SystemMessageRequest` instance with a name.
362    ///
363    /// # Arguments
364    ///
365    /// * `name` - A string slice representing the name.
366    /// * `msg` - A string slice representing the message content.
367    pub fn new_with_name(name: &str, msg: &str) -> Self {
368        SystemMessageRequest {
369            content: msg.to_string(),
370            name: Some(name.to_string()),
371        }
372    }
373}
374
375/// Represents a user message request.
376#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
377pub struct UserMessageRequest {
378    pub content: String,
379    pub name: Option<String>,
380}
381
382impl UserMessageRequest {
383    /// Creates a new `UserMessageRequest` instance.
384    ///
385    /// # Arguments
386    ///
387    /// * `msg` - A string slice representing the message content.
388    pub fn new(msg: &str) -> Self {
389        UserMessageRequest {
390            content: msg.to_string(),
391            name: None,
392        }
393    }
394
395    /// Creates a new `UserMessageRequest` instance with a name.
396    ///
397    /// # Arguments
398    ///
399    /// * `name` - A string slice representing the name.
400    /// * `msg` - A string slice representing the message content.
401    pub fn new_with_name(name: &str, msg: &str) -> Self {
402        UserMessageRequest {
403            content: msg.to_string(),
404            name: Some(name.to_string()),
405        }
406    }
407}
408
409/// Represents a tool message request.
410#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
411pub struct ToolMessageRequest {
412    pub content: String,
413    pub tool_call_id: String,
414}
415
416impl ToolMessageRequest {
417    /// Creates a new `ToolMessageRequest` instance.
418    ///
419    /// # Arguments
420    ///
421    /// * `msg` - A string slice representing the message content.
422    /// * `tool_call_id` - A string slice representing the tool call ID.
423    pub fn new(msg: &str, tool_call_id: &str) -> Self {
424        ToolMessageRequest {
425            content: msg.to_string(),
426            tool_call_id: tool_call_id.to_string(),
427        }
428    }
429}