Skip to main content

hanzo_engine/utils/
mod.rs

1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16    ($thing:expr) => {
17        loop {
18            if let Ok(inner) = $thing.try_lock() {
19                break inner;
20            }
21            // Yield to allow other threads to make progress and release the lock.
22            // This prevents deadlock when a spawned async task busy-loops while
23            // another task holds the lock across an await point.
24            std::thread::yield_now();
25        }
26    };
27}
28
29#[doc(hidden)]
30#[macro_export]
31macro_rules! handle_seq_error {
32    ($fallible:expr, $response:expr) => {
33        match $fallible {
34            Ok(v) => v,
35            Err(e) => {
36                use $crate::response::Response;
37                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
38                    tracing::warn!("Receiver disconnected");
39                }
40                return;
41            }
42        }
43    };
44}
45
46#[doc(hidden)]
47#[macro_export]
48macro_rules! handle_seq_error_ok {
49    ($fallible:expr, $response:expr) => {
50        match $fallible {
51            Ok(v) => v,
52            Err(e) => {
53                use $crate::response::Response;
54                if let Err(_) = $response.send(Response::InternalError(e.into())).await {
55                    tracing::warn!("Receiver disconnected");
56                }
57                return Ok(());
58            }
59        }
60    };
61}
62
63#[doc(hidden)]
64#[macro_export]
65macro_rules! handle_seq_error_stateaware_ok {
66    ($fallible:expr, $seq:expr) => {
67        match $fallible {
68            Ok(v) => v,
69            Err(e) => {
70                use $crate::response::Response;
71                use $crate::sequence::SequenceState;
72                if let Err(_) = $seq
73                    .responder()
74                    .send(Response::InternalError(e.into()))
75                    .await
76                {
77                    tracing::warn!("Receiver disconnected");
78                }
79                $seq.set_state(SequenceState::Error);
80                return Ok(());
81            }
82        }
83    };
84}
85
86#[doc(hidden)]
87#[macro_export]
88macro_rules! handle_pipeline_forward_error {
89    ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
90        match $fallible {
91            Ok(v) => v,
92            Err(e) => {
93                // Auto-retry on iOS Metal background GPU error: when the iOS app
94                // goes to background, Metal rejects command buffers. We detect this,
95                // reset cache, sleep, and let the engine loop retry. Sequences stay
96                // in the scheduler (still in Running state) and are re-scheduled.
97                #[cfg(feature = "metal")]
98                {
99                    let err_str = e.to_string();
100                    if err_str.contains("Insufficient Permission")
101                        || err_str.contains("BackgroundExecutionNotPermitted")
102                    {
103                        tracing::warn!(
104                            "Metal GPU background error detected (iOS app likely in background). \
105                             Pausing 1s before retry..."
106                        );
107                        {
108                            let p = get_mut_arcmutex!($pipeline);
109                            p.set_none_cache($seq_slice, true, true, false);
110                        }
111                        get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
112                        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
113                        continue $label;
114                    }
115                }
116
117                let (tokenizer, pipeline_name) = {
118                    let pipeline = get_mut_arcmutex!($pipeline);
119                    let pipeline_name = pipeline.name();
120                    let tokenizer = pipeline.tokenizer();
121                    (tokenizer, pipeline_name)
122                };
123                use $crate::response::Response;
124                use $crate::sequence::SequenceState;
125                use $crate::response::SYSTEM_FINGERPRINT;
126                use tracing::error;
127                error!("{} - Model failed with error: {:?}", $stage, &e);
128                for seq in $seq_slice.iter_mut() {
129                    // Step 1: Add all choices to groups
130                    let start = seq.prompt_tokens().min(seq.get_toks().len());
131                    let res = match &tokenizer {
132                        Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
133                            Ok(t) => t,
134                            Err(_) => "".to_string(),
135                        },
136                        None => "".to_string(),
137                    };
138
139                    if seq.get_mut_group().is_chat {
140                        let choice = Choice {
141                            finish_reason: "error".to_string(),
142                            index: seq.get_response_index(),
143                            message: ResponseMessage {
144                                content: Some(res),
145                                role: "assistant".to_string(),
146                                tool_calls: None,
147                                reasoning_content: None,
148                            },
149                            logprobs: None,
150                        };
151                        seq.add_choice_to_group(choice);
152                    } else {
153                        let choice = CompletionChoice {
154                            finish_reason: "error".to_string(),
155                            index: seq.get_response_index(),
156                            text: res,
157                            logprobs: None,
158                        };
159                        seq.add_completion_choice_to_group(choice);
160                    }
161                }
162                for seq in $seq_slice.iter_mut() {
163                    // Step 2: Respond with all groups
164                    let group = seq.get_mut_group();
165
166                    if group.is_chat {
167                        let partial_completion_response = ChatCompletionResponse {
168                            id: seq.id().to_string(),
169                            choices: group.get_choices().to_vec(),
170                            created: seq.creation_time(),
171                            model: pipeline_name.clone(),
172                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
173                            object: "chat.completion".to_string(),
174                            usage: group.get_usage(),
175                            agentic_tool_calls: None,
176                            files: None,
177                            session_id: None,
178                        };
179
180                        seq.responder()
181                            .send(Response::ModelError(
182                                e.to_string(),
183                                partial_completion_response
184                            ))
185                            .await
186                            .unwrap();
187                    } else {
188                        let partial_completion_response = CompletionResponse {
189                            id: seq.id().to_string(),
190                            choices: group.get_completion_choices().to_vec(),
191                            created: seq.creation_time(),
192                            model: pipeline_name.clone(),
193                            system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
194                            object: "text_completion".to_string(),
195                            usage: group.get_usage(),
196                        };
197
198                        seq.responder()
199                            .send(Response::CompletionModelError(
200                                e.to_string(),
201                                partial_completion_response
202                            ))
203                            .await
204                            .unwrap();
205                    }
206                }
207                for seq in $seq_slice.iter_mut() {
208                    // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell
209                    seq.set_state(SequenceState::Error);
210                }
211
212                let p = get_mut_arcmutex!($pipeline);
213                // Also reset non granular state because:
214                // - The sequence is gone
215                // - We should reset the state then, including draft.
216                p.set_none_cache($seq_slice, true, true, false);
217                get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
218
219                continue $label;
220            }
221        }
222    };
223}
224
225#[doc(hidden)]
226#[macro_export]
227macro_rules! get_mut_group {
228    ($this:expr) => {
229        loop {
230            if let Ok(inner) = $this.group.try_lock() {
231                break inner;
232            }
233            // Yield to allow other threads to make progress and release the lock.
234            std::thread::yield_now();
235        }
236    };
237}
238
239#[doc(hidden)]
240#[macro_export]
241macro_rules! serde_default_fn {
242    ($t:ty, $name:ident, $v:expr) => {
243        fn $name() -> $t {
244            $v
245        }
246    };
247}
248
249/// `true` if built with CUDA (requires Unix) /Metal
250#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
251pub const fn paged_attn_supported() -> bool {
252    true
253}
254
255/// `true` if built with CUDA (requires Unix) /Metal
256#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
257pub const fn paged_attn_supported() -> bool {
258    false
259}
260
261/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
262#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
263pub const fn using_flash_attn() -> bool {
264    false
265}
266
267/// `true` if built with the `flash-attn` or `flash-attn-v3` features, false otherwise.
268#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
269pub const fn using_flash_attn() -> bool {
270    true
271}