#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use crate::MemoryUsage;
use hanzo_ml::{Device, Result, Tensor};
use hanzo_quant::MatMul;
use crate::attention::{chunked_attention, SdpaParams};
pub(crate) fn maybe_synchronize(device: &Device) -> Result<()> {
if !device.is_cuda() {
return Ok(());
}
#[cfg(target_pointer_width = "64")]
const FOUR_GIB: usize = 4 * 1024 * 1024 * 1024;
#[cfg(not(target_pointer_width = "64"))]
const FOUR_GIB: usize = usize::MAX;
if MemoryUsage.query(device)?.available() < FOUR_GIB {
device.synchronize()?;
}
Ok(())
}
pub(crate) fn naive_sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
maybe_synchronize(q.device())?;
chunked_attention(q, k, v, mask, |q_chunk, k, v, mask_chunk| {
let mut att =
MatMul.matmul_affine_mul(q_chunk, &k.t()?, sdpa_params.softmax_scale.into())?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}
if let Some(mask) = mask_chunk {
att = att.broadcast_add(mask)?;
}
let att_dtype = att.dtype();
if att_dtype == hanzo_ml::DType::BF16 || att_dtype == hanzo_ml::DType::F16 {
att = att.to_dtype(hanzo_ml::DType::F32)?;
}
att = hanzo_nn::ops::softmax_last_dim(&att)?;
if att.dtype() != att_dtype {
att = att.to_dtype(att_dtype)?;
}
MatMul.matmul(&att, v)
})
}