hanzo_engine/utils/
mod.rs1pub(crate) mod debug;
2pub(crate) mod gguf_metadata;
3pub(crate) mod memory_usage;
4pub(crate) mod model_config;
5pub(crate) mod normal;
6pub(crate) mod progress;
7pub(crate) mod tiktoken;
8pub(crate) mod tokenizer;
9pub(crate) mod tokens;
10pub(crate) mod unvarbuilder;
11pub(crate) mod varbuilder_utils;
12
13#[doc(hidden)]
14#[macro_export]
15macro_rules! get_mut_arcmutex {
16 ($thing:expr) => {
17 loop {
18 if let Ok(inner) = $thing.try_lock() {
19 break inner;
20 }
21 std::thread::yield_now();
25 }
26 };
27}
28
29#[doc(hidden)]
30#[macro_export]
31macro_rules! handle_seq_error {
32 ($fallible:expr, $response:expr) => {
33 match $fallible {
34 Ok(v) => v,
35 Err(e) => {
36 use $crate::response::Response;
37 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
38 tracing::warn!("Receiver disconnected");
39 }
40 return;
41 }
42 }
43 };
44}
45
46#[doc(hidden)]
47#[macro_export]
48macro_rules! handle_seq_error_ok {
49 ($fallible:expr, $response:expr) => {
50 match $fallible {
51 Ok(v) => v,
52 Err(e) => {
53 use $crate::response::Response;
54 if let Err(_) = $response.send(Response::InternalError(e.into())).await {
55 tracing::warn!("Receiver disconnected");
56 }
57 return Ok(());
58 }
59 }
60 };
61}
62
63#[doc(hidden)]
64#[macro_export]
65macro_rules! handle_seq_error_stateaware_ok {
66 ($fallible:expr, $seq:expr) => {
67 match $fallible {
68 Ok(v) => v,
69 Err(e) => {
70 use $crate::response::Response;
71 use $crate::sequence::SequenceState;
72 if let Err(_) = $seq
73 .responder()
74 .send(Response::InternalError(e.into()))
75 .await
76 {
77 tracing::warn!("Receiver disconnected");
78 }
79 $seq.set_state(SequenceState::Error);
80 return Ok(());
81 }
82 }
83 };
84}
85
86#[doc(hidden)]
87#[macro_export]
88macro_rules! handle_pipeline_forward_error {
89 ($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
90 match $fallible {
91 Ok(v) => v,
92 Err(e) => {
93 #[cfg(feature = "metal")]
98 {
99 let err_str = e.to_string();
100 if err_str.contains("Insufficient Permission")
101 || err_str.contains("BackgroundExecutionNotPermitted")
102 {
103 tracing::warn!(
104 "Metal GPU background error detected (iOS app likely in background). \
105 Pausing 1s before retry..."
106 );
107 {
108 let p = get_mut_arcmutex!($pipeline);
109 p.set_none_cache($seq_slice, true, true, false);
110 }
111 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
112 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
113 continue $label;
114 }
115 }
116
117 let (tokenizer, pipeline_name) = {
118 let pipeline = get_mut_arcmutex!($pipeline);
119 let pipeline_name = pipeline.name();
120 let tokenizer = pipeline.tokenizer();
121 (tokenizer, pipeline_name)
122 };
123 use $crate::response::Response;
124 use $crate::sequence::SequenceState;
125 use $crate::response::SYSTEM_FINGERPRINT;
126 use tracing::error;
127 error!("{} - Model failed with error: {:?}", $stage, &e);
128 for seq in $seq_slice.iter_mut() {
129 let start = seq.prompt_tokens().min(seq.get_toks().len());
131 let res = match &tokenizer {
132 Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
133 Ok(t) => t,
134 Err(_) => "".to_string(),
135 },
136 None => "".to_string(),
137 };
138
139 if seq.get_mut_group().is_chat {
140 let choice = Choice {
141 finish_reason: "error".to_string(),
142 index: seq.get_response_index(),
143 message: ResponseMessage {
144 content: Some(res),
145 role: "assistant".to_string(),
146 tool_calls: None,
147 reasoning_content: None,
148 },
149 logprobs: None,
150 };
151 seq.add_choice_to_group(choice);
152 } else {
153 let choice = CompletionChoice {
154 finish_reason: "error".to_string(),
155 index: seq.get_response_index(),
156 text: res,
157 logprobs: None,
158 };
159 seq.add_completion_choice_to_group(choice);
160 }
161 }
162 for seq in $seq_slice.iter_mut() {
163 let group = seq.get_mut_group();
165
166 if group.is_chat {
167 let partial_completion_response = ChatCompletionResponse {
168 id: seq.id().to_string(),
169 choices: group.get_choices().to_vec(),
170 created: seq.creation_time(),
171 model: pipeline_name.clone(),
172 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
173 object: "chat.completion".to_string(),
174 usage: group.get_usage(),
175 agentic_tool_calls: None,
176 files: None,
177 session_id: None,
178 };
179
180 seq.responder()
181 .send(Response::ModelError(
182 e.to_string(),
183 partial_completion_response
184 ))
185 .await
186 .unwrap();
187 } else {
188 let partial_completion_response = CompletionResponse {
189 id: seq.id().to_string(),
190 choices: group.get_completion_choices().to_vec(),
191 created: seq.creation_time(),
192 model: pipeline_name.clone(),
193 system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
194 object: "text_completion".to_string(),
195 usage: group.get_usage(),
196 };
197
198 seq.responder()
199 .send(Response::CompletionModelError(
200 e.to_string(),
201 partial_completion_response
202 ))
203 .await
204 .unwrap();
205 }
206 }
207 for seq in $seq_slice.iter_mut() {
208 seq.set_state(SequenceState::Error);
210 }
211
212 let p = get_mut_arcmutex!($pipeline);
213 p.set_none_cache($seq_slice, true, true, false);
217 get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
218
219 continue $label;
220 }
221 }
222 };
223}
224
225#[doc(hidden)]
226#[macro_export]
227macro_rules! get_mut_group {
228 ($this:expr) => {
229 loop {
230 if let Ok(inner) = $this.group.try_lock() {
231 break inner;
232 }
233 std::thread::yield_now();
235 }
236 };
237}
238
239#[doc(hidden)]
240#[macro_export]
241macro_rules! serde_default_fn {
242 ($t:ty, $name:ident, $v:expr) => {
243 fn $name() -> $t {
244 $v
245 }
246 };
247}
248
249#[cfg(any(all(feature = "cuda", target_family = "unix"), feature = "metal"))]
251pub const fn paged_attn_supported() -> bool {
252 true
253}
254
255#[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))]
257pub const fn paged_attn_supported() -> bool {
258 false
259}
260
261#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
263pub const fn using_flash_attn() -> bool {
264 false
265}
266
267#[cfg(any(feature = "flash-attn", feature = "flash-attn-v3"))]
269pub const fn using_flash_attn() -> bool {
270 true
271}