openai_interface/chat/
request.rs

1//! This module contains the request body and POST method for the chat completion API.
2
3use futures_util::{TryStreamExt, stream::BoxStream};
4use serde::Serialize;
5
6use crate::errors::RequestError;
7
8#[derive(Serialize, Debug, Default)]
9pub struct RequestBody {
10    /// A list of messages comprising the conversation so far.
11    pub messages: Vec<Message>,
12
13    /// Name of the model to use to generate the response.
14    pub model: String,
15
16    /// Although it is optional, you should explicitly designate it
17    /// for an expected response.
18    pub stream: bool,
19
20    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
21    /// existing frequency in the text so far, decreasing the model's likelihood to
22    /// repeat the same line verbatim.
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub frequency_penalty: Option<f32>,
25
26    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
27    /// whether they appear in the text so far, increasing the model's likelihood to
28    /// talk about new topics.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub presence_penalty: Option<f32>,
31
32    /// The maximum number of tokens that can be generated in the chat completion.
33    /// Deprecated according to OpenAI's Python SDK in favour of
34    /// `max_completion_tokens`.
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub max_tokens: Option<u32>,
37
38    /// An upper bound for the number of tokens that can be generated for a completion,
39    /// including visible output tokens and reasoning tokens.
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub max_completion_tokens: Option<u32>,
42
43    /// specifying the format that the model must output.
44    ///
45    /// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
46    /// Outputs which ensures the model will match your supplied JSON schema. Learn more
47    /// in the
48    /// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
49    /// Setting to `{ "type": "json_object" }` enables the older JSON mode, which
50    /// ensures the message the model generates is valid JSON. Using `json_schema` is
51    /// preferred for models that support it.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub response_format: Option<ResponseFormat>, // The type of this attribute needs improvements.
54
55    /// A stable identifier used to help detect users of your application that may be
56    /// violating OpenAI's usage policies. The IDs should be a string that uniquely
57    /// identifies each user. It is recommended to hash their username or email address, in
58    /// order to avoid sending any identifying information.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub safety_identifier: Option<String>,
61
62    /// If specified, the system will make a best effort to sample deterministically. Determinism
63    /// is not guaranteed, and you should refer to the `system_fingerprint` response parameter to
64    /// monitor changes in the backend.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub seed: Option<i64>,
67
68    /// How many chat completion choices to generate for each input message. Note that
69    /// you will be charged based on the number of generated tokens across all of the
70    /// choices. Keep `n` as `1` to minimize costs.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub n: Option<u32>,
73
74    /// Up to 4 sequences where the API will stop generating further tokens. The
75    /// returned text will not contain the stop sequence.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub stop: Option<StopKeywords>,
78
79    /// Options for streaming response. Only set this when you set `stream: true`
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub stream_options: Option<StreamOptions>,
82
83    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
84    /// make the output more random, while lower values like 0.2 will make it more
85    /// focused and deterministic. It is generally recommended to alter this or `top_p` but
86    /// not both.
87    pub temperature: Option<f32>,
88
89    /// An alternative to sampling with temperature, called nucleus sampling, where the
90    /// model considers the results of the tokens with top_p probability mass. So 0.1
91    /// means only the tokens comprising the top 10% probability mass are considered.
92    ///
93    /// It is generally recommended to alter this or `temperature` but not both.
94    pub top_p: Option<f32>,
95
96    /// A list of tools the model may call.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub tools: Option<Vec<RequestTool>>,
99
100    /// Controls which (if any) tool is called by the model. `none` means the model will
101    /// not call any tool and instead generates a message. `auto` means the model can
102    /// pick between generating a message or calling one or more tools. `required` means
103    /// the model must call one or more tools. Specifying a particular tool via
104    /// `{"type": "function", "function": {"name": "my_function"}}` forces the model to
105    /// call that tool.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub tool_choice: Option<ToolChoice>,
108
109    /// Whether to return log probabilities of the output tokens or not. If true,
110    /// returns the log probabilities of each output token returned in the `content` of
111    /// `message`.
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub logprobs: Option<bool>,
114
115    /// An integer between 0 and 20 specifying the number of most likely tokens to
116    /// return at each token position, each with an associated log probability.
117    /// `logprobs` must be set to `true` if this parameter is used.
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub top_logprobs: Option<u32>,
120
121    /// Other request bodies that are not in standard OpenAI API.
122    #[serde(flatten, skip_serializing_if = "Option::is_none")]
123    pub extra_body: Option<ExtraBody>,
124
125    /// Other request bodies that are not in standard OpenAI API and
126    /// not included in the ExtraBody struct.
127    #[serde(flatten, skip_serializing_if = "Option::is_none")]
128    pub extra_body_map: Option<serde_json::Map<String, serde_json::Value>>,
129}
130
131#[derive(Serialize, Debug)]
132#[serde(tag = "role", rename_all = "lowercase")]
133pub enum Message {
134    /// In this case, the role of the message author is `system`.
135    /// The field `{ role = "system" }` is added automatically.
136    System {
137        /// The contents of the system message.
138        content: String,
139        /// An optional name for the participant.
140        ///
141        /// Provides the model information to differentiate between
142        /// participants of the same role.
143        #[serde(skip_serializing_if = "Option::is_none")]
144        name: Option<String>,
145    },
146    /// In this case, the role of the message author is `user`.
147    /// The field `{ role = "user" }` is added automatically.
148    User {
149        /// The contents of the user message.
150        content: String,
151        /// An optional name for the participant.
152        ///
153        /// Provides the model information to differentiate between
154        /// participants of the same role.
155        #[serde(skip_serializing_if = "Option::is_none")]
156        name: Option<String>,
157    },
158    /// In this case, the role of the message author is `assistant`.
159    /// The field `{ role = "assistant" }` is added automatically.
160    ///
161    /// Unimplemented params:
162    /// - _audio_: Data about a previous audio response from the model.
163    Assistant {
164        /// The contents of the assistant message. Required unless `tool_calls`
165        /// or `function_call` is specified. (Note that `function_call` is deprecated
166        /// in favour of `tool_calls`.)
167        content: Option<String>,
168        /// The refusal message by the assistant.
169        #[serde(skip_serializing_if = "Option::is_none")]
170        refusal: Option<String>,
171        #[serde(skip_serializing_if = "Option::is_none")]
172        name: Option<String>,
173        /// Set this to true for completion
174        #[serde(skip_serializing_if = "is_false")]
175        prefix: bool,
176        /// Used for the deepseek-reasoner model in the Chat Prefix
177        /// Completion feature as the input for the CoT in the last
178        /// assistant message. When using this feature, the prefix
179        /// parameter must be set to true.
180        #[serde(skip_serializing_if = "Option::is_none")]
181        reasoning_content: Option<String>,
182
183        /// The tool calls generated by the model, such as function calls.
184        #[serde(skip_serializing_if = "Option::is_none")]
185        tool_calls: Option<Vec<AssistantToolCall>>,
186    },
187    /// In this case, the role of the message author is `assistant`.
188    /// The field `{ role = "tool" }` is added automatically.
189    Tool {
190        /// The contents of the tool message.
191        content: String,
192        /// Tool call that this message is responding to.
193        tool_call_id: String,
194    },
195    /// In this case, the role of the message author is `function`.
196    /// The field `{ role = "function" }` is added automatically.
197    Function {
198        /// The contents of the function message.
199        content: String,
200        /// The name of the function to call.
201        name: String,
202    },
203    /// In this case, the role of the message author is `developer`.
204    /// The field `{ role = "developer" }` is added automatically.
205    Developer {
206        /// The contents of the developer message.
207        content: String,
208        /// An optional name for the participant.
209        ///
210        /// Provides the model information to differentiate between
211        /// participants of the same role.
212        name: Option<String>,
213    },
214}
215
216#[derive(Debug, Serialize)]
217#[serde(tag = "role", rename_all = "lowercase")]
218pub enum AssistantToolCall {
219    Function {
220        /// The ID of the tool call.
221        id: String,
222        /// The function that the model called.
223        function: ToolCallFunction,
224    },
225    Custom {
226        /// The ID of the tool call.
227        id: String,
228        /// The custom tool that the model called.
229        custom: ToolCallCustom,
230    },
231}
232
233#[derive(Debug, Serialize)]
234pub struct ToolCallFunction {
235    /// The arguments to call the function with, as generated by the model in JSON
236    /// format. Note that the model does not always generate valid JSON, and may
237    /// hallucinate parameters not defined by your function schema. Validate the
238    /// arguments in your code before calling your function.
239    arguments: String,
240    /// The name of the function to call.
241    name: String,
242}
243
244#[derive(Debug, Serialize)]
245pub struct ToolCallCustom {
246    /// The input for the custom tool call generated by the model.
247    input: String,
248    /// The name of the custom tool to call.
249    name: String,
250}
251
252#[derive(Debug, Serialize)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum ResponseFormat {
255    /// The type of response format being defined. Always `json_schema`.
256    JsonSchema {
257        /// Structured Outputs configuration options, including a JSON Schema.
258        json_schema: JSONSchema,
259    },
260    /// The type of response format being defined. Always `json_object`.
261    JsonObject,
262    /// The type of response format being defined. Always `text`.
263    Text,
264}
265
266#[derive(Debug, Serialize)]
267pub struct JSONSchema {
268    /// The name of the response format. Must be a-z, A-Z, 0-9, or contain
269    /// underscores and dashes, with a maximum length of 64.
270    pub name: String,
271    /// A description of what the response format is for, used by the model to determine
272    /// how to respond in the format.
273    pub description: String,
274    /// The schema for the response format, described as a JSON Schema object. Learn how
275    /// to build JSON schemas [here](https://json-schema.org/).
276    pub schema: serde_json::Map<String, serde_json::Value>,
277    /// Whether to enable strict schema adherence when generating the output. If set to
278    /// true, the model will always follow the exact schema defined in the `schema`
279    /// field. Only a subset of JSON Schema is supported when `strict` is `true`. To
280    /// learn more, read the
281    /// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
282    pub strict: Option<bool>,
283}
284
285#[inline]
286fn is_false(value: &bool) -> bool {
287    !value
288}
289
290#[derive(Serialize, Debug)]
291#[serde(untagged)]
292pub enum StopKeywords {
293    Word(String),
294    Words(Vec<String>),
295}
296
297#[derive(Serialize, Debug)]
298pub struct StreamOptions {
299    /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
300    ///
301    /// The `usage` field on this chunk shows the token usage statistics for the entire
302    /// request, and the `choices` field will always be an empty array.
303    ///
304    /// All other chunks will also include a `usage` field, but with a null value.
305    /// **NOTE:** If the stream is interrupted, you may not receive the final usage
306    /// chunk which contains the total token usage for the request.
307    pub include_usage: bool,
308}
309
310#[derive(Serialize, Debug)]
311#[serde(tag = "type", rename_all = "snake_case")]
312pub enum RequestTool {
313    /// The type of the tool. Currently, only `function` is supported.
314    Function { function: ToolFunction },
315    /// The type of the custom tool. Always `custom`.
316    Custom {
317        /// Properties of the custom tool.
318        custom: ToolCustom,
319    },
320}
321
322#[derive(Serialize, Debug)]
323pub struct ToolFunction {
324    /// The name of the function to be called. Must be a-z, A-Z, 0-9, or
325    /// contain underscores and dashes, with a maximum length
326    /// of 64.
327    pub name: String,
328    /// A description of what the function does, used by the model to choose when and
329    /// how to call the function.
330    pub description: String,
331    /// The parameters the functions accepts, described as a JSON Schema object.
332    ///
333    /// See the
334    /// [openai function calling guide](https://platform.openai.com/docs/guides/function-calling)
335    /// for examples, and the
336    /// [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
337    /// documentation about the format.
338    ///
339    /// Omitting `parameters` defines a function with an empty parameter list.
340    pub parameters: serde_json::Map<String, serde_json::Value>,
341    /// Whether to enable strict schema adherence when generating the function call.
342    ///
343    /// If set to true, the model will follow the exact schema defined in the
344    /// `parameters` field. Only a subset of JSON Schema is supported when `strict` is
345    /// `true`. Learn more about Structured Outputs in the
346    /// [openai function calling guide](https://platform.openai.com/docs/guides/function-calling).
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub strict: Option<bool>,
349}
350
351#[derive(Serialize, Debug)]
352pub struct ToolCustom {
353    /// The name of the custom tool, used to identify it in tool calls.
354    pub name: String,
355    /// Optional description of the custom tool, used to provide more context.
356    pub description: String,
357    /// The input format for the custom tool. Default is unconstrained text.
358    pub format: String,
359}
360
361#[derive(Serialize, Debug)]
362#[serde(rename_all = "snake_case", tag = "type")]
363pub enum ToolCustomFormat {
364    /// Unconstrained text format. Always `text`.
365    CustomFormatText,
366    /// Grammar format. Always `grammar`.
367    CustomFormatGrammar {
368        /// Your chosen grammar.
369        grammar: ToolCustomFormatGrammarGrammar,
370    },
371}
372
373#[derive(Debug, Serialize)]
374pub struct ToolCustomFormatGrammarGrammar {
375    /// The grammar definition.
376    pub definition: String,
377    /// The syntax of the grammar definition. One of `lark` or `regex`.
378    pub syntax: ToolCustomFormatGrammarGrammarSyntax,
379}
380
381#[derive(Debug, Serialize)]
382#[serde(rename_all = "snake_case")]
383pub enum ToolCustomFormatGrammarGrammarSyntax {
384    Lark,
385    Regex,
386}
387
388#[derive(Debug, Serialize)]
389#[serde(rename_all = "snake_case")]
390pub enum ToolChoice {
391    None,
392    Auto,
393    Required,
394    #[serde(untagged)]
395    Specific(ToolChoiceSpecific),
396}
397
398#[derive(Debug, Serialize)]
399#[serde(rename_all = "snake_case", tag = "type")]
400pub enum ToolChoiceSpecific {
401    /// Allowed tool configuration type. Always `allowed_tools`.
402    AllowedTools {
403        /// Constrains the tools available to the model to a pre-defined set.
404        allowed_tools: ToolChoiceAllowedTools,
405    },
406    /// For function calling, the type is always `function`.
407    Function { function: ToolChoiceFunction },
408    /// For custom tool calling, the type is always `custom`.
409    Custom { custom: ToolChoiceCustom },
410}
411
412#[derive(Debug, Serialize)]
413pub struct ToolChoiceAllowedTools {
414    /// Constrains the tools available to the model to a pre-defined set.
415    ///
416    /// - `auto` allows the model to pick from among the allowed tools and generate a
417    /// message.
418    /// - `required` requires the model to call one or more of the allowed tools.
419    pub mode: ToolChoiceAllowedToolsMode,
420    /// A list of tool definitions that the model should be allowed to call.
421    ///
422    /// For the Chat Completions API, the list of tool definitions might look like:
423    ///
424    /// ```json
425    /// [
426    ///   { "type": "function", "function": { "name": "get_weather" } },
427    ///   { "type": "function", "function": { "name": "get_time" } }
428    /// ]
429    /// ```
430    pub tools: serde_json::Map<String, serde_json::Value>,
431}
432
433/// The mode for allowed tools in tool choice.
434///
435/// Controls how the model should handle the set of allowed tools:
436///
437/// - `auto` allows the model to pick from among the allowed tools and generate a
438///   message.
439/// - `required` requires the model to call one or more of the allowed tools.
440#[derive(Debug, Serialize)]
441#[serde(rename_all = "lowercase")]
442pub enum ToolChoiceAllowedToolsMode {
443    /// The model can choose whether to use the allowed tools or not.
444    Auto,
445    /// The model must use at least one of the allowed tools.
446    Required,
447}
448
449#[derive(Debug, Serialize)]
450pub struct ToolChoiceFunction {
451    /// The name of the function to call.
452    pub name: String,
453}
454
455#[derive(Debug, Serialize)]
456pub struct ToolChoiceCustom {
457    /// The name of the custom tool to call.
458    pub name: String,
459}
460
461#[derive(Serialize, Debug)]
462pub struct ExtraBody {
463    /// Make sense only for Qwen API.
464    #[serde(skip_serializing_if = "Option::is_none")]
465    pub enable_thinking: Option<bool>,
466    /// Make sense only for Qwen API.
467    #[serde(skip_serializing_if = "Option::is_none")]
468    pub thinking_budget: Option<u32>,
469    ///The size of the candidate set for sampling during generation.
470    ///
471    /// Make sense only for Qwen API.
472    #[serde(skip_serializing_if = "Option::is_none")]
473    pub top_k: Option<u32>,
474}
475
476impl RequestBody {
477    pub async fn get_response(&self, url: &str, key: &str) -> anyhow::Result<String> {
478        assert!(!self.stream);
479
480        let client = reqwest::Client::new();
481        let response = client
482            .post(url)
483            .headers({
484                let mut headers = reqwest::header::HeaderMap::new();
485                headers.insert("Content-Type", "application/json".parse().unwrap());
486                headers.insert("Accept", "application/json".parse().unwrap());
487                headers
488            })
489            .bearer_auth(key)
490            .json(self)
491            .send()
492            .await
493            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
494
495        if response.status() != reqwest::StatusCode::OK {
496            return Err(
497                crate::errors::RequestError::ResponseStatus(response.status().as_u16()).into(),
498            );
499        }
500
501        let text = response.text().await?;
502
503        Ok(text)
504    }
505
506    /// Getting stream response. You must ensure self.stream is true, or otherwise it will panic.
507    ///
508    /// # Example
509    ///
510    /// ```rust
511    /// use std::sync::LazyLock;
512    /// use futures_util::StreamExt;
513    /// use openai_interface::chat::request::{Message, RequestBody};
514    ///
515    /// const DEEPSEEK_API_KEY: LazyLock<&str> =
516    ///     LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
517    /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
518    /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
519    ///
520    /// #[tokio::main]
521    /// async fn main() {
522    ///     let request = RequestBody {
523    ///         messages: vec![
524    ///             Message::System {
525    ///                 content: "This is a request of test purpose. Reply briefly".to_string(),
526    ///                 name: None,
527    ///             },
528    ///             Message::User {
529    ///                 content: "What's your name?".to_string(),
530    ///                 name: None,
531    ///             },
532    ///         ],
533    ///         model: DEEPSEEK_MODEL.to_string(),
534    ///         stream: true,
535    ///         ..Default::default()
536    ///     };
537    ///
538    ///     let mut response = request
539    ///         .get_stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
540    ///         .await
541    ///         .unwrap();
542    ///
543    ///     while let Some(chunk) = response.next().await {
544    ///         println!("{}", chunk.unwrap());
545    ///     }
546    /// }
547    /// ```
548    pub async fn get_stream_response(
549        &self,
550        url: &str,
551        api_key: &str,
552    ) -> Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error> {
553        // 断言开启了流模式
554        assert!(
555            self.stream,
556            "RequestBody::get_stream_response requires `stream: true`"
557        );
558
559        let client = reqwest::Client::new();
560
561        let response = client
562            .post(url)
563            .headers({
564                let mut headers = reqwest::header::HeaderMap::new();
565                headers.insert("Content-Type", "application/json".parse().unwrap());
566                headers.insert("Accept", "application/json".parse().unwrap());
567                headers
568            })
569            .bearer_auth(api_key)
570            .json(self)
571            .send()
572            .await
573            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
574
575        if !response.status().is_success() {
576            return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
577        }
578
579        let stream = response
580            .bytes_stream()
581            .map_err(|e| RequestError::StreamError(e.to_string()).into())
582            .try_filter_map(|bytes| async move {
583                let s = std::str::from_utf8(&bytes)
584                    .map_err(|e| RequestError::SseParseError(e.to_string()))?;
585                if s.starts_with("[DONE]") {
586                    Ok(None)
587                } else {
588                    Ok(Some(s.to_string()))
589                }
590            });
591
592        Ok(Box::pin(stream) as BoxStream<'static, _>)
593
594        // return Err(anyhow!("Not implemented"));
595    }
596}
597
598#[cfg(test)]
599mod request_test {
600    use std::sync::LazyLock;
601
602    use futures_util::StreamExt;
603
604    use crate::chat::request::{Message, RequestBody};
605
606    const DEEPSEEK_API_KEY: LazyLock<&str> =
607        LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
608    const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
609    const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
610
611    #[tokio::test]
612    async fn test_00_basics() {
613        let request = RequestBody {
614            messages: vec![
615                Message::System {
616                    content: "This is a request of test purpose. Reply briefly".to_string(),
617                    name: None,
618                },
619                Message::User {
620                    content: "What's your name?".to_string(),
621                    name: None,
622                },
623            ],
624            model: DEEPSEEK_MODEL.to_string(),
625            stream: false,
626            ..Default::default()
627        };
628
629        let response = request
630            .get_response(DEEPSEEK_CHAT_URL, &*DEEPSEEK_API_KEY)
631            .await
632            .unwrap();
633
634        println!("{}", response);
635
636        assert!(response.to_ascii_lowercase().contains("deepseek"));
637    }
638
639    #[tokio::test]
640    async fn test_01_streaming() {
641        let request = RequestBody {
642            messages: vec![
643                Message::System {
644                    content: "This is a request of test purpose. Reply briefly".to_string(),
645                    name: None,
646                },
647                Message::User {
648                    content: "What's your name?".to_string(),
649                    name: None,
650                },
651            ],
652            model: DEEPSEEK_MODEL.to_string(),
653            stream: true,
654            ..Default::default()
655        };
656
657        let mut response = request
658            .get_stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
659            .await
660            .unwrap();
661
662        while let Some(chunk) = response.next().await {
663            println!("{}", chunk.unwrap());
664        }
665    }
666}