Skip to main content

hanzo_engine/
request.rs

1use either::Either;
2use hanzo_audio::AudioInput;
3use hanzo_quant::IsqType;
4use indexmap::IndexMap;
5#[cfg(feature = "pyo3_macros")]
6use pyo3::{pyclass, pymethods};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::VideoInput;
11
12use crate::{
13    response::Response, sampler::SamplingParams, tools::ToolChoice, AgentPermission,
14    AgentToolApprovalHandler, CodeExecutionPermission, CustomLogitsProcessor,
15    DiffusionGenerationParams, Tool,
16};
17use std::{fmt::Debug, path::PathBuf, sync::Arc};
18use tokio::sync::mpsc::Sender;
19
20pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
21
22#[derive(Clone, Serialize, Deserialize)]
23/// Control the constraint with llguidance.
24pub enum Constraint {
25    Regex(String),
26    Lark(String),
27    JsonSchema(serde_json::Value),
28    Llguidance(LlguidanceGrammar),
29    None,
30}
31
32#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
33#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
34#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
35/// Image generation response format
36pub enum ImageGenerationResponseFormat {
37    Url,
38    B64Json,
39}
40
41pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
42
43/// Reasoning effort level for models that support it (e.g., GPT-OSS with Harmony format).
44/// Controls the depth of reasoning/analysis in the model's response.
45#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
46#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
47#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
48#[serde(rename_all = "lowercase")]
49pub enum ReasoningEffort {
50    /// Minimal reasoning, faster responses
51    Low,
52    /// Balanced reasoning depth
53    #[default]
54    Medium,
55    /// Deep reasoning, more thorough analysis
56    High,
57}
58
59impl ReasoningEffort {
60    /// Convert to string representation for chat template
61    pub fn as_str(&self) -> &'static str {
62        match self {
63            Self::Low => "low",
64            Self::Medium => "medium",
65            Self::High => "high",
66        }
67    }
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71/// Message or messages for a [`Request`].
72pub enum RequestMessage {
73    Chat {
74        messages: Vec<IndexMap<String, MessageContent>>,
75        enable_thinking: Option<bool>,
76        /// Reasoning effort level for Harmony-format models
77        reasoning_effort: Option<ReasoningEffort>,
78    },
79    Completion {
80        text: String,
81        echo_prompt: bool,
82        best_of: Option<usize>,
83    },
84    CompletionTokens(Vec<u32>),
85    MultimodalChat {
86        #[serde(skip)] // TODO
87        images: Vec<image::DynamicImage>,
88        #[serde(skip)] // TODO
89        audios: Vec<AudioInput>,
90        #[serde(skip)]
91        videos: Vec<VideoInput>,
92        messages: Vec<IndexMap<String, MessageContent>>,
93        enable_thinking: Option<bool>,
94        /// Reasoning effort level for Harmony-format models
95        reasoning_effort: Option<ReasoningEffort>,
96    },
97    ImageGeneration {
98        prompt: String,
99        format: ImageGenerationResponseFormat,
100        generation_params: DiffusionGenerationParams,
101        save_file: Option<PathBuf>,
102    },
103    SpeechGeneration {
104        prompt: String,
105    },
106    Embedding {
107        prompt: String,
108    },
109    EmbeddingTokens {
110        prompt: Vec<u32>,
111    },
112}
113
114fn default_responder<T>() -> Sender<T> {
115    let (sender, _) = tokio::sync::mpsc::channel(1);
116    sender
117}
118
119#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
120#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
121#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
122pub enum SearchContextSize {
123    #[serde(rename = "low")]
124    Low,
125    #[default]
126    #[serde(rename = "medium")]
127    Medium,
128    #[serde(rename = "high")]
129    High,
130}
131
132#[cfg_attr(feature = "pyo3_macros", pyclass(eq))]
133#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
134#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
135#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
136pub struct ApproximateUserLocation {
137    pub city: String,
138    pub country: String,
139    pub region: String,
140    pub timezone: String,
141}
142
143#[cfg(feature = "pyo3_macros")]
144#[pymethods]
145impl ApproximateUserLocation {
146    #[new]
147    fn py_new(city: String, country: String, region: String, timezone: String) -> Self {
148        Self {
149            city,
150            country,
151            region,
152            timezone,
153        }
154    }
155}
156
157#[cfg_attr(feature = "pyo3_macros", pyclass(eq))]
158#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
159#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
160#[serde(tag = "type")]
161pub enum WebSearchUserLocation {
162    #[serde(rename = "approximate")]
163    Approximate {
164        approximate: ApproximateUserLocation,
165    },
166}
167
168#[cfg(feature = "pyo3_macros")]
169#[pymethods]
170impl WebSearchUserLocation {
171    #[staticmethod]
172    fn approximate(approximate: ApproximateUserLocation) -> Self {
173        Self::Approximate { approximate }
174    }
175}
176
177#[cfg_attr(feature = "pyo3_macros", pyclass(eq))]
178#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
179#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
180#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
181pub struct WebSearchOptions {
182    pub search_context_size: Option<SearchContextSize>,
183    pub user_location: Option<WebSearchUserLocation>,
184    /// Override the description for the search tool.
185    pub search_description: Option<String>,
186    /// Override the description for the extraction tool.
187    pub extract_description: Option<String>,
188}
189
190#[cfg(feature = "pyo3_macros")]
191#[pymethods]
192impl WebSearchOptions {
193    #[new]
194    #[pyo3(signature = (
195        search_context_size = None,
196        user_location = None,
197        search_description = None,
198        extract_description = None,
199    ))]
200    fn py_new(
201        search_context_size: Option<SearchContextSize>,
202        user_location: Option<WebSearchUserLocation>,
203        search_description: Option<String>,
204        extract_description: Option<String>,
205    ) -> Self {
206        Self {
207            search_context_size,
208            user_location,
209            search_description,
210            extract_description,
211        }
212    }
213}
214
215#[derive(Clone, Serialize, Deserialize)]
216/// A normal request request to the `Hanzo`.
217/// - `messages`: Messages for the request
218/// - `sampling_params`: Sampling parameters for generation
219/// - `response`: Object to send the result through
220/// - `return_logprobs`: Whether to return logprobs
221/// - `is_streaming`: Control whether the request is streaming, if so chunk responses will be sent
222/// - `id`: Request ID
223/// - `constraint`: Constraint to use during generation
224/// - `suffix`: Suffix to add
225/// - `tools`: Tools available in this request
226/// - `tool_choice`: Choice of tools
227/// - `logits_processors`: Custom logits processors. Order of application:
228///     1) Apply penalties from `sampling_params`
229///     2) Apply these custom logits processors sequentially
230///     3) Apply temperature and softmax
231///     4) Sample the next token (topk, topp, minp, etc)
232/// - `return_raw_logits`: Return raw logits.
233/// - `truncate_sequence`: Whether to truncate the prompt if it exceeds the model's maximum context length.
234pub struct NormalRequest {
235    pub messages: RequestMessage,
236    pub sampling_params: SamplingParams,
237    #[serde(default = "default_responder")]
238    #[serde(skip)]
239    pub response: Sender<Response>,
240    pub return_logprobs: bool,
241    pub is_streaming: bool,
242    pub id: usize,
243    pub constraint: Constraint,
244    pub suffix: Option<String>,
245    pub tools: Option<Vec<Tool>>,
246    pub tool_choice: Option<ToolChoice>,
247    #[serde(skip)]
248    pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
249    pub return_raw_logits: bool,
250    pub web_search_options: Option<WebSearchOptions>,
251    /// When true, registered code-execution tools are injected and the agentic loop runs.
252    #[serde(default)]
253    pub enable_code_execution: bool,
254    #[serde(default)]
255    pub code_execution_permission: Option<CodeExecutionPermission>,
256    #[serde(skip)]
257    pub code_execution_approval_notifier: Option<Arc<hanzo_llm_mcp::CodeExecutionApprovalNotifier>>,
258    #[serde(default)]
259    pub agent_permission: Option<AgentPermission>,
260    #[serde(skip)]
261    pub agent_approval_handler: Option<AgentToolApprovalHandler>,
262    #[serde(skip)]
263    pub agent_approval_notifier: Option<Arc<hanzo_llm_mcp::AgentToolApprovalNotifier>>,
264    pub max_tool_rounds: Option<usize>,
265    /// URL to POST `{"name": ..., "arguments": ...}` to when no server-side callback is registered. Expects `{"content": "..."}` back.
266    pub tool_dispatch_url: Option<String>,
267    pub model_id: Option<String>,
268    #[serde(default)]
269    pub truncate_sequence: bool,
270    /// Persistent agentic state. If `None`, a new session is created and the ID is returned in the response.
271    #[serde(default)]
272    pub session_id: Option<String>,
273    /// Required output files. The runtime asks the model to produce them and surfaces a `File` (or error placeholder) for each.
274    #[serde(default)]
275    pub files: Option<Vec<crate::files::RequestedFile>>,
276}
277
278impl NormalRequest {
279    pub fn new_simple(
280        messages: RequestMessage,
281        sampling_params: SamplingParams,
282        response: Sender<Response>,
283        id: usize,
284        tools: Option<Vec<Tool>>,
285        tool_choice: Option<ToolChoice>,
286    ) -> Self {
287        Self {
288            messages,
289            sampling_params,
290            response,
291            id,
292            tools,
293            tool_choice,
294            return_logprobs: false,
295            is_streaming: false,
296            constraint: Constraint::None,
297            suffix: None,
298            logits_processors: None,
299            return_raw_logits: false,
300            web_search_options: None,
301            enable_code_execution: false,
302            code_execution_permission: None,
303            code_execution_approval_notifier: None,
304            agent_permission: None,
305            agent_approval_handler: None,
306            agent_approval_notifier: None,
307            max_tool_rounds: None,
308            tool_dispatch_url: None,
309            model_id: None,
310            truncate_sequence: false,
311            session_id: None,
312            files: None,
313        }
314    }
315}
316
317#[derive(Clone, Serialize, Deserialize)]
318/// Request to tokenize some messages or some text.
319/// - `add_generation_prompt` is only applicable if chat messages are provided and not a raw string.
320pub struct TokenizationRequest {
321    pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
322    pub tools: Option<Vec<Tool>>,
323    pub add_generation_prompt: bool,
324    pub add_special_tokens: bool,
325    pub enable_thinking: Option<bool>,
326    pub reasoning_effort: Option<ReasoningEffort>,
327    #[serde(default = "default_responder")]
328    #[serde(skip)]
329    pub response: Sender<anyhow::Result<Vec<u32>>>,
330}
331
332#[derive(Clone, Serialize, Deserialize)]
333/// Request to detokenize some text.
334pub struct DetokenizationRequest {
335    pub tokens: Vec<u32>,
336    pub skip_special_tokens: bool,
337    #[serde(default = "default_responder")]
338    #[serde(skip)]
339    pub response: Sender<anyhow::Result<String>>,
340}
341
342#[derive(Clone, Serialize, Deserialize)]
343/// A request to the Engine, encapsulating the various parameters as well as
344/// the `mpsc` response `Sender` used to return the [`Response`].
345pub enum Request {
346    Normal(Box<NormalRequest>),
347    ReIsq(IsqType),
348    Tokenize(TokenizationRequest),
349    Detokenize(DetokenizationRequest),
350    // Sending a terminate request causes the `run` function to return to the thread created in `Hanzo::new`,
351    // and then Engine will be dropped.
352    Terminate,
353    TerminateAllSeqsNextStep,
354}
355
356impl Debug for Request {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        match self {
359            Request::Normal(boxed_req) => {
360                let NormalRequest {
361                    messages,
362                    sampling_params,
363                    is_streaming,
364                    id,
365                    ..
366                } = &**boxed_req;
367                write!(
368                    f,
369                    "Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
370                )
371            }
372            Request::ReIsq(tp) => {
373                write!(f, "Re ISQ Request {tp:?}",)
374            }
375            Request::Tokenize(req) => {
376                write!(f, "Tokenization Request {:?}", req.text)
377            }
378            Request::Detokenize(req) => {
379                write!(f, "Tokenization Request {:?}", req.tokens)
380            }
381            Request::Terminate => write!(f, "Termination Request"),
382            Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
383        }
384    }
385}