deepseek_api/
request_builder.rs

1use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize, Serializer};
2
3use crate::{
4    request::{
5        FrequencyPenalty, MaxToken, MessageRequest, PresencePenalty, ResponseFormat, ResponseType,
6        Stop, StreamOptions, Temperature, ToolChoice, ToolObject, TopLogprobs, TopP,
7    },
8    response::{
9        ChatCompletion, ChatCompletionStream, ChatResponse, JSONChoiceStream, ModelType,
10        TextChoiceStream,
11    },
12    DeepSeekClient,
13};
14use anyhow::{Ok, Result};
15
16pub trait RequestBuilder: Sized + Send {
17    type Request: Serialize + Send;
18    type Response: DeserializeOwned + Send + 'static;
19    type Item: DeserializeOwned + Send + 'static;
20
21    fn is_beta(&self) -> bool;
22    fn is_stream(&self) -> bool;
23    fn build(self) -> Self::Request;
24
25    cfg_if::cfg_if! {
26        if #[cfg(feature = "is_sync")] {
27            fn do_request(self, client: &DeepSeekClient) ->  Result<ChatResponse<Self::Response, Self::Item>>  {
28                client.send_completion_request(self)
29            }
30        } else {
31            fn do_request(self, client: &DeepSeekClient) ->  impl std::future::Future<Output = Result<ChatResponse<Self::Response, Self::Item>>> + Send {async {
32                client.send_completion_request(self).await
33            }}
34        }
35    }
36}
37
38/// Represents a request for completions.
39#[derive(Debug, Default, Clone)]
40pub struct CompletionsRequest<'a> {
41    pub messages: &'a [MessageRequest],
42    pub model: ModelType,
43    pub max_tokens: Option<MaxToken>,
44    pub response_format: Option<ResponseFormat>,
45    pub stop: Option<Stop>,
46    pub stream: bool,
47    pub stream_options: Option<StreamOptions>,
48    pub tools: Option<&'a [ToolObject]>,
49    pub tool_choice: Option<ToolChoice>,
50
51    // ignore when model is deepseek-reasoner
52    pub temperature: Option<Temperature>,
53    pub top_p: Option<TopP>,
54    pub presence_penalty: Option<PresencePenalty>,
55    pub frequency_penalty: Option<FrequencyPenalty>,
56    pub logprobs: Option<bool>,
57    pub top_logprobs: Option<TopLogprobs>,
58}
59
60impl Serialize for CompletionsRequest<'_> {
61    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62    where
63        S: Serializer,
64    {
65        let mut state = serializer.serialize_struct("CompletionsRequest", 12)?;
66
67        state.serialize_field("messages", &self.messages)?;
68        state.serialize_field("model", &self.model)?;
69
70        if let Some(max_tokens) = &self.max_tokens {
71            state.serialize_field("max_tokens", max_tokens)?;
72        }
73        if let Some(response_format) = &self.response_format {
74            state.serialize_field("response_format", response_format)?;
75        }
76        if let Some(stop) = &self.stop {
77            state.serialize_field("stop", stop)?;
78        }
79        state.serialize_field("stream", &self.stream)?;
80        if let Some(stream_options) = &self.stream_options {
81            state.serialize_field("stream_options", stream_options)?;
82        }
83        if let Some(tools) = &self.tools {
84            state.serialize_field("tools", tools)?;
85        }
86        if let Some(tool_choice) = &self.tool_choice {
87            state.serialize_field("tool_choice", tool_choice)?;
88        }
89
90        // Skip these fields if model is DeepSeekReasoner
91        if self.model != ModelType::DeepSeekReasoner {
92            if let Some(temperature) = &self.temperature {
93                state.serialize_field("temperature", temperature)?;
94            }
95            if let Some(top_p) = &self.top_p {
96                state.serialize_field("top_p", top_p)?;
97            }
98            if let Some(presence_penalty) = &self.presence_penalty {
99                state.serialize_field("presence_penalty", presence_penalty)?;
100            }
101            if let Some(frequency_penalty) = &self.frequency_penalty {
102                state.serialize_field("frequency_penalty", frequency_penalty)?;
103            }
104            if let Some(logprobs) = &self.logprobs {
105                state.serialize_field("logprobs", logprobs)?;
106            }
107            if let Some(top_logprobs) = &self.top_logprobs {
108                state.serialize_field("top_logprobs", top_logprobs)?;
109            }
110        }
111
112        state.end()
113    }
114}
115
116#[derive(Debug, Default)]
117pub struct CompletionsRequestBuilder<'a> {
118    //todo too many colone when use this type, improve it especially for message field
119    beta: bool,
120    messages: &'a [MessageRequest],
121    model: ModelType,
122
123    stream: bool,
124    stream_options: Option<StreamOptions>,
125
126    max_tokens: Option<MaxToken>,
127    response_format: Option<ResponseFormat>,
128    stop: Option<Stop>,
129    tools: Option<&'a [ToolObject]>,
130    tool_choice: Option<ToolChoice>,
131    temperature: Option<Temperature>,
132    top_p: Option<TopP>,
133    presence_penalty: Option<PresencePenalty>,
134    frequency_penalty: Option<FrequencyPenalty>,
135    logprobs: Option<bool>,
136    top_logprobs: Option<TopLogprobs>,
137}
138
139impl<'a> CompletionsRequestBuilder<'a> {
140    pub fn new(messages: &'a [MessageRequest]) -> Self {
141        Self {
142            messages,
143            model: ModelType::DeepSeekChat,
144            ..Default::default()
145        }
146    }
147    pub fn use_model(mut self, model: ModelType) -> Self {
148        self.model = model;
149        self
150    }
151
152    pub fn max_tokens(mut self, value: u32) -> Result<Self> {
153        self.max_tokens = Some(MaxToken::new(value)?);
154        Ok(self)
155    }
156
157    pub fn use_beta(mut self, value: bool) -> Self {
158        self.beta = value;
159        self
160    }
161
162    pub fn stream(mut self, value: bool) -> Self {
163        self.stream = value;
164        self
165    }
166
167    pub fn stream_options(mut self, value: StreamOptions) -> Self {
168        self.stream_options = Some(value);
169        self
170    }
171
172    pub fn response_format(mut self, value: ResponseType) -> Self {
173        self.response_format = Some(ResponseFormat { resp_type: value });
174        self
175    }
176
177    pub fn stop(mut self, value: Stop) -> Self {
178        self.stop = Some(value);
179        self
180    }
181
182    pub fn tools(mut self, value: &'a [ToolObject]) -> Self {
183        self.tools = Some(value);
184        self
185    }
186
187    pub fn tool_choice(mut self, value: ToolChoice) -> Self {
188        self.tool_choice = Some(value);
189        self
190    }
191
192    pub fn temperature(mut self, value: f32) -> Result<Self> {
193        self.temperature = Some(Temperature::new(value)?);
194        Ok(self)
195    }
196
197    pub fn top_p(mut self, value: f32) -> Result<Self> {
198        self.top_p = Some(TopP::new(value)?);
199        Ok(self)
200    }
201
202    pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
203        self.presence_penalty = Some(PresencePenalty::new(value)?);
204        Ok(self)
205    }
206
207    pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
208        self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
209        Ok(self)
210    }
211
212    pub fn logprobs(mut self, value: bool) -> Self {
213        self.logprobs = Some(value);
214        self
215    }
216
217    pub fn top_logprobs(mut self, value: u32) -> Result<Self> {
218        self.top_logprobs = Some(TopLogprobs::new(value)?);
219        Ok(self)
220    }
221}
222
223impl<'a> RequestBuilder for CompletionsRequestBuilder<'a> {
224    type Request = CompletionsRequest<'a>;
225    type Response = ChatCompletion;
226    type Item = ChatCompletionStream<JSONChoiceStream>;
227
228    fn is_beta(&self) -> bool {
229        self.beta
230    }
231
232    fn is_stream(&self) -> bool {
233        self.stream
234    }
235
236    fn build(self) -> CompletionsRequest<'a> {
237        CompletionsRequest {
238            messages: self.messages,
239            model: self.model,
240            max_tokens: self.max_tokens,
241            response_format: self.response_format,
242            stop: self.stop,
243            stream: self.stream,
244            stream_options: self.stream_options,
245            tools: self.tools,
246            tool_choice: self.tool_choice,
247            temperature: self.temperature,
248            top_p: self.top_p,
249            presence_penalty: self.presence_penalty,
250            frequency_penalty: self.frequency_penalty,
251            logprobs: self.logprobs,
252            top_logprobs: self.top_logprobs,
253        }
254    }
255}
256
257/// Represents a request for completions.
258#[derive(Debug, Default, Clone, PartialEq, Serialize)]
259pub struct FMICompletionsRequest {
260    pub model: ModelType,
261    pub prompt: String,
262    pub echo: bool,
263    pub suffix: String,
264
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub frequency_penalty: Option<FrequencyPenalty>,
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub logprobs: Option<bool>,
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub max_tokens: Option<MaxToken>,
271    #[serde(skip_serializing_if = "Option::is_none")]
272    pub presence_penalty: Option<PresencePenalty>,
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub stop: Option<Stop>,
275    pub stream: bool,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub stream_options: Option<StreamOptions>,
278
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub temperature: Option<Temperature>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub top_p: Option<TopP>,
283}
284
285#[derive(Debug, Default)]
286pub struct FMICompletionsRequestBuilder {
287    model: ModelType,
288    prompt: String,
289    echo: bool,
290    frequency_penalty: Option<FrequencyPenalty>,
291    logprobs: Option<bool>,
292    max_tokens: Option<MaxToken>,
293    presence_penalty: Option<PresencePenalty>,
294    stop: Option<Stop>,
295    stream: bool,
296    stream_options: Option<StreamOptions>,
297    suffix: String,
298    temperature: Option<Temperature>,
299    top_p: Option<TopP>,
300}
301
302impl FMICompletionsRequestBuilder {
303    pub fn new(prompt: &str, suffix: &str) -> Self {
304        Self {
305            //fim only support deepseek-chat model
306            model: ModelType::DeepSeekChat,
307            prompt: prompt.to_string(),
308            suffix: suffix.to_string(),
309            echo: false,
310            stream: false,
311            ..Default::default()
312        }
313    }
314
315    pub fn echo(mut self, value: bool) -> Self {
316        self.echo = value;
317        self
318    }
319
320    pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
321        self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
322        Ok(self)
323    }
324
325    pub fn logprobs(mut self, value: bool) -> Self {
326        self.logprobs = Some(value);
327        self
328    }
329
330    pub fn max_tokens(mut self, value: u32) -> Result<Self> {
331        self.max_tokens = Some(MaxToken::new(value)?);
332        Ok(self)
333    }
334
335    pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
336        self.presence_penalty = Some(PresencePenalty::new(value)?);
337        Ok(self)
338    }
339
340    pub fn stop(mut self, value: Stop) -> Self {
341        self.stop = Some(value);
342        self
343    }
344
345    pub fn stream(mut self, value: bool) -> Self {
346        self.stream = value;
347        self
348    }
349
350    pub fn stream_options(mut self, value: StreamOptions) -> Self {
351        self.stream_options = Some(value);
352        self
353    }
354
355    pub fn temperature(mut self, value: f32) -> Result<Self> {
356        self.temperature = Some(Temperature::new(value)?);
357        Ok(self)
358    }
359
360    pub fn top_p(mut self, value: f32) -> Result<Self> {
361        self.top_p = Some(TopP::new(value)?);
362        Ok(self)
363    }
364}
365
366impl RequestBuilder for FMICompletionsRequestBuilder {
367    type Request = FMICompletionsRequest;
368    type Response = ChatCompletion;
369    type Item = ChatCompletionStream<TextChoiceStream>;
370
371    fn is_beta(&self) -> bool {
372        true
373    }
374
375    fn is_stream(&self) -> bool {
376        self.stream
377    }
378
379    fn build(self) -> FMICompletionsRequest {
380        FMICompletionsRequest {
381            model: self.model,
382            prompt: self.prompt,
383            echo: self.echo,
384            frequency_penalty: self.frequency_penalty,
385            logprobs: self.logprobs,
386            max_tokens: self.max_tokens,
387            presence_penalty: self.presence_penalty,
388            stop: self.stop,
389            stream: self.stream,
390            stream_options: self.stream_options,
391            suffix: self.suffix,
392            temperature: self.temperature,
393            top_p: self.top_p,
394        }
395    }
396}