1use std::collections::HashMap;
6use std::sync::mpsc as std_mpsc;
7
8use crate::chat::request::{Stop, StreamOptions, is_none_or_empty_stop};
9use crate::chat::response::ChatGeneric;
10use crate::error::DeepSeekError;
11use crate::{DeepSeekClient, api_request_stream};
12use crate::{DeepSeekRequest, api_post};
13use derive_builder::Builder;
14use futures_util::StreamExt;
15use reqwest::Method;
16use reqwest_eventsource::Event;
17use serde::{Deserialize, Serialize};
18use tokio::sync::mpsc;
19
20pub type Completion = ChatGeneric<CompletionChoice>;
22
23#[derive(Clone, Debug, PartialEq, Serialize, Builder)]
25#[builder(
26 pattern = "owned",
27 setter(into, strip_option),
28 build_fn(validate = "Self::validate"),
29 name = "FIMCompletionRequestBuilder"
30)]
31pub struct FIMCompletionRequest {
32 #[serde(skip_serializing)]
33 pub client: DeepSeekClient,
34
35 pub model: String,
39
40 pub prompt: String,
42
43 #[builder(default)]
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub echo: Option<bool>,
47
48 #[builder(default)]
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub logprobs: Option<u32>,
57
58 #[builder(default)]
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub max_tokens: Option<u32>,
62
63 #[builder(default)]
66 #[serde(skip_serializing_if = "is_none_or_empty_stop")]
67 pub stop: Option<Stop>,
68
69 #[builder(default)]
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub stream: Option<bool>,
74
75 #[builder(default)]
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub stream_options: Option<StreamOptions>,
79
80 #[builder(default)]
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub suffix: Option<String>,
84
85 #[builder(default)]
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub temperature: Option<f64>,
95
96 #[builder(default)]
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub top_p: Option<f64>,
107}
108
109impl FIMCompletionRequestBuilder {
110 fn validate(&self) -> Result<(), String> {
111 if let Some(temperature) = self.temperature.flatten()
112 && !(0.0..=2.0).contains(&temperature) {
113 return Err("temperature must be between 0 and 2".to_string());
114 }
115 if let Some(logprobs) = self.logprobs.flatten()
116 && logprobs > 20 {
117 return Err("logprobs must be <= 20".to_string());
118 }
119
120 if let Some(top_p) = self.top_p.flatten()
121 && !(0.0..=1.0).contains(&top_p) {
122 return Err("top_p must be between 0 and 1".to_string());
123 }
124
125 if let Some(stream) = self.stream.flatten()
126 && !stream && self.stream_options.is_some() {
127 return Err("stream_options cannot be set when stream is false".to_string());
128 }
129
130 if let Some(stop) = self.stop.as_ref().and_then(|s| s.as_ref())
131 && let Stop::Many(values) = stop
132 && values.len() > 16 {
133 return Err("a maximum of 16 stop sequences are allowed".to_string());
134 }
135
136 Ok(())
137 }
138}
139
140#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
142pub struct CompletionChoice {
143 pub finish_reason: FinishReason,
151 pub index: u64,
152 pub text: String,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub logprobs: Option<Logprobs>,
155}
156
157#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
159#[serde(rename_all = "snake_case")]
160pub enum FinishReason {
161 Stop,
162 Length,
163 ContentFilter,
164 InsufficientSystemResources,
165}
166
167#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
169pub struct Logprobs {
170 pub text_offset: Vec<u64>,
171 pub token_logprobs: Vec<f64>,
172 pub tokens: Vec<String>,
173 pub top_logprobs: Option<Vec<HashMap<String, f64>>>,
174}
175#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
177pub struct CompletionChoiceStream {
178 pub finish_reason: Option<FinishReason>,
179 pub index: u64,
180 pub text: String,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 pub logprobs: Option<Logprobs>,
183}
184pub type CompletionStream = ChatGeneric<CompletionChoiceStream>;
195pub type CompletionStreamItem = Result<CompletionStream, DeepSeekError>;
197pub struct CompletionStreamBlocking {
199 rx: std_mpsc::Receiver<CompletionStreamItem>,
200}
201
202impl Iterator for CompletionStreamBlocking {
203 type Item = CompletionStreamItem;
204
205 fn next(&mut self) -> Option<Self::Item> {
206 self.rx.recv().ok()
207 }
208}
209impl DeepSeekRequest for FIMCompletionRequest {
210 type Response = Completion;
211 type StreamItem = CompletionStreamItem;
212 type BlockingStream = CompletionStreamBlocking;
213
214 async fn send(self) -> Result<Self::Response, DeepSeekError> {
215 let client = self.client.clone();
216 api_post("/completions", &self, client).await
217 }
218
219 async fn stream(self) -> Result<mpsc::Receiver<Self::StreamItem>, DeepSeekError> {
220 let mut request = self;
221 request.stream = Some(true);
222
223 let client = request.client.clone();
224 let mut event_source = api_request_stream(
225 Method::POST,
226 "/completions",
227 |builder| builder.json(&request),
228 client,
229 )
230 .await?;
231
232 let (tx, rx) = mpsc::channel(32);
233
234 tokio::spawn(async move {
235 while let Some(event) = event_source.next().await {
236 match event {
237 Ok(Event::Open) => {}
238 Ok(Event::Message(message)) => {
239 if message.data == "[DONE]" {
240 break;
241 }
242 match serde_json::from_str::<CompletionStream>(&message.data) {
243 Ok(chunk) => {
244 if tx.send(Ok(chunk)).await.is_err() {
245 break;
246 }
247 }
248 Err(err) => {
249 let _ = tx
250 .send(Err(DeepSeekError::decode(err.to_string(), message.data)))
251 .await;
252 break;
253 }
254 }
255 }
256 Err(err) => {
257 let _ = tx
258 .send(Err(DeepSeekError::decode(err.to_string(), String::new())))
259 .await;
260 break;
261 }
262 }
263 }
264 });
265
266 Ok(rx)
267 }
268
269 fn stream_blocking(self) -> Result<CompletionStreamBlocking, DeepSeekError> {
270 let (tx, rx) = std_mpsc::channel();
271
272 std::thread::spawn(move || {
273 let runtime = match tokio::runtime::Builder::new_current_thread()
274 .enable_all()
275 .build()
276 {
277 Ok(runtime) => runtime,
278 Err(err) => {
279 let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
280 return;
281 }
282 };
283
284 runtime.block_on(async move {
285 match self.stream().await {
286 Ok(mut stream_rx) => {
287 while let Some(item) = stream_rx.recv().await {
288 if tx.send(item).is_err() {
289 break;
290 }
291 }
292 }
293 Err(err) => {
294 let _ = tx.send(Err(err));
295 }
296 }
297 });
298 });
299
300 Ok(CompletionStreamBlocking { rx })
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::DEFAULT_BETA_BASE_URL;
308
309 fn get_client() -> DeepSeekClient {
310 DeepSeekClient::new(
311 std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
312 DEFAULT_BETA_BASE_URL.clone(),
313 )
314 }
315
316 fn get_fim_builder() -> FIMCompletionRequestBuilder {
317 FIMCompletionRequestBuilder::default()
318 .client(get_client())
319 .model("deepseek-v4-flash")
320 .max_tokens(64_u32)
321 }
322
323 #[tokio::test]
324 async fn test_fim_completion() {
325 let fim_request = get_fim_builder()
326 .prompt("def fib(a):")
327 .suffix(" return fib(a-1) + fib(a-2)")
328 .build()
329 .unwrap();
330 let response = fim_request.send().await.unwrap();
331 println!("{:#?}", response);
332 assert_eq!(response.object, "text_completion");
333 assert_eq!(response.model, "deepseek-v4-flash");
334 assert_eq!(response.choices.len(), 1);
335 }
336
337 #[tokio::test]
338 async fn test_fim_completion_stream() {
339 let fim_request = get_fim_builder()
340 .prompt("def fib(a):")
341 .suffix(" return fib(a-1) + fib(a-2)")
342 .stream(true)
343 .build()
344 .unwrap();
345 let mut stream = fim_request.stream().await.unwrap();
346 while let Some(item) = stream.recv().await {
347 match item {
348 Ok(chunk) => println!("Received chunk: {:#?}", chunk),
349 Err(err) => eprintln!("Stream error: {}", err),
350 }
351 }
352 }
353
354 #[tokio::test]
355 async fn test_fim_completion_stream_blocking() {
356 let fim_request = get_fim_builder()
357 .prompt("def fib(a):")
358 .suffix(" return fib(a-1) + fib(a-2)")
359 .stream(true)
360 .build()
361 .unwrap();
362 let mut stream = fim_request.stream_blocking().unwrap();
363 while let Some(item) = stream.next() {
364 match item {
365 Ok(chunk) => println!("Received chunk: {:#?}", chunk),
366 Err(err) => eprintln!("Stream error: {}", err),
367 }
368 }
369 }
370}