1use crate::error::{ApiError, Result};
69pub use crate::raw::*;
70use eventsource_stream::Eventsource;
71use futures::Stream;
72use futures::StreamExt;
73
74const DEFAULT_API_BASE: &str = "https://api.deepseek.com";
76
77pub struct Request {
80 raw: ChatCompletionRequest,
81}
82
83impl Request {
84 pub fn basic_query(messages: Vec<Message>) -> Self {
96 Self::builder()
97 .messages(messages)
98 .model(Model::DeepseekChat)
99 }
100
101 pub fn basic_query_reasoner(messages: Vec<Message>) -> Self {
113 Self::builder()
114 .messages(messages)
115 .model(Model::DeepseekReasoner)
116 }
117
118 pub fn builder() -> Self {
119 Self {
120 raw: ChatCompletionRequest::default(),
121 }
122 }
123
124 pub fn add_message(mut self, message: Message) -> Self {
125 self.raw.messages.push(message);
126 self
127 }
128
129 pub fn messages(mut self, messages: Vec<Message>) -> Self {
130 self.raw.messages = messages;
131 self
132 }
133
134 pub fn model(mut self, model: Model) -> Self {
135 self.raw.model = model;
136 self
137 }
138
139 pub fn response_format_type(mut self, response_format_type: ResponseFormatType) -> Self {
140 self.raw.response_format = Some(ResponseFormat {
141 r#type: response_format_type,
142 });
143 self
144 }
145
146 pub fn json(self) -> Self {
147 self.response_format_type(ResponseFormatType::JsonObject)
148 }
149
150 pub fn text(self) -> Self {
151 self.response_format_type(ResponseFormatType::Text)
152 }
153
154 pub fn frequency_penalty(mut self, penalty: f32) -> Self {
158 self.raw.frequency_penalty = Some(penalty);
159 self
160 }
161
162 pub fn presence_penalty(mut self, penalty: f32) -> Self {
166 self.raw.presence_penalty = Some(penalty);
167 self
168 }
169
170 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
172 self.raw.max_tokens = Some(max_tokens);
173 self
174 }
175
176 pub fn temperature(mut self, temperature: f32) -> Self {
180 self.raw.temperature = Some(temperature);
181 self
182 }
183
184 pub fn stop_vec(mut self, stop: Vec<String>) -> Self {
185 self.raw.stop = Some(Stop::Array(stop));
186 self
187 }
188
189 pub fn stop_str(mut self, stop: String) -> Self {
190 self.raw.stop = Some(Stop::String(stop));
191 self
192 }
193
194 pub fn top_p(mut self, top_p: f32) -> Self {
198 self.raw.top_p = Some(top_p);
199 self
200 }
201
202 pub fn add_tool(mut self, tool: Tool) -> Self {
203 if let Some(tools) = &mut self.raw.tools {
204 tools.push(tool);
205 } else {
206 self.raw.tools = Some(vec![tool]);
207 }
208 self
209 }
210
211 pub fn tool_choice_type(mut self, tool_choice: ToolChoiceType) -> Self {
212 self.raw.tool_choice = Some(ToolChoice::String(tool_choice));
213 self
214 }
215
216 pub fn tool_choice_object(mut self, tool_choice: ToolChoiceObject) -> Self {
217 self.raw.tool_choice = Some(ToolChoice::Object(tool_choice));
218 self
219 }
220
221 pub fn logprobs(mut self, top_logprobs: u32) -> Self {
223 self.raw.logprobs = Some(true);
224 self.raw.top_logprobs = Some(top_logprobs);
225 self
226 }
227
228 pub fn raw(&self) -> &ChatCompletionRequest {
229 &self.raw
230 }
231
232 pub async fn execute_client_baseurl_nostreaming(
235 self,
236 client: &reqwest::Client,
237 base_url: &str,
238 token: &str,
239 ) -> Result<ChatCompletionResponse> {
240 let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
242
243 let resp = client
244 .post(&url)
245 .bearer_auth(token)
246 .json(&self.raw)
247 .send()
248 .await?;
249
250 if !resp.status().is_success() {
251 let status = resp.status();
252 let text = resp.text().await.unwrap_or_else(|e| e.to_string());
254 return Err(ApiError::http_error(status, text));
255 }
256
257 let parsed = resp.json::<ChatCompletionResponse>().await?;
258 Ok(parsed)
259 }
260
261 pub async fn execute_client_nostreaming(
263 self,
264 client: &reqwest::Client,
265 token: &str,
266 ) -> Result<ChatCompletionResponse> {
267 self.execute_client_baseurl_nostreaming(client, DEFAULT_API_BASE, token)
268 .await
269 }
270
271 pub async fn execute_baseurl_nostreaming(
273 self,
274 base_url: &str,
275 token: &str,
276 ) -> Result<ChatCompletionResponse> {
277 let client = reqwest::Client::new();
278 self.execute_client_baseurl_nostreaming(&client, base_url, token)
279 .await
280 }
281
282 pub async fn execute_nostreaming(self, token: &str) -> Result<ChatCompletionResponse> {
284 self.execute_baseurl_nostreaming(DEFAULT_API_BASE, token)
285 .await
286 }
287
288 pub async fn execute_client_streaming_baseurl(
294 mut self,
295 client: &reqwest::Client,
296 base_url: &str,
297 token: &str,
298 ) -> Result<impl Stream<Item = std::result::Result<ChatCompletionChunk, ApiError>>> {
299 self.raw.stream = Some(true); let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
302 let response = client
303 .post(&url)
304 .bearer_auth(token)
305 .json(&self.raw)
306 .send()
307 .await?;
308
309 if !response.status().is_success() {
310 let status = response.status();
311 let error_text = response.text().await.unwrap_or_else(|e| e.to_string());
312 return Err(ApiError::http_error(status, error_text));
313 }
314
315 let event_stream = response.bytes_stream().eventsource();
317
318 let chunk_stream = event_stream.filter_map(|event_result| async move {
326 match event_result {
327 Ok(event) => {
328 if event.data == "[DONE]" {
329 None
330 } else {
331 match serde_json::from_str::<ChatCompletionChunk>(&event.data) {
332 Ok(chunk) => Some(Ok(chunk)),
333 Err(e) => Some(Err(ApiError::Json(e))),
334 }
335 }
336 }
337 Err(e) => Some(Err(ApiError::EventSource(e.to_string()))),
338 }
339 });
340
341 Ok(chunk_stream)
342 }
343
344 pub async fn execute_client_streaming(
347 self,
348 client: &reqwest::Client,
349 token: &str,
350 ) -> Result<impl Stream<Item = std::result::Result<ChatCompletionChunk, ApiError>>> {
351 self.execute_client_streaming_baseurl(client, DEFAULT_API_BASE, token)
352 .await
353 }
354
355 pub unsafe fn from_raw_unchecked(raw: ChatCompletionRequest) -> Self {
358 Self { raw }
359 }
360
361 pub unsafe fn get_raw_mut(&mut self) -> &mut ChatCompletionRequest {
364 &mut self.raw
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_hello_world_request() {
374 let request = Request::basic_query(vec![Message {
375 role: Role::User,
376 content: Some("Hello, world!".to_string()),
377 ..Default::default()
378 }]);
379
380 assert_eq!(request.raw().messages.len(), 1);
381 assert_eq!(
382 request.raw().messages[0].content.as_ref().unwrap(),
383 "Hello, world!"
384 );
385 assert!(matches!(request.raw().model, Model::DeepseekChat));
386 }
387}