#[cfg(all(feature = "cuda", target_family = "unix"))]
use std::sync::Mutex;
#[cfg(all(feature = "cuda", target_family = "unix"))]
use candle_core::{Device, Result, Tensor, D};
#[cfg(all(feature = "cuda", target_family = "unix"))]
use mistralrs_quant::QuantMethod;
pub struct MlaWeights {
#[cfg(all(feature = "cuda", target_family = "unix"))]
weights: Option<Mutex<Option<(Tensor, Tensor)>>>,
#[cfg(not(all(feature = "cuda", target_family = "unix")))]
_phantom: std::marker::PhantomData<()>,
}
impl MlaWeights {
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub fn new(paged_attn_enabled: bool, device: Option<&Device>) -> Self {
let weights = if paged_attn_enabled {
if let Some(device) = device {
if matches!(device, Device::Cuda(_)) {
Some(Mutex::new(None))
} else {
None
}
} else {
Some(Mutex::new(None))
}
} else {
None
};
Self { weights }
}
#[cfg(not(all(feature = "cuda", target_family = "unix")))]
pub fn new(_paged_attn_enabled: bool, _device: Option<&candle_core::Device>) -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub fn compute_weights(
kv_b_proj: &dyn QuantMethod,
device: &Device,
num_attention_heads: usize,
kv_lora_rank: usize,
qk_nope_head_dim: usize,
v_head_dim: usize,
) -> Result<(Tensor, Tensor)> {
let mut w = kv_b_proj.dequantize_w()?;
if !w.device().same_device(device) {
w = w.to_device(device)?;
}
let (out_dim, in_dim) = w.dims2()?;
if in_dim != kv_lora_rank {
candle_core::bail!(
"kv_b_proj weight in_dim mismatch: expected {}, got {}",
kv_lora_rank,
in_dim
);
}
let per_head_dim = qk_nope_head_dim + v_head_dim;
if out_dim != num_attention_heads * per_head_dim {
candle_core::bail!(
"kv_b_proj weight out_dim mismatch: expected {}, got {}",
num_attention_heads * per_head_dim,
out_dim
);
}
let w = w.reshape((num_attention_heads, per_head_dim, kv_lora_rank))?;
let w_uk = w.narrow(D::Minus2, 0, qk_nope_head_dim)?.contiguous()?;
let w_uv = w
.narrow(D::Minus2, qk_nope_head_dim, v_head_dim)?
.contiguous()?;
let w_uv_t = w_uv.transpose(1, 2)?.contiguous()?;
Ok((w_uk, w_uv_t))
}
#[cfg(all(feature = "cuda", target_family = "unix"))]
pub fn get_or_compute(
&self,
kv_b_proj: &dyn QuantMethod,
device: &Device,
num_attention_heads: usize,
kv_lora_rank: usize,
qk_nope_head_dim: usize,
v_head_dim: usize,
) -> Result<(Tensor, Tensor)> {
let Some(mla_weights) = &self.weights else {
candle_core::bail!("MLA weights are not initialized on this device");
};
let mut guard = mla_weights.lock().expect("MLA weights mutex was poisoned");
if let Some((w_uk, w_uv_t)) = guard.as_ref() {
return Ok((w_uk.clone(), w_uv_t.clone()));
}
let (w_uk, w_uv_t) = Self::compute_weights(
kv_b_proj,
device,
num_attention_heads,
kv_lora_rank,
qk_nope_head_dim,
v_head_dim,
)?;
*guard = Some((w_uk.clone(), w_uv_t.clone()));
Ok((w_uk, w_uv_t))
}
#[cfg(not(all(feature = "cuda", target_family = "unix")))]
#[allow(dead_code)]
pub fn get_or_compute(
&self,
_kv_b_proj: &dyn mistralrs_quant::QuantMethod,
_device: &candle_core::Device,
_num_attention_heads: usize,
_kv_lora_rank: usize,
_qk_nope_head_dim: usize,
_v_head_dim: usize,
) -> candle_core::Result<(candle_core::Tensor, candle_core::Tensor)> {
candle_core::bail!("MLA weights require CUDA support")
}
}