Skip to main content

ds_api/
request.rs

1//! 高级请求构建器模块
2//!
3//! 提供类型安全的 API 请求构建器,简化与 DeepSeek API 的交互。
4//!
5//! # 主要类型
6//!
7//! - [`Request`]: 主要的请求构建器,提供流畅的 API 来构建聊天补全请求
8//!
9//! # 示例
10//!
11//! ## 基本使用
12//!
13//! ```rust
14//! use ds_api::{Request, Message, Role};
15//!
16//! let request = Request::basic_query(vec![
17//!     Message::new(Role::User, "Hello, world!")
18//! ]);
19//! ```
20//!
21//! ## 使用构建器模式
22//!
23//! ```rust
24//! use ds_api::{Request, Message, Role};
25//!
26//! let request = Request::builder()
27//!     .add_message(Message::new(Role::System, "You are a helpful assistant."))
28//!     .add_message(Message::new(Role::User, "What is Rust?"))
29//!     .temperature(0.7)
30//!     .max_tokens(100);
31//! ```
32//!
33//! ## 流式响应
34//!
35//! ```rust,no_run
36//! use ds_api::{Request, Message, Role};
37//! use futures::StreamExt;
38//!
39//! # #[tokio::main]
40//! # async fn main() -> ds_api::error::Result<()> {
41//! let token = "your_token".to_string();
42//! let client = reqwest::Client::new();
43//!
44//! let request = Request::basic_query(vec![
45//!     Message::new(Role::User, "Tell me a story.")
46//! ]);
47//!
48//! let stream = request.execute_client_streaming(&client, &token).await?;
49//!
50//! // 使用 pin_mut! 宏来固定流
51//! use futures::pin_mut;
52//! pin_mut!(stream);
53//!
54//! while let Some(chunk_result) = stream.next().await {
55//!     match chunk_result {
56//!         Ok(chunk) => {
57//!             if let Some(content) = chunk.choices[0].delta.content.as_ref() {
58//!                 print!("{}", content);
59//!             }
60//!         }
61//!         Err(e) => eprintln!("Error: {}", e),
62//!     }
63//! }
64//! # Ok(())
65//! # }
66//! ```
67
68use crate::error::{ApiError, Result};
69pub use crate::raw::*;
70use eventsource_stream::Eventsource;
71use futures::Stream;
72use futures::StreamExt;
73
74/// 默认 API base url(不包含版本前缀,路径将使用 /chat/completions)
75const DEFAULT_API_BASE: &str = "https://api.deepseek.com";
76
77/// 一个发送至 Deepseek API 的请求对象,封装了原始请求数据。
78/// 该结构体保证请求合法
79pub struct Request {
80    raw: ChatCompletionRequest,
81}
82
83impl Request {
84    /// 创建一个基本的聊天请求,使用 DeepseekChat 模型。
85    /// 参数 `messages` 是一个消息列表,表示对话的上下文。
86    /// example:
87    /// ```
88    /// use ds_api::request::message::Role;
89    /// use ds_api::request::Message;
90    /// use ds_api::request::Request;
91    /// let request = Request::basic_query(vec![
92    ///    Message::new(Role::User, "What is the capital of France?")
93    /// ]);
94    /// ```
95    pub fn basic_query(messages: Vec<Message>) -> Self {
96        Self::builder()
97            .messages(messages)
98            .model(Model::DeepseekChat)
99    }
100
101    /// 创建一个基本的聊天请求,使用 DeepseekReasoner 模型。
102    /// 参数 `messages` 是一个消息列表,表示对话的上下文。
103    /// example:
104    /// ```
105    /// use ds_api::request::message::Role;
106    /// use ds_api::request::Message;
107    /// use ds_api::request::Request;
108    /// let request = Request::basic_query_reasoner(vec![
109    ///    Message::new(Role::User, "What is the capital of France?")
110    /// ]);
111    /// ```
112    pub fn basic_query_reasoner(messages: Vec<Message>) -> Self {
113        Self::builder()
114            .messages(messages)
115            .model(Model::DeepseekReasoner)
116    }
117
118    pub fn builder() -> Self {
119        Self {
120            raw: ChatCompletionRequest::default(),
121        }
122    }
123
124    pub fn add_message(mut self, message: Message) -> Self {
125        self.raw.messages.push(message);
126        self
127    }
128
129    pub fn messages(mut self, messages: Vec<Message>) -> Self {
130        self.raw.messages = messages;
131        self
132    }
133
134    pub fn model(mut self, model: Model) -> Self {
135        self.raw.model = model;
136        self
137    }
138
139    pub fn response_format_type(mut self, response_format_type: ResponseFormatType) -> Self {
140        self.raw.response_format = Some(ResponseFormat {
141            r#type: response_format_type,
142        });
143        self
144    }
145
146    pub fn json(self) -> Self {
147        self.response_format_type(ResponseFormatType::JsonObject)
148    }
149
150    pub fn text(self) -> Self {
151        self.response_format_type(ResponseFormatType::Text)
152    }
153
154    /// Possible values: >= -2 and <= 2
155    /// Default value: 0
156    /// 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
157    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
158        self.raw.frequency_penalty = Some(penalty);
159        self
160    }
161
162    /// Possible values: >= -2 and <= 2
163    /// Default value: 0
164    /// 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。
165    pub fn presence_penalty(mut self, penalty: f32) -> Self {
166        self.raw.presence_penalty = Some(penalty);
167        self
168    }
169
170    /// 限制一次请求中模型生成 completion 的最大 token 数。输入 token 和输出 token 的总长度受模型的上下文长度的限制。取值范围与默认值详见文档。
171    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
172        self.raw.max_tokens = Some(max_tokens);
173        self
174    }
175
176    /// Possible values: <= 2
177    /// Default value: 1
178    /// 采样温度,介于 0 和 2 之间。更高的值,如 0.8,会使输出更随机,而更低的值,如 0.2,会使其更加集中和确定。 我们通常建议可以更改这个值或者更改 top_p,但不建议同时对两者进行修改。
179    pub fn temperature(mut self, temperature: f32) -> Self {
180        self.raw.temperature = Some(temperature);
181        self
182    }
183
184    pub fn stop_vec(mut self, stop: Vec<String>) -> Self {
185        self.raw.stop = Some(Stop::Array(stop));
186        self
187    }
188
189    pub fn stop_str(mut self, stop: String) -> Self {
190        self.raw.stop = Some(Stop::String(stop));
191        self
192    }
193
194    /// Possible values: <= 1
195    /// Default value: 1
196    /// 作为调节采样温度的替代方案,模型会考虑前 top_p 概率的 token 的结果。所以 0.1 就意味着只有包括在最高 10% 概率中的 token 会被考虑。 我们通常建议修改这个值或者更改 temperature,但不建议同时对两者进行修改。
197    pub fn top_p(mut self, top_p: f32) -> Self {
198        self.raw.top_p = Some(top_p);
199        self
200    }
201
202    pub fn add_tool(mut self, tool: Tool) -> Self {
203        if let Some(tools) = &mut self.raw.tools {
204            tools.push(tool);
205        } else {
206            self.raw.tools = Some(vec![tool]);
207        }
208        self
209    }
210
211    pub fn tool_choice_type(mut self, tool_choice: ToolChoiceType) -> Self {
212        self.raw.tool_choice = Some(ToolChoice::String(tool_choice));
213        self
214    }
215
216    pub fn tool_choice_object(mut self, tool_choice: ToolChoiceObject) -> Self {
217        self.raw.tool_choice = Some(ToolChoice::Object(tool_choice));
218        self
219    }
220
221    /// top_logprobs: 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率
222    pub fn logprobs(mut self, top_logprobs: u32) -> Self {
223        self.raw.logprobs = Some(true);
224        self.raw.top_logprobs = Some(top_logprobs);
225        self
226    }
227
228    pub fn raw(&self) -> &ChatCompletionRequest {
229        &self.raw
230    }
231
232    /// 执行无流式(non-streaming)请求,使用指定的 `base_url`(会自动追加 `/chat/completions` path)。
233    /// 接收一个不可变的 `&reqwest::Client`,避免对外部 client 所有权的要求。
234    pub async fn execute_client_baseurl_nostreaming(
235        self,
236        client: &reqwest::Client,
237        base_url: &str,
238        token: &str,
239    ) -> Result<ChatCompletionResponse> {
240        // 构建 url(确保不会重复斜杠)
241        let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
242
243        let resp = client
244            .post(&url)
245            .bearer_auth(token)
246            .json(&self.raw)
247            .send()
248            .await?;
249
250        if !resp.status().is_success() {
251            let status = resp.status();
252            // 尝试读取响应体文本以便诊断,若读取失败则用错误字符串占位
253            let text = resp.text().await.unwrap_or_else(|e| e.to_string());
254            return Err(ApiError::http_error(status, text));
255        }
256
257        let parsed = resp.json::<ChatCompletionResponse>().await?;
258        Ok(parsed)
259    }
260
261    /// 便捷方法:使用默认 base URL 执行无流式请求(接收不可变的 `&Client`)。
262    pub async fn execute_client_nostreaming(
263        self,
264        client: &reqwest::Client,
265        token: &str,
266    ) -> Result<ChatCompletionResponse> {
267        self.execute_client_baseurl_nostreaming(client, DEFAULT_API_BASE, token)
268            .await
269    }
270
271    /// 在没有外部 `Client` 的情况下执行无流式请求(使用传入的 `base_url`)。
272    pub async fn execute_baseurl_nostreaming(
273        self,
274        base_url: &str,
275        token: &str,
276    ) -> Result<ChatCompletionResponse> {
277        let client = reqwest::Client::new();
278        self.execute_client_baseurl_nostreaming(&client, base_url, token)
279            .await
280    }
281
282    /// 使用默认 base URL 执行无流式请求(最常用的便捷方法)。
283    pub async fn execute_nostreaming(self, token: &str) -> Result<ChatCompletionResponse> {
284        self.execute_baseurl_nostreaming(DEFAULT_API_BASE, token)
285            .await
286    }
287
288    /// 执行流式(SSE)响应请求。使用 `DEFAULT_API_BASE` 构建 URL,并对流中可能出现的解析错误进行更加明确的映射。
289    ///
290    /// 返回一个 Stream,每个 Item 是 `Result<ChatCompletionChunk, ApiError>`。
291    /// Execute a streaming request (SSE) using a custom `base_url`.
292    /// 返回一个 Stream,每个 Item 是 `Result<ChatCompletionChunk, ApiError>`。
293    pub async fn execute_client_streaming_baseurl(
294        mut self,
295        client: &reqwest::Client,
296        base_url: &str,
297        token: &str,
298    ) -> Result<impl Stream<Item = std::result::Result<ChatCompletionChunk, ApiError>>> {
299        self.raw.stream = Some(true); // 确保请求中包含 stream: true
300
301        let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
302        let response = client
303            .post(&url)
304            .bearer_auth(token)
305            .json(&self.raw)
306            .send()
307            .await?;
308
309        if !response.status().is_success() {
310            let status = response.status();
311            let error_text = response.text().await.unwrap_or_else(|e| e.to_string());
312            return Err(ApiError::http_error(status, error_text));
313        }
314
315        // 将响应字节流转换为 SSE 事件流
316        let event_stream = response.bytes_stream().eventsource();
317
318        // 映射每个事件:
319        // - 如果是 Ok(event),判断 event.data:
320        //   - 若 data == "[DONE]",忽略(返回 None)
321        //   - 否则尝试反序列化为 ChatCompletionChunk
322        //     - 若解析成功 -> Some(Ok(chunk))
323        //     - 若解析失败 -> Some(Err(ApiError::Json(...)))
324        // - 如果 eventsource 返回错误 -> Some(Err(ApiError::Other(...)))
325        let chunk_stream = event_stream.filter_map(|event_result| async move {
326            match event_result {
327                Ok(event) => {
328                    if event.data == "[DONE]" {
329                        None
330                    } else {
331                        match serde_json::from_str::<ChatCompletionChunk>(&event.data) {
332                            Ok(chunk) => Some(Ok(chunk)),
333                            Err(e) => Some(Err(ApiError::Json(e))),
334                        }
335                    }
336                }
337                Err(e) => Some(Err(ApiError::EventSource(e.to_string()))),
338            }
339        });
340
341        Ok(chunk_stream)
342    }
343
344    /// Execute a streaming request (SSE) using the default API base URL.
345    /// Convenience wrapper that delegates to `execute_client_streaming_baseurl`.
346    pub async fn execute_client_streaming(
347        self,
348        client: &reqwest::Client,
349        token: &str,
350    ) -> Result<impl Stream<Item = std::result::Result<ChatCompletionChunk, ApiError>>> {
351        self.execute_client_streaming_baseurl(client, DEFAULT_API_BASE, token)
352            .await
353    }
354
355    /// # Safety
356    /// 该函数允许直接从原始请求数据创建一个 Request 对象,绕过了构建器的合法性检查。调用者必须确保提供的原始数据是合法且符合 API 要求的,否则可能导致请求失败或产生不可预期的行为。
357    pub unsafe fn from_raw_unchecked(raw: ChatCompletionRequest) -> Self {
358        Self { raw }
359    }
360
361    /// # Safety
362    /// 该函数返回对原始请求数据的可变引用,允许直接修改请求的各个字段。调用者必须确保在修改过程中保持请求数据的合法性和一致性,以避免产生无效的请求或引发错误。
363    pub unsafe fn get_raw_mut(&mut self) -> &mut ChatCompletionRequest {
364        &mut self.raw
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_hello_world_request() {
374        let request = Request::basic_query(vec![Message {
375            role: Role::User,
376            content: Some("Hello, world!".to_string()),
377            ..Default::default()
378        }]);
379
380        assert_eq!(request.raw().messages.len(), 1);
381        assert_eq!(
382            request.raw().messages[0].content.as_ref().unwrap(),
383            "Hello, world!"
384        );
385        assert!(matches!(request.raw().model, Model::DeepseekChat));
386    }
387}