1use either::Either;
2use hanzo_audio::AudioInput;
3use hanzo_quant::IsqType;
4use indexmap::IndexMap;
5#[cfg(feature = "pyo3_macros")]
6use pyo3::{pyclass, pymethods};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::VideoInput;
11
12use crate::{
13 response::Response, sampler::SamplingParams, tools::ToolChoice, AgentPermission,
14 AgentToolApprovalHandler, CodeExecutionPermission, CustomLogitsProcessor,
15 DiffusionGenerationParams, Tool,
16};
17use std::{fmt::Debug, path::PathBuf, sync::Arc};
18use tokio::sync::mpsc::Sender;
19
20pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
21
22#[derive(Clone, Serialize, Deserialize)]
23pub enum Constraint {
25 Regex(String),
26 Lark(String),
27 JsonSchema(serde_json::Value),
28 Llguidance(LlguidanceGrammar),
29 None,
30}
31
32#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
33#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
34#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
35pub enum ImageGenerationResponseFormat {
37 Url,
38 B64Json,
39}
40
41pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
42
43#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
46#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
47#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
48#[serde(rename_all = "lowercase")]
49pub enum ReasoningEffort {
50 Low,
52 #[default]
54 Medium,
55 High,
57}
58
59impl ReasoningEffort {
60 pub fn as_str(&self) -> &'static str {
62 match self {
63 Self::Low => "low",
64 Self::Medium => "medium",
65 Self::High => "high",
66 }
67 }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub enum RequestMessage {
73 Chat {
74 messages: Vec<IndexMap<String, MessageContent>>,
75 enable_thinking: Option<bool>,
76 reasoning_effort: Option<ReasoningEffort>,
78 },
79 Completion {
80 text: String,
81 echo_prompt: bool,
82 best_of: Option<usize>,
83 },
84 CompletionTokens(Vec<u32>),
85 MultimodalChat {
86 #[serde(skip)] images: Vec<image::DynamicImage>,
88 #[serde(skip)] audios: Vec<AudioInput>,
90 #[serde(skip)]
91 videos: Vec<VideoInput>,
92 messages: Vec<IndexMap<String, MessageContent>>,
93 enable_thinking: Option<bool>,
94 reasoning_effort: Option<ReasoningEffort>,
96 },
97 ImageGeneration {
98 prompt: String,
99 format: ImageGenerationResponseFormat,
100 generation_params: DiffusionGenerationParams,
101 save_file: Option<PathBuf>,
102 },
103 SpeechGeneration {
104 prompt: String,
105 },
106 Embedding {
107 prompt: String,
108 },
109 EmbeddingTokens {
110 prompt: Vec<u32>,
111 },
112}
113
114fn default_responder<T>() -> Sender<T> {
115 let (sender, _) = tokio::sync::mpsc::channel(1);
116 sender
117}
118
119#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
120#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
121#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
122pub enum SearchContextSize {
123 #[serde(rename = "low")]
124 Low,
125 #[default]
126 #[serde(rename = "medium")]
127 Medium,
128 #[serde(rename = "high")]
129 High,
130}
131
132#[cfg_attr(feature = "pyo3_macros", pyclass(eq))]
133#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
134#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
135#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
136pub struct ApproximateUserLocation {
137 pub city: String,
138 pub country: String,
139 pub region: String,
140 pub timezone: String,
141}
142
143#[cfg(feature = "pyo3_macros")]
144#[pymethods]
145impl ApproximateUserLocation {
146 #[new]
147 fn py_new(city: String, country: String, region: String, timezone: String) -> Self {
148 Self {
149 city,
150 country,
151 region,
152 timezone,
153 }
154 }
155}
156
157#[cfg_attr(feature = "pyo3_macros", pyclass(eq))]
158#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
159#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
160#[serde(tag = "type")]
161pub enum WebSearchUserLocation {
162 #[serde(rename = "approximate")]
163 Approximate {
164 approximate: ApproximateUserLocation,
165 },
166}
167
168#[cfg(feature = "pyo3_macros")]
169#[pymethods]
170impl WebSearchUserLocation {
171 #[staticmethod]
172 fn approximate(approximate: ApproximateUserLocation) -> Self {
173 Self::Approximate { approximate }
174 }
175}
176
177#[cfg_attr(feature = "pyo3_macros", pyclass(eq))]
178#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
179#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
180#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
181pub struct WebSearchOptions {
182 pub search_context_size: Option<SearchContextSize>,
183 pub user_location: Option<WebSearchUserLocation>,
184 pub search_description: Option<String>,
186 pub extract_description: Option<String>,
188}
189
190#[cfg(feature = "pyo3_macros")]
191#[pymethods]
192impl WebSearchOptions {
193 #[new]
194 #[pyo3(signature = (
195 search_context_size = None,
196 user_location = None,
197 search_description = None,
198 extract_description = None,
199 ))]
200 fn py_new(
201 search_context_size: Option<SearchContextSize>,
202 user_location: Option<WebSearchUserLocation>,
203 search_description: Option<String>,
204 extract_description: Option<String>,
205 ) -> Self {
206 Self {
207 search_context_size,
208 user_location,
209 search_description,
210 extract_description,
211 }
212 }
213}
214
215#[derive(Clone, Serialize, Deserialize)]
216pub struct NormalRequest {
235 pub messages: RequestMessage,
236 pub sampling_params: SamplingParams,
237 #[serde(default = "default_responder")]
238 #[serde(skip)]
239 pub response: Sender<Response>,
240 pub return_logprobs: bool,
241 pub is_streaming: bool,
242 pub id: usize,
243 pub constraint: Constraint,
244 pub suffix: Option<String>,
245 pub tools: Option<Vec<Tool>>,
246 pub tool_choice: Option<ToolChoice>,
247 #[serde(skip)]
248 pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
249 pub return_raw_logits: bool,
250 pub web_search_options: Option<WebSearchOptions>,
251 #[serde(default)]
253 pub enable_code_execution: bool,
254 #[serde(default)]
255 pub code_execution_permission: Option<CodeExecutionPermission>,
256 #[serde(skip)]
257 pub code_execution_approval_notifier: Option<Arc<hanzo_llm_mcp::CodeExecutionApprovalNotifier>>,
258 #[serde(default)]
259 pub agent_permission: Option<AgentPermission>,
260 #[serde(skip)]
261 pub agent_approval_handler: Option<AgentToolApprovalHandler>,
262 #[serde(skip)]
263 pub agent_approval_notifier: Option<Arc<hanzo_llm_mcp::AgentToolApprovalNotifier>>,
264 pub max_tool_rounds: Option<usize>,
265 pub tool_dispatch_url: Option<String>,
267 pub model_id: Option<String>,
268 #[serde(default)]
269 pub truncate_sequence: bool,
270 #[serde(default)]
272 pub session_id: Option<String>,
273 #[serde(default)]
275 pub files: Option<Vec<crate::files::RequestedFile>>,
276}
277
278impl NormalRequest {
279 pub fn new_simple(
280 messages: RequestMessage,
281 sampling_params: SamplingParams,
282 response: Sender<Response>,
283 id: usize,
284 tools: Option<Vec<Tool>>,
285 tool_choice: Option<ToolChoice>,
286 ) -> Self {
287 Self {
288 messages,
289 sampling_params,
290 response,
291 id,
292 tools,
293 tool_choice,
294 return_logprobs: false,
295 is_streaming: false,
296 constraint: Constraint::None,
297 suffix: None,
298 logits_processors: None,
299 return_raw_logits: false,
300 web_search_options: None,
301 enable_code_execution: false,
302 code_execution_permission: None,
303 code_execution_approval_notifier: None,
304 agent_permission: None,
305 agent_approval_handler: None,
306 agent_approval_notifier: None,
307 max_tool_rounds: None,
308 tool_dispatch_url: None,
309 model_id: None,
310 truncate_sequence: false,
311 session_id: None,
312 files: None,
313 }
314 }
315}
316
317#[derive(Clone, Serialize, Deserialize)]
318pub struct TokenizationRequest {
321 pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
322 pub tools: Option<Vec<Tool>>,
323 pub add_generation_prompt: bool,
324 pub add_special_tokens: bool,
325 pub enable_thinking: Option<bool>,
326 pub reasoning_effort: Option<ReasoningEffort>,
327 #[serde(default = "default_responder")]
328 #[serde(skip)]
329 pub response: Sender<anyhow::Result<Vec<u32>>>,
330}
331
332#[derive(Clone, Serialize, Deserialize)]
333pub struct DetokenizationRequest {
335 pub tokens: Vec<u32>,
336 pub skip_special_tokens: bool,
337 #[serde(default = "default_responder")]
338 #[serde(skip)]
339 pub response: Sender<anyhow::Result<String>>,
340}
341
342#[derive(Clone, Serialize, Deserialize)]
343pub enum Request {
346 Normal(Box<NormalRequest>),
347 ReIsq(IsqType),
348 Tokenize(TokenizationRequest),
349 Detokenize(DetokenizationRequest),
350 Terminate,
353 TerminateAllSeqsNextStep,
354}
355
356impl Debug for Request {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 match self {
359 Request::Normal(boxed_req) => {
360 let NormalRequest {
361 messages,
362 sampling_params,
363 is_streaming,
364 id,
365 ..
366 } = &**boxed_req;
367 write!(
368 f,
369 "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
370 )
371 }
372 Request::ReIsq(tp) => {
373 write!(f, "Re ISQ Request {tp:?}",)
374 }
375 Request::Tokenize(req) => {
376 write!(f, "Tokenization Request {:?}", req.text)
377 }
378 Request::Detokenize(req) => {
379 write!(f, "Tokenization Request {:?}", req.tokens)
380 }
381 Request::Terminate => write!(f, "Termination Request"),
382 Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
383 }
384 }
385}