1use serde::{de::DeserializeOwned, ser::SerializeStruct, Serialize, Serializer};
2
3use crate::{
4 request::{
5 FrequencyPenalty, MaxToken, MessageRequest, PresencePenalty, ResponseFormat, ResponseType,
6 Stop, StreamOptions, Temperature, ToolChoice, ToolObject, TopLogprobs, TopP,
7 },
8 response::{
9 ChatCompletion, ChatCompletionStream, ChatResponse, JSONChoiceStream, ModelType,
10 TextChoiceStream,
11 },
12 DeepSeekClient,
13};
14use anyhow::{Ok, Result};
15
16pub trait RequestBuilder: Sized + Send {
17 type Request: Serialize + Send;
18 type Response: DeserializeOwned + Send + 'static;
19 type Item: DeserializeOwned + Send + 'static;
20
21 fn is_beta(&self) -> bool;
22 fn is_stream(&self) -> bool;
23 fn build(self) -> Self::Request;
24
25 cfg_if::cfg_if! {
26 if #[cfg(feature = "is_sync")] {
27 fn do_request(self, client: &DeepSeekClient) -> Result<ChatResponse<Self::Response, Self::Item>> {
28 client.send_completion_request(self)
29 }
30 } else {
31 fn do_request(self, client: &DeepSeekClient) -> impl std::future::Future<Output = Result<ChatResponse<Self::Response, Self::Item>>> + Send {async {
32 client.send_completion_request(self).await
33 }}
34 }
35 }
36}
37
38#[derive(Debug, Default, Clone)]
40pub struct CompletionsRequest<'a> {
41 pub messages: &'a [MessageRequest],
42 pub model: ModelType,
43 pub max_tokens: Option<MaxToken>,
44 pub response_format: Option<ResponseFormat>,
45 pub stop: Option<Stop>,
46 pub stream: bool,
47 pub stream_options: Option<StreamOptions>,
48 pub tools: Option<&'a [ToolObject]>,
49 pub tool_choice: Option<ToolChoice>,
50
51 pub temperature: Option<Temperature>,
53 pub top_p: Option<TopP>,
54 pub presence_penalty: Option<PresencePenalty>,
55 pub frequency_penalty: Option<FrequencyPenalty>,
56 pub logprobs: Option<bool>,
57 pub top_logprobs: Option<TopLogprobs>,
58}
59
60impl Serialize for CompletionsRequest<'_> {
61 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62 where
63 S: Serializer,
64 {
65 let mut state = serializer.serialize_struct("CompletionsRequest", 12)?;
66
67 state.serialize_field("messages", &self.messages)?;
68 state.serialize_field("model", &self.model)?;
69
70 if let Some(max_tokens) = &self.max_tokens {
71 state.serialize_field("max_tokens", max_tokens)?;
72 }
73 if let Some(response_format) = &self.response_format {
74 state.serialize_field("response_format", response_format)?;
75 }
76 if let Some(stop) = &self.stop {
77 state.serialize_field("stop", stop)?;
78 }
79 state.serialize_field("stream", &self.stream)?;
80 if let Some(stream_options) = &self.stream_options {
81 state.serialize_field("stream_options", stream_options)?;
82 }
83 if let Some(tools) = &self.tools {
84 state.serialize_field("tools", tools)?;
85 }
86 if let Some(tool_choice) = &self.tool_choice {
87 state.serialize_field("tool_choice", tool_choice)?;
88 }
89
90 if self.model != ModelType::DeepSeekReasoner {
92 if let Some(temperature) = &self.temperature {
93 state.serialize_field("temperature", temperature)?;
94 }
95 if let Some(top_p) = &self.top_p {
96 state.serialize_field("top_p", top_p)?;
97 }
98 if let Some(presence_penalty) = &self.presence_penalty {
99 state.serialize_field("presence_penalty", presence_penalty)?;
100 }
101 if let Some(frequency_penalty) = &self.frequency_penalty {
102 state.serialize_field("frequency_penalty", frequency_penalty)?;
103 }
104 if let Some(logprobs) = &self.logprobs {
105 state.serialize_field("logprobs", logprobs)?;
106 }
107 if let Some(top_logprobs) = &self.top_logprobs {
108 state.serialize_field("top_logprobs", top_logprobs)?;
109 }
110 }
111
112 state.end()
113 }
114}
115
116#[derive(Debug, Default)]
117pub struct CompletionsRequestBuilder<'a> {
118 beta: bool,
120 messages: &'a [MessageRequest],
121 model: ModelType,
122
123 stream: bool,
124 stream_options: Option<StreamOptions>,
125
126 max_tokens: Option<MaxToken>,
127 response_format: Option<ResponseFormat>,
128 stop: Option<Stop>,
129 tools: Option<&'a [ToolObject]>,
130 tool_choice: Option<ToolChoice>,
131 temperature: Option<Temperature>,
132 top_p: Option<TopP>,
133 presence_penalty: Option<PresencePenalty>,
134 frequency_penalty: Option<FrequencyPenalty>,
135 logprobs: Option<bool>,
136 top_logprobs: Option<TopLogprobs>,
137}
138
139impl<'a> CompletionsRequestBuilder<'a> {
140 pub fn new(messages: &'a [MessageRequest]) -> Self {
141 Self {
142 messages,
143 model: ModelType::DeepSeekChat,
144 ..Default::default()
145 }
146 }
147 pub fn use_model(mut self, model: ModelType) -> Self {
148 self.model = model;
149 self
150 }
151
152 pub fn max_tokens(mut self, value: u32) -> Result<Self> {
153 self.max_tokens = Some(MaxToken::new(value)?);
154 Ok(self)
155 }
156
157 pub fn use_beta(mut self, value: bool) -> Self {
158 self.beta = value;
159 self
160 }
161
162 pub fn stream(mut self, value: bool) -> Self {
163 self.stream = value;
164 self
165 }
166
167 pub fn stream_options(mut self, value: StreamOptions) -> Self {
168 self.stream_options = Some(value);
169 self
170 }
171
172 pub fn response_format(mut self, value: ResponseType) -> Self {
173 self.response_format = Some(ResponseFormat { resp_type: value });
174 self
175 }
176
177 pub fn stop(mut self, value: Stop) -> Self {
178 self.stop = Some(value);
179 self
180 }
181
182 pub fn tools(mut self, value: &'a [ToolObject]) -> Self {
183 self.tools = Some(value);
184 self
185 }
186
187 pub fn tool_choice(mut self, value: ToolChoice) -> Self {
188 self.tool_choice = Some(value);
189 self
190 }
191
192 pub fn temperature(mut self, value: f32) -> Result<Self> {
193 self.temperature = Some(Temperature::new(value)?);
194 Ok(self)
195 }
196
197 pub fn top_p(mut self, value: f32) -> Result<Self> {
198 self.top_p = Some(TopP::new(value)?);
199 Ok(self)
200 }
201
202 pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
203 self.presence_penalty = Some(PresencePenalty::new(value)?);
204 Ok(self)
205 }
206
207 pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
208 self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
209 Ok(self)
210 }
211
212 pub fn logprobs(mut self, value: bool) -> Self {
213 self.logprobs = Some(value);
214 self
215 }
216
217 pub fn top_logprobs(mut self, value: u32) -> Result<Self> {
218 self.top_logprobs = Some(TopLogprobs::new(value)?);
219 Ok(self)
220 }
221}
222
223impl<'a> RequestBuilder for CompletionsRequestBuilder<'a> {
224 type Request = CompletionsRequest<'a>;
225 type Response = ChatCompletion;
226 type Item = ChatCompletionStream<JSONChoiceStream>;
227
228 fn is_beta(&self) -> bool {
229 self.beta
230 }
231
232 fn is_stream(&self) -> bool {
233 self.stream
234 }
235
236 fn build(self) -> CompletionsRequest<'a> {
237 CompletionsRequest {
238 messages: self.messages,
239 model: self.model,
240 max_tokens: self.max_tokens,
241 response_format: self.response_format,
242 stop: self.stop,
243 stream: self.stream,
244 stream_options: self.stream_options,
245 tools: self.tools,
246 tool_choice: self.tool_choice,
247 temperature: self.temperature,
248 top_p: self.top_p,
249 presence_penalty: self.presence_penalty,
250 frequency_penalty: self.frequency_penalty,
251 logprobs: self.logprobs,
252 top_logprobs: self.top_logprobs,
253 }
254 }
255}
256
257#[derive(Debug, Default, Clone, PartialEq, Serialize)]
259pub struct FMICompletionsRequest {
260 pub model: ModelType,
261 pub prompt: String,
262 pub echo: bool,
263 pub suffix: String,
264
265 #[serde(skip_serializing_if = "Option::is_none")]
266 pub frequency_penalty: Option<FrequencyPenalty>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 pub logprobs: Option<bool>,
269 #[serde(skip_serializing_if = "Option::is_none")]
270 pub max_tokens: Option<MaxToken>,
271 #[serde(skip_serializing_if = "Option::is_none")]
272 pub presence_penalty: Option<PresencePenalty>,
273 #[serde(skip_serializing_if = "Option::is_none")]
274 pub stop: Option<Stop>,
275 pub stream: bool,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 pub stream_options: Option<StreamOptions>,
278
279 #[serde(skip_serializing_if = "Option::is_none")]
280 pub temperature: Option<Temperature>,
281 #[serde(skip_serializing_if = "Option::is_none")]
282 pub top_p: Option<TopP>,
283}
284
285#[derive(Debug, Default)]
286pub struct FMICompletionsRequestBuilder {
287 model: ModelType,
288 prompt: String,
289 echo: bool,
290 frequency_penalty: Option<FrequencyPenalty>,
291 logprobs: Option<bool>,
292 max_tokens: Option<MaxToken>,
293 presence_penalty: Option<PresencePenalty>,
294 stop: Option<Stop>,
295 stream: bool,
296 stream_options: Option<StreamOptions>,
297 suffix: String,
298 temperature: Option<Temperature>,
299 top_p: Option<TopP>,
300}
301
302impl FMICompletionsRequestBuilder {
303 pub fn new(prompt: &str, suffix: &str) -> Self {
304 Self {
305 model: ModelType::DeepSeekChat,
307 prompt: prompt.to_string(),
308 suffix: suffix.to_string(),
309 echo: false,
310 stream: false,
311 ..Default::default()
312 }
313 }
314
315 pub fn echo(mut self, value: bool) -> Self {
316 self.echo = value;
317 self
318 }
319
320 pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
321 self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
322 Ok(self)
323 }
324
325 pub fn logprobs(mut self, value: bool) -> Self {
326 self.logprobs = Some(value);
327 self
328 }
329
330 pub fn max_tokens(mut self, value: u32) -> Result<Self> {
331 self.max_tokens = Some(MaxToken::new(value)?);
332 Ok(self)
333 }
334
335 pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
336 self.presence_penalty = Some(PresencePenalty::new(value)?);
337 Ok(self)
338 }
339
340 pub fn stop(mut self, value: Stop) -> Self {
341 self.stop = Some(value);
342 self
343 }
344
345 pub fn stream(mut self, value: bool) -> Self {
346 self.stream = value;
347 self
348 }
349
350 pub fn stream_options(mut self, value: StreamOptions) -> Self {
351 self.stream_options = Some(value);
352 self
353 }
354
355 pub fn temperature(mut self, value: f32) -> Result<Self> {
356 self.temperature = Some(Temperature::new(value)?);
357 Ok(self)
358 }
359
360 pub fn top_p(mut self, value: f32) -> Result<Self> {
361 self.top_p = Some(TopP::new(value)?);
362 Ok(self)
363 }
364}
365
366impl RequestBuilder for FMICompletionsRequestBuilder {
367 type Request = FMICompletionsRequest;
368 type Response = ChatCompletion;
369 type Item = ChatCompletionStream<TextChoiceStream>;
370
371 fn is_beta(&self) -> bool {
372 true
373 }
374
375 fn is_stream(&self) -> bool {
376 self.stream
377 }
378
379 fn build(self) -> FMICompletionsRequest {
380 FMICompletionsRequest {
381 model: self.model,
382 prompt: self.prompt,
383 echo: self.echo,
384 frequency_penalty: self.frequency_penalty,
385 logprobs: self.logprobs,
386 max_tokens: self.max_tokens,
387 presence_penalty: self.presence_penalty,
388 stop: self.stop,
389 stream: self.stream,
390 stream_options: self.stream_options,
391 suffix: self.suffix,
392 temperature: self.temperature,
393 top_p: self.top_p,
394 }
395 }
396}