1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle::{Device, Result, Storage, Tensor, WithDType};
4use std::sync::LazyLock;
5use std::{f32, iter::Sum};
6
7use rayon::prelude::*;
8use rayon::ThreadPool;
9
10#[cfg(target_os = "macos")]
11unsafe fn set_thread_affinity() {
13 use libc::{pthread_set_qos_class_self_np, qos_class_t::QOS_CLASS_USER_INTERACTIVE};
16 pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0);
18}
19
20#[cfg(not(target_os = "macos"))]
21#[inline(always)]
22unsafe fn set_thread_affinity() {
23 }
25
26static FLASH_ATTN_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
29 rayon::ThreadPoolBuilder::new()
30 .start_handler(|_| unsafe {
31 set_thread_affinity();
32 })
33 .build()
34 .expect("Failed to build custom Rayon thread‑pool for flash‑attention")
35});
36
37const DOT_CHUNK: usize = 4;
38
39const TILE_KV: usize = 16;
41
42#[inline]
43fn vec_dot<T: WithDType + Sum + Copy + std::ops::Mul<Output = T>>(a: &[T], b: &[T]) -> T {
44 let mut sum = T::zero();
45 let chunks = a.len() / DOT_CHUNK;
46
47 for i in 0..chunks {
48 let i_chunk = i * DOT_CHUNK;
49 sum = sum
50 + a[i_chunk] * b[i_chunk]
51 + a[i_chunk + 1] * b[i_chunk + 1]
52 + a[i_chunk + 2] * b[i_chunk + 2]
53 + a[i_chunk + 3] * b[i_chunk + 3];
54 }
55
56 for i in (chunks * DOT_CHUNK)..a.len() {
57 sum += a[i] * b[i];
58 }
59 sum
60}
61
62pub fn run_flash_attn_cpu<T>(
76 q: &Tensor,
77 k: &Tensor,
78 v: &Tensor,
79 mask: Option<&Tensor>,
80 softmax_scale: f32,
81 max_bias: Option<f32>,
82 softcap: Option<f32>,
83) -> Result<Tensor>
84where
85 T: WithDType + Sum + num_traits::real::Real,
86{
87 let (q_guard, q_layout) = q.storage_and_layout();
89 let q_data: &[T] = if let Storage::Cpu(cpu) = &*q_guard {
90 let data = cpu.as_slice::<T>()?;
91 &data[q_layout.start_offset()..]
92 } else {
93 return Err(candle::Error::Msg("Expected CPU storage for q".into()));
94 };
95 let (k_guard, k_layout) = k.storage_and_layout();
96 let k_data: &[T] = if let Storage::Cpu(cpu) = &*k_guard {
97 let data = cpu.as_slice::<T>()?;
98 &data[k_layout.start_offset()..]
99 } else {
100 return Err(candle::Error::Msg("Expected CPU storage for k".into()));
101 };
102 let (v_guard, v_layout) = v.storage_and_layout();
103 let v_data: &[T] = if let Storage::Cpu(cpu) = &*v_guard {
104 let data = cpu.as_slice::<T>()?;
105 &data[v_layout.start_offset()..]
106 } else {
107 return Err(candle::Error::Msg("Expected CPU storage for v".into()));
108 };
109 let mask_guard = mask.map(|mask| mask.storage_and_layout().0);
110 let mask_data: Option<&[T]> = if let Some(mask_guard) = &mask_guard {
111 let mask = mask.as_ref().unwrap();
112
113 if let Storage::Cpu(cpu) = &**mask_guard {
114 let data = cpu.as_slice::<T>()?;
115 Some(&data[mask.layout().start_offset()..])
116 } else {
117 return Err(candle::Error::Msg("Expected CPU storage for mask".into()));
118 }
119 } else {
120 None
121 };
122 let q_stride = q.stride();
125 let k_stride = k.stride();
126 let v_stride = v.stride();
127
128 if q.shape().dims()[1] == 1 {
130 return flash_attn_cpu_single_q(
131 q_data,
132 k_data,
133 v_data,
134 mask_data,
135 q.shape().dims(),
136 k.shape().dims(),
137 v.shape().dims(),
138 q_stride,
139 k_stride,
140 v_stride,
141 softmax_scale,
142 max_bias.unwrap_or(0.0),
143 softcap.unwrap_or(0.0),
144 );
145 }
146
147 flash_attn_cpu(
148 q_data,
149 k_data,
150 v_data,
151 mask_data,
152 q.shape().dims(),
153 k.shape().dims(),
154 v.shape().dims(),
155 q_stride,
156 k_stride,
157 v_stride,
158 softmax_scale,
159 max_bias.unwrap_or(0.0),
160 softcap.unwrap_or(0.0),
161 )
162}
163
164#[allow(clippy::too_many_arguments)]
167fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
168 q_data: &[T],
169 k_data: &[T],
170 v_data: &[T],
171 mask_vec: Option<&[T]>,
172 qshape: &[usize],
173 kshape: &[usize],
174 vshape: &[usize],
175 qstride: &[usize],
176 kstride: &[usize],
177 vstride: &[usize],
178 scale: f32,
179 max_bias: f32,
180 logit_softcap: f32,
181) -> Result<Tensor> {
182 let (b, _q_len, h, d) = (
184 qshape[0], qshape[1], qshape[2], qshape[3],
186 );
187 let kv_len = kshape[1];
188 let k_h = kshape[2];
189 let v_h = vshape[2];
190 let rk2 = h / k_h;
191 let rv2 = h / v_h;
192 let dv = d;
193
194 let n2 = 2_usize.pow((h as f32).log2().ceil() as u32);
195
196 let mut out = vec![0f32; b * h * dv];
198
199 let kv_tiles = kv_len.div_ceil(TILE_KV);
202
203 FLASH_ATTN_POOL.install(|| {
206 out.par_chunks_mut(dv)
207 .with_min_len(64)
208 .enumerate()
209 .for_each(|(row_idx, out_chunk)| {
210 let b_i = row_idx / h;
211 let h_i = row_idx % h;
212
213 let slope = if max_bias > 0.0 {
215 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32)
216 } else {
217 1.0
218 };
219
220 let k_head = h_i / rk2;
222 let v_head = h_i / rv2;
223
224 let (vkq, s_tot, _m_tot) = (0..kv_tiles)
229 .into_par_iter()
230 .map(|tile_idx| {
231 let start = tile_idx * TILE_KV;
233 let end = (start + TILE_KV).min(kv_len);
234
235 let mut vkq = vec![0f32; dv];
236 let mut s = 0.0f32;
237 let mut m = f32::NEG_INFINITY;
238
239 let q_base =
241 b_i * qstride[0] + h_i * qstride[2] ;
242 let q_row = &q_data[q_base..q_base + d];
243
244 for kv_pos in start..end {
246 let mv = if let Some(mv_vec) = mask_vec {
248 let mval = mv_vec[(b_i * kv_len) + kv_pos];
249 slope * mval.to_f64() as f32
250 } else {
251 0.0
252 };
253 if mv == f32::NEG_INFINITY {
254 continue;
255 }
256
257 let k_base =
259 b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2];
260 let k_row = &k_data[k_base..k_base + d];
261
262 let mut s_val = vec_dot::<T>(q_row, k_row).to_f64() as f32;
264
265 let mut scale_applied = scale;
266 if logit_softcap != 0.0 {
267 scale_applied /= logit_softcap;
268 }
269 s_val *= scale_applied;
270 if logit_softcap != 0.0 {
271 s_val = logit_softcap * s_val.tanh();
272 }
273 s_val += mv;
274
275 let m_old = m;
277 let mut ms = 1.0f32;
278 let mut vs = 1.0f32;
279 if s_val > m {
280 m = s_val;
281 ms = (m_old - m).exp();
282 for v in vkq.iter_mut() {
283 *v *= ms;
284 }
285 } else {
286 vs = (s_val - m).exp();
287 }
288
289 let v_base =
291 b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2];
292 for d_i in 0..dv {
293 vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs;
294 }
295
296 s = s * ms + vs;
297 }
298
299 (vkq, s, m)
301 })
302 .reduce(
304 || (vec![0f32; dv], 0.0f32, f32::NEG_INFINITY),
305 |mut a, b| {
306 let (ref mut vkq_a, mut s_a, m_a) = a;
307 let (vkq_b, s_b, m_b) = b;
308 if m_a >= m_b {
309 let factor = (m_b - m_a).exp();
310 for (va, vb) in vkq_a.iter_mut().zip(vkq_b) {
311 *va += vb * factor;
312 }
313 s_a += s_b * factor;
314 (vkq_a.clone(), s_a, m_a)
315 } else {
316 let factor = (m_a - m_b).exp();
317 let mut vkq_new = vkq_b;
318 for (vb, va) in vkq_new.iter_mut().zip(vkq_a) {
319 *vb += *va * factor;
320 }
321 (vkq_new, s_b + s_a * factor, m_b)
322 }
323 },
324 );
325
326 let inv_s = 1.0 / s_tot;
328 for v in out_chunk.iter_mut().zip(vkq.iter()) {
329 *v.0 = *v.1 * inv_s;
330 }
331 });
332 });
333
334 let out_shape = (b, h, 1usize, dv);
335 Tensor::from_vec(out, out_shape, &Device::Cpu)
336}
337
338#[allow(clippy::too_many_arguments)]
341fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
342 q_data: &[T],
343 k_data: &[T],
344 v_data: &[T],
345 mask_vec: Option<&[T]>,
346 qshape: &[usize],
347 kshape: &[usize],
348 vshape: &[usize],
349 qstride: &[usize],
350 kstride: &[usize],
351 vstride: &[usize],
352 scale: f32,
353 max_bias: f32,
354 logit_softcap: f32,
355) -> Result<Tensor> {
356 let (b, q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]);
357 let kv_len = kshape[1];
358 let k_h = kshape[2];
362 let v_h = vshape[2];
363 let rk2 = h / k_h; let rv2 = h / v_h;
365 let dv = d; let n2 = 2_usize.pow((h as f32).log2().ceil() as u32);
369
370 let mut out = vec![0f32; b * q_len * h * dv];
371
372 let _rows = b * h * q_len; FLASH_ATTN_POOL.install(|| {
381 out.par_chunks_mut(dv)
382 .with_min_len(64)
383 .enumerate()
384 .for_each(|(row_idx, out_chunk)| {
385 let rows_per_batch = h * q_len;
387 let b_i = row_idx / rows_per_batch;
388 let rem = row_idx % rows_per_batch;
389 let h_i = rem / q_len;
390 let q_pos = rem % q_len;
391
392 let slope = if max_bias > 0.0 {
393 2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32)
394 } else {
395 1.0
396 };
397
398 let k_head = h_i / rk2;
400 let v_head = h_i / rv2;
401
402 let mut vkq = vec![0f32; dv];
404 let mut s = 0.0f32;
405 let mut m = f32::NEG_INFINITY;
406
407 let mut q_row: Vec<T> = Vec::with_capacity(d);
409 let mut k_row: Vec<T> = Vec::with_capacity(d);
410
411 let q_base = b_i * qstride[0] + q_pos * qstride[1] + h_i * qstride[2];
413 q_row.clear();
414 for di in 0..d {
415 q_row.push(q_data[q_base + di * qstride[3]]);
416 }
417
418 for kv_pos in 0..kv_len {
420 let mv = if let Some(mv_vec) = mask_vec {
422 let mval = mv_vec[((b_i * q_len + q_pos) * kv_len) + kv_pos];
423 slope * mval.to_f64() as f32
424 } else {
425 0.0
426 };
427 if mv == f32::NEG_INFINITY {
428 continue;
429 }
430
431 let k_base = b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2];
433 k_row.clear();
434 for di in 0..d {
435 k_row.push(k_data[k_base + di * kstride[3]]);
436 }
437
438 let mut s_val = vec_dot::<T>(&q_row, &k_row);
440 let mut scale_applied = scale;
441 if logit_softcap != 0.0 {
442 scale_applied /= logit_softcap;
443 }
444 s_val *= T::from_f64(scale_applied as f64);
445 if logit_softcap != 0.0 {
446 s_val = T::from_f64(logit_softcap as f64 * s_val.to_f64().tanh());
447 }
448 s_val += T::from_f64(mv as f64);
449
450 let m_old = m;
452 let mut ms = 1.0f32;
453 let mut vs = 1.0f32;
454 if s_val.to_f64() as f32 > m {
455 m = s_val.to_f64() as f32;
456 ms = (m_old - m).exp();
457 for v in vkq.iter_mut() {
458 *v *= ms;
459 }
460 } else {
461 vs = (s_val.to_f64() as f32 - m).exp();
462 }
463
464 let v_base = b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2];
466 for d_i in 0..dv {
467 vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs;
468 }
469
470 s = s * ms + vs;
471 }
472
473 let inv_s = 1.0 / s;
475 for v in vkq.iter_mut() {
476 *v *= inv_s;
477 }
478 out_chunk.copy_from_slice(&vkq);
479 });
480 });
481
482 let out_shape = (b, h, q_len, dv);
484 Tensor::from_vec(out, out_shape, &Device::Cpu)
485}