use std::collections::HashMap;
use std::sync::mpsc as std_mpsc;
use crate::chat::request::{Stop, StreamOptions, is_none_or_empty_stop};
use crate::chat::response::ChatGeneric;
use crate::error::DeepSeekError;
use crate::{DeepSeekClient, api_request_stream};
use crate::{DeepSeekRequest, api_post};
use derive_builder::Builder;
use futures_util::StreamExt;
use reqwest::Method;
use reqwest_eventsource::Event;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
pub type Completion = ChatGeneric<CompletionChoice>;
#[derive(Clone, Debug, PartialEq, Serialize, Builder)]
#[builder(
pattern = "owned",
setter(into, strip_option),
build_fn(validate = "Self::validate"),
name = "FIMCompletionRequestBuilder"
)]
pub struct FIMCompletionRequest {
#[serde(skip_serializing)]
pub client: DeepSeekClient,
pub model: String,
pub prompt: String,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub echo: Option<bool>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[builder(default)]
#[serde(skip_serializing_if = "is_none_or_empty_stop")]
pub stop: Option<Stop>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}
impl FIMCompletionRequestBuilder {
fn validate(&self) -> Result<(), String> {
if let Some(temperature) = self.temperature.flatten()
&& !(0.0..=2.0).contains(&temperature) {
return Err("temperature must be between 0 and 2".to_string());
}
if let Some(logprobs) = self.logprobs.flatten()
&& logprobs > 20 {
return Err("logprobs must be <= 20".to_string());
}
if let Some(top_p) = self.top_p.flatten()
&& !(0.0..=1.0).contains(&top_p) {
return Err("top_p must be between 0 and 1".to_string());
}
if let Some(stream) = self.stream.flatten()
&& !stream && self.stream_options.is_some() {
return Err("stream_options cannot be set when stream is false".to_string());
}
if let Some(stop) = self.stop.as_ref().and_then(|s| s.as_ref())
&& let Stop::Many(values) = stop
&& values.len() > 16 {
return Err("a maximum of 16 stop sequences are allowed".to_string());
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct CompletionChoice {
pub finish_reason: FinishReason,
pub index: u64,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Logprobs>,
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ContentFilter,
InsufficientSystemResources,
}
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct Logprobs {
pub text_offset: Vec<u64>,
pub token_logprobs: Vec<f64>,
pub tokens: Vec<String>,
pub top_logprobs: Option<Vec<HashMap<String, f64>>>,
}
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct CompletionChoiceStream {
pub finish_reason: Option<FinishReason>,
pub index: u64,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Logprobs>,
}
pub type CompletionStream = ChatGeneric<CompletionChoiceStream>;
pub type CompletionStreamItem = Result<CompletionStream, DeepSeekError>;
pub struct CompletionStreamBlocking {
rx: std_mpsc::Receiver<CompletionStreamItem>,
}
impl Iterator for CompletionStreamBlocking {
type Item = CompletionStreamItem;
fn next(&mut self) -> Option<Self::Item> {
self.rx.recv().ok()
}
}
impl DeepSeekRequest for FIMCompletionRequest {
type Response = Completion;
type StreamItem = CompletionStreamItem;
type BlockingStream = CompletionStreamBlocking;
async fn send(self) -> Result<Self::Response, DeepSeekError> {
let client = self.client.clone();
api_post("/completions", &self, client).await
}
async fn stream(self) -> Result<mpsc::Receiver<Self::StreamItem>, DeepSeekError> {
let mut request = self;
request.stream = Some(true);
let client = request.client.clone();
let mut event_source = api_request_stream(
Method::POST,
"/completions",
|builder| builder.json(&request),
client,
)
.await?;
let (tx, rx) = mpsc::channel(32);
tokio::spawn(async move {
while let Some(event) = event_source.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
if message.data == "[DONE]" {
break;
}
match serde_json::from_str::<CompletionStream>(&message.data) {
Ok(chunk) => {
if tx.send(Ok(chunk)).await.is_err() {
break;
}
}
Err(err) => {
let _ = tx
.send(Err(DeepSeekError::decode(err.to_string(), message.data)))
.await;
break;
}
}
}
Err(err) => {
let _ = tx
.send(Err(DeepSeekError::decode(err.to_string(), String::new())))
.await;
break;
}
}
}
});
Ok(rx)
}
fn stream_blocking(self) -> Result<CompletionStreamBlocking, DeepSeekError> {
let (tx, rx) = std_mpsc::channel();
std::thread::spawn(move || {
let runtime = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(runtime) => runtime,
Err(err) => {
let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
return;
}
};
runtime.block_on(async move {
match self.stream().await {
Ok(mut stream_rx) => {
while let Some(item) = stream_rx.recv().await {
if tx.send(item).is_err() {
break;
}
}
}
Err(err) => {
let _ = tx.send(Err(err));
}
}
});
});
Ok(CompletionStreamBlocking { rx })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DEFAULT_BETA_BASE_URL;
fn get_client() -> DeepSeekClient {
DeepSeekClient::new(
std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
DEFAULT_BETA_BASE_URL.clone(),
)
}
fn get_fim_builder() -> FIMCompletionRequestBuilder {
FIMCompletionRequestBuilder::default()
.client(get_client())
.model("deepseek-v4-flash")
.max_tokens(64_u32)
}
#[tokio::test]
async fn test_fim_completion() {
let fim_request = get_fim_builder()
.prompt("def fib(a):")
.suffix(" return fib(a-1) + fib(a-2)")
.build()
.unwrap();
let response = fim_request.send().await.unwrap();
println!("{:#?}", response);
assert_eq!(response.object, "text_completion");
assert_eq!(response.model, "deepseek-v4-flash");
assert_eq!(response.choices.len(), 1);
}
#[tokio::test]
async fn test_fim_completion_stream() {
let fim_request = get_fim_builder()
.prompt("def fib(a):")
.suffix(" return fib(a-1) + fib(a-2)")
.stream(true)
.build()
.unwrap();
let mut stream = fim_request.stream().await.unwrap();
while let Some(item) = stream.recv().await {
match item {
Ok(chunk) => println!("Received chunk: {:#?}", chunk),
Err(err) => eprintln!("Stream error: {}", err),
}
}
}
#[tokio::test]
async fn test_fim_completion_stream_blocking() {
let fim_request = get_fim_builder()
.prompt("def fib(a):")
.suffix(" return fib(a-1) + fib(a-2)")
.stream(true)
.build()
.unwrap();
let mut stream = fim_request.stream_blocking().unwrap();
while let Some(item) = stream.next() {
match item {
Ok(chunk) => println!("Received chunk: {:#?}", chunk),
Err(err) => eprintln!("Stream error: {}", err),
}
}
}
}