mod common;
pub mod config;
pub mod cuda;
mod pqo;
mod precomputed;
pub(crate) mod quantize_tensor;
mod storage;
mod tq;
mod wht_tensor;
use std::sync::OnceLock;
use candle_core::{Device, Result};
use parking_lot::Mutex;
#[doc(hidden)]
#[derive(Default)]
pub struct PrecomputedState {
pub(crate) cell: OnceLock<GpuPrecomputed>,
pub(crate) init_mutex: Mutex<()>,
}
pub use config::{CacheConfig, QuantNormMode, QUANT_BLOCK_SIZE};
pub use pqo::PqoCache;
pub use precomputed::GpuPrecomputed;
pub use storage::{LayerBuffers, LayerStorage, QuantizedKV, StorageMetadata};
pub use tq::TqCache;
pub(crate) fn cache_err(msg: impl std::fmt::Display) -> candle_core::Error {
candle_core::Error::Msg(format!("TurboQuant cache: {msg}"))
}
#[doc(hidden)]
pub fn ensure_gpu_precomputed<'a>(
state: &'a PrecomputedState,
config: &CacheConfig,
device: &Device,
) -> Result<&'a GpuPrecomputed> {
if let Some(p) = state.cell.get() {
return Ok(p);
}
let _init_guard = state.init_mutex.lock();
if let Some(p) = state.cell.get() {
return Ok(p);
}
let fresh = GpuPrecomputed::new(config, device)?;
state
.cell
.set(fresh)
.map_err(|_| cache_err("precomputed cell was initialized concurrently during set"))?;
state
.cell
.get()
.ok_or_else(|| cache_err("precomputed cell unset after successful set — unreachable"))
}