mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
use std::collections::HashMap;

use candle_core::{DType, Device, Result, Tensor};
#[allow(unused_imports)]
use mistralrs_paged_attn::{kv_scale_update, paged_attention, reshape_and_cache};

const KV_SCALE_UPDATE_ITERATION: i32 = 128;
use std::sync::atomic::{AtomicI32, Ordering};

use crate::{
    attention::SdpaParams,
    layers::Sdpa,
    paged_attention::_PAD_SLOT_ID,
    pipeline::text_models_inputs_processor::{
        FlashKMeta, FlashParams, PagedAttentionInputMetadata,
    },
};

fn resolve_tensor_for_device(
    tensors: &HashMap<candle_core::DeviceLocation, Tensor>,
    device: &Device,
    what: &str,
) -> Result<Tensor> {
    if let Some(tensor) = tensors.get(&device.location()) {
        return Ok(tensor.clone());
    }
    if let Some(tensor) = tensors.values().next() {
        return tensor.to_device(device);
    }
    candle_core::bail!("Missing {what} tensor for {:?}", device.location())
}

fn cumulative_seqlens_from_lengths(lengths: &[usize], device: &Device) -> Result<Tensor> {
    let mut cumulative = Vec::with_capacity(lengths.len() + 1);
    cumulative.push(0u32);
    for &len in lengths {
        cumulative.push(cumulative.last().copied().unwrap_or(0) + len as u32);
    }
    Tensor::new(&cumulative[..], &Device::Cpu)?.to_device(device)
}

fn new_token_lens_from_slot_mapping(
    slot_mapping: &Tensor,
    batch_size: usize,
    seq_len: usize,
) -> Result<Vec<usize>> {
    let slot_mapping_cpu = slot_mapping.to_device(&Device::Cpu)?;
    let slot_mapping_cpu = if slot_mapping_cpu.dims().len() == 2 {
        slot_mapping_cpu
    } else {
        slot_mapping_cpu.reshape((batch_size, seq_len))?
    };
    Ok(slot_mapping_cpu
        .to_vec2::<i64>()?
        .into_iter()
        .map(|row| row.into_iter().filter(|&slot| slot != _PAD_SLOT_ID).count())
        .collect())
}

fn unpack_gathered_kv(
    packed: &Tensor,
    kv_lens: &[usize],
    num_kv_heads: usize,
    head_size: usize,
    device: &Device,
) -> Result<Tensor> {
    let max_kv = kv_lens.iter().copied().max().unwrap_or(0);
    let mut start = 0;
    let mut unpacked = Vec::with_capacity(kv_lens.len());

    for &kv_len in kv_lens {
        let mut seq = packed
            .narrow(0, start, kv_len)?
            .transpose(0, 1)?
            .unsqueeze(0)?;
        if kv_len < max_kv {
            let pad = Tensor::zeros(
                (1, num_kv_heads, max_kv - kv_len, head_size),
                packed.dtype(),
                device,
            )?;
            seq = Tensor::cat(&[&seq, &pad], 2)?;
        }
        unpacked.push(seq);
        start += kv_len;
    }

    Tensor::cat(&unpacked, 0)
}

fn adjust_kv_mask(mask: &Tensor, kv_seq_len: usize) -> Result<Tensor> {
    let mask_dims = mask.dims();
    match mask.rank() {
        2 if mask_dims[1] > kv_seq_len => mask.narrow(1, 0, kv_seq_len),
        3 if mask_dims[2] > kv_seq_len => mask.narrow(2, 0, kv_seq_len),
        4 if mask_dims[3] > kv_seq_len => mask.narrow(3, 0, kv_seq_len),
        _ => Ok(mask.clone()),
    }
}

fn supports_packed_varlen_sdpa(query: &Tensor) -> bool {
    query.device().is_cpu()
        || (query.device().is_cuda() && crate::using_flash_attn() && query.dtype() != DType::F32)
}

pub struct PagedAttention {
    alibi_slopes: Option<Tensor>,
    k_scale: Option<Tensor>,
    v_scale: Option<Tensor>,
    kv_updated_times: AtomicI32,
}

impl PagedAttention {
    pub fn new(head_dim: usize, device: &Device, alibi_slopes: Option<Vec<f32>>) -> Result<Self> {
        let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes {
            assert_eq!(alibi_slopes.len(), head_dim);
            Some(Tensor::new(alibi_slopes, device)?)
        } else {
            None
        };
        Ok(Self {
            alibi_slopes,
            k_scale: Some(Tensor::new(1f32, device)?),
            v_scale: Some(Tensor::new(1f32, device)?),
            kv_updated_times: AtomicI32::new(0),
        })
    }

    #[allow(
        clippy::too_many_arguments,
        clippy::cast_possible_truncation,
        unused_variables
    )]
    fn forward_impl(
        &self,
        query: &Tensor,
        key: &Tensor,
        value: &Tensor,
        attention_mask: Option<&Tensor>,
        mut key_cache: Option<Tensor>,
        mut value_cache: Option<Tensor>,
        input_metadata: &PagedAttentionInputMetadata,
        sdpa_params: &SdpaParams,
        flash_params: Option<&FlashParams>,
        write_cache: bool,
    ) -> Result<Tensor> {
        if write_cache {
            if let (Some(k_scale), Some(v_scale), Some(key_cache)) =
                (&self.k_scale, &self.v_scale, &key_cache)
            {
                if self.kv_updated_times.load(Ordering::Relaxed) < KV_SCALE_UPDATE_ITERATION
                    && key_cache.dtype() == DType::F8E4M3
                {
                    kv_scale_update(key, value, k_scale, v_scale)?;
                    self.kv_updated_times.fetch_add(1, Ordering::Relaxed);
                }
            }
        }

        let slot_mapping_full = input_metadata
            .slot_mappings
            .get(&query.device().location())
            .unwrap();
        let dims = slot_mapping_full.dims();
        let slot_mapping = if dims.len() > 1 {
            &slot_mapping_full.flatten(0, dims.len())?
        } else {
            slot_mapping_full
        };

        // For models with per-layer sliding windows (GPT-OSS, Gemma2):
        // - Full-attention layers (sliding_window == None) use the full block tables.
        // - Sliding-window layers (sliding_window == Some) use the windowed block tables.
        // If full_block_tables is not populated, fall back to the regular block_tables.
        let use_full =
            sdpa_params.sliding_window.is_none() && input_metadata.full_block_tables.is_some();

        let resolve_block_tables = |dev: &candle_core::DeviceLocation| -> Option<&Tensor> {
            if use_full {
                input_metadata.full_block_tables.as_ref()?.get(dev)
            } else {
                input_metadata.block_tables.as_ref()?.get(dev)
            }
        };
        let resolve_context_lens = |dev: &candle_core::DeviceLocation| -> Option<&Tensor> {
            if use_full {
                input_metadata.full_context_lens.as_ref()?.get(dev)
            } else {
                input_metadata.context_lens.as_ref()?.get(dev)
            }
        };

        let alibi_slopes = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() {
            Some(alibi_slopes.to_device(query.device())?)
        } else {
            None
        };

        let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?;
        let (_, key_value_heads, _, _) = key.shape().dims4()?;

        // === Prefix cache / donor-gather prompt path ===
        // Entered when:
        //  - write_cache=true  AND num_cached_tokens is set (prefix cache hit)
        //  - write_cache=false AND attention_mask is set  (donor cache prompt)
        // The gather path needs block_tables. During calibration forwards
        // there is no paged cache, so block_tables is None — skip to the
        // regular prompt path.
        let has_block_tables = input_metadata.block_tables.is_some();
        let use_gather_path = if write_cache {
            input_metadata.num_cached_tokens.is_some()
                && attention_mask.is_some()
                && has_block_tables
        } else {
            attention_mask.is_some() && has_block_tables
        };

        if use_gather_path {
            let block_tables = resolve_block_tables(&query.device().location()).unwrap();
            let context_lens = resolve_context_lens(&query.device().location()).unwrap();
            // Write new tokens to cache (skipped for donor/shared layers)
            if write_cache && key_cache.as_ref().is_some_and(|_| value_cache.is_some()) {
                let k_flat = key
                    .transpose(1, 2)?
                    .reshape(((), key_value_heads, head_size))?;
                let v_flat = value
                    .transpose(1, 2)?
                    .reshape(((), key_value_heads, head_size))?;
                reshape_and_cache(
                    &k_flat,
                    &v_flat,
                    self.k_scale.as_ref(),
                    self.v_scale.as_ref(),
                    key_cache.as_mut().unwrap(),
                    value_cache.as_mut().unwrap(),
                    slot_mapping,
                )?;
            }

            assert!(
                alibi_slopes.is_none(),
                "alibi slopes not supported in prefix cache path"
            );

            let device = query.device();

            let new_token_lens =
                new_token_lens_from_slot_mapping(slot_mapping_full, batch_size, seq_len)?;
            let query_lens = input_metadata
                .query_lens
                .clone()
                .unwrap_or_else(|| new_token_lens.clone());
            let kv_lens = if let Some(num_cached_tokens) = input_metadata.num_cached_tokens.as_ref()
            {
                num_cached_tokens
                    .iter()
                    .zip(query_lens.iter())
                    .map(|(&cached, &query_len)| cached + query_len)
                    .collect::<Vec<_>>()
            } else {
                new_token_lens.clone()
            };

            // Resolve cu_seqlens_kv: scheduler-provided for prefix cache hits,
            // or synthesize it from the actual slot-mapping lengths on first prompt.
            let cu_kv = if let Some(map) = input_metadata.cu_seqlens_kv.as_ref() {
                resolve_tensor_for_device(map, device, "cu_seqlens_kv")?
            } else {
                cumulative_seqlens_from_lengths(&kv_lens, device)?
            };

            // Gather all K/V from paged cache into contiguous tensors.
            let (k_gathered, v_gathered) = mistralrs_paged_attn::gather_kv_cache(
                key_cache.as_ref().unwrap(),
                value_cache.as_ref().unwrap(),
                self.k_scale.as_ref(),
                self.v_scale.as_ref(),
                block_tables,
                &cu_kv,
                query.dtype(),
            )?;

            if supports_packed_varlen_sdpa(query) {
                let cu_q = if let Some(fp) = flash_params {
                    if !fp.cumulative_seqlens_q.is_empty() {
                        resolve_tensor_for_device(
                            &fp.cumulative_seqlens_q,
                            device,
                            "cumulative_seqlens_q",
                        )?
                    } else {
                        cumulative_seqlens_from_lengths(&query_lens, device)?
                    }
                } else {
                    cumulative_seqlens_from_lengths(&query_lens, device)?
                };

                // gathered: (total_kv, kv_heads, dim) -> (1, kv_heads, total_kv, dim)
                let k_4d = k_gathered.unsqueeze(0)?.transpose(1, 2)?;
                let v_4d = v_gathered.unsqueeze(0)?.transpose(1, 2)?;

                let mut cu_q_map = HashMap::new();
                cu_q_map.insert(device.location(), cu_q);
                let mut cu_kv_map = HashMap::new();
                cu_kv_map.insert(device.location(), cu_kv);
                let prefix_flash_params = FlashParams {
                    max_q: query_lens.iter().copied().max().unwrap_or(0) as u32,
                    cumulative_seqlens_q: cu_q_map,
                    logical_k: FlashKMeta {
                        max: kv_lens.iter().copied().max().unwrap_or(0) as u32,
                        cumulative_seqlens: cu_kv_map,
                    },
                    sliding_k: None,
                    causal: flash_params.map_or(attention_mask.is_some(), |fp| fp.causal),
                };

                return Sdpa.run_attention(
                    query,
                    &k_4d,
                    &v_4d,
                    attention_mask,
                    Some(&prefix_flash_params),
                    sdpa_params,
                );
            }

            let max_kv = kv_lens.iter().copied().max().unwrap_or(0);
            let k_batched =
                unpack_gathered_kv(&k_gathered, &kv_lens, key_value_heads, head_size, device)?;
            let v_batched =
                unpack_gathered_kv(&v_gathered, &kv_lens, key_value_heads, head_size, device)?;
            let adjusted_mask = attention_mask
                .map(|mask| adjust_kv_mask(mask, max_kv))
                .transpose()?;

            return Sdpa.run_attention(
                query,
                &k_batched,
                &v_batched,
                adjusted_mask.as_ref(),
                None,
                sdpa_params,
            );
        }

        // === Regular prompt path (no prefix cache, write_cache=true only) ===
        #[allow(clippy::cast_possible_truncation)]
        let att = match attention_mask {
            None => None,
            Some(mask) => Some(Sdpa.run_attention(
                query,
                key,
                value,
                Some(mask),
                flash_params,
                sdpa_params,
            )?),
        };

        // paged-attn expects [batch_size, num_tokens, num_heads, head_size]
        let (query, key, value) = if seq_len > 1 {
            let q = query
                .transpose(1, 2)?
                .reshape(((), attention_heads, head_size))?;
            let k = key
                .transpose(1, 2)?
                .reshape(((), key_value_heads, head_size))?;
            let v = value
                .transpose(1, 2)?
                .reshape(((), key_value_heads, head_size))?;
            (q, k, v)
        } else {
            // avoid unnecessary transpose for decoding
            let q = query.reshape(((), attention_heads, head_size))?;
            let k = key.reshape(((), key_value_heads, head_size))?;
            let v = value.reshape(((), key_value_heads, head_size))?;
            (q, k, v)
        };

        if write_cache && key_cache.as_ref().is_some_and(|_| value_cache.is_some()) {
            reshape_and_cache(
                &key,
                &value,
                self.k_scale.as_ref(),
                self.v_scale.as_ref(),
                key_cache.as_mut().unwrap(),
                value_cache.as_mut().unwrap(),
                slot_mapping,
            )?;
        }

        if let Some(att) = att {
            // Return result in prefill or first prefix chunk
            return Ok(att);
        }

        // === Decode path ===
        #[allow(clippy::cast_possible_truncation)]
        let dev = query.device().location();
        let res = paged_attention(
            &query,
            self.k_scale.as_ref(),
            self.v_scale.as_ref(),
            key_cache.as_ref().unwrap(),
            value_cache.as_ref().unwrap(),
            resolve_block_tables(&dev).unwrap(),
            resolve_context_lens(&dev).unwrap(),
            alibi_slopes.as_ref(),
            if use_full {
                input_metadata.full_max_context_len.unwrap()
            } else {
                input_metadata.max_context_len.unwrap()
            },
            sdpa_params.softmax_scale,
            sdpa_params.softcap.unwrap_or(1.0f32),
            sdpa_params.sinks.as_ref(),
        )?;

        Ok(res)
    }

    /// Standard paged attention forward: writes key/value to cache, then
    /// runs attention (Sdpa for prompt, paged kernel for decode).
    #[allow(clippy::too_many_arguments)]
    pub fn forward(
        &self,
        query: &Tensor,
        key: &Tensor,
        value: &Tensor,
        attention_mask: Option<&Tensor>,
        key_cache: Option<Tensor>,
        value_cache: Option<Tensor>,
        input_metadata: &PagedAttentionInputMetadata,
        sdpa_params: &SdpaParams,
        flash_params: Option<&FlashParams>,
    ) -> Result<Tensor> {
        self.forward_impl(
            query,
            key,
            value,
            attention_mask,
            key_cache,
            value_cache,
            input_metadata,
            sdpa_params,
            flash_params,
            true,
        )
    }

    /// Read-only paged attention against a donor layer's cache. Identical to
    /// [`forward`] but never calls `reshape_and_cache`, the donor layer has
    /// already written its K,V.  On prompt the donor's cached K,V are
    /// gathered; on decode the paged-attention kernel reads them directly.
    #[allow(clippy::too_many_arguments)]
    pub fn forward_donor_cache(
        &self,
        query: &Tensor,
        key_cache: &Tensor,
        value_cache: &Tensor,
        attention_mask: Option<&Tensor>,
        input_metadata: &PagedAttentionInputMetadata,
        sdpa_params: &SdpaParams,
        flash_params: Option<&FlashParams>,
    ) -> Result<Tensor> {
        // key/value are unused (donor's cache already has them), but
        // forward_impl needs tensors for shape queries. Reuse query as
        // a placeholder, reshape_and_cache is skipped so they're never read.
        self.forward_impl(
            query,
            query,
            query,
            attention_mask,
            Some(key_cache.clone()),
            Some(value_cache.clone()),
            input_metadata,
            sdpa_params,
            flash_params,
            false,
        )
    }
}