openai_interface/chat/create/
request.rs

1//! This module contains the request body and POST method for the chat completion API.
2
3use std::collections::HashMap;
4
5use serde::Serialize;
6use url::Url;
7
8use crate::{
9    chat::ServiceTier,
10    errors::OapiError,
11    rest::post::{Post, PostNoStream, PostStream},
12};
13
14/// Creates a model response for the given chat conversation.
15///
16/// # Example
17///
18/// ```rust
19/// use std::sync::LazyLock;
20/// use futures_util::StreamExt;
21/// use openai_interface::chat::create::request::{Message, RequestBody};
22/// use openai_interface::rest::post::PostStream;
23///
24/// const DEEPSEEK_API_KEY: LazyLock<&str> =
25///     LazyLock::new(|| include_str!("../../../keys/deepseek_domestic_key").trim());
26/// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/v1";
27/// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
28///
29/// #[tokio::main]
30/// async fn main() {
31///     let request = RequestBody {
32///         messages: vec![
33///             Message::System {
34///                 content: "This is a request of test purpose. Reply briefly".to_string(),
35///                 name: None,
36///             },
37///             Message::User {
38///                 content: "What's your name?".to_string(),
39///                 name: None,
40///             },
41///         ],
42///         model: DEEPSEEK_MODEL.to_string(),
43///         stream: true,
44///         ..Default::default()
45///     };
46///
47///     let mut response = request
48///         .get_stream_response_string(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
49///         .await
50///         .unwrap();
51///
52///     while let Some(chunk) = response.next().await {
53///         println!("{}", chunk.unwrap());
54///     }
55/// }
56/// ```
57#[derive(Serialize, Debug, Default, Clone)]
58pub struct RequestBody {
59    /// Other request bodies that are not in standard OpenAI API.
60    #[serde(flatten, skip_serializing_if = "Option::is_none")]
61    pub extra_body: Option<ExtraBody>,
62
63    /// Other request bodies that are not in standard OpenAI API and
64    /// not included in the ExtraBody struct.
65    #[serde(flatten, skip_serializing_if = "Option::is_none")]
66    pub extra_body_map: Option<serde_json::Map<String, serde_json::Value>>,
67
68    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
69    /// existing frequency in the text so far, decreasing the model's likelihood to
70    /// repeat the same line verbatim.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub frequency_penalty: Option<f32>,
73
74    /// Whether to return log probabilities of the output tokens or not. If true,
75    /// returns the log probabilities of each output token returned in the `content` of
76    /// `message`.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub logprobs: Option<bool>,
79
80    /// An upper bound for the number of tokens that can be generated for a completion,
81    /// including visible output tokens and reasoning tokens.
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub max_completion_tokens: Option<u32>,
84
85    /// The maximum number of tokens that can be generated in the chat completion.
86    /// Deprecated according to OpenAI's Python SDK in favour of
87    /// `max_completion_tokens`.
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub max_tokens: Option<u32>,
90
91    /// A list of messages comprising the conversation so far.
92    pub messages: Vec<Message>,
93
94    /// Set of 16 key-value pairs that can be attached to an object. This can be useful
95    /// for storing additional information about the object in a structured format, and
96    /// querying for objects via API or the dashboard.
97    ///
98    /// Keys are strings with a maximum length of 64 characters. Values are strings with
99    /// a maximum length of 512 characters.
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub metadata: Option<HashMap<String, String>>,
102
103    /// Output types that you would like the model to generate. Most models are capable
104    /// of generating text, which is the default:
105    ///
106    /// `["text"]`
107    ///
108    /// The `gpt-4o-audio-preview` model can also be used to
109    /// [generate audio](https://platform.openai.com/docs/guides/audio). To request that
110    /// this model generate both text and audio responses, you can use:
111    ///
112    /// `["text", "audio"]`
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub modalities: Option<Vec<Modality>>,
115
116    /// Name of the model to use to generate the response.
117    pub model: String, // The type of this attribute needs improvements.
118
119    /// How many chat completion choices to generate for each input message. Note that
120    /// you will be charged based on the number of generated tokens across all of the
121    /// choices. Keep `n` as `1` to minimize costs.
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub n: Option<u32>,
124
125    /// Whether to enable
126    /// [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling)
127    /// during tool use.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub parallel_tool_calls: Option<bool>,
130
131    /// Static predicted output content, such as the content of a text file that is
132    /// being regenerated.
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub prediction: Option<ChatCompletionPredictionContentParam>,
135
136    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
137    /// whether they appear in the text so far, increasing the model's likelihood to
138    /// talk about new topics.
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub presence_penalty: Option<f32>,
141
142    /// Used by OpenAI to cache responses for similar requests to optimize your cache
143    /// hit rates. Replaces the `user` field.
144    /// [Learn more](https://platform.openai.com/docs/guides/prompt-caching).
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub prompt_cache_key: Option<String>,
147
148    /// Constrains effort on reasoning for
149    /// [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently
150    /// supported values are `minimal`, `low`, `medium`, and `high`. Reducing reasoning
151    /// effort can result in faster responses and fewer tokens used on reasoning in a
152    /// response.
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub reasoning_effort: Option<String>,
155
156    /// specifying the format that the model must output.
157    ///
158    /// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
159    /// Outputs which ensures the model will match your supplied JSON schema. Learn more
160    /// in the
161    /// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
162    /// Setting to `{ "type": "json_object" }` enables the older JSON mode, which
163    /// ensures the message the model generates is valid JSON. Using `json_schema` is
164    /// preferred for models that support it.
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub response_format: Option<ResponseFormat>,
167
168    /// A stable identifier used to help detect users of your application that may be
169    /// violating OpenAI's usage policies. The IDs should be a string that uniquely
170    /// identifies each user. It is recommended to hash their username or email address, in
171    /// order to avoid sending any identifying information.
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub safety_identifier: Option<String>,
174
175    /// If specified, the system will make a best effort to sample deterministically. Determinism
176    /// is not guaranteed, and you should refer to the `system_fingerprint` response parameter to
177    /// monitor changes in the backend.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub seed: Option<i64>,
180
181    /// Specifies the processing type used for serving the request.
182    ///
183    /// - If set to 'auto', then the request will be processed with the service tier
184    ///   configured in the Project settings. Unless otherwise configured, the Project
185    ///   will use 'default'.
186    /// - If set to 'default', then the request will be processed with the standard
187    ///   pricing and performance for the selected model.
188    /// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
189    ///   '[priority](https://openai.com/api-priority-processing/)', then the request
190    ///   will be processed with the corresponding service tier.
191    /// - When not set, the default behavior is 'auto'.
192    ///
193    /// When the `service_tier` parameter is set, the response body will include the
194    /// `service_tier` value based on the processing mode actually used to serve the
195    /// request. This response value may be different from the value set in the
196    /// parameter.
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub service_tier: Option<ServiceTier>,
199
200    /// Up to 4 sequences where the API will stop generating further tokens. The
201    /// returned text will not contain the stop sequence.
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub stop: Option<StopKeywords>,
204
205    /// Whether or not to store the output of this chat completion request for use in
206    /// our [model distillation](https://platform.openai.com/docs/guides/distillation)
207    /// or [evals](https://platform.openai.com/docs/guides/evals) products.
208    ///
209    /// Supports text and image inputs. Note: image inputs over 8MB will be dropped.
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub store: Option<bool>,
212
213    /// Although it is optional, you should explicitly designate it
214    /// for an expected response.
215    pub stream: bool,
216
217    /// Options for streaming response. Only set this when you set `stream: true`
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub stream_options: Option<StreamOptions>,
220
221    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
222    /// make the output more random, while lower values like 0.2 will make it more
223    /// focused and deterministic. It is generally recommended to alter this or `top_p` but
224    /// not both.
225    pub temperature: Option<f32>,
226
227    /// Controls which (if any) tool is called by the model. `none` means the model will
228    /// not call any tool and instead generates a message. `auto` means the model can
229    /// pick between generating a message or calling one or more tools. `required` means
230    /// the model must call one or more tools. Specifying a particular tool via
231    /// `{"type": "function", "function": {"name": "my_function"}}` forces the model to
232    /// call that tool.
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub tool_choice: Option<ToolChoice>,
235
236    /// A list of tools the model may call.
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub tools: Option<Vec<RequestTool>>,
239
240    /// An integer between 0 and 20 specifying the number of most likely tokens to
241    /// return at each token position, each with an associated log probability.
242    /// `logprobs` must be set to `true` if this parameter is used.
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub top_logprobs: Option<u32>,
245
246    /// An alternative to sampling with temperature, called nucleus sampling, where the
247    /// model considers the results of the tokens with top_p probability mass. So 0.1
248    /// means only the tokens comprising the top 10% probability mass are considered.
249    ///
250    /// It is generally recommended to alter this or `temperature` but not both.
251    pub top_p: Option<f32>,
252
253    /// This field is being replaced by `safety_identifier` and `prompt_cache_key`. Use
254    /// `prompt_cache_key` instead to maintain caching optimizations. A stable
255    /// identifier for your end-users. Used to boost cache hit rates by better bucketing
256    /// similar requests and to help OpenAI detect and prevent abuse.
257    /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers).
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub user: Option<String>,
260
261    /// Constrains the verbosity of the model's response. Lower values will result in
262    /// more concise responses, while higher values will result in more verbose
263    /// responses. Currently supported values are `low`, `medium`, and `high`.
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub verbosity: Option<LowMediumHighEnum>,
266
267    /// This tool searches the web for relevant results to use in a response. Learn more
268    /// about the
269    /// [web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat).
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub web_search: Option<WebSearchOptions>,
272}
273
274#[derive(Serialize, Debug, Clone)]
275#[serde(tag = "role", rename_all = "lowercase")]
276pub enum Message {
277    /// In this case, the role of the message author is `system`.
278    /// The field `{ role = "system" }` is added automatically.
279    System {
280        /// The contents of the system message.
281        content: String,
282        /// An optional name for the participant.
283        ///
284        /// Provides the model information to differentiate between
285        /// participants of the same role.
286        #[serde(skip_serializing_if = "Option::is_none")]
287        name: Option<String>,
288    },
289    /// In this case, the role of the message author is `user`.
290    /// The field `{ role = "user" }` is added automatically.
291    User {
292        /// The contents of the user message.
293        content: String,
294        /// An optional name for the participant.
295        ///
296        /// Provides the model information to differentiate between
297        /// participants of the same role.
298        #[serde(skip_serializing_if = "Option::is_none")]
299        name: Option<String>,
300    },
301    /// In this case, the role of the message author is `assistant`.
302    /// The field `{ role = "assistant" }` is added automatically.
303    ///
304    /// Unimplemented params:
305    /// - _audio_: Data about a previous audio response from the model.
306    Assistant {
307        /// The contents of the assistant message. Required unless `tool_calls`
308        /// or `function_call` is specified. (Note that `function_call` is deprecated
309        /// in favour of `tool_calls`.)
310        content: Option<String>,
311        /// The refusal message by the assistant.
312        #[serde(skip_serializing_if = "Option::is_none")]
313        refusal: Option<String>,
314        #[serde(skip_serializing_if = "Option::is_none")]
315        name: Option<String>,
316        /// Set this to true for completion
317        #[serde(skip_serializing_if = "is_false")]
318        prefix: bool,
319        /// Used for the deepseek-reasoner model in the Chat Prefix
320        /// Completion feature as the input for the CoT in the last
321        /// assistant message. When using this feature, the prefix
322        /// parameter must be set to true.
323        #[serde(skip_serializing_if = "Option::is_none")]
324        reasoning_content: Option<String>,
325
326        /// The tool calls generated by the model, such as function calls.
327        #[serde(skip_serializing_if = "Option::is_none")]
328        tool_calls: Option<Vec<AssistantToolCall>>,
329    },
330    /// In this case, the role of the message author is `assistant`.
331    /// The field `{ role = "tool" }` is added automatically.
332    Tool {
333        /// The contents of the tool message.
334        content: String,
335        /// Tool call that this message is responding to.
336        tool_call_id: String,
337    },
338    /// In this case, the role of the message author is `function`.
339    /// The field `{ role = "function" }` is added automatically.
340    Function {
341        /// The contents of the function message.
342        content: String,
343        /// The name of the function to call.
344        name: String,
345    },
346    /// In this case, the role of the message author is `developer`.
347    /// The field `{ role = "developer" }` is added automatically.
348    Developer {
349        /// The contents of the developer message.
350        content: String,
351        /// An optional name for the participant.
352        ///
353        /// Provides the model information to differentiate between
354        /// participants of the same role.
355        name: Option<String>,
356    },
357}
358
359#[derive(Debug, Serialize, Clone)]
360#[serde(tag = "role", rename_all = "lowercase")]
361pub enum AssistantToolCall {
362    Function {
363        /// The ID of the tool call.
364        id: String,
365        /// The function that the model called.
366        function: ToolCallFunction,
367    },
368    Custom {
369        /// The ID of the tool call.
370        id: String,
371        /// The custom tool that the model called.
372        custom: ToolCallCustom,
373    },
374}
375
376#[derive(Debug, Serialize, Clone)]
377pub struct ToolCallFunction {
378    /// The arguments to call the function with, as generated by the model in JSON
379    /// format. Note that the model does not always generate valid JSON, and may
380    /// hallucinate parameters not defined by your function schema. Validate the
381    /// arguments in your code before calling your function.
382    arguments: String,
383    /// The name of the function to call.
384    name: String,
385}
386
387#[derive(Debug, Serialize, Clone)]
388pub struct ToolCallCustom {
389    /// The input for the custom tool call generated by the model.
390    input: String,
391    /// The name of the custom tool to call.
392    name: String,
393}
394
395#[derive(Debug, Serialize, Clone)]
396#[serde(tag = "type", rename_all = "snake_case")]
397pub enum ResponseFormat {
398    /// The type of response format being defined. Always `json_schema`.
399    JsonSchema {
400        /// Structured Outputs configuration options, including a JSON Schema.
401        json_schema: JSONSchema,
402    },
403    /// The type of response format being defined. Always `json_object`.
404    JsonObject,
405    /// The type of response format being defined. Always `text`.
406    Text,
407}
408
409#[derive(Debug, Serialize, Clone)]
410pub struct JSONSchema {
411    /// The name of the response format. Must be a-z, A-Z, 0-9, or contain
412    /// underscores and dashes, with a maximum length of 64.
413    pub name: String,
414    /// A description of what the response format is for, used by the model to determine
415    /// how to respond in the format.
416    pub description: String,
417    /// The schema for the response format, described as a JSON Schema object. Learn how
418    /// to build JSON schemas [here](https://json-schema.org/).
419    pub schema: serde_json::Map<String, serde_json::Value>,
420    /// Whether to enable strict schema adherence when generating the output. If set to
421    /// true, the model will always follow the exact schema defined in the `schema`
422    /// field. Only a subset of JSON Schema is supported when `strict` is `true`. To
423    /// learn more, read the
424    /// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
425    pub strict: Option<bool>,
426}
427
428#[derive(Serialize, Debug, Clone)]
429#[serde(rename_all = "snake_case")]
430pub enum Modality {
431    Text,
432    Audio,
433}
434
435#[derive(Serialize, Debug, Clone)]
436pub struct ChatCompletionPredictionContentParam {
437    /// The content that should be matched when generating a model response. If
438    /// generated tokens would match this content, the entire model response can be
439    /// returned much more quickly.
440    pub content: ChatCompletionPredictionContentParamContent,
441
442    /// The type of the predicted content you want to provide.
443    /// This type is currently always `content`.
444    pub type_: ChatCompletionPredictionContentParamType,
445}
446
447#[derive(Serialize, Debug, Clone)]
448#[serde(untagged)]
449pub enum ChatCompletionPredictionContentParamContent {
450    Text(String),
451    ChatCompletionContentPartTextParam {
452        /// The text content.
453        text: String,
454        /// The type of the content part.
455        #[serde(rename = "type")]
456        type_: ChatCompletionContentPartTextParamType,
457    },
458}
459
460#[derive(Serialize, Debug, Clone)]
461#[serde(rename_all = "snake_case")]
462pub enum ChatCompletionContentPartTextParamType {
463    Text,
464}
465
466#[derive(Serialize, Debug, Clone)]
467#[serde(rename_all = "snake_case")]
468pub enum ChatCompletionPredictionContentParamType {
469    Content,
470}
471
472fn is_false(value: &bool) -> bool {
473    !value
474}
475
476#[derive(Serialize, Debug, Clone)]
477#[serde(untagged)]
478pub enum StopKeywords {
479    Word(String),
480    Words(Vec<String>),
481}
482
483#[derive(Serialize, Debug, Clone)]
484#[serde(rename_all = "snake_case")]
485pub enum LowMediumHighEnum {
486    Low,
487    Medium,
488    High,
489}
490
491#[derive(Serialize, Debug, Clone)]
492pub struct WebSearchOptions {
493    /// High level guidance for the amount of context window space to use for the
494    /// search. One of `low`, `medium`, or `high`. `medium` is the default.
495    pub search_context_size: LowMediumHighEnum,
496
497    pub user_location: Option<WebSearchOptionsUserLocation>,
498}
499
500#[derive(Serialize, Debug, Clone)]
501#[serde(tag = "type", rename_all = "snake_case")]
502pub enum WebSearchOptionsUserLocation {
503    /// The type of location approximation. Always `approximate`.
504    Approximate(WebSearchOptionsUserLocationApproximate),
505}
506
507#[derive(Serialize, Debug, Clone)]
508pub struct WebSearchOptionsUserLocationApproximate {
509    /// Free text input for the city of the user, e.g. `San Francisco`.
510    pub city: String,
511
512    /// The two-letter [ISO country code](https://en.wikipedia.org/wiki/ISO_3166-1) of
513    /// the user, e.g. `US`.
514    pub country: String,
515
516    /// Free text input for the region of the user, e.g. `California`.
517    pub region: String,
518
519    /// The [IANA timezone](https://timeapi.io/documentation/iana-timezones) of the
520    /// user, e.g. `America/Los_Angeles`.
521    pub timezone: String,
522}
523
524#[derive(Serialize, Debug, Clone)]
525pub struct StreamOptions {
526    /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
527    ///
528    /// The `usage` field on this chunk shows the token usage statistics for the entire
529    /// request, and the `choices` field will always be an empty array.
530    ///
531    /// All other chunks will also include a `usage` field, but with a null value.
532    /// **NOTE:** If the stream is interrupted, you may not receive the final usage
533    /// chunk which contains the total token usage for the request.
534    pub include_usage: bool,
535}
536
537#[derive(Serialize, Debug, Clone)]
538#[serde(tag = "type", rename_all = "snake_case")]
539pub enum RequestTool {
540    /// The type of the tool. Currently, only `function` is supported.
541    Function { function: ToolFunction },
542    /// The type of the custom tool. Always `custom`.
543    Custom {
544        /// Properties of the custom tool.
545        custom: ToolCustom,
546    },
547}
548
549#[derive(Serialize, Debug, Clone)]
550#[serde(rename_all = "lowercase")]
551pub enum ReasoningEffort {
552    Minimal,
553    Low,
554    Medium,
555    High,
556}
557
558#[derive(Serialize, Debug, Clone)]
559pub struct ToolFunction {
560    /// The name of the function to be called. Must be a-z, A-Z, 0-9, or
561    /// contain underscores and dashes, with a maximum length
562    /// of 64.
563    pub name: String,
564    /// A description of what the function does, used by the model to choose when and
565    /// how to call the function.
566    pub description: String,
567    /// The parameters the functions accepts, described as a JSON Schema object.
568    ///
569    /// See the
570    /// [openai function calling guide](https://platform.openai.com/docs/guides/function-calling)
571    /// for examples, and the
572    /// [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
573    /// documentation about the format.
574    ///
575    /// Omitting `parameters` defines a function with an empty parameter list.
576    pub parameters: serde_json::Map<String, serde_json::Value>,
577    /// Whether to enable strict schema adherence when generating the function call.
578    ///
579    /// If set to true, the model will follow the exact schema defined in the
580    /// `parameters` field. Only a subset of JSON Schema is supported when `strict` is
581    /// `true`. Learn more about Structured Outputs in the
582    /// [openai function calling guide](https://platform.openai.com/docs/guides/function-calling).
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub strict: Option<bool>,
585}
586
587#[derive(Serialize, Debug, Clone)]
588pub struct ToolCustom {
589    /// The name of the custom tool, used to identify it in tool calls.
590    pub name: String,
591    /// Optional description of the custom tool, used to provide more context.
592    pub description: String,
593    /// The input format for the custom tool. Default is unconstrained text.
594    pub format: String,
595}
596
597#[derive(Serialize, Debug, Clone)]
598#[serde(rename_all = "snake_case", tag = "type")]
599pub enum ToolCustomFormat {
600    /// Unconstrained text format. Always `text`.
601    CustomFormatText,
602    /// Grammar format. Always `grammar`.
603    CustomFormatGrammar {
604        /// Your chosen grammar.
605        grammar: ToolCustomFormatGrammarGrammar,
606    },
607}
608
609#[derive(Debug, Serialize, Clone)]
610pub struct ToolCustomFormatGrammarGrammar {
611    /// The grammar definition.
612    pub definition: String,
613    /// The syntax of the grammar definition. One of `lark` or `regex`.
614    pub syntax: ToolCustomFormatGrammarGrammarSyntax,
615}
616
617#[derive(Debug, Serialize, Clone)]
618#[serde(rename_all = "snake_case")]
619pub enum ToolCustomFormatGrammarGrammarSyntax {
620    Lark,
621    Regex,
622}
623
624#[derive(Debug, Serialize, Clone)]
625#[serde(rename_all = "snake_case")]
626pub enum ToolChoice {
627    None,
628    Auto,
629    Required,
630    #[serde(untagged)]
631    Specific(ToolChoiceSpecific),
632}
633
634#[derive(Debug, Serialize, Clone)]
635#[serde(rename_all = "snake_case", tag = "type")]
636pub enum ToolChoiceSpecific {
637    /// Allowed tool configuration type. Always `allowed_tools`.
638    AllowedTools {
639        /// Constrains the tools available to the model to a pre-defined set.
640        allowed_tools: ToolChoiceAllowedTools,
641    },
642    /// For function calling, the type is always `function`.
643    Function { function: ToolChoiceFunction },
644    /// For custom tool calling, the type is always `custom`.
645    Custom { custom: ToolChoiceCustom },
646}
647
648#[derive(Debug, Serialize, Clone)]
649pub struct ToolChoiceAllowedTools {
650    /// Constrains the tools available to the model to a pre-defined set.
651    ///
652    /// - `auto` allows the model to pick from among the allowed tools and generate a
653    /// message.
654    /// - `required` requires the model to call one or more of the allowed tools.
655    pub mode: ToolChoiceAllowedToolsMode,
656    /// A list of tool definitions that the model should be allowed to call.
657    ///
658    /// For the Chat Completions API, the list of tool definitions might look like:
659    ///
660    /// ```json
661    /// [
662    ///   { "type": "function", "function": { "name": "get_weather" } },
663    ///   { "type": "function", "function": { "name": "get_time" } }
664    /// ]
665    /// ```
666    pub tools: serde_json::Map<String, serde_json::Value>,
667}
668
669/// The mode for allowed tools in tool choice.
670///
671/// Controls how the model should handle the set of allowed tools:
672///
673/// - `auto` allows the model to pick from among the allowed tools and generate a
674///   message.
675/// - `required` requires the model to call one or more of the allowed tools.
676#[derive(Debug, Serialize, Clone)]
677#[serde(rename_all = "lowercase")]
678pub enum ToolChoiceAllowedToolsMode {
679    /// The model can choose whether to use the allowed tools or not.
680    Auto,
681    /// The model must use at least one of the allowed tools.
682    Required,
683}
684
685#[derive(Debug, Serialize, Clone)]
686pub struct ToolChoiceFunction {
687    /// The name of the function to call.
688    pub name: String,
689}
690
691#[derive(Debug, Serialize, Clone)]
692pub struct ToolChoiceCustom {
693    /// The name of the custom tool to call.
694    pub name: String,
695}
696
697#[derive(Debug, Serialize, Clone)]
698pub struct ExtraBody {
699    /// Make sense only for Qwen API.
700    #[serde(skip_serializing_if = "Option::is_none")]
701    pub enable_thinking: Option<bool>,
702    /// Make sense only for Qwen API.
703    #[serde(skip_serializing_if = "Option::is_none")]
704    pub thinking_budget: Option<u32>,
705    ///The size of the candidate set for sampling during generation.
706    ///
707    /// Make sense only for Qwen API.
708    #[serde(skip_serializing_if = "Option::is_none")]
709    pub top_k: Option<u32>,
710}
711
712impl Post for RequestBody {
713    fn is_streaming(&self) -> bool {
714        self.stream
715    }
716
717    /// Builds the URL for the request.
718    ///
719    /// `base_url` should be like <https://api.openai.com/v1>
720    fn build_url(&self, base_url: &str) -> Result<String, OapiError> {
721        let mut url =
722            Url::parse(base_url.trim_end_matches('/')).map_err(|e| OapiError::UrlError(e))?;
723        url.path_segments_mut()
724            .map_err(|_| OapiError::UrlCannotBeBase(base_url.to_string()))?
725            .push("chat")
726            .push("completions");
727
728        Ok(url.to_string())
729    }
730}
731
732impl PostNoStream for RequestBody {
733    type Response = super::response::no_streaming::ChatCompletion;
734}
735
736impl PostStream for RequestBody {
737    type Response = super::response::streaming::ChatCompletionChunk;
738}
739
740#[cfg(test)]
741mod request_test {
742    use std::sync::LazyLock;
743
744    use futures_util::StreamExt;
745
746    use super::*;
747
748    const DEEPSEEK_API_KEY: LazyLock<&str> =
749        LazyLock::new(|| include_str!("../../../keys/deepseek_domestic_key").trim());
750    const DEEPSEEK_CHAT_BASE: &'static str = "https://api.deepseek.com/v1";
751    const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
752
753    #[tokio::test]
754    async fn test_deepseek_no_stream() {
755        let request = RequestBody {
756            messages: vec![
757                Message::System {
758                    content: "This is a request of test purpose. Reply briefly".to_string(),
759                    name: None,
760                },
761                Message::User {
762                    content: "What's your name?".to_string(),
763                    name: None,
764                },
765            ],
766            model: DEEPSEEK_MODEL.to_string(),
767            stream: false,
768            ..Default::default()
769        };
770
771        let response = request
772            .get_response_string(DEEPSEEK_CHAT_BASE, &*DEEPSEEK_API_KEY)
773            .await
774            .unwrap();
775
776        println!("{}", response);
777
778        assert!(response.to_ascii_lowercase().contains("deepseek"));
779    }
780
781    #[tokio::test]
782    async fn test_deepseek_stream() {
783        let request = RequestBody {
784            messages: vec![
785                Message::System {
786                    content: "This is a request of test purpose. Reply briefly".to_string(),
787                    name: None,
788                },
789                Message::User {
790                    content: "What's your name?".to_string(),
791                    name: None,
792                },
793            ],
794            model: DEEPSEEK_MODEL.to_string(),
795            stream: true,
796            ..Default::default()
797        };
798
799        let mut response = request
800            .get_stream_response_string(DEEPSEEK_CHAT_BASE, *DEEPSEEK_API_KEY)
801            .await
802            .unwrap();
803
804        while let Some(chunk) = response.next().await {
805            println!("{}", chunk.unwrap());
806        }
807    }
808}