1use std::collections::HashMap;
6use std::sync::mpsc as std_mpsc;
7
8use crate::chat::request::{Stop, StreamOptions, is_none_or_empty_stop};
9use crate::chat::response::ChatGeneric;
10use crate::error::DeepSeekError;
11use crate::{DeepSeekClient, api_request_stream};
12use crate::{DeepSeekRequest, api_post};
13use derive_builder::Builder;
14use futures_util::StreamExt;
15use reqwest::Method;
16use reqwest_eventsource::Event;
17use serde::{Deserialize, Serialize};
18use tokio::sync::mpsc;
19
20pub type Completion = ChatGeneric<CompletionChoice>;
22
23#[derive(Clone, Debug, Serialize, Builder)]
25#[builder(
26 pattern = "owned",
27 setter(into, strip_option),
28 build_fn(validate = "Self::validate"),
29 name = "FIMCompletionRequestBuilder"
30)]
31pub struct FIMCompletionRequest {
32 #[serde(skip_serializing)]
33 pub client: DeepSeekClient,
34
35 pub model: String,
39
40 pub prompt: String,
42
43 #[builder(default)]
45 pub echo: Option<bool>,
46
47 #[builder(default)]
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub logprobs: Option<u32>,
56
57 #[builder(default)]
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub max_tokens: Option<u32>,
61
62 #[builder(default)]
65 #[serde(skip_serializing_if = "is_none_or_empty_stop")]
66 pub stop: Option<Stop>,
67
68 #[builder(default)]
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub stream: Option<bool>,
73
74 #[builder(default)]
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub stream_options: Option<StreamOptions>,
78
79 #[builder(default)]
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub suffix: Option<String>,
83
84 #[builder(default)]
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub temperature: Option<f32>,
94
95 #[builder(default)]
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub top_p: Option<f64>,
106}
107
108impl FIMCompletionRequestBuilder {
109 fn validate(&self) -> Result<(), String> {
110 if let Some(temperature) = self.temperature.flatten() {
111 if !(0.0..=2.0).contains(&temperature) {
112 return Err("temperature must be between 0 and 2".to_string());
113 }
114 }
115 if let Some(logprobs) = self.logprobs.flatten() {
116 if logprobs > 20 {
117 return Err("logprobs must be <= 20".to_string());
118 }
119 }
120
121 if let Some(top_p) = self.top_p.flatten() {
122 if !(0.0..=1.0).contains(&top_p) {
123 return Err("top_p must be between 0 and 1".to_string());
124 }
125 }
126
127 if let Some(stream) = self.stream.flatten() {
128 if !stream && self.stream_options.is_some() {
129 return Err("stream_options cannot be set when stream is false".to_string());
130 }
131 }
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug, PartialEq, Deserialize)]
138pub struct CompletionChoice {
139 pub finish_reason: FinishReason,
147 pub index: u64,
148 pub text: String,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub logprobs: Option<Logprobs>,
151}
152
153#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
155#[serde(rename_all = "snake_case")]
156pub enum FinishReason {
157 Stop,
158 Length,
159 ContentFilter,
160 InsufficientSystemResources,
161}
162
163#[derive(Clone, Debug, PartialEq, Deserialize)]
165pub struct Logprobs {
166 pub text_offset: Vec<u64>,
167 pub token_logprobs: Vec<f64>,
168 pub tokens: Vec<String>,
169 pub top_logprobs: Option<Vec<HashMap<String, f64>>>,
170}
171#[derive(Clone, Debug, PartialEq, Deserialize)]
173pub struct CompletionChoiceStream {
174 pub finish_reason: Option<FinishReason>,
175 pub index: u64,
176 pub text: String,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 pub logprobs: Option<Logprobs>,
179}
180pub type CompletionStream = ChatGeneric<CompletionChoiceStream>;
191pub type CompletionStreamItem = Result<CompletionStream, DeepSeekError>;
193pub struct CompletionStreamBlocking {
195 rx: std_mpsc::Receiver<CompletionStreamItem>,
196}
197
198impl Iterator for CompletionStreamBlocking {
199 type Item = CompletionStreamItem;
200
201 fn next(&mut self) -> Option<Self::Item> {
202 self.rx.recv().ok()
203 }
204}
205impl DeepSeekRequest for FIMCompletionRequest {
206 type Response = Completion;
207 type StreamItem = CompletionStreamItem;
208 type BlockingStream = CompletionStreamBlocking;
209
210 async fn send(self) -> Result<Self::Response, DeepSeekError> {
211 let client = self.client.clone();
212 api_post("/completions", &self, client).await
213 }
214
215 async fn stream(self) -> Result<mpsc::Receiver<Self::StreamItem>, DeepSeekError> {
216 let mut request = self;
217 request.stream = Some(true);
218
219 let client = request.client.clone();
220 let mut event_source = api_request_stream(
221 Method::POST,
222 "/completions",
223 |builder| builder.json(&request),
224 client,
225 )
226 .await?;
227
228 let (tx, rx) = mpsc::channel(32);
229
230 tokio::spawn(async move {
231 while let Some(event) = event_source.next().await {
232 match event {
233 Ok(Event::Open) => {}
234 Ok(Event::Message(message)) => {
235 if message.data == "[DONE]" {
236 break;
237 }
238 match serde_json::from_str::<CompletionStream>(&message.data) {
239 Ok(chunk) => {
240 if tx.send(Ok(chunk)).await.is_err() {
241 break;
242 }
243 }
244 Err(err) => {
245 let _ = tx
246 .send(Err(DeepSeekError::decode(err.to_string(), message.data)))
247 .await;
248 break;
249 }
250 }
251 }
252 Err(err) => {
253 let _ = tx
254 .send(Err(DeepSeekError::decode(err.to_string(), String::new())))
255 .await;
256 break;
257 }
258 }
259 }
260 });
261
262 Ok(rx)
263 }
264
265 fn stream_blocking(self) -> Result<CompletionStreamBlocking, DeepSeekError> {
266 let (tx, rx) = std_mpsc::channel();
267
268 std::thread::spawn(move || {
269 let runtime = match tokio::runtime::Builder::new_current_thread()
270 .enable_all()
271 .build()
272 {
273 Ok(runtime) => runtime,
274 Err(err) => {
275 let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
276 return;
277 }
278 };
279
280 runtime.block_on(async move {
281 match self.stream().await {
282 Ok(mut stream_rx) => {
283 while let Some(item) = stream_rx.recv().await {
284 if tx.send(item).is_err() {
285 break;
286 }
287 }
288 }
289 Err(err) => {
290 let _ = tx.send(Err(err));
291 }
292 }
293 });
294 });
295
296 Ok(CompletionStreamBlocking { rx })
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::DEFAULT_BETA_BASE_URL;
304
305 fn get_client() -> DeepSeekClient {
306 DeepSeekClient::new(
307 std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
308 DEFAULT_BETA_BASE_URL.clone(),
309 )
310 }
311
312 fn get_fim_builder() -> FIMCompletionRequestBuilder {
313 FIMCompletionRequestBuilder::default()
314 .client(get_client())
315 .model("deepseek-v4-flash")
316 .max_tokens(64_u32)
317 }
318
319 #[tokio::test]
320 async fn test_fim_completion() {
321 let fim_request = get_fim_builder()
322 .prompt("def fib(a):")
323 .suffix(" return fib(a-1) + fib(a-2)")
324 .build()
325 .unwrap();
326 let response = fim_request.send().await.unwrap();
327 println!("{:#?}", response);
328 assert_eq!(response.object, "text_completion");
329 assert_eq!(response.model, "deepseek-v4-flash");
330 assert_eq!(response.choices.len(), 1);
331 }
332
333 #[tokio::test]
334 async fn test_fim_completion_stream() {
335 let fim_request = get_fim_builder()
336 .prompt("def fib(a):")
337 .suffix(" return fib(a-1) + fib(a-2)")
338 .stream(true)
339 .build()
340 .unwrap();
341 let mut stream = fim_request.stream().await.unwrap();
342 while let Some(item) = stream.recv().await {
343 match item {
344 Ok(chunk) => println!("Received chunk: {:#?}", chunk),
345 Err(err) => eprintln!("Stream error: {}", err),
346 }
347 }
348 }
349
350 #[tokio::test]
351 async fn test_fim_completion_stream_blocking() {
352 let fim_request = get_fim_builder()
353 .prompt("def fib(a):")
354 .suffix(" return fib(a-1) + fib(a-2)")
355 .stream(true)
356 .build()
357 .unwrap();
358 let mut stream = fim_request.stream_blocking().unwrap();
359 while let Some(item) = stream.next() {
360 match item {
361 Ok(chunk) => println!("Received chunk: {:#?}", chunk),
362 Err(err) => eprintln!("Stream error: {}", err),
363 }
364 }
365 }
366}