use candle_core::{Result, Tensor};
pub mod adaptive_linear;
pub mod attention;
pub mod bit_linear;
#[cfg(feature = "flash-attention")]
pub mod flash_attention;
pub mod isomorphic;
pub mod kv_cache;
pub mod linear_4bit; pub mod rms_norm;
pub mod swiglu;
pub mod ttt;
pub use adaptive_linear::AdaptiveBitLinear;
pub use attention::{BitAttention, KVCache};
pub use bit_linear::BitLinear;
pub use isomorphic::{IsomorphicOffloader, LayerPlacement, MemoryPressure};
pub use kv_cache::QuantizedKVCache;
pub use linear_4bit::Linear4Bit;
pub use rms_norm::RMSNorm;
pub use swiglu::SwiGLU;
pub use ttt::TTTLayer;
#[cfg(feature = "flash-attention")]
pub use flash_attention::{flash_attention, FlashAttentionConfig};
#[allow(dead_code)]
pub(crate) trait TensorExt {
fn matmul_robust(&self, rhs: &Tensor) -> Result<Tensor>;
}
impl TensorExt for Tensor {
fn matmul_robust(&self, rhs: &Tensor) -> Result<Tensor> {
let lhs = self.contiguous()?;
let rhs = rhs.contiguous()?;
let lhs_rank = lhs.rank();
let rhs = if rhs.device().same_device(lhs.device()) {
rhs.clone()
} else {
rhs.to_device(lhs.device())?
};
let rhs = &rhs;
if lhs_rank == 1 {
lhs.unsqueeze(0)?.matmul(rhs)?.squeeze(0)
} else if lhs_rank == 2 {
lhs.matmul(rhs)
} else {
let flattened = lhs.flatten(0, lhs_rank - 2)?;
let out = flattened.matmul(rhs)?;
let mut new_shape = lhs.dims()[..lhs_rank - 1].to_vec();
new_shape.push(out.dim(1)?);
out.reshape(new_shape)
}
}
}