Skip to main content

onde_mistralrs_core/
request.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::VideoInput;
9
10use crate::{
11    response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
12    DiffusionGenerationParams, Tool,
13};
14use std::{fmt::Debug, path::PathBuf, sync::Arc};
15use tokio::sync::mpsc::Sender;
16
17pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
18
19#[derive(Clone, Serialize, Deserialize)]
20/// Control the constraint with llguidance.
21pub enum Constraint {
22    Regex(String),
23    Lark(String),
24    JsonSchema(serde_json::Value),
25    Llguidance(LlguidanceGrammar),
26    None,
27}
28
29#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
30#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32/// Image generation response format
33pub enum ImageGenerationResponseFormat {
34    Url,
35    B64Json,
36}
37
38pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
39
40/// Reasoning effort level for models that support it (e.g., GPT-OSS with Harmony format).
41/// Controls the depth of reasoning/analysis in the model's response.
42#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
43#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
44#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
45#[serde(rename_all = "lowercase")]
46pub enum ReasoningEffort {
47    /// Minimal reasoning, faster responses
48    Low,
49    /// Balanced reasoning depth
50    #[default]
51    Medium,
52    /// Deep reasoning, more thorough analysis
53    High,
54}
55
56impl ReasoningEffort {
57    /// Convert to string representation for chat template
58    pub fn as_str(&self) -> &'static str {
59        match self {
60            Self::Low => "low",
61            Self::Medium => "medium",
62            Self::High => "high",
63        }
64    }
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68/// Message or messages for a [`Request`].
69pub enum RequestMessage {
70    Chat {
71        messages: Vec<IndexMap<String, MessageContent>>,
72        enable_thinking: Option<bool>,
73        /// Reasoning effort level for Harmony-format models
74        reasoning_effort: Option<ReasoningEffort>,
75    },
76    Completion {
77        text: String,
78        echo_prompt: bool,
79        best_of: Option<usize>,
80    },
81    CompletionTokens(Vec<u32>),
82    MultimodalChat {
83        #[serde(skip)] // TODO
84        images: Vec<image::DynamicImage>,
85        #[serde(skip)] // TODO
86        audios: Vec<AudioInput>,
87        #[serde(skip)]
88        videos: Vec<VideoInput>,
89        messages: Vec<IndexMap<String, MessageContent>>,
90        enable_thinking: Option<bool>,
91        /// Reasoning effort level for Harmony-format models
92        reasoning_effort: Option<ReasoningEffort>,
93    },
94    ImageGeneration {
95        prompt: String,
96        format: ImageGenerationResponseFormat,
97        generation_params: DiffusionGenerationParams,
98        save_file: Option<PathBuf>,
99    },
100    SpeechGeneration {
101        prompt: String,
102    },
103    Embedding {
104        prompt: String,
105    },
106    EmbeddingTokens {
107        prompt: Vec<u32>,
108    },
109}
110
111fn default_responder<T>() -> Sender<T> {
112    let (sender, _) = tokio::sync::mpsc::channel(1);
113    sender
114}
115
116#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
117#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
118#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
119pub enum SearchContextSize {
120    #[serde(rename = "low")]
121    Low,
122    #[default]
123    #[serde(rename = "medium")]
124    Medium,
125    #[serde(rename = "high")]
126    High,
127}
128
129#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
130#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
131#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
132pub struct ApproximateUserLocation {
133    pub city: String,
134    pub country: String,
135    pub region: String,
136    pub timezone: String,
137}
138
139#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
140#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
141#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
142#[serde(tag = "type")]
143pub enum WebSearchUserLocation {
144    #[serde(rename = "approximate")]
145    Approximate {
146        approximate: ApproximateUserLocation,
147    },
148}
149
150#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
151#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
152#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
153pub struct WebSearchOptions {
154    pub search_context_size: Option<SearchContextSize>,
155    pub user_location: Option<WebSearchUserLocation>,
156    /// Override the description for the search tool.
157    pub search_description: Option<String>,
158    /// Override the description for the extraction tool.
159    pub extract_description: Option<String>,
160}
161
162#[derive(Clone, Serialize, Deserialize)]
163/// A normal request request to the `MistralRs`.
164/// - `messages`: Messages for the request
165/// - `sampling_params`: Sampling parameters for generation
166/// - `response`: Object to send the result through
167/// - `return_logprobs`: Whether to return logprobs
168/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
169/// - `id`: Request ID
170/// - `constraint`: Constraint to use during generation
171/// - `suffix`: Suffix to add
172/// - `tools`: Tools available in this request
173/// - `tool_choice`: Choice of tools
174/// - `logits_processors`: Custom logits processors. Order of application:
175///     1) Apply penalties from `sampling_params`
176///     2) Apply these custom logits processors sequentially
177///     3) Apply temperature and softmax
178///     4) Sample the next token (topk, topp, minp, etc)
179/// - `return_raw_logits`: Return raw logits.
180/// - `truncate_sequence`: Whether to truncate the prompt if it exceeds the model's maximum context length.
181pub struct NormalRequest {
182    pub messages: RequestMessage,
183    pub sampling_params: SamplingParams,
184    #[serde(default = "default_responder")]
185    #[serde(skip)]
186    pub response: Sender<Response>,
187    pub return_logprobs: bool,
188    pub is_streaming: bool,
189    pub id: usize,
190    pub constraint: Constraint,
191    pub suffix: Option<String>,
192    pub tools: Option<Vec<Tool>>,
193    pub tool_choice: Option<ToolChoice>,
194    #[serde(skip)]
195    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
196    pub return_raw_logits: bool,
197    pub web_search_options: Option<WebSearchOptions>,
198    pub max_tool_rounds: Option<usize>,
199    /// URL to POST tool calls to when no server-side callback is registered.
200    /// The server sends `{"name": "...", "arguments": {...}}` and expects
201    /// `{"content": "..."}` back.
202    pub tool_dispatch_url: Option<String>,
203    pub model_id: Option<String>,
204    #[serde(default)]
205    pub truncate_sequence: bool,
206}
207
208impl NormalRequest {
209    pub fn new_simple(
210        messages: RequestMessage,
211        sampling_params: SamplingParams,
212        response: Sender<Response>,
213        id: usize,
214        tools: Option<Vec<Tool>>,
215        tool_choice: Option<ToolChoice>,
216    ) -> Self {
217        Self {
218            messages,
219            sampling_params,
220            response,
221            id,
222            tools,
223            tool_choice,
224            return_logprobs: false,
225            is_streaming: false,
226            constraint: Constraint::None,
227            suffix: None,
228            logits_processors: None,
229            return_raw_logits: false,
230            web_search_options: None,
231            max_tool_rounds: None,
232            tool_dispatch_url: None,
233            model_id: None,
234            truncate_sequence: false,
235        }
236    }
237}
238
239#[derive(Clone, Serialize, Deserialize)]
240/// Request to tokenize some messages or some text.
241/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
242pub struct TokenizationRequest {
243    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
244    pub tools: Option<Vec<Tool>>,
245    pub add_generation_prompt: bool,
246    pub add_special_tokens: bool,
247    pub enable_thinking: Option<bool>,
248    pub reasoning_effort: Option<ReasoningEffort>,
249    #[serde(default = "default_responder")]
250    #[serde(skip)]
251    pub response: Sender<anyhow::Result<Vec<u32>>>,
252}
253
254#[derive(Clone, Serialize, Deserialize)]
255/// Request to detokenize some text.
256pub struct DetokenizationRequest {
257    pub tokens: Vec<u32>,
258    pub skip_special_tokens: bool,
259    #[serde(default = "default_responder")]
260    #[serde(skip)]
261    pub response: Sender<anyhow::Result<String>>,
262}
263
264#[derive(Clone, Serialize, Deserialize)]
265/// A request to the Engine, encapsulating the various parameters as well as
266/// the `mpsc` response `Sender` used to return the [`Response`].
267pub enum Request {
268    Normal(Box<NormalRequest>),
269    ReIsq(IsqType),
270    Tokenize(TokenizationRequest),
271    Detokenize(DetokenizationRequest),
272    // Sending a terminate request causes the `run` function to return to the thread created in `MistralRs::new`,
273    // and then Engine will be dropped.
274    Terminate,
275    TerminateAllSeqsNextStep,
276}
277
278impl Debug for Request {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        match self {
281            Request::Normal(boxed_req) => {
282                let NormalRequest {
283                    messages,
284                    sampling_params,
285                    is_streaming,
286                    id,
287                    ..
288                } = &**boxed_req;
289                write!(
290                    f,
291                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
292                )
293            }
294            Request::ReIsq(tp) => {
295                write!(f, "Re ISQ Request {tp:?}",)
296            }
297            Request::Tokenize(req) => {
298                write!(f, "Tokenization Request {:?}", req.text)
299            }
300            Request::Detokenize(req) => {
301                write!(f, "Tokenization Request {:?}", req.tokens)
302            }
303            Request::Terminate => write!(f, "Termination Request"),
304            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
305        }
306    }
307}