Skip to main content

candle_nn/
cpu_flash_attention.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle::{Device, Result, Storage, Tensor, WithDType};
4use std::sync::LazyLock;
5use std::{f32, iter::Sum};
6
7use rayon::prelude::*;
8use rayon::ThreadPool;
9
10#[cfg(target_os = "macos")]
11/// Elevate the thread QoS so macOS prefers running it on Performance (P) cores.
12unsafe fn set_thread_affinity() {
13    // USER_INTERACTIVE has the highest scheduling priority that user code
14    // can request and is most likely to be scheduled on P‑cores.
15    use libc::{pthread_set_qos_class_self_np, qos_class_t::QOS_CLASS_USER_INTERACTIVE};
16    // The second argument is a relative priority within the QoS class (0 = default).
17    pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0);
18}
19
20#[cfg(not(target_os = "macos"))]
21#[inline(always)]
22unsafe fn set_thread_affinity() {
23    // On non‑macOS platforms we currently leave affinity untouched.
24}
25
26/// Rayon pool used by the flash‑attention CPU kernels, with a per‑thread
27/// start handler that applies our affinity hint exactly once.
28static FLASH_ATTN_POOL: LazyLock<ThreadPool> = LazyLock::new(|| {
29    rayon::ThreadPoolBuilder::new()
30        .start_handler(|_| unsafe {
31            set_thread_affinity();
32        })
33        .build()
34        .expect("Failed to build custom Rayon thread‑pool for flash‑attention")
35});
36
37const DOT_CHUNK: usize = 4;
38
39/// Size (in KV positions) processed by each inner‑tile job.
40const TILE_KV: usize = 16;
41
42#[inline]
43fn vec_dot<T: WithDType + Sum + Copy + std::ops::Mul<Output = T>>(a: &[T], b: &[T]) -> T {
44    let mut sum = T::zero();
45    let chunks = a.len() / DOT_CHUNK;
46
47    for i in 0..chunks {
48        let i_chunk = i * DOT_CHUNK;
49        sum = sum
50            + a[i_chunk] * b[i_chunk]
51            + a[i_chunk + 1] * b[i_chunk + 1]
52            + a[i_chunk + 2] * b[i_chunk + 2]
53            + a[i_chunk + 3] * b[i_chunk + 3];
54    }
55
56    for i in (chunks * DOT_CHUNK)..a.len() {
57        sum += a[i] * b[i];
58    }
59    sum
60}
61
62/// Fused attention optimized for CPU.
63///
64/// Computes softmax(qk^T*scale)v.
65///
66/// **Inputs shapes:**
67/// - `q`: (bs, seq, qhead, hidden)
68/// - `k`: (bs, kv_seq, v_head, hidden)
69/// - `k`: (bs, kv_seq, kv_head_seq, v_hidden)
70/// - `scale` is applied before softmax.
71///
72/// - This supports ALiBi with `max_bias` as well as softcapping with `softcap`.
73///
74/// **Output shape:** (bs, qhead, seq, v_hidden)
75pub fn run_flash_attn_cpu<T>(
76    q: &Tensor,
77    k: &Tensor,
78    v: &Tensor,
79    mask: Option<&Tensor>,
80    softmax_scale: f32,
81    max_bias: Option<f32>,
82    softcap: Option<f32>,
83) -> Result<Tensor>
84where
85    T: WithDType + Sum + num_traits::real::Real,
86{
87    // Inline CPU slice extraction for q, k, v, and optional mask
88    let (q_guard, q_layout) = q.storage_and_layout();
89    let q_data: &[T] = if let Storage::Cpu(cpu) = &*q_guard {
90        let data = cpu.as_slice::<T>()?;
91        &data[q_layout.start_offset()..]
92    } else {
93        return Err(candle::Error::Msg("Expected CPU storage for q".into()));
94    };
95    let (k_guard, k_layout) = k.storage_and_layout();
96    let k_data: &[T] = if let Storage::Cpu(cpu) = &*k_guard {
97        let data = cpu.as_slice::<T>()?;
98        &data[k_layout.start_offset()..]
99    } else {
100        return Err(candle::Error::Msg("Expected CPU storage for k".into()));
101    };
102    let (v_guard, v_layout) = v.storage_and_layout();
103    let v_data: &[T] = if let Storage::Cpu(cpu) = &*v_guard {
104        let data = cpu.as_slice::<T>()?;
105        &data[v_layout.start_offset()..]
106    } else {
107        return Err(candle::Error::Msg("Expected CPU storage for v".into()));
108    };
109    let mask_guard = mask.map(|mask| mask.storage_and_layout().0);
110    let mask_data: Option<&[T]> = if let Some(mask_guard) = &mask_guard {
111        let mask = mask.as_ref().unwrap();
112
113        if let Storage::Cpu(cpu) = &**mask_guard {
114            let data = cpu.as_slice::<T>()?;
115            Some(&data[mask.layout().start_offset()..])
116        } else {
117            return Err(candle::Error::Msg("Expected CPU storage for mask".into()));
118        }
119    } else {
120        None
121    };
122    // q_guard, k_guard, v_guard, and m_guard (if any) are kept in scope to hold storage alive
123
124    let q_stride = q.stride();
125    let k_stride = k.stride();
126    let v_stride = v.stride();
127
128    // Fast path for decode: q_len == 1
129    if q.shape().dims()[1] == 1 {
130        return flash_attn_cpu_single_q(
131            q_data,
132            k_data,
133            v_data,
134            mask_data,
135            q.shape().dims(),
136            k.shape().dims(),
137            v.shape().dims(),
138            q_stride,
139            k_stride,
140            v_stride,
141            softmax_scale,
142            max_bias.unwrap_or(0.0),
143            softcap.unwrap_or(0.0),
144        );
145    }
146
147    flash_attn_cpu(
148        q_data,
149        k_data,
150        v_data,
151        mask_data,
152        q.shape().dims(),
153        k.shape().dims(),
154        v.shape().dims(),
155        q_stride,
156        k_stride,
157        v_stride,
158        softmax_scale,
159        max_bias.unwrap_or(0.0),
160        softcap.unwrap_or(0.0),
161    )
162}
163
164/// Optimised path for the common decode case: q_len == 1 but kv_len ≫ 1.
165/// We drop the inner q‑position loop and parallelise over `(batch, head)`.
166#[allow(clippy::too_many_arguments)]
167fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
168    q_data: &[T],
169    k_data: &[T],
170    v_data: &[T],
171    mask_vec: Option<&[T]>,
172    qshape: &[usize],
173    kshape: &[usize],
174    vshape: &[usize],
175    qstride: &[usize],
176    kstride: &[usize],
177    vstride: &[usize],
178    scale: f32,
179    max_bias: f32,
180    logit_softcap: f32,
181) -> Result<Tensor> {
182    // Shapes: (B, 1, H, D)
183    let (b, _q_len, h, d) = (
184        qshape[0], qshape[1], // == 1
185        qshape[2], qshape[3],
186    );
187    let kv_len = kshape[1];
188    let k_h = kshape[2];
189    let v_h = vshape[2];
190    let rk2 = h / k_h;
191    let rv2 = h / v_h;
192    let dv = d;
193
194    let n2 = 2_usize.pow((h as f32).log2().ceil() as u32);
195
196    // Output buffer: (B, H, 1, D)
197    let mut out = vec![0f32; b * h * dv];
198
199    // Expose a second dimension of work: split the KV axis into tiles that
200    // fit in the last‑level cache and let Rayon schedule them.
201    let kv_tiles = kv_len.div_ceil(TILE_KV);
202
203    // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut slices, so no two
204    // threads write the same output area.
205    FLASH_ATTN_POOL.install(|| {
206        out.par_chunks_mut(dv)
207            .with_min_len(64)
208            .enumerate()
209            .for_each(|(row_idx, out_chunk)| {
210                let b_i = row_idx / h;
211                let h_i = row_idx % h;
212
213                // ALiBi positional bias (standard formula)
214                let slope = if max_bias > 0.0 {
215                    2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32)
216                } else {
217                    1.0
218                };
219
220                // For grouped‑KV we collapse multiple query heads into the same K/V head.
221                let k_head = h_i / rk2;
222                let v_head = h_i / rv2;
223
224                // ------------------------------------------------------------------
225                // Nested parallelism: each KV tile is mapped independently, then we
226                // reduce the partial results with the correct soft‑max algebra.
227                // ------------------------------------------------------------------
228                let (vkq, s_tot, _m_tot) = (0..kv_tiles)
229                    .into_par_iter()
230                    .map(|tile_idx| {
231                        // ---- per‑tile scratch -------------------------------------------------
232                        let start = tile_idx * TILE_KV;
233                        let end = (start + TILE_KV).min(kv_len);
234
235                        let mut vkq = vec![0f32; dv];
236                        let mut s = 0.0f32;
237                        let mut m = f32::NEG_INFINITY;
238
239                        // ---------------- single‑Q row (already contiguous) -------------------
240                        let q_base =
241                            b_i * qstride[0] /*batch*/ + h_i * qstride[2] /*head*/;
242                        let q_row = &q_data[q_base..q_base + d];
243
244                        // ---------------- iterate over this KV slice --------------------------
245                        for kv_pos in start..end {
246                            // Mask
247                            let mv = if let Some(mv_vec) = mask_vec {
248                                let mval = mv_vec[(b_i * kv_len) + kv_pos];
249                                slope * mval.to_f64() as f32
250                            } else {
251                                0.0
252                            };
253                            if mv == f32::NEG_INFINITY {
254                                continue;
255                            }
256
257                            // K row
258                            let k_base =
259                                b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2];
260                            let k_row = &k_data[k_base..k_base + d];
261
262                            // dot(Q, K)
263                            let mut s_val = vec_dot::<T>(q_row, k_row).to_f64() as f32;
264
265                            let mut scale_applied = scale;
266                            if logit_softcap != 0.0 {
267                                scale_applied /= logit_softcap;
268                            }
269                            s_val *= scale_applied;
270                            if logit_softcap != 0.0 {
271                                s_val = logit_softcap * s_val.tanh();
272                            }
273                            s_val += mv;
274
275                            // Tile‑local online softmax ------------------------------------------
276                            let m_old = m;
277                            let mut ms = 1.0f32;
278                            let mut vs = 1.0f32;
279                            if s_val > m {
280                                m = s_val;
281                                ms = (m_old - m).exp();
282                                for v in vkq.iter_mut() {
283                                    *v *= ms;
284                                }
285                            } else {
286                                vs = (s_val - m).exp();
287                            }
288
289                            // V row
290                            let v_base =
291                                b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2];
292                            for d_i in 0..dv {
293                                vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs;
294                            }
295
296                            s = s * ms + vs;
297                        }
298
299                        // Return per‑tile accumulator + softmax stats
300                        (vkq, s, m)
301                    })
302                    // -------- reduce two tiles -----------------------------------------------
303                    .reduce(
304                        || (vec![0f32; dv], 0.0f32, f32::NEG_INFINITY),
305                        |mut a, b| {
306                            let (ref mut vkq_a, mut s_a, m_a) = a;
307                            let (vkq_b, s_b, m_b) = b;
308                            if m_a >= m_b {
309                                let factor = (m_b - m_a).exp();
310                                for (va, vb) in vkq_a.iter_mut().zip(vkq_b) {
311                                    *va += vb * factor;
312                                }
313                                s_a += s_b * factor;
314                                (vkq_a.clone(), s_a, m_a)
315                            } else {
316                                let factor = (m_a - m_b).exp();
317                                let mut vkq_new = vkq_b;
318                                for (vb, va) in vkq_new.iter_mut().zip(vkq_a) {
319                                    *vb += *va * factor;
320                                }
321                                (vkq_new, s_b + s_a * factor, m_b)
322                            }
323                        },
324                    );
325
326                // ---------------- final normalisation ---------------------------------------
327                let inv_s = 1.0 / s_tot;
328                for v in out_chunk.iter_mut().zip(vkq.iter()) {
329                    *v.0 = *v.1 * inv_s;
330                }
331            });
332    });
333
334    let out_shape = (b, h, 1usize, dv);
335    Tensor::from_vec(out, out_shape, &Device::Cpu)
336}
337
338/// Main forward flash-attention CPU routine.
339/// Shapes follow Candle convention: (B, S, H, D)
340#[allow(clippy::too_many_arguments)]
341fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
342    q_data: &[T],
343    k_data: &[T],
344    v_data: &[T],
345    mask_vec: Option<&[T]>,
346    qshape: &[usize],
347    kshape: &[usize],
348    vshape: &[usize],
349    qstride: &[usize],
350    kstride: &[usize],
351    vstride: &[usize],
352    scale: f32,
353    max_bias: f32,
354    logit_softcap: f32,
355) -> Result<Tensor> {
356    let (b, q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]);
357    let kv_len = kshape[1];
358    // --- Head broadcasting factors ----------------------------------------------------
359    // Allows K and V to have fewer heads than Q (grouped‑KV); the ratio is an
360    // integer factor.  rk2 = #Q‑heads / #K‑heads,  rv2 = #Q‑heads / #V‑heads.
361    let k_h = kshape[2];
362    let v_h = vshape[2];
363    let rk2 = h / k_h; // must divide exactly; panic otherwise
364    let rv2 = h / v_h;
365    let dv = d; // value dim = key dim in this kernel
366
367    // Precompute value for ALiBi slope calculation
368    let n2 = 2_usize.pow((h as f32).log2().ceil() as u32);
369
370    let mut out = vec![0f32; b * q_len * h * dv];
371
372    // ------------------------------------------------------------------
373    // Rayon‑parallel version: each (b_i, h_i, q_pos) row is independent.
374    // ------------------------------------------------------------------
375
376    let _rows = b * h * q_len; // total independent work items
377
378    // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut [f32] slices,
379    // so no two threads can write the same output area.
380    FLASH_ATTN_POOL.install(|| {
381        out.par_chunks_mut(dv)
382            .with_min_len(64)
383            .enumerate()
384            .for_each(|(row_idx, out_chunk)| {
385                // Decode flat index back to (batch, head, q_pos)
386                let rows_per_batch = h * q_len;
387                let b_i = row_idx / rows_per_batch;
388                let rem = row_idx % rows_per_batch;
389                let h_i = rem / q_len;
390                let q_pos = rem % q_len;
391
392                let slope = if max_bias > 0.0 {
393                    2.0f32.powf(-max_bias * ((h_i + 1) as f32) / n2 as f32)
394                } else {
395                    1.0
396                };
397
398                // For grouped‑KV we collapse multiple query heads into the same K/V head.
399                let k_head = h_i / rk2;
400                let v_head = h_i / rv2;
401
402                // Buffers local to this row
403                let mut vkq = vec![0f32; dv];
404                let mut s = 0.0f32;
405                let mut m = f32::NEG_INFINITY;
406
407                // Allocate q_row and k_row once per row
408                let mut q_row: Vec<T> = Vec::with_capacity(d);
409                let mut k_row: Vec<T> = Vec::with_capacity(d);
410
411                // ------------------- gather Q (strided) --------------------
412                let q_base = b_i * qstride[0] + q_pos * qstride[1] + h_i * qstride[2];
413                q_row.clear();
414                for di in 0..d {
415                    q_row.push(q_data[q_base + di * qstride[3]]);
416                }
417
418                // ---------------- iterate over keys/values -----------------
419                for kv_pos in 0..kv_len {
420                    // Mask (optional)
421                    let mv = if let Some(mv_vec) = mask_vec {
422                        let mval = mv_vec[((b_i * q_len + q_pos) * kv_len) + kv_pos];
423                        slope * mval.to_f64() as f32
424                    } else {
425                        0.0
426                    };
427                    if mv == f32::NEG_INFINITY {
428                        continue;
429                    }
430
431                    // K row (strided)
432                    let k_base = b_i * kstride[0] + kv_pos * kstride[1] + k_head * kstride[2];
433                    k_row.clear();
434                    for di in 0..d {
435                        k_row.push(k_data[k_base + di * kstride[3]]);
436                    }
437
438                    // dot(Q, K)
439                    let mut s_val = vec_dot::<T>(&q_row, &k_row);
440                    let mut scale_applied = scale;
441                    if logit_softcap != 0.0 {
442                        scale_applied /= logit_softcap;
443                    }
444                    s_val *= T::from_f64(scale_applied as f64);
445                    if logit_softcap != 0.0 {
446                        s_val = T::from_f64(logit_softcap as f64 * s_val.to_f64().tanh());
447                    }
448                    s_val += T::from_f64(mv as f64);
449
450                    // online softmax
451                    let m_old = m;
452                    let mut ms = 1.0f32;
453                    let mut vs = 1.0f32;
454                    if s_val.to_f64() as f32 > m {
455                        m = s_val.to_f64() as f32;
456                        ms = (m_old - m).exp();
457                        for v in vkq.iter_mut() {
458                            *v *= ms;
459                        }
460                    } else {
461                        vs = (s_val.to_f64() as f32 - m).exp();
462                    }
463
464                    // V row (strided)
465                    let v_base = b_i * vstride[0] + kv_pos * vstride[1] + v_head * vstride[2];
466                    for d_i in 0..dv {
467                        vkq[d_i] += v_data[v_base + d_i * vstride[3]].to_f64() as f32 * vs;
468                    }
469
470                    s = s * ms + vs;
471                }
472
473                // ------------------- normalise & write out ------------------
474                let inv_s = 1.0 / s;
475                for v in vkq.iter_mut() {
476                    *v *= inv_s;
477                }
478                out_chunk.copy_from_slice(&vkq);
479            });
480    });
481
482    // Build output tensor with shape (B, H, S, D) to match standard (permute 0,2,1,3)
483    let out_shape = (b, h, q_len, dv);
484    Tensor::from_vec(out, out_shape, &Device::Cpu)
485}