openai_interface/chat/request.rs
1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use futures_util::{TryStreamExt, stream::BoxStream};
6
7use crate::errors::RequestError;
8
9#[derive(Serialize, Deserialize, Debug)]
10pub struct RequestBody {
11 /// A list of messages comprising the conversation so far.
12 pub messages: Vec<Message>,
13 /// Name of the model to use to generate the response.
14 pub model: String,
15 /// Although it is optional, you should explicitly designate it
16 /// for an expected response.
17 pub stream: bool,
18 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
19 /// existing frequency in the text so far, decreasing the model's likelihood to
20 /// repeat the same line verbatim.
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub frequency_penalty: Option<f32>,
23 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
24 /// whether they appear in the text so far, increasing the model's likelihood to
25 /// talk about new topics.
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub presence_penalty: Option<f32>,
28 /// The maximum number of tokens that can be generated in the chat completion.
29 /// Deprecated according to OpenAI's Python SDK in favour of
30 /// `max_completion_tokens`.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub max_tokens: Option<u32>,
33 /// An upper bound for the number of tokens that can be generated for a completion,
34 /// including visible output tokens and reasoning tokens.
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub max_completion_tokens: Option<u32>,
37 /// specifying the format that the model must output.
38 ///
39 /// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
40 /// Outputs which ensures the model will match your supplied JSON schema. Learn more
41 /// in the
42 /// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
43 /// Setting to `{ "type": "json_object" }` enables the older JSON mode, which
44 /// ensures the message the model generates is valid JSON. Using `json_schema` is
45 /// preferred for models that support it.
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub response_format: Option<ResponseFormat>, // The type of this attribute needs improvements.
48 /// If specified, the system will make a best effort to sample deterministically. Determinism
49 /// is not guaranteed, and you should refer to the `system_fingerprint` response parameter to
50 /// monitor changes in the backend.
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub seed: Option<i64>,
53 /// How many chat completion choices to generate for each input message. Note that
54 /// you will be charged based on the number of generated tokens across all of the
55 /// choices. Keep `n` as `1` to minimize costs.
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub n: Option<u32>,
58 /// Up to 4 sequences where the API will stop generating further tokens. The
59 /// returned text will not contain the stop sequence.
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub stop: Option<StopKeywords>,
62 /// Options for streaming response. Only set this when you set `stream: true`
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub stream_options: Option<StreamOptions>,
65 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
66 /// make the output more random, while lower values like 0.2 will make it more
67 /// focused and deterministic. It is generally recommended to alter this or `top_p` but
68 /// not both.
69 pub temperature: Option<f32>,
70 /// An alternative to sampling with temperature, called nucleus sampling, where the
71 /// model considers the results of the tokens with top_p probability mass. So 0.1
72 /// means only the tokens comprising the top 10% probability mass are considered.
73 ///
74 /// It is generally recommended to alter this or `temperature` but not both.
75 pub top_p: Option<f32>,
76 /// A list of tools the model may call.
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub tools: Option<Vec<Tools>>,
79 /// Controls which (if any) tool is called by the model. `none` means the model will
80 /// not call any tool and instead generates a message. `auto` means the model can
81 /// pick between generating a message or calling one or more tools. `required` means
82 /// the model must call one or more tools. Specifying a particular tool via
83 /// `{"type": "function", "function": {"name": "my_function"}}` forces the model to
84 /// call that tool.
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub tool_choice: Option<ToolChoice>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub logprobs: Option<bool>,
89 /// An integer between 0 and 20 specifying the number of most likely tokens to
90 /// return at each token position, each with an associated log probability.
91 /// `logprobs` must be set to `true` if this parameter is used.
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub top_logprobs: Option<u32>,
94
95 /// Other request bodies that are not in standard OpenAI API.
96 #[serde(flatten, skip_serializing_if = "Option::is_none")]
97 pub extra_body: Option<ExtraBody>,
98
99 /// Other request bodies that are not in standard OpenAI API and
100 /// not included in the ExtraBody struct.
101 #[serde(flatten, skip_serializing_if = "Option::is_none")]
102 pub extra_body_map: Option<HashMap<String, String>>,
103}
104
105#[derive(Serialize, Deserialize, Debug)]
106#[serde(tag = "role", rename_all = "lowercase")]
107pub enum Message {
108 /// In this case, the role of the message author is `system`.
109 /// The field `{ role = "system" }` is added automatically.
110 System {
111 /// The contents of the system message.
112 content: String,
113 /// An optional name for the participant.
114 ///
115 /// Provides the model information to differentiate between
116 /// participants of the same role.
117 #[serde(skip_serializing_if = "Option::is_none")]
118 name: Option<String>,
119 },
120 /// In this case, the role of the message author is `user`.
121 /// The field `{ role = "user" }` is added automatically.
122 User {
123 /// The contents of the user message.
124 content: String,
125 /// An optional name for the participant.
126 ///
127 /// Provides the model information to differentiate between
128 /// participants of the same role.
129 #[serde(skip_serializing_if = "Option::is_none")]
130 name: Option<String>,
131 },
132 /// In this case, the role of the message author is `assistant`.
133 /// The field `{ role = "assistant" }` is added automatically.
134 ///
135 /// Unimplemented params:
136 /// - _audio_: Data about a previous audio response from the model.
137 Assistant {
138 /// The contents of the assistant message. Required unless `tool_calls`
139 /// or `function_call` is specified. (Note that `function_call` is deprecated
140 /// in favour of `tool_calls`.)
141 content: Option<String>,
142 /// The refusal message by the assistant.
143 #[serde(skip_serializing_if = "Option::is_none")]
144 refusal: Option<String>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 name: Option<String>,
147 /// Set this to true for completion
148 #[serde(skip_serializing_if = "is_false")]
149 prefix: bool,
150 /// Used for the deepseek-reasoner model in the Chat Prefix
151 /// Completion feature as the input for the CoT in the last
152 /// assistant message. When using this feature, the prefix
153 /// parameter must be set to true.
154 #[serde(skip_serializing_if = "Option::is_none")]
155 reasoning_content: Option<String>,
156
157 /// The tool calls generated by the model, such as function calls.
158 #[serde(skip_serializing_if = "Option::is_none")]
159 tool_calls: Option<Vec<AssistantToolCall>>,
160 },
161 /// In this case, the role of the message author is `assistant`.
162 /// The field `{ role = "tool" }` is added automatically.
163 Tool {
164 /// The contents of the tool message.
165 content: String,
166 /// Tool call that this message is responding to.
167 tool_call_id: String,
168 },
169}
170
171#[derive(Debug, Serialize, Deserialize)]
172#[serde(tag = "role", rename_all = "lowercase")]
173pub enum AssistantToolCall {
174 Function {
175 /// The ID of the tool call.
176 id: String,
177 /// The function that the model called.
178 function: ToolCallFunction,
179 },
180 Custom {
181 /// The ID of the tool call.
182 id: String,
183 /// The custom tool that the model called.
184 custom: ToolCallCustom,
185 },
186}
187
188#[derive(Debug, Serialize, Deserialize)]
189pub struct ToolCallFunction {
190 /// The arguments to call the function with, as generated by the model in JSON
191 /// format. Note that the model does not always generate valid JSON, and may
192 /// hallucinate parameters not defined by your function schema. Validate the
193 /// arguments in your code before calling your function.
194 arguments: String,
195 /// The name of the function to call.
196 name: String,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct ToolCallCustom {
201 /// The input for the custom tool call generated by the model.
202 input: String,
203 /// The name of the custom tool to call.
204 name: String,
205}
206
207#[derive(Serialize, Deserialize, Debug)]
208pub enum ResponseFormat {
209 JsonObject,
210 Text,
211}
212
213fn is_false(value: &bool) -> bool {
214 !value
215}
216
217#[derive(Serialize, Deserialize, Debug)]
218#[serde(untagged)]
219pub enum StopKeywords {
220 Word(String),
221 Words(Vec<String>),
222}
223
224#[derive(Serialize, Deserialize, Debug)]
225pub struct StreamOptions {
226 pub include_usage: bool,
227}
228
229#[derive(Serialize, Deserialize, Debug)]
230pub struct Tools {
231 #[serde(rename = "type")]
232 pub type_: String,
233 pub function: Option<Vec<ToolFunction>>,
234}
235
236#[derive(Serialize, Deserialize, Debug)]
237pub struct ToolFunction {
238 name: String,
239 description: String,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 strict: Option<bool>,
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct ToolFunctionParameter {
246 name: String,
247 description: String,
248 required: bool,
249 parameters: String,
250}
251
252#[derive(Serialize, Deserialize, Debug)]
253pub enum ToolChoice {
254 #[serde(rename = "none")]
255 None,
256 #[serde(rename = "auto")]
257 Auto,
258 #[serde(rename = "required")]
259 Required,
260 #[serde(untagged)]
261 Specific {
262 /// This parameter should always be "function" literal.
263 #[serde(rename = "type")]
264 type_: ToolChoiceSpecificType,
265 function: ToolChoiceFunction,
266 },
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct ToolChoiceFunction {
271 pub name: String,
272}
273
274#[derive(Serialize, Deserialize, Debug)]
275#[serde(rename_all = "lowercase")]
276pub enum ToolChoiceSpecificType {
277 Function,
278}
279
280#[derive(Serialize, Deserialize, Debug)]
281pub struct ExtraBody {
282 /// Make sense only for Qwen API.
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub enable_thinking: Option<bool>,
285 /// Make sense only for Qwen API.
286 #[serde(skip_serializing_if = "Option::is_none")]
287 pub thinking_budget: Option<u32>,
288 ///The size of the candidate set for sampling during generation.
289 ///
290 /// Make sense only for Qwen API.
291 #[serde(skip_serializing_if = "Option::is_none")]
292 pub top_k: Option<u32>,
293}
294
295impl Default for RequestBody {
296 fn default() -> Self {
297 RequestBody {
298 messages: vec![],
299 model: "deepseek-chat".to_string(),
300 frequency_penalty: None,
301 presence_penalty: None,
302 max_completion_tokens: None,
303 max_tokens: None,
304 response_format: None,
305 seed: None,
306 n: None,
307 stop: None,
308 stream: false,
309 stream_options: None,
310 temperature: None,
311 top_p: None,
312 tools: None,
313 tool_choice: None,
314 logprobs: None,
315 top_logprobs: None,
316 extra_body: None,
317 extra_body_map: None,
318 }
319 }
320}
321
322impl RequestBody {
323 pub async fn get_response(&self, url: &str, key: &str) -> anyhow::Result<String> {
324 assert!(!self.stream);
325
326 let client = reqwest::Client::new();
327 let response = client
328 .post(url)
329 .headers({
330 let mut headers = reqwest::header::HeaderMap::new();
331 headers.insert("Content-Type", "application/json".parse().unwrap());
332 headers.insert("Accept", "application/json".parse().unwrap());
333 headers
334 })
335 .bearer_auth(key)
336 .json(self)
337 .send()
338 .await
339 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
340
341 if response.status() != reqwest::StatusCode::OK {
342 return Err(
343 crate::errors::RequestError::ResponseStatus(response.status().as_u16()).into(),
344 );
345 }
346
347 let text = response.text().await?;
348
349 Ok(text)
350 }
351
352 /// Getting stream response. You must ensure self.stream is true, or otherwise it will panic.
353 ///
354 /// # Example
355 ///
356 /// ```rust
357 /// use std::sync::LazyLock;
358 /// use futures_util::StreamExt;
359 /// use openai_interface::chat::request::{Message, RequestBody};
360 ///
361 /// const DEEPSEEK_API_KEY: LazyLock<&str> =
362 /// LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
363 /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
364 /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
365 ///
366 /// #[tokio::main]
367 /// async fn main() {
368 /// let request = RequestBody {
369 /// messages: vec![
370 /// Message::System {
371 /// content: "This is a request of test purpose. Reply briefly".to_string(),
372 /// name: None,
373 /// },
374 /// Message::User {
375 /// content: "What's your name?".to_string(),
376 /// name: None,
377 /// },
378 /// ],
379 /// model: DEEPSEEK_MODEL.to_string(),
380 /// stream: true,
381 /// ..Default::default()
382 /// };
383 ///
384 /// let mut response = request
385 /// .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
386 /// .await
387 /// .unwrap();
388 ///
389 /// while let Some(chunk) = response.next().await {
390 /// println!("{}", chunk.unwrap());
391 /// }
392 /// }
393 /// ```
394 pub async fn stream_response(
395 &self,
396 url: &str,
397 api_key: &str,
398 ) -> Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error> {
399 // 断言开启了流模式
400 assert!(
401 self.stream,
402 "RequestBody::stream_response requires `stream: true`"
403 );
404
405 let client = reqwest::Client::new();
406
407 let response = client
408 .post(url)
409 .headers({
410 let mut headers = reqwest::header::HeaderMap::new();
411 headers.insert("Content-Type", "application/json".parse().unwrap());
412 headers.insert("Accept", "application/json".parse().unwrap());
413 headers
414 })
415 .bearer_auth(api_key)
416 .json(self)
417 .send()
418 .await
419 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
420
421 if !response.status().is_success() {
422 return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
423 }
424
425 let stream = response
426 .bytes_stream()
427 .map_err(|e| RequestError::StreamError(e.to_string()).into())
428 .try_filter_map(|bytes| async move {
429 let s = std::str::from_utf8(&bytes)
430 .map_err(|e| RequestError::SseParseError(e.to_string()))?;
431 if s.starts_with("[DONE]") {
432 Ok(None)
433 } else {
434 Ok(Some(s.to_string()))
435 }
436 });
437
438 Ok(Box::pin(stream) as BoxStream<'static, _>)
439
440 // return Err(anyhow!("Not implemented"));
441 }
442}
443
444#[cfg(test)]
445mod request_test {
446 use std::sync::LazyLock;
447
448 use futures_util::StreamExt;
449
450 use crate::chat::request::{Message, RequestBody};
451
452 const DEEPSEEK_API_KEY: LazyLock<&str> =
453 LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
454 const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
455 const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
456
457 #[tokio::test]
458 async fn test_00_basics() {
459 let request = RequestBody {
460 messages: vec![
461 Message::System {
462 content: "This is a request of test purpose. Reply briefly".to_string(),
463 name: None,
464 },
465 Message::User {
466 content: "What's your name?".to_string(),
467 name: None,
468 },
469 ],
470 model: DEEPSEEK_MODEL.to_string(),
471 stream: false,
472 ..Default::default()
473 };
474
475 let response = request
476 .get_response(DEEPSEEK_CHAT_URL, &*DEEPSEEK_API_KEY)
477 .await
478 .unwrap();
479
480 println!("{}", response);
481
482 assert!(response.to_ascii_lowercase().contains("deepseek"));
483 }
484
485 #[tokio::test]
486 async fn test_01_streaming() {
487 let request = RequestBody {
488 messages: vec![
489 Message::System {
490 content: "This is a request of test purpose. Reply briefly".to_string(),
491 name: None,
492 },
493 Message::User {
494 content: "What's your name?".to_string(),
495 name: None,
496 },
497 ],
498 model: DEEPSEEK_MODEL.to_string(),
499 stream: true,
500 ..Default::default()
501 };
502
503 let mut response = request
504 .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
505 .await
506 .unwrap();
507
508 while let Some(chunk) = response.next().await {
509 println!("{}", chunk.unwrap());
510 }
511 }
512}