1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_audio::AudioInput;
4use mistralrs_quant::IsqType;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::VideoInput;
9
10use crate::{
11 response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
12 DiffusionGenerationParams, Tool,
13};
14use std::{fmt::Debug, path::PathBuf, sync::Arc};
15use tokio::sync::mpsc::Sender;
16
17pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
18
19#[derive(Clone, Serialize, Deserialize)]
20pub enum Constraint {
22 Regex(String),
23 Lark(String),
24 JsonSchema(serde_json::Value),
25 Llguidance(LlguidanceGrammar),
26 None,
27}
28
29#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
30#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32pub enum ImageGenerationResponseFormat {
34 Url,
35 B64Json,
36}
37
38pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
39
40#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
43#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
44#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
45#[serde(rename_all = "lowercase")]
46pub enum ReasoningEffort {
47 Low,
49 #[default]
51 Medium,
52 High,
54}
55
56impl ReasoningEffort {
57 pub fn as_str(&self) -> &'static str {
59 match self {
60 Self::Low => "low",
61 Self::Medium => "medium",
62 Self::High => "high",
63 }
64 }
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub enum RequestMessage {
70 Chat {
71 messages: Vec<IndexMap<String, MessageContent>>,
72 enable_thinking: Option<bool>,
73 reasoning_effort: Option<ReasoningEffort>,
75 },
76 Completion {
77 text: String,
78 echo_prompt: bool,
79 best_of: Option<usize>,
80 },
81 CompletionTokens(Vec<u32>),
82 MultimodalChat {
83 #[serde(skip)] images: Vec<image::DynamicImage>,
85 #[serde(skip)] audios: Vec<AudioInput>,
87 #[serde(skip)]
88 videos: Vec<VideoInput>,
89 messages: Vec<IndexMap<String, MessageContent>>,
90 enable_thinking: Option<bool>,
91 reasoning_effort: Option<ReasoningEffort>,
93 },
94 ImageGeneration {
95 prompt: String,
96 format: ImageGenerationResponseFormat,
97 generation_params: DiffusionGenerationParams,
98 save_file: Option<PathBuf>,
99 },
100 SpeechGeneration {
101 prompt: String,
102 },
103 Embedding {
104 prompt: String,
105 },
106 EmbeddingTokens {
107 prompt: Vec<u32>,
108 },
109}
110
111fn default_responder<T>() -> Sender<T> {
112 let (sender, _) = tokio::sync::mpsc::channel(1);
113 sender
114}
115
116#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
117#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
118#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
119pub enum SearchContextSize {
120 #[serde(rename = "low")]
121 Low,
122 #[default]
123 #[serde(rename = "medium")]
124 Medium,
125 #[serde(rename = "high")]
126 High,
127}
128
129#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
130#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
131#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
132pub struct ApproximateUserLocation {
133 pub city: String,
134 pub country: String,
135 pub region: String,
136 pub timezone: String,
137}
138
139#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
140#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
141#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
142#[serde(tag = "type")]
143pub enum WebSearchUserLocation {
144 #[serde(rename = "approximate")]
145 Approximate {
146 approximate: ApproximateUserLocation,
147 },
148}
149
150#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
151#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
152#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
153pub struct WebSearchOptions {
154 pub search_context_size: Option<SearchContextSize>,
155 pub user_location: Option<WebSearchUserLocation>,
156 pub search_description: Option<String>,
158 pub extract_description: Option<String>,
160}
161
162#[derive(Clone, Serialize, Deserialize)]
163pub struct NormalRequest {
182 pub messages: RequestMessage,
183 pub sampling_params: SamplingParams,
184 #[serde(default = "default_responder")]
185 #[serde(skip)]
186 pub response: Sender<Response>,
187 pub return_logprobs: bool,
188 pub is_streaming: bool,
189 pub id: usize,
190 pub constraint: Constraint,
191 pub suffix: Option<String>,
192 pub tools: Option<Vec<Tool>>,
193 pub tool_choice: Option<ToolChoice>,
194 #[serde(skip)]
195 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
196 pub return_raw_logits: bool,
197 pub web_search_options: Option<WebSearchOptions>,
198 pub max_tool_rounds: Option<usize>,
199 pub tool_dispatch_url: Option<String>,
203 pub model_id: Option<String>,
204 #[serde(default)]
205 pub truncate_sequence: bool,
206}
207
208impl NormalRequest {
209 pub fn new_simple(
210 messages: RequestMessage,
211 sampling_params: SamplingParams,
212 response: Sender<Response>,
213 id: usize,
214 tools: Option<Vec<Tool>>,
215 tool_choice: Option<ToolChoice>,
216 ) -> Self {
217 Self {
218 messages,
219 sampling_params,
220 response,
221 id,
222 tools,
223 tool_choice,
224 return_logprobs: false,
225 is_streaming: false,
226 constraint: Constraint::None,
227 suffix: None,
228 logits_processors: None,
229 return_raw_logits: false,
230 web_search_options: None,
231 max_tool_rounds: None,
232 tool_dispatch_url: None,
233 model_id: None,
234 truncate_sequence: false,
235 }
236 }
237}
238
239#[derive(Clone, Serialize, Deserialize)]
240pub struct TokenizationRequest {
243 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
244 pub tools: Option<Vec<Tool>>,
245 pub add_generation_prompt: bool,
246 pub add_special_tokens: bool,
247 pub enable_thinking: Option<bool>,
248 pub reasoning_effort: Option<ReasoningEffort>,
249 #[serde(default = "default_responder")]
250 #[serde(skip)]
251 pub response: Sender<anyhow::Result<Vec<u32>>>,
252}
253
254#[derive(Clone, Serialize, Deserialize)]
255pub struct DetokenizationRequest {
257 pub tokens: Vec<u32>,
258 pub skip_special_tokens: bool,
259 #[serde(default = "default_responder")]
260 #[serde(skip)]
261 pub response: Sender<anyhow::Result<String>>,
262}
263
264#[derive(Clone, Serialize, Deserialize)]
265pub enum Request {
268 Normal(Box<NormalRequest>),
269 ReIsq(IsqType),
270 Tokenize(TokenizationRequest),
271 Detokenize(DetokenizationRequest),
272 Terminate,
275 TerminateAllSeqsNextStep,
276}
277
278impl Debug for Request {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 match self {
281 Request::Normal(boxed_req) => {
282 let NormalRequest {
283 messages,
284 sampling_params,
285 is_streaming,
286 id,
287 ..
288 } = &**boxed_req;
289 write!(
290 f,
291 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
292 )
293 }
294 Request::ReIsq(tp) => {
295 write!(f, "Re ISQ Request {tp:?}",)
296 }
297 Request::Tokenize(req) => {
298 write!(f, "Tokenization Request {:?}", req.text)
299 }
300 Request::Detokenize(req) => {
301 write!(f, "Tokenization Request {:?}", req.tokens)
302 }
303 Request::Terminate => write!(f, "Termination Request"),
304 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
305 }
306 }
307}