1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use futures_util::{TryStreamExt, stream::BoxStream};
6
7use crate::errors::RequestError;
8
9#[derive(Serialize, Deserialize, Debug)]
10pub struct RequestBody {
11 pub messages: Vec<Message>,
12 pub model: String,
13 #[serde(skip_serializing_if = "Option::is_none")]
15 pub frequency_penalty: Option<f32>,
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub presence_penalty: Option<f32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub max_tokens: Option<u32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub response_format: Option<ResponseFormat>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub seed: Option<i64>,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub n: Option<u32>,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub stop: Option<StopKeywords>,
32 pub stream: bool,
35 pub stream_options: Option<StreamOptions>,
36 pub temperature: Option<f32>,
37 pub top_p: Option<f32>,
38 pub tools: Option<Vec<Tools>>,
39 pub tool_choice: Option<ToolChoice>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub logprobs: Option<bool>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub top_logprobs: Option<u32>,
44
45 #[serde(flatten)]
47 pub extra_body: Option<ExtraBody>,
48
49 #[serde(flatten)]
52 pub extra_body_map: Option<HashMap<String, String>>,
53}
54
55#[derive(Serialize, Deserialize, Debug)]
56#[serde(tag = "role", rename_all = "lowercase")]
57pub enum Message {
58 System {
59 content: String,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 name: Option<String>,
62 },
63 User {
64 content: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 name: Option<String>,
67 },
68 Assistant {
69 content: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 name: Option<String>,
72 #[serde(skip_serializing_if = "is_false")]
74 prefix: bool,
75 #[serde(skip_serializing_if = "Option::is_none")]
80 reasoning_content: Option<String>,
81 },
82}
83
84#[derive(Serialize, Deserialize, Debug)]
85pub enum ResponseFormat {
86 JsonObject,
87 Text,
88}
89
90fn is_false(value: &bool) -> bool {
91 !value
92}
93
94#[derive(Serialize, Deserialize, Debug)]
95#[serde(untagged)]
96pub enum StopKeywords {
97 Word(String),
98 Words(Vec<String>),
99}
100
101#[derive(Serialize, Deserialize, Debug)]
102pub struct StreamOptions {
103 pub include_usage: bool,
104}
105
106#[derive(Serialize, Deserialize, Debug)]
107pub struct Tools {
108 #[serde(rename = "type")]
109 pub type_: String,
110 pub function: Option<Vec<ToolFunction>>,
111}
112
113#[derive(Serialize, Deserialize, Debug)]
114pub struct ToolFunction {
115 name: String,
116 description: String,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 strict: Option<bool>,
119}
120
121#[derive(Serialize, Deserialize, Debug)]
122pub struct ToolFunctionParameter {
123 name: String,
124 description: String,
125 required: bool,
126 parameters: String,
127}
128
129#[derive(Serialize, Deserialize, Debug)]
130pub enum ToolChoice {
131 #[serde(rename = "none")]
132 None,
133 #[serde(rename = "auto")]
134 Auto,
135 #[serde(rename = "required")]
136 Required,
137 #[serde(untagged)]
138 Specific {
139 #[serde(rename = "type")]
141 type_: ToolChoiceSpecificType,
142 function: ToolChoiceFunction,
143 },
144}
145
146#[derive(Serialize, Deserialize, Debug)]
147pub struct ToolChoiceFunction {
148 pub name: String,
149}
150
151#[derive(Serialize, Deserialize, Debug)]
152#[serde(rename_all = "lowercase")]
153pub enum ToolChoiceSpecificType {
154 Function,
155}
156
157#[derive(Serialize, Deserialize, Debug)]
158pub struct ExtraBody {
159 #[serde(skip_serializing_if = "Option::is_none")]
161 pub enable_thinking: Option<bool>,
162 #[serde(skip_serializing_if = "Option::is_none")]
164 pub thinking_budget: Option<u32>,
165 #[serde(skip_serializing_if = "Option::is_none")]
169 pub top_k: Option<u32>,
170}
171
172impl Default for RequestBody {
173 fn default() -> Self {
174 RequestBody {
175 messages: vec![],
176 model: "deepseek-chat".to_string(),
177 frequency_penalty: None,
178 presence_penalty: None,
179 max_tokens: None,
180 response_format: None,
181 seed: None,
182 n: None,
183 stop: None,
184 stream: false,
185 stream_options: None,
186 temperature: None,
187 top_p: None,
188 tools: None,
189 tool_choice: None,
190 logprobs: None,
191 top_logprobs: None,
192 extra_body: None,
193 extra_body_map: None,
194 }
195 }
196}
197
198impl RequestBody {
199 pub async fn get_response(&self, url: &str, key: &str) -> anyhow::Result<String> {
200 assert!(!self.stream);
201
202 let client = reqwest::Client::new();
203 let response = client
204 .post(url)
205 .headers({
206 let mut headers = reqwest::header::HeaderMap::new();
207 headers.insert("Content-Type", "application/json".parse().unwrap());
208 headers.insert("Accept", "application/json".parse().unwrap());
209 headers
210 })
211 .bearer_auth(key)
212 .json(self)
213 .send()
214 .await
215 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
216
217 if response.status() != reqwest::StatusCode::OK {
218 return Err(
219 crate::errors::RequestError::ResponseStatus(response.status().as_u16()).into(),
220 );
221 }
222
223 let text = response.text().await?;
224
225 Ok(text)
226 }
227
228 pub async fn stream_response(
271 &self,
272 url: &str,
273 api_key: &str,
274 ) -> Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error> {
275 assert!(
277 self.stream,
278 "RequestBody::stream_response requires `stream: true`"
279 );
280
281 let client = reqwest::Client::new();
282
283 let response = client
284 .post(url)
285 .headers({
286 let mut headers = reqwest::header::HeaderMap::new();
287 headers.insert("Content-Type", "application/json".parse().unwrap());
288 headers.insert("Accept", "application/json".parse().unwrap());
289 headers
290 })
291 .bearer_auth(api_key)
292 .json(self)
293 .send()
294 .await
295 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
296
297 if !response.status().is_success() {
298 return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
299 }
300
301 let stream = response
302 .bytes_stream()
303 .map_err(|e| RequestError::StreamError(e.to_string()).into())
304 .try_filter_map(|bytes| async move {
305 let s = std::str::from_utf8(&bytes)
306 .map_err(|e| RequestError::SseParseError(e.to_string()))?;
307 if s.starts_with("[DONE]") {
308 Ok(None)
309 } else {
310 Ok(Some(s.to_string()))
311 }
312 });
313
314 Ok(Box::pin(stream) as BoxStream<'static, _>)
315
316 }
318}
319
320#[cfg(test)]
321mod request_test {
322 use std::sync::LazyLock;
323
324 use futures_util::StreamExt;
325
326 use crate::chat::request::{Message, RequestBody};
327
328 const DEEPSEEK_API_KEY: LazyLock<&str> =
329 LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
330 const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
331 const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
332
333 #[tokio::test]
334 async fn test_00_basics() {
335 let request = RequestBody {
336 messages: vec![
337 Message::System {
338 content: "This is a request of test purpose. Reply briefly".to_string(),
339 name: None,
340 },
341 Message::User {
342 content: "What's your name?".to_string(),
343 name: None,
344 },
345 ],
346 model: DEEPSEEK_MODEL.to_string(),
347 stream: false,
348 ..Default::default()
349 };
350
351 let response = request
352 .get_response(DEEPSEEK_CHAT_URL, &*DEEPSEEK_API_KEY)
353 .await
354 .unwrap();
355
356 println!("{}", response);
357
358 assert!(response.to_ascii_lowercase().contains("deepseek"));
359 }
360
361 #[tokio::test]
362 async fn test_01_streaming() {
363 let request = RequestBody {
364 messages: vec![
365 Message::System {
366 content: "This is a request of test purpose. Reply briefly".to_string(),
367 name: None,
368 },
369 Message::User {
370 content: "What's your name?".to_string(),
371 name: None,
372 },
373 ],
374 model: DEEPSEEK_MODEL.to_string(),
375 stream: true,
376 ..Default::default()
377 };
378
379 let mut response = request
380 .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
381 .await
382 .unwrap();
383
384 while let Some(chunk) = response.next().await {
385 println!("{}", chunk.unwrap());
386 }
387 }
388}