deepseek_api/
request.rs

1use crate::response::{
2    ChatCompletion, ChatCompletionStream, JSONChoiceStream, Message, ModelType, TextChoiceStream,
3};
4use anyhow::{anyhow, Ok, Result};
5use schemars::schema::SchemaObject;
6use serde::{de::DeserializeOwned, ser::SerializeStruct, Deserialize, Serialize, Serializer};
7
8/// Represents a frequency penalty with a value between -2 and 2.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct FrequencyPenalty(pub f32);
11
12impl FrequencyPenalty {
13    /// Creates a new `FrequencyPenalty` instance.
14    ///
15    /// # Arguments
16    ///
17    /// * `v` - A float value representing the frequency penalty.
18    ///
19    /// # Errors
20    ///
21    /// Returns an error if the value is not between -2 and 2.
22    pub fn new(v: f32) -> Result<Self> {
23        if !(-2.0..=2.0).contains(&v) {
24            return Err(anyhow!(
25                "Frequency penalty value must be between -2 and 2.".to_string()
26            ));
27        }
28        Ok(FrequencyPenalty(v))
29    }
30}
31
32impl Default for FrequencyPenalty {
33    /// Returns the default value for `FrequencyPenalty`, which is 0.0.
34    fn default() -> Self {
35        FrequencyPenalty(0.0)
36    }
37}
38
39/// Represents a presence penalty with a value between -2 and 2.
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct PresencePenalty(pub f32);
42
43impl PresencePenalty {
44    /// Creates a new `PresencePenalty` instance.
45    ///
46    /// # Arguments
47    ///
48    /// * `v` - A float value representing the presence penalty.
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if the value is not between -2 and 2.
53    pub fn new(v: f32) -> Result<Self> {
54        if !(-2.0..=2.0).contains(&v) {
55            return Err(anyhow!(
56                "Presence penalty value must be between -2 and 2.".to_string()
57            ));
58        }
59        Ok(PresencePenalty(v))
60    }
61}
62
63impl Default for PresencePenalty {
64    /// Returns the default value for `PresencePenalty`, which is 0.0.
65    fn default() -> Self {
66        PresencePenalty(0.0)
67    }
68}
69
70/// Represents the type of response.
71#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
72pub enum ResponseType {
73    #[serde(rename = "json_object")]
74    Json,
75    #[serde(rename = "text")]
76    Text,
77}
78
79/// Represents the format of the response.
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
81pub struct ResponseFormat {
82    #[serde(rename = "type")]
83    pub resp_type: ResponseType,
84}
85
86impl ResponseFormat {
87    /// Creates a new `ResponseFormat` instance.
88    ///
89    /// # Arguments
90    ///
91    /// * `rt` - The type of response.
92    pub fn new(rt: ResponseType) -> Self {
93        ResponseFormat { resp_type: rt }
94    }
95}
96
97/// Represents the maximum number of tokens with a value between 1 and 8192.
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct MaxToken(pub u32);
100
101impl MaxToken {
102    /// Creates a new `MaxToken` instance.
103    ///
104    /// # Arguments
105    ///
106    /// * `v` - An unsigned integer representing the maximum number of tokens.
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if the value is not between 1 and 8192.
111    pub fn new(v: u32) -> Result<Self> {
112        if !(1..=8192).contains(&v) {
113            return Err(anyhow!("Max token must be between 1 and 8192.".to_string()));
114        }
115        Ok(MaxToken(v))
116    }
117}
118
119impl Default for MaxToken {
120    /// Returns the default value for `MaxToken`, which is 4096.
121    fn default() -> Self {
122        MaxToken(4096)
123    }
124}
125
126/// Represents the stopping criteria for the completion.
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128pub enum Stop {
129    Single(String),
130    Multiple(Vec<String>),
131}
132
133/// Represents the options for streaming responses.
134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
135pub struct StreamOptions {
136    pub include_usage: bool,
137}
138
139impl StreamOptions {
140    /// Creates a new `StreamOptions` instance.
141    ///
142    /// # Arguments
143    ///
144    /// * `include_usage` - A boolean indicating whether to include usage information.
145    pub fn new(include_usage: bool) -> Self {
146        StreamOptions { include_usage }
147    }
148}
149
150/// Represents the temperature with a value between 0 and 2.
151#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
152pub struct Temperature(pub u32);
153
154impl Temperature {
155    /// Creates a new `Temperature` instance.
156    ///
157    /// # Arguments
158    ///
159    /// * `v` - An unsigned integer representing the temperature.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if the value is not between 0 and 2.
164    pub fn new(v: u32) -> Result<Self> {
165        if v > 2 {
166            return Err(anyhow!("Temperature must be between 0 and 2.".to_string()));
167        }
168        Ok(Temperature(v))
169    }
170}
171
172impl Default for Temperature {
173    /// Returns the default value for `Temperature`, which is 1.
174    fn default() -> Self {
175        Temperature(1)
176    }
177}
178
179/// Represents the top-p value with a value between 0.0 and 1.0.
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
181pub struct TopP(pub f32);
182
183impl TopP {
184    /// Creates a new `TopP` instance.
185    ///
186    /// # Arguments
187    ///
188    /// * `v` - A float value representing the top-p value.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if the value is not between 0.0 and 1.0.
193    pub fn new(v: f32) -> Result<Self> {
194        if !(0.0..=1.0).contains(&v) {
195            return Err(anyhow!("TopP value must be between 0and 2.".to_string()));
196        }
197        Ok(TopP(v))
198    }
199}
200
201impl Default for TopP {
202    /// Returns the default value for `TopP`, which is 1.0.
203    fn default() -> Self {
204        TopP(1.0)
205    }
206}
207
208/// Represents the type of tool.
209#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
210pub enum ToolType {
211    #[serde(rename = "function")]
212    Function,
213}
214
215/// Represents a function with a description, name, and parameters.
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct Function {
218    pub description: String,
219    pub name: String,
220    pub parameters: SchemaObject,
221}
222
223/// Represents a tool object with a type and function.
224#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
225pub struct ToolObject {
226    #[serde(rename = "type")]
227    pub tool_type: ToolType,
228    pub function: Function,
229}
230
231/// Represents the choice of chat completion tool.
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub enum ChatCompletionToolChoice {
234    #[serde(rename = "none")]
235    None,
236    #[serde(rename = "auto")]
237    Auto,
238    #[serde(rename = "required")]
239    Required,
240}
241
242/// Represents a function choice with a name.
243#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct FunctionChoice {
245    pub name: String,
246}
247
248/// Represents the choice of named chat completion tool.
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct ChatCompletionNamedToolChoice {
251    #[serde(rename = "type")]
252    pub tool_type: ToolType,
253    pub function: FunctionChoice,
254}
255
256/// Represents the choice of tool.
257#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
258pub enum ToolChoice {
259    ChatCompletion(ChatCompletionToolChoice),
260    ChatCompletionNamed(ChatCompletionNamedToolChoice),
261}
262
263/// Represents the top log probabilities with a value between 0 and 20.
264#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
265pub struct TopLogprobs(pub u32);
266
267impl TopLogprobs {
268    /// Creates a new `TopLogprobs` instance.
269    ///
270    /// # Arguments
271    ///
272    /// * `v` - An unsigned integer representing the top log probabilities.
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if the value is not between 0 and 20.
277    pub fn new(v: u32) -> Result<Self> {
278        if v > 20 {
279            return Err(anyhow!(
280                "Top log probs must be between 0 and 20.".to_string()
281            ));
282        }
283        Ok(TopLogprobs(v))
284    }
285}
286
287impl Default for TopLogprobs {
288    /// Returns the default value for `TopLogprobs`, which is 0.
289    fn default() -> Self {
290        TopLogprobs(0)
291    }
292}
293
294/// Represents a message request with different roles.
295#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
296#[serde(tag = "role")]
297pub enum MessageRequest {
298    #[serde(rename = "system")]
299    System(SystemMessageRequest),
300    #[serde(rename = "user")]
301    User(UserMessageRequest),
302    #[serde(rename = "assistant")]
303    Assistant(AssistantMessageRequest),
304    #[serde(rename = "tool")]
305    Tool(ToolMessageRequest),
306}
307
308impl MessageRequest {
309    /// Creates a `MessageRequest` instance from a `Message`.
310    ///
311    /// # Arguments
312    ///
313    /// * `resp_message` - A reference to a `Message`.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the message role is invalid.
318    pub fn from_message(resp_message: &Message) -> Result<Self> {
319        match resp_message.role.as_str() {
320            "system" => Ok(MessageRequest::System(SystemMessageRequest {
321                content: resp_message.content.clone(),
322                name: None,
323            })),
324            "user" => Ok(MessageRequest::User(UserMessageRequest {
325                content: resp_message.content.clone(),
326                name: None,
327            })),
328            "assistant" => {
329                let request = match resp_message.reasoning_content.clone() {
330                    Some(reasoning_content) => {
331                        AssistantMessageRequest::new(resp_message.content.as_str())
332                            .set_reasoning_content(reasoning_content.as_str())
333                    }
334                    None => AssistantMessageRequest::new(resp_message.content.as_str()),
335                };
336                Ok(MessageRequest::Assistant(request))
337            }
338            "tool" => Ok(MessageRequest::Tool(ToolMessageRequest {
339                content: resp_message.content.clone(),
340                tool_call_id: "".to_string(), //todo how to get tool_call_id ?
341            })),
342            _ => Err(anyhow!("Invalid message role.".to_string())),
343        }
344    }
345
346    pub fn get_content(&self) -> String {
347        match self {
348            MessageRequest::System(req) => req.content.clone(),
349            MessageRequest::User(req) => req.content.clone(),
350            MessageRequest::Assistant(req) => req.content.clone(),
351            MessageRequest::Tool(req) => req.content.clone(),
352        }
353    }
354}
355
356/// Represents a system message request.
357#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
358pub struct SystemMessageRequest {
359    pub content: String,
360    pub name: Option<String>,
361}
362
363impl SystemMessageRequest {
364    /// Creates a new `SystemMessageRequest` instance.
365    ///
366    /// # Arguments
367    ///
368    /// * `msg` - A string slice representing the message content.
369    pub fn new(msg: &str) -> Self {
370        SystemMessageRequest {
371            content: msg.to_string(),
372            name: None,
373        }
374    }
375
376    /// Creates a new `SystemMessageRequest` instance with a name.
377    ///
378    /// # Arguments
379    ///
380    /// * `name` - A string slice representing the name.
381    /// * `msg` - A string slice representing the message content.
382    pub fn new_with_name(name: &str, msg: &str) -> Self {
383        SystemMessageRequest {
384            content: msg.to_string(),
385            name: Some(name.to_string()),
386        }
387    }
388}
389
390/// Represents a user message request.
391#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
392pub struct UserMessageRequest {
393    pub content: String,
394    pub name: Option<String>,
395}
396
397impl UserMessageRequest {
398    /// Creates a new `UserMessageRequest` instance.
399    ///
400    /// # Arguments
401    ///
402    /// * `msg` - A string slice representing the message content.
403    pub fn new(msg: &str) -> Self {
404        UserMessageRequest {
405            content: msg.to_string(),
406            name: None,
407        }
408    }
409
410    /// Creates a new `UserMessageRequest` instance with a name.
411    ///
412    /// # Arguments
413    ///
414    /// * `name` - A string slice representing the name.
415    /// * `msg` - A string slice representing the message content.
416    pub fn new_with_name(name: &str, msg: &str) -> Self {
417        UserMessageRequest {
418            content: msg.to_string(),
419            name: Some(name.to_string()),
420        }
421    }
422}
423
424/// Represents an assistant message request.
425#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
426pub struct AssistantMessageRequest {
427    pub content: String,
428    pub name: Option<String>,
429    pub prefix: bool,
430    pub reasoning_content: Option<String>,
431}
432
433impl AssistantMessageRequest {
434    /// Creates a new `AssistantMessageRequest` instance.
435    ///
436    /// # Arguments
437    ///
438    /// * `msg` - A string slice representing the message content.
439    pub fn new(msg: &str) -> Self {
440        AssistantMessageRequest {
441            content: msg.to_string(),
442            name: None,
443            prefix: false,
444            reasoning_content: None,
445        }
446    }
447
448    /// Creates a new `AssistantMessageRequest` instance with a name.
449    ///
450    /// # Arguments
451    ///
452    /// * `name` - A string slice representing the name.
453    /// * `msg` - A string slice representing the message content.
454    pub fn new_with_name(name: &str, msg: &str) -> Self {
455        AssistantMessageRequest {
456            content: msg.to_string(),
457            name: Some(name.to_string()),
458            prefix: false,
459            reasoning_content: None,
460        }
461    }
462
463    /// Sets the reasoning content for the `AssistantMessageRequest`.
464    ///
465    /// # Arguments
466    ///
467    /// * `content` - A string slice representing the reasoning content.
468    ///
469    /// # Returns
470    ///
471    /// Returns the updated `AssistantMessageRequest` instance.
472    pub fn set_reasoning_content(mut self, content: &str) -> Self {
473        self.prefix = true;
474        self.reasoning_content = Some(content.to_string());
475        self
476    }
477
478    pub fn set_prefix(mut self, content: &str) -> Self {
479        self.prefix = true;
480        self.content = content.to_string();
481        self
482    }
483}
484
485/// Represents a tool message request.
486#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
487pub struct ToolMessageRequest {
488    pub content: String,
489    pub tool_call_id: String,
490}
491
492impl ToolMessageRequest {
493    /// Creates a new `ToolMessageRequest` instance.
494    ///
495    /// # Arguments
496    ///
497    /// * `msg` - A string slice representing the message content.
498    /// * `tool_call_id` - A string slice representing the tool call ID.
499    pub fn new(msg: &str, tool_call_id: &str) -> Self {
500        ToolMessageRequest {
501            content: msg.to_string(),
502            tool_call_id: tool_call_id.to_string(),
503        }
504    }
505}
506
507pub trait RequestBuilder {
508    type Request: Serialize;
509    type Response: DeserializeOwned;
510    type Item: DeserializeOwned + Send + 'static;
511
512    fn is_beta(&self) -> bool;
513    fn is_stream(&self) -> bool;
514    fn build(self) -> Self::Request;
515}
516
517/// Represents a request for completions.
518#[derive(Debug, Default, Clone, PartialEq, Deserialize)]
519pub struct CompletionsRequest {
520    pub messages: Vec<MessageRequest>,
521    pub model: ModelType,
522    pub prompt: String,
523    #[serde(skip_serializing_if = "Option::is_none")]
524    pub max_tokens: Option<MaxToken>,
525    #[serde(skip_serializing_if = "Option::is_none")]
526    pub response_format: Option<ResponseFormat>,
527    #[serde(skip_serializing_if = "Option::is_none")]
528    pub stop: Option<Stop>,
529    pub stream: bool,
530    #[serde(skip_serializing_if = "Option::is_none")]
531    pub stream_options: Option<StreamOptions>,
532    #[serde(skip_serializing_if = "Option::is_none")]
533    pub tools: Option<Vec<ToolObject>>,
534    #[serde(skip_serializing_if = "Option::is_none")]
535    pub tool_choice: Option<ToolChoice>,
536
537    // ignore when model is deepseek-reasoner
538    #[serde(skip_serializing_if = "Option::is_none")]
539    pub temperature: Option<Temperature>,
540    #[serde(skip_serializing_if = "Option::is_none")]
541    pub top_p: Option<TopP>,
542    #[serde(skip_serializing_if = "Option::is_none")]
543    pub presence_penalty: Option<PresencePenalty>,
544    #[serde(skip_serializing_if = "Option::is_none")]
545    pub frequency_penalty: Option<FrequencyPenalty>,
546    #[serde(skip_serializing_if = "Option::is_none")]
547    pub logprobs: Option<bool>,
548    #[serde(skip_serializing_if = "Option::is_none")]
549    pub top_logprobs: Option<TopLogprobs>,
550}
551
552impl Serialize for CompletionsRequest {
553    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
554    where
555        S: Serializer,
556    {
557        let mut state = serializer.serialize_struct("CompletionsRequest", 12)?;
558
559        state.serialize_field("messages", &self.messages)?;
560        state.serialize_field("model", &self.model)?;
561        state.serialize_field("max_tokens", &self.max_tokens)?;
562        state.serialize_field("response_format", &self.response_format)?;
563        state.serialize_field("stop", &self.stop)?;
564        state.serialize_field("stream", &self.stream)?;
565        state.serialize_field("stream_options", &self.stream_options)?;
566        state.serialize_field("tools", &self.tools)?;
567        state.serialize_field("tool_choice", &self.tool_choice)?;
568        state.serialize_field("prompt", &self.prompt)?;
569
570        // Skip these fields if model is DeepSeekReasoner
571        if self.model != ModelType::DeepSeekReasoner {
572            state.serialize_field("temperature", &self.temperature)?;
573            state.serialize_field("top_p", &self.top_p)?;
574            state.serialize_field("presence_penalty", &self.presence_penalty)?;
575            state.serialize_field("frequency_penalty", &self.frequency_penalty)?;
576            state.serialize_field("logprobs", &self.logprobs)?;
577            state.serialize_field("top_logprobs", &self.top_logprobs)?;
578        }
579
580        state.end()
581    }
582}
583
584#[derive(Debug, Default)]
585pub struct CompletionsRequestBuilder {
586    beta: bool,
587    messages: Vec<MessageRequest>,
588    model: ModelType,
589
590    stream: bool,
591    stream_options: Option<StreamOptions>,
592
593    max_tokens: Option<MaxToken>,
594    response_format: Option<ResponseFormat>,
595    stop: Option<Stop>,
596    tools: Option<Vec<ToolObject>>,
597    tool_choice: Option<ToolChoice>,
598    prompt: String,
599    temperature: Option<Temperature>,
600    top_p: Option<TopP>,
601    presence_penalty: Option<PresencePenalty>,
602    frequency_penalty: Option<FrequencyPenalty>,
603    logprobs: Option<bool>,
604    top_logprobs: Option<TopLogprobs>,
605}
606
607impl CompletionsRequestBuilder {
608    pub fn new(messages: Vec<MessageRequest>) -> Self {
609        Self {
610            messages,
611            model: ModelType::DeepSeekChat,
612            prompt: String::new(),
613            ..Default::default()
614        }
615    }
616    pub fn use_model(mut self, model: ModelType) -> Self {
617        self.model = model;
618        self
619    }
620
621    //https://api-docs.deepseek.com/guides/fim_completion
622    pub fn append_fim_message(self, _prompt: &str, _suffix: &str) -> Self {
623        todo!("Not enough detail in document")
624    }
625
626    // https://api-docs.deepseek.com/zh-cn/guides/chat_prefix_completion
627    pub fn append_prefix_message(mut self, msg: &str) -> Self {
628        self.messages.push(MessageRequest::Assistant(
629            AssistantMessageRequest::new(msg).set_prefix(msg),
630        ));
631        self
632    }
633
634    pub fn append_user_message(mut self, msg: &str) -> Self {
635        self.messages
636            .push(MessageRequest::User(UserMessageRequest::new(msg)));
637        self
638    }
639
640    pub fn max_tokens(mut self, value: u32) -> Result<Self> {
641        self.max_tokens = Some(MaxToken::new(value)?);
642        Ok(self)
643    }
644
645    pub fn use_beta(mut self, value: bool) -> Self {
646        self.beta = value;
647        self
648    }
649
650    pub fn stream(mut self, value: bool) -> Self {
651        self.stream = value;
652        self
653    }
654
655    pub fn stream_options(mut self, value: StreamOptions) -> Self {
656        self.stream_options = Some(value);
657        self
658    }
659
660    pub fn response_format(mut self, value: ResponseType) -> Self {
661        self.response_format = Some(ResponseFormat { resp_type: value });
662        self
663    }
664
665    pub fn stop(mut self, value: Stop) -> Self {
666        self.stop = Some(value);
667        self
668    }
669
670    pub fn tools(mut self, value: Vec<ToolObject>) -> Self {
671        self.tools = Some(value);
672        self
673    }
674
675    pub fn tool_choice(mut self, value: ToolChoice) -> Self {
676        self.tool_choice = Some(value);
677        self
678    }
679
680    pub fn prompt(mut self, value: String) -> Self {
681        self.prompt = value;
682        self
683    }
684
685    pub fn temperature(mut self, value: u32) -> Result<Self> {
686        self.temperature = Some(Temperature::new(value)?);
687        Ok(self)
688    }
689
690    pub fn top_p(mut self, value: f32) -> Result<Self> {
691        self.top_p = Some(TopP::new(value)?);
692        Ok(self)
693    }
694
695    pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
696        self.presence_penalty = Some(PresencePenalty::new(value)?);
697        Ok(self)
698    }
699
700    pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
701        self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
702        Ok(self)
703    }
704
705    pub fn logprobs(mut self, value: bool) -> Self {
706        self.logprobs = Some(value);
707        self
708    }
709
710    pub fn top_logprobs(mut self, value: u32) -> Result<Self> {
711        self.top_logprobs = Some(TopLogprobs::new(value)?);
712        Ok(self)
713    }
714}
715
716impl RequestBuilder for CompletionsRequestBuilder {
717    type Request = CompletionsRequest;
718    type Response = ChatCompletion;
719    type Item = ChatCompletionStream<JSONChoiceStream>;
720
721    fn is_beta(&self) -> bool {
722        self.beta
723    }
724
725    fn is_stream(&self) -> bool {
726        self.stream
727    }
728
729    fn build(self) -> CompletionsRequest {
730        CompletionsRequest {
731            messages: self.messages,
732            model: self.model,
733            max_tokens: self.max_tokens,
734            response_format: self.response_format,
735            stop: self.stop,
736            stream: self.stream,
737            stream_options: self.stream_options,
738            tools: self.tools,
739            tool_choice: self.tool_choice,
740            prompt: self.prompt,
741            temperature: self.temperature,
742            top_p: self.top_p,
743            presence_penalty: self.presence_penalty,
744            frequency_penalty: self.frequency_penalty,
745            logprobs: self.logprobs,
746            top_logprobs: self.top_logprobs,
747        }
748    }
749}
750
751/// Represents a request for completions.
752#[derive(Debug, Default, Clone, PartialEq, Serialize)]
753pub struct FMICompletionsRequest {
754    pub model: ModelType,
755    pub prompt: String,
756    pub echo: bool,
757
758    #[serde(skip_serializing_if = "Option::is_none")]
759    pub frequency_penalty: Option<FrequencyPenalty>,
760    #[serde(skip_serializing_if = "Option::is_none")]
761    pub logprobs: Option<bool>,
762    #[serde(skip_serializing_if = "Option::is_none")]
763    pub max_tokens: Option<MaxToken>,
764    #[serde(skip_serializing_if = "Option::is_none")]
765    pub presence_penalty: Option<PresencePenalty>,
766    #[serde(skip_serializing_if = "Option::is_none")]
767    pub stop: Option<Stop>,
768    pub stream: bool,
769    #[serde(skip_serializing_if = "Option::is_none")]
770    pub stream_options: Option<StreamOptions>,
771    pub suffix: String,
772    #[serde(skip_serializing_if = "Option::is_none")]
773    pub temperature: Option<Temperature>,
774    #[serde(skip_serializing_if = "Option::is_none")]
775    pub top_p: Option<TopP>,
776}
777
778#[derive(Debug, Default)]
779pub struct FMICompletionsRequestBuilder {
780    model: ModelType,
781    prompt: String,
782    echo: bool,
783    frequency_penalty: Option<FrequencyPenalty>,
784    logprobs: Option<bool>,
785    max_tokens: Option<MaxToken>,
786    presence_penalty: Option<PresencePenalty>,
787    stop: Option<Stop>,
788    stream: bool,
789    stream_options: Option<StreamOptions>,
790    suffix: String,
791    temperature: Option<Temperature>,
792    top_p: Option<TopP>,
793}
794
795impl FMICompletionsRequestBuilder {
796    pub fn new(prompt: &str, suffix: &str) -> Self {
797        Self {
798            model: ModelType::DeepSeekChat,
799            prompt: prompt.to_string(),
800            suffix: suffix.to_string(),
801            echo: false,
802            stream: false,
803            ..Default::default()
804        }
805    }
806
807    pub fn use_model(mut self, model: ModelType) -> Self {
808        self.model = model;
809        self
810    }
811
812    pub fn echo(mut self, value: bool) -> Self {
813        self.echo = value;
814        self
815    }
816
817    pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
818        self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
819        Ok(self)
820    }
821
822    pub fn logprobs(mut self, value: bool) -> Self {
823        self.logprobs = Some(value);
824        self
825    }
826
827    pub fn max_tokens(mut self, value: u32) -> Result<Self> {
828        self.max_tokens = Some(MaxToken::new(value)?);
829        Ok(self)
830    }
831
832    pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
833        self.presence_penalty = Some(PresencePenalty::new(value)?);
834        Ok(self)
835    }
836
837    pub fn stop(mut self, value: Stop) -> Self {
838        self.stop = Some(value);
839        self
840    }
841
842    pub fn stream(mut self, value: bool) -> Self {
843        self.stream = value;
844        self
845    }
846
847    pub fn stream_options(mut self, value: StreamOptions) -> Self {
848        self.stream_options = Some(value);
849        self
850    }
851
852    pub fn temperature(mut self, value: u32) -> Result<Self> {
853        self.temperature = Some(Temperature::new(value)?);
854        Ok(self)
855    }
856
857    pub fn top_p(mut self, value: f32) -> Result<Self> {
858        self.top_p = Some(TopP::new(value)?);
859        Ok(self)
860    }
861}
862
863impl RequestBuilder for FMICompletionsRequestBuilder {
864    type Request = FMICompletionsRequest;
865    type Response = ChatCompletion;
866    type Item = ChatCompletionStream<TextChoiceStream>;
867
868    fn is_beta(&self) -> bool {
869        true
870    }
871
872    fn is_stream(&self) -> bool {
873        self.stream
874    }
875
876    fn build(self) -> FMICompletionsRequest {
877        FMICompletionsRequest {
878            model: self.model,
879            prompt: self.prompt,
880            echo: self.echo,
881            frequency_penalty: self.frequency_penalty,
882            logprobs: self.logprobs,
883            max_tokens: self.max_tokens,
884            presence_penalty: self.presence_penalty,
885            stop: self.stop,
886            stream: self.stream,
887            stream_options: self.stream_options,
888            suffix: self.suffix,
889            temperature: self.temperature,
890            top_p: self.top_p,
891        }
892    }
893}