mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
//! MLA forward pass functions for decode and cache operations.

use candle_core::{Device, Result, Tensor};

use crate::{
    attention::SdpaParams,
    pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
};

use super::MlaWeights;

#[cfg(all(feature = "cuda", target_family = "unix"))]
use candle_core::{DType, D};

#[cfg(all(feature = "cuda", target_family = "unix"))]
use crate::layers::Sdpa;

#[cfg(all(feature = "cuda", target_family = "unix"))]
use crate::ops::SplitOp;

/// Environment variable to disable MLA optimization.
#[cfg(all(feature = "cuda", target_family = "unix"))]
const MISTRALRS_NO_MLA: &str = "MISTRALRS_NO_MLA";

/// Check if MLA is disabled via environment variable.
#[cfg(all(feature = "cuda", target_family = "unix"))]
fn is_mla_disabled() -> bool {
    std::env::var(MISTRALRS_NO_MLA).is_ok_and(|x| x == "1")
}

/// Check if MLA decode should be used for single-token generation.
///
/// MLA decode is used when:
/// - `MISTRALRS_NO_MLA` is not set to "1"
/// - No attention mask (single-token decode)
/// - Sequence length is 1
/// - Paged attention is enabled
/// - Running on CUDA
/// - Paged KV indptr metadata is available
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub fn should_use_mla_decode(
    attention_mask: Option<&Tensor>,
    seq_len: usize,
    paged_attn_enabled: bool,
    device: &Device,
    metadata: &Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
) -> bool {
    !is_mla_disabled()
        && attention_mask.is_none()
        && seq_len == 1
        && paged_attn_enabled
        && matches!(device, Device::Cuda(_))
        && metadata
            .as_ref()
            .and_then(|(_, meta)| meta.paged_kv_indptr.as_ref())
            .is_some()
}

#[cfg(not(all(feature = "cuda", target_family = "unix")))]
pub fn should_use_mla_decode(
    _attention_mask: Option<&Tensor>,
    _seq_len: usize,
    _paged_attn_enabled: bool,
    _device: &Device,
    _metadata: &Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
) -> bool {
    false
}

/// Check if MLA cache forward should be used for prefill with prefix caching.
///
/// MLA cache is used when:
/// - `MISTRALRS_NO_MLA` is not set to "1"
/// - Paged attention is enabled
/// - Running on CUDA
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub fn should_use_mla_cache(paged_attn_enabled: bool, device: &Device) -> bool {
    !is_mla_disabled() && paged_attn_enabled && matches!(device, Device::Cuda(_))
}

#[cfg(not(all(feature = "cuda", target_family = "unix")))]
pub fn should_use_mla_cache(_paged_attn_enabled: bool, _device: &Device) -> bool {
    false
}

/// MLA decode forward pass for single-token generation.
///
/// Uses FlashInfer MLA kernels for efficient decode with latent attention.
///
/// # Arguments
/// * `q_nope` - Non-positional query tensor [bs, num_heads, seq_len, qk_nope_head_dim]
/// * `q_pe` - Positional query tensor [bs, num_heads, seq_len, qk_rope_head_dim]
/// * `ckv` - Compressed KV tensor [bs, seq_len, kv_lora_rank]
/// * `k_pe` - Positional key tensor [bs, 1, seq_len, qk_rope_head_dim]
/// * `metadata` - PagedAttention cache and metadata
/// * `mla_weights` - Cached MLA weight matrices
/// * `kv_b_proj` - KV B projection layer for weight computation
/// * `sdpa_params` - SDPA parameters including softmax scale
/// * `num_attention_heads` - Number of attention heads
/// * `kv_lora_rank` - KV latent dimension
/// * `qk_rope_head_dim` - Rope head dimension
/// * `qk_nope_head_dim` - Non-positional head dimension
/// * `v_head_dim` - Value head dimension
/// * `bs` - Batch size
/// * `seq_len` - Sequence length (should be 1)
#[cfg(all(feature = "cuda", target_family = "unix"))]
#[allow(clippy::too_many_arguments)]
pub fn mla_decode_forward(
    q_nope: &Tensor,
    q_pe: &Tensor,
    ckv: &Tensor,
    k_pe: &Tensor,
    metadata: &Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
    mla_weights: &MlaWeights,
    kv_b_proj: &dyn mistralrs_quant::QuantMethod,
    sdpa_params: &SdpaParams,
    num_attention_heads: usize,
    kv_lora_rank: usize,
    qk_rope_head_dim: usize,
    qk_nope_head_dim: usize,
    v_head_dim: usize,
    bs: usize,
    seq_len: usize,
) -> Result<Tensor> {
    let ((key_cache, value_cache), input_metadata) = metadata
        .as_ref()
        .ok_or_else(|| candle_core::Error::msg("paged attention metadata missing"))?;
    let device_location = q_nope.device().location();
    let slot_mapping = input_metadata
        .slot_mappings
        .get(&device_location)
        .ok_or_else(|| candle_core::Error::msg("slot mapping missing"))?;
    let slot_mapping = if slot_mapping.dims().len() > 1 {
        slot_mapping.flatten(0, slot_mapping.dims().len())?
    } else {
        slot_mapping.clone()
    };
    let paged_kv_indptr = input_metadata
        .paged_kv_indptr
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_indptr missing"))?;
    let paged_kv_indices = input_metadata
        .paged_kv_indices
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_indices missing"))?;
    let paged_kv_last_page_len = input_metadata
        .paged_kv_last_page_len
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_last_page_len missing"))?;
    let paged_kv_request_indices = input_metadata
        .paged_kv_request_indices
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_request_indices missing"))?;
    let paged_kv_tile_indices = input_metadata
        .paged_kv_tile_indices
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_tile_indices missing"))?;
    let paged_kv_o_indptr = input_metadata
        .paged_kv_o_indptr
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_o_indptr missing"))?;
    let paged_kv_chunk_size = input_metadata
        .paged_kv_chunk_size
        .as_ref()
        .and_then(|m| m.get(&device_location))
        .ok_or_else(|| candle_core::Error::msg("paged_kv_chunk_size missing"))?;

    let ckv_flat = ckv.contiguous()?.reshape((bs * seq_len, kv_lora_rank))?;
    let k_pe_flat = k_pe
        .squeeze(1)?
        .contiguous()?
        .reshape((bs * seq_len, qk_rope_head_dim))?;

    mistralrs_paged_attn::concat_and_cache_mla(
        &ckv_flat,
        &k_pe_flat,
        key_cache,
        value_cache,
        &slot_mapping,
    )?;

    let q_nope = q_nope.squeeze(2)?.contiguous()?;
    let q_pe = q_pe.squeeze(2)?.contiguous()?;
    let (w_uk, w_uv_t) = mla_weights.get_or_compute(
        kv_b_proj,
        q_nope.device(),
        num_attention_heads,
        kv_lora_rank,
        qk_nope_head_dim,
        v_head_dim,
    )?;
    let ql_nope = q_nope
        .unsqueeze(D::Minus2)?
        .broadcast_matmul(&w_uk.unsqueeze(0)?)?
        .squeeze(D::Minus2)?
        .contiguous()?;

    let attn_latent = mistralrs_paged_attn::flashinfer_mla_decode(
        &ql_nope,
        &q_pe,
        key_cache,
        value_cache,
        paged_kv_indptr,
        paged_kv_indices,
        paged_kv_last_page_len,
        paged_kv_request_indices,
        paged_kv_tile_indices,
        paged_kv_o_indptr,
        paged_kv_chunk_size,
        sdpa_params.softmax_scale,
    )?;

    attn_latent
        .unsqueeze(D::Minus2)?
        .broadcast_matmul(&w_uv_t.unsqueeze(0)?)?
        .squeeze(D::Minus2)
}

#[cfg(not(all(feature = "cuda", target_family = "unix")))]
#[allow(clippy::too_many_arguments)]
pub fn mla_decode_forward(
    _q_nope: &Tensor,
    _q_pe: &Tensor,
    _ckv: &Tensor,
    _k_pe: &Tensor,
    _metadata: &Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
    _mla_weights: &MlaWeights,
    _kv_b_proj: &dyn mistralrs_quant::QuantMethod,
    _sdpa_params: &SdpaParams,
    _num_attention_heads: usize,
    _kv_lora_rank: usize,
    _qk_rope_head_dim: usize,
    _qk_nope_head_dim: usize,
    _v_head_dim: usize,
    _bs: usize,
    _seq_len: usize,
) -> Result<Tensor> {
    candle_core::bail!("MLA decode requires CUDA support")
}

/// MLA cache forward pass for prefill with prefix caching support.
///
/// Handles both fresh prefill and prefix-cached scenarios.
///
/// # Arguments
/// * `q` - Query tensor [bs, num_heads, seq_len, head_dim]
/// * `k` - Key tensor [bs, num_heads, seq_len, head_dim]
/// * `v` - Value tensor [bs, num_heads, seq_len, v_head_dim]
/// * `ckv` - Compressed KV tensor [bs, seq_len, kv_lora_rank]
/// * `k_pe` - Positional key tensor [bs, 1, seq_len, qk_rope_head_dim]
/// * `attention_mask` - Optional attention mask
/// * `seqlen_offsets` - Prefix lengths for each sequence
/// * `metadata` - PagedAttention cache and metadata
/// * `flash_params` - Flash attention parameters
/// * `kv_b_proj` - KV B projection layer
/// * `sdpa_params` - SDPA parameters
/// * `num_attention_heads` - Number of attention heads
/// * `kv_lora_rank` - KV latent dimension
/// * `qk_rope_head_dim` - Rope head dimension
/// * `qk_nope_head_dim` - Non-positional head dimension
/// * `v_head_dim` - Value head dimension
/// * `bs` - Batch size
/// * `seq_len` - Sequence length
#[cfg(all(feature = "cuda", target_family = "unix"))]
#[allow(clippy::too_many_arguments)]
pub fn mla_cache_forward(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    ckv: &Tensor,
    k_pe: &Tensor,
    attention_mask: Option<&Tensor>,
    seqlen_offsets: &[usize],
    metadata: &Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
    flash_params: &FlashParams,
    kv_b_proj: &dyn mistralrs_quant::QuantMethod,
    sdpa_params: &SdpaParams,
    num_attention_heads: usize,
    kv_lora_rank: usize,
    qk_rope_head_dim: usize,
    qk_nope_head_dim: usize,
    v_head_dim: usize,
    bs: usize,
    seq_len: usize,
) -> Result<Tensor> {
    let mut key_cache = None;
    let mut value_cache = None;
    let mut input_metadata = None;
    if let Some(((key_cache_tensor, value_cache_tensor), meta)) = metadata.as_ref() {
        key_cache = Some(key_cache_tensor);
        value_cache = Some(value_cache_tensor);
        input_metadata = Some(*meta);

        let device_location = q.device().location();
        let slot_mapping = meta
            .slot_mappings
            .get(&device_location)
            .ok_or_else(|| candle_core::Error::msg("slot mapping missing"))?;
        let slot_mapping = if slot_mapping.dims().len() > 1 {
            slot_mapping.flatten(0, slot_mapping.dims().len())?
        } else {
            slot_mapping.clone()
        };
        let ckv_flat = ckv.contiguous()?.reshape((bs * seq_len, kv_lora_rank))?;
        let k_pe_flat = k_pe
            .squeeze(1)?
            .contiguous()?
            .reshape((bs * seq_len, qk_rope_head_dim))?;
        mistralrs_paged_attn::concat_and_cache_mla(
            &ckv_flat,
            &k_pe_flat,
            key_cache_tensor,
            value_cache_tensor,
            &slot_mapping,
        )?;
    }

    let prefix_lens = seqlen_offsets;
    let needs_prefix = prefix_lens.iter().any(|&len| len > 0);
    if !needs_prefix && attention_mask.is_some() {
        Sdpa.run_attention(q, k, v, attention_mask, Some(flash_params), sdpa_params)
    } else {
        let ((key_cache, value_cache), input_metadata) =
            match (key_cache, value_cache, input_metadata) {
                (Some(k), Some(v), Some(m)) => ((k, v), m),
                _ => {
                    return Sdpa.run_attention(
                        q,
                        k,
                        v,
                        attention_mask,
                        Some(flash_params),
                        sdpa_params,
                    );
                }
            };

        let device_location = q.device().location();
        let slot_mapping = input_metadata
            .slot_mappings
            .get(&device_location)
            .ok_or_else(|| candle_core::Error::msg("slot mapping missing"))?;
        let slot_mapping_cpu = slot_mapping.to_device(&Device::Cpu)?;
        let slot_mapping_cpu = if slot_mapping_cpu.dims().len() == 2 {
            slot_mapping_cpu
        } else {
            slot_mapping_cpu.reshape((bs, seq_len))?
        };
        let slot_mapping_vec = slot_mapping_cpu.to_vec2::<i64>()?;
        let cur_lens: Vec<usize> = slot_mapping_vec
            .iter()
            .map(|row| {
                row.iter()
                    .filter(|&&v| v != crate::paged_attention::_PAD_SLOT_ID)
                    .count()
            })
            .collect();

        let block_tables = input_metadata
            .block_tables
            .as_ref()
            .and_then(|m| m.get(&device_location))
            .ok_or_else(|| candle_core::Error::msg("block tables missing"))?;
        let block_tables_cpu = block_tables.to_device(&Device::Cpu)?;
        let (block_rows, block_stride) = block_tables_cpu.dims2()?;
        let block_tables_vec = block_tables_cpu.to_vec2::<u32>()?;
        let expected_rows: usize = cur_lens.iter().sum();
        let mut block_tables_seq: Vec<i32> = Vec::with_capacity(bs * block_stride);
        if block_rows == bs {
            for row in block_tables_vec.iter().take(bs) {
                block_tables_seq.extend(row.iter().map(|&v| v as i32));
            }
        } else if block_rows == expected_rows {
            let mut offset = 0usize;
            for len in cur_lens.iter() {
                if offset < block_tables_vec.len() {
                    block_tables_seq.extend(block_tables_vec[offset].iter().map(|&v| v as i32));
                } else {
                    block_tables_seq.extend(std::iter::repeat_n(0, block_stride));
                }
                offset = offset.saturating_add(*len);
            }
        } else {
            candle_core::bail!(
                "unexpected block_tables rows: got {block_rows}, expected {bs} or {expected_rows}"
            );
        }

        let block_tables_seq =
            Tensor::from_vec(block_tables_seq, (bs, block_stride), &Device::Cpu)?
                .to_device(q.device())?;

        let total_prefix_tokens: usize = prefix_lens.iter().sum();
        let mut k_prefix = None;
        let mut v_prefix = None;
        let mut prefix_offsets = Vec::with_capacity(bs + 1);
        prefix_offsets.push(0usize);
        for len in prefix_lens {
            let next = *prefix_offsets.last().unwrap() + *len;
            prefix_offsets.push(next);
        }

        if total_prefix_tokens > 0 {
            let mut cu_seq_lens = Vec::with_capacity(bs + 1);
            cu_seq_lens.push(0i32);
            let mut token_to_seq = Vec::with_capacity(total_prefix_tokens);
            #[allow(clippy::cast_possible_truncation)]
            for (seq_idx, len) in prefix_lens.iter().enumerate() {
                let next = *cu_seq_lens.last().unwrap() + *len as i32;
                cu_seq_lens.push(next);
                token_to_seq.extend(std::iter::repeat_n(seq_idx as i32, *len));
            }

            let cu_seq_lens =
                Tensor::from_vec(cu_seq_lens, (bs + 1,), &Device::Cpu)?.to_device(q.device())?;
            let token_to_seq =
                Tensor::from_vec(token_to_seq, (total_prefix_tokens,), &Device::Cpu)?
                    .to_device(q.device())?;

            let (ckv_prefix, kpe_prefix) = mistralrs_paged_attn::gather_mla_cache(
                key_cache,
                value_cache,
                &block_tables_seq,
                &cu_seq_lens,
                &token_to_seq,
            )?;

            let mut kv_prefix = kv_b_proj.forward_autocast(&ckv_prefix)?;
            kv_prefix = kv_prefix.reshape((
                total_prefix_tokens,
                num_attention_heads,
                qk_nope_head_dim + v_head_dim,
            ))?;
            let kv_prefix_split = kv_prefix.split(&[qk_nope_head_dim, v_head_dim], D::Minus1)?;
            let k_nope_prefix = kv_prefix_split[0].clone();
            let v_prefix_full = kv_prefix_split[1].clone();
            let kpe_prefix = kpe_prefix.unsqueeze(1)?.expand((
                total_prefix_tokens,
                num_attention_heads,
                qk_rope_head_dim,
            ))?;
            let k_prefix_full = Tensor::cat(&[&k_nope_prefix, &kpe_prefix], D::Minus1)?;
            k_prefix = Some(k_prefix_full);
            v_prefix = Some(v_prefix_full);
        }

        let mut outputs = Vec::with_capacity(bs);
        for (seq_idx, cur_len) in cur_lens.iter().enumerate() {
            let cur_len = (*cur_len).min(seq_len);
            if cur_len == 0 {
                outputs.push(Tensor::zeros(
                    (seq_len, num_attention_heads, v_head_dim),
                    q.dtype(),
                    q.device(),
                )?);
                continue;
            }
            let mut q_i = q.narrow(0, seq_idx, 1)?;
            let mut k_i = k.narrow(0, seq_idx, 1)?;
            let mut v_i = v.narrow(0, seq_idx, 1)?;
            if cur_len < seq_len {
                q_i = q_i.narrow(D::Minus2, 0, cur_len)?;
                k_i = k_i.narrow(D::Minus2, 0, cur_len)?;
                v_i = v_i.narrow(D::Minus2, 0, cur_len)?;
            }

            let prefix_len = prefix_lens[seq_idx];
            let k_full = if let Some(k_prefix_full) = &k_prefix {
                let start = prefix_offsets[seq_idx];
                let k_prefix_i = if prefix_len > 0 {
                    Some(k_prefix_full.narrow(0, start, prefix_len)?)
                } else {
                    None
                };
                let k_i = k_i.squeeze(0)?.contiguous()?;
                if let Some(k_prefix_i) = k_prefix_i {
                    let k_prefix_i = k_prefix_i.transpose(0, 1)?.contiguous()?;
                    Tensor::cat(&[&k_prefix_i, &k_i], 1)?
                } else {
                    k_i
                }
            } else {
                k_i.squeeze(0)?.contiguous()?
            };

            let v_full = if let Some(v_prefix_full) = &v_prefix {
                let start = prefix_offsets[seq_idx];
                let v_prefix_i = if prefix_len > 0 {
                    Some(v_prefix_full.narrow(0, start, prefix_len)?)
                } else {
                    None
                };
                let v_i = v_i.squeeze(0)?.contiguous()?;
                if let Some(v_prefix_i) = v_prefix_i {
                    let v_prefix_i = v_prefix_i.transpose(0, 1)?.contiguous()?;
                    Tensor::cat(&[&v_prefix_i, &v_i], 1)?
                } else {
                    v_i
                }
            } else {
                v_i.squeeze(0)?.contiguous()?
            };

            let mask = if cur_len > 1 {
                let offset = cur_len + prefix_len;
                let mask: Vec<_> = (0..cur_len)
                    .flat_map(|i| (0..offset).map(move |j| u8::from(j + cur_len > i + offset)))
                    .collect();
                let mask = Tensor::from_slice(&mask, (cur_len, offset), q.device())?
                    .to_dtype(DType::U8)?;
                let zero = Tensor::new(0.0f32, q.device())?;
                let mask = crate::layers_masker::masked_fill(
                    &zero.to_dtype(q.dtype())?.broadcast_as(mask.shape())?,
                    &mask,
                    f32::NEG_INFINITY,
                )?;
                Some(mask)
            } else {
                None
            };

            let attn_out_i = Sdpa.run_attention_noflash(
                &q_i,
                &k_full.unsqueeze(0)?,
                &v_full.unsqueeze(0)?,
                mask.as_ref(),
                sdpa_params,
            )?;
            let mut attn_out_i = attn_out_i.squeeze(0)?.transpose(0, 1)?;
            if cur_len < seq_len {
                attn_out_i = attn_out_i.pad_with_zeros(D::Minus(3), 0, seq_len - cur_len)?;
            }
            outputs.push(attn_out_i);
        }
        Tensor::cat(&outputs, 0)
    }
}

#[cfg(not(all(feature = "cuda", target_family = "unix")))]
#[allow(clippy::too_many_arguments)]
pub fn mla_cache_forward(
    _q: &Tensor,
    _k: &Tensor,
    _v: &Tensor,
    _ckv: &Tensor,
    _k_pe: &Tensor,
    _attention_mask: Option<&Tensor>,
    _seqlen_offsets: &[usize],
    _metadata: &Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
    _flash_params: &FlashParams,
    _kv_b_proj: &dyn mistralrs_quant::QuantMethod,
    _sdpa_params: &SdpaParams,
    _num_attention_heads: usize,
    _kv_lora_rank: usize,
    _qk_rope_head_dim: usize,
    _qk_nope_head_dim: usize,
    _v_head_dim: usize,
    _bs: usize,
    _seq_len: usize,
) -> Result<Tensor> {
    candle_core::bail!("MLA cache requires CUDA support")
}