#![cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
use cudarc::driver::result as cudarc_result;
use cudarc::driver::sys;
use cudarc::driver::{CudaFunction, CudaSlice, CudaStream};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
struct CuGraphHolder {
cu_graph: sys::CUgraph,
cu_graph_exec: sys::CUgraphExec,
stream: Arc<CudaStream>,
}
impl CuGraphHolder {
unsafe fn launch(&self) -> Result<(), cudarc::driver::DriverError> {
cudarc_result::graph::launch(self.cu_graph_exec, self.stream.cu_stream())
}
unsafe fn upload(&self) -> Result<(), cudarc::driver::DriverError> {
cudarc_result::graph::upload(self.cu_graph_exec, self.stream.cu_stream())
}
}
impl Drop for CuGraphHolder {
fn drop(&mut self) {
unsafe {
let _ = cudarc_result::graph::exec_destroy(self.cu_graph_exec);
let _ = cudarc_result::graph::destroy(self.cu_graph);
}
}
}
unsafe impl Send for CuGraphHolder {}
use super::cuda_attn_kernels::CUDA_ATTENTION_KERNELS_SRC;
use super::cuda_graph::{compile_or_load_ptx, CudaGraph, CudaGraphError};
mod launchers;
use launchers::{
launch_batched_attn_scores_v2, launch_batched_attn_weighted_sum, launch_batched_softmax,
launch_fused_kv_store, launch_fused_qk_norm_rope,
};
pub mod encode_q1;
pub mod encode_ternary;
pub struct CudaAttnModules {
pub fused_qk_norm: CudaFunction,
pub fused_qk_rope: CudaFunction,
pub fused_qk_norm_rope: CudaFunction,
pub fused_kv_store: CudaFunction,
pub batched_attn_scores_v2: CudaFunction,
pub batched_softmax: CudaFunction,
pub batched_attn_weighted_sum: CudaFunction,
}
unsafe impl Send for CudaAttnModules {}
unsafe impl Sync for CudaAttnModules {}
pub struct CudaKvCache {
pub k_cache: CudaSlice<u16>,
pub v_cache: CudaSlice<u16>,
pub n_layers: usize,
pub n_kv: usize,
pub max_seq: usize,
pub head_dim: usize,
}
unsafe impl Send for CudaKvCache {}
unsafe impl Sync for CudaKvCache {}
impl CudaKvCache {
#[inline]
pub fn layer_offset_elements(&self, layer_idx: usize) -> u32 {
(layer_idx * self.n_kv * self.max_seq * self.head_dim) as u32
}
pub fn matches(&self, n_layers: usize, n_kv: usize, max_seq: usize, head_dim: usize) -> bool {
self.n_layers == n_layers
&& self.n_kv == n_kv
&& self.max_seq == max_seq
&& self.head_dim == head_dim
}
}
pub struct CudaFullLayerBuffers {
pub d_hidden: CudaSlice<f32>,
pub d_normed: CudaSlice<f32>,
pub d_qkv: CudaSlice<f32>,
pub d_q_rope: CudaSlice<f32>,
pub d_k_rope: CudaSlice<f32>,
pub d_cos: CudaSlice<f32>,
pub d_sin: CudaSlice<f32>,
pub d_scores: CudaSlice<f32>,
pub d_attn_out: CudaSlice<f32>,
pub d_gate_up: CudaSlice<f32>,
pub d_swiglu: CudaSlice<f32>,
pub d_pos_seqlen: CudaSlice<u32>,
pub hidden_size: usize,
pub nq: usize,
pub nkv: usize,
pub head_dim: usize,
pub max_seq: usize,
pub intermediate_size: usize,
}
unsafe impl Send for CudaFullLayerBuffers {}
unsafe impl Sync for CudaFullLayerBuffers {}
impl CudaFullLayerBuffers {
pub fn matches(
&self,
hidden_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
max_seq: usize,
intermediate_size: usize,
) -> bool {
self.hidden_size == hidden_size
&& self.nq == nq
&& self.nkv == nkv
&& self.head_dim == head_dim
&& self.max_seq == max_seq
&& self.intermediate_size == intermediate_size
}
}
pub struct CudaCachedLayerWeights {
pub q_weight: Arc<CudaSlice<u8>>,
pub k_weight: Arc<CudaSlice<u8>>,
pub v_weight: Arc<CudaSlice<u8>>,
pub o_weight: Arc<CudaSlice<u8>>,
pub gate_up_weight: Arc<CudaSlice<u8>>,
pub down_weight: Arc<CudaSlice<u8>>,
pub pre_attn_norm: Arc<CudaSlice<f32>>,
pub post_attn_norm: Arc<CudaSlice<f32>>,
pub q_norm: Arc<CudaSlice<f32>>,
pub k_norm: Arc<CudaSlice<f32>>,
}
unsafe impl Send for CudaCachedLayerWeights {}
unsafe impl Sync for CudaCachedLayerWeights {}
pub struct CudaFullForwardLayerParams<'a> {
pub attn_norm_handle: u64,
pub attn_norm_bytes: &'a [f32],
pub fused_qkv_handle: u64,
pub fused_qkv_bytes: &'a [u8],
pub q_norm_handle: u64,
pub q_norm_bytes: &'a [f32],
pub k_norm_handle: u64,
pub k_norm_bytes: &'a [f32],
pub attn_proj_handle: u64,
pub attn_proj_bytes: &'a [u8],
pub ffn_norm_handle: u64,
pub ffn_norm_bytes: &'a [f32],
pub gate_up_handle: u64,
pub gate_bytes: &'a [u8],
pub up_bytes: &'a [u8],
pub down_handle: u64,
pub down_bytes: &'a [u8],
}
pub struct CudaCachedModelWeights {
pub graph: Arc<CudaGraph>,
pub dummy_weight: Arc<CudaSlice<u8>>,
pub layers: Arc<Vec<CudaCachedLayerWeights>>,
pub n_layers: usize,
}
unsafe impl Send for CudaCachedModelWeights {}
unsafe impl Sync for CudaCachedModelWeights {}
struct CudaFullLayerState {
attn_modules: Mutex<Option<Arc<CudaAttnModules>>>,
full_layer_buffers: Mutex<Option<CudaFullLayerBuffers>>,
kv_cache: Mutex<Option<CudaKvCache>>,
f32_weight_cache: Mutex<HashMap<u64, Arc<CudaSlice<f32>>>>,
cached_model_weights: Mutex<Option<CudaCachedModelWeights>>,
cuda_driver_graph: Mutex<Option<Option<CuGraphHolder>>>,
}
unsafe impl Send for CudaFullLayerState {}
unsafe impl Sync for CudaFullLayerState {}
static FULL_LAYER_STATE: OnceLock<CudaFullLayerState> = OnceLock::new();
fn full_layer_state() -> &'static CudaFullLayerState {
FULL_LAYER_STATE.get_or_init(|| CudaFullLayerState {
attn_modules: Mutex::new(None),
full_layer_buffers: Mutex::new(None),
kv_cache: Mutex::new(None),
f32_weight_cache: Mutex::new(HashMap::new()),
cached_model_weights: Mutex::new(None),
cuda_driver_graph: Mutex::new(None),
})
}
static PROFILE_ENABLED: OnceLock<bool> = OnceLock::new();
#[inline(always)]
pub(super) fn profiling() -> bool {
*PROFILE_ENABLED.get_or_init(|| std::env::var("CUDA_PROFILE").is_ok())
}
pub fn get_or_upload_f32_weight(
graph: &CudaGraph,
key: u64,
data: &[f32],
) -> Result<Arc<CudaSlice<f32>>, CudaGraphError> {
let state = full_layer_state();
{
let cache = state
.f32_weight_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(existing) = cache.get(&key) {
return Ok(Arc::clone(existing));
}
}
let d_slice = graph
.stream_arc()
.clone_htod(data)
.map_err(|e| CudaGraphError::DriverError(format!("clone_htod f32: {e}")))?;
let arc = Arc::new(d_slice);
let mut cache = state
.f32_weight_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
cache.insert(key, Arc::clone(&arc));
Ok(arc)
}
pub(super) fn get_or_build_model_weights(
layer_params: &[CudaFullForwardLayerParams<'_>],
) -> Option<(Arc<CudaGraph>, Arc<Vec<CudaCachedLayerWeights>>)> {
let n_layers = layer_params.len();
let state = full_layer_state();
{
let guard = state.cached_model_weights.lock().ok()?;
if let Some(ref cmw) = *guard {
if cmw.n_layers == n_layers {
return Some((Arc::clone(&cmw.graph), Arc::clone(&cmw.layers)));
}
}
}
let graph = CudaGraph::global().ok()?;
let dummy_weight = Arc::new(graph.stream_arc().alloc_zeros::<u8>(1).ok()?);
let mut cached: Vec<CudaCachedLayerWeights> = Vec::with_capacity(n_layers);
for lp in layer_params {
let q_weight = graph
.get_or_upload_weight_soa(lp.fused_qkv_handle, lp.fused_qkv_bytes)
.ok()?;
let o_weight = graph
.get_or_upload_weight_soa(lp.attn_proj_handle, lp.attn_proj_bytes)
.ok()?;
let gate_bytes = lp.gate_bytes;
let up_bytes = lp.up_bytes;
let gate_up_weight = graph
.get_or_upload_weight_soa_lazy(lp.gate_up_handle, || {
let mut fused = Vec::with_capacity(gate_bytes.len() + up_bytes.len());
fused.extend_from_slice(gate_bytes);
fused.extend_from_slice(up_bytes);
fused
})
.ok()?;
let down_weight = graph
.get_or_upload_weight_soa(lp.down_handle, lp.down_bytes)
.ok()?;
let pre_attn_norm =
get_or_upload_f32_weight(&graph, lp.attn_norm_handle, lp.attn_norm_bytes).ok()?;
let post_attn_norm =
get_or_upload_f32_weight(&graph, lp.ffn_norm_handle, lp.ffn_norm_bytes).ok()?;
let q_norm = get_or_upload_f32_weight(&graph, lp.q_norm_handle, lp.q_norm_bytes).ok()?;
let k_norm = get_or_upload_f32_weight(&graph, lp.k_norm_handle, lp.k_norm_bytes).ok()?;
cached.push(CudaCachedLayerWeights {
q_weight,
k_weight: Arc::clone(&dummy_weight),
v_weight: Arc::clone(&dummy_weight),
o_weight,
gate_up_weight,
down_weight,
pre_attn_norm,
post_attn_norm,
q_norm,
k_norm,
});
}
let layers = Arc::new(cached);
let cmw = CudaCachedModelWeights {
graph: Arc::clone(&graph),
dummy_weight,
layers: Arc::clone(&layers),
n_layers,
};
if let Ok(mut guard) = state.cached_model_weights.lock() {
*guard = Some(cmw);
}
Some((graph, layers))
}
pub fn init_attn_modules(graph: &CudaGraph) -> Result<Arc<CudaAttnModules>, CudaGraphError> {
let state = full_layer_state();
let mut guard = state
.attn_modules
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(ref m) = *guard {
return Ok(Arc::clone(m));
}
let ptx = compile_or_load_ptx(CUDA_ATTENTION_KERNELS_SRC, "attn_kernels")?;
let module = graph
.context_arc()
.load_module(ptx)
.map_err(|e| CudaGraphError::DriverError(format!("load_module attn: {e}")))?;
let load = |name: &str| -> Result<CudaFunction, CudaGraphError> {
module
.load_function(name)
.map_err(|e| CudaGraphError::DriverError(format!("load_function({name}): {e}")))
};
let modules = Arc::new(CudaAttnModules {
fused_qk_norm: load("fused_qk_norm")?,
fused_qk_rope: load("fused_qk_rope")?,
fused_qk_norm_rope: load("fused_qk_norm_rope")?,
fused_kv_store: load("fused_kv_store")?,
batched_attn_scores_v2: load("batched_attn_scores_v2")?,
batched_softmax: load("batched_softmax")?,
batched_attn_weighted_sum: load("batched_attn_weighted_sum")?,
});
*guard = Some(Arc::clone(&modules));
Ok(modules)
}
pub(super) fn acquire_full_layer_buffers(
graph: &CudaGraph,
hidden_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
max_seq: usize,
intermediate_size: usize,
) -> Result<std::sync::MutexGuard<'static, Option<CudaFullLayerBuffers>>, CudaGraphError> {
let state = full_layer_state();
let mut guard = state
.full_layer_buffers
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some(b) => !b.matches(hidden_size, nq, nkv, head_dim, max_seq, intermediate_size),
None => true,
};
if needs_alloc {
let alloc = |n: usize| -> Result<CudaSlice<f32>, CudaGraphError> {
graph
.stream_arc()
.alloc_zeros::<f32>(n)
.map_err(|e| CudaGraphError::DriverError(format!("alloc_zeros fl({n}): {e}")))
};
let qkv_total = nq * head_dim + 2 * nkv * head_dim;
let half_dim = head_dim / 2;
let alloc_u32 = |n: usize| -> Result<CudaSlice<u32>, CudaGraphError> {
graph
.stream_arc()
.alloc_zeros::<u32>(n)
.map_err(|e| CudaGraphError::DriverError(format!("alloc_zeros u32({n}): {e}")))
};
*guard = Some(CudaFullLayerBuffers {
d_hidden: alloc(hidden_size)?,
d_normed: alloc(hidden_size)?,
d_qkv: alloc(qkv_total)?,
d_q_rope: alloc(nq * head_dim)?,
d_k_rope: alloc(nkv * head_dim)?,
d_cos: alloc(half_dim)?,
d_sin: alloc(half_dim)?,
d_scores: alloc(nq * max_seq)?,
d_attn_out: alloc(nq * head_dim)?,
d_gate_up: alloc(2 * intermediate_size)?,
d_swiglu: alloc(intermediate_size)?,
d_pos_seqlen: alloc_u32(2)?,
hidden_size,
nq,
nkv,
head_dim,
max_seq,
intermediate_size,
});
if let Ok(mut g) = full_layer_state().cuda_driver_graph.lock() {
*g = None;
}
}
Ok(guard)
}
pub(super) fn acquire_kv_cache(
graph: &CudaGraph,
n_layers: usize,
n_kv: usize,
max_seq: usize,
head_dim: usize,
) -> Result<std::sync::MutexGuard<'static, Option<CudaKvCache>>, CudaGraphError> {
let state = full_layer_state();
let mut guard = state
.kv_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some(c) => !c.matches(n_layers, n_kv, max_seq, head_dim),
None => true,
};
if needs_alloc {
let total_elements = n_layers * n_kv * max_seq * head_dim;
let k_cache = graph
.stream_arc()
.alloc_zeros::<u16>(total_elements)
.map_err(|e| CudaGraphError::DriverError(format!("alloc kv k_cache: {e}")))?;
let v_cache = graph
.stream_arc()
.alloc_zeros::<u16>(total_elements)
.map_err(|e| CudaGraphError::DriverError(format!("alloc kv v_cache: {e}")))?;
*guard = Some(CudaKvCache {
k_cache,
v_cache,
n_layers,
n_kv,
max_seq,
head_dim,
});
}
Ok(guard)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn encode_attn_phase(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_pre_norm_weight: &CudaSlice<f32>,
d_fused_qkv_weight: &Arc<CudaSlice<u8>>,
d_q_norm_weight: &CudaSlice<f32>,
d_k_norm_weight: &CudaSlice<f32>,
kv: &mut CudaKvCache,
layer_idx: usize,
_pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
bufs: &mut CudaFullLayerBuffers,
) -> Result<(), CudaGraphError> {
let h_u32 = hidden_size as u32;
let nq_u32 = nq as u32;
let nkv_u32 = nkv as u32;
let hd_u32 = head_dim as u32;
let qkv_total_rows = (nq * head_dim + 2 * nkv * head_dim) as u32;
let heads_per_group_u32 = heads_per_group as u32;
let max_seq_u32 = bufs.max_seq as u32;
let inv_sqrt_hd = 1.0f32 / (head_dim as f32).sqrt();
let layer_offset = kv.layer_offset_elements(layer_idx);
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
d_pre_norm_weight,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_pub(
d_fused_qkv_weight,
&bufs.d_normed,
&mut bufs.d_qkv,
qkv_total_rows,
h_u32,
)?;
let k_offset = nq * head_dim;
let k_in_view = bufs.d_qkv.slice(k_offset..);
launch_fused_qk_norm_rope(
graph,
mods,
&bufs.d_qkv, &k_in_view, &mut bufs.d_q_rope,
&mut bufs.d_k_rope,
d_q_norm_weight,
d_k_norm_weight,
&bufs.d_cos,
&bufs.d_sin,
nq_u32,
nkv_u32,
hd_u32,
norm_eps,
)?;
let v_offset = (nq + nkv) * head_dim;
let v_view = bufs.d_qkv.slice(v_offset..);
launch_fused_kv_store(
graph,
mods,
&bufs.d_k_rope,
&v_view,
&mut kv.k_cache,
&mut kv.v_cache,
hd_u32,
nkv_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
launch_batched_attn_scores_v2(
graph,
mods,
&bufs.d_q_rope,
&kv.k_cache,
&mut bufs.d_scores,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
inv_sqrt_hd,
layer_offset,
)?;
launch_batched_softmax(
graph,
mods,
&mut bufs.d_scores,
nq_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
)?;
launch_batched_attn_weighted_sum(
graph,
mods,
&bufs.d_scores,
&kv.v_cache,
&mut bufs.d_attn_out,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn encode_attn_phase_tq2(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_pre_norm_weight: &CudaSlice<f32>,
d_fused_qkv_weight: &Arc<CudaSlice<u8>>,
d_q_norm_weight: &CudaSlice<f32>,
d_k_norm_weight: &CudaSlice<f32>,
kv: &mut CudaKvCache,
layer_idx: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
bufs: &mut CudaFullLayerBuffers,
) -> Result<(), CudaGraphError> {
let h_u32 = hidden_size as u32;
let nq_u32 = nq as u32;
let nkv_u32 = nkv as u32;
let hd_u32 = head_dim as u32;
let qkv_total_rows = (nq * head_dim + 2 * nkv * head_dim) as u32;
let heads_per_group_u32 = heads_per_group as u32;
let max_seq_u32 = bufs.max_seq as u32;
let inv_sqrt_hd = 1.0f32 / (head_dim as f32).sqrt();
let layer_offset = kv.layer_offset_elements(layer_idx);
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
d_pre_norm_weight,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_tq2_v1_pub(
d_fused_qkv_weight,
&bufs.d_normed,
&mut bufs.d_qkv,
qkv_total_rows,
h_u32,
)?;
let k_offset = nq * head_dim;
let k_in_view = bufs.d_qkv.slice(k_offset..);
launch_fused_qk_norm_rope(
graph,
mods,
&bufs.d_qkv,
&k_in_view,
&mut bufs.d_q_rope,
&mut bufs.d_k_rope,
d_q_norm_weight,
d_k_norm_weight,
&bufs.d_cos,
&bufs.d_sin,
nq_u32,
nkv_u32,
hd_u32,
norm_eps,
)?;
let v_offset = (nq + nkv) * head_dim;
let v_view = bufs.d_qkv.slice(v_offset..);
launch_fused_kv_store(
graph,
mods,
&bufs.d_k_rope,
&v_view,
&mut kv.k_cache,
&mut kv.v_cache,
hd_u32,
nkv_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
launch_batched_attn_scores_v2(
graph,
mods,
&bufs.d_q_rope,
&kv.k_cache,
&mut bufs.d_scores,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
inv_sqrt_hd,
layer_offset,
)?;
launch_batched_softmax(
graph,
mods,
&mut bufs.d_scores,
nq_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
)?;
launch_batched_attn_weighted_sum(
graph,
mods,
&bufs.d_scores,
&kv.v_cache,
&mut bufs.d_attn_out,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn encode_attn_phase_from_qkv(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_norm_weight: &CudaSlice<f32>,
d_k_norm_weight: &CudaSlice<f32>,
kv: &mut CudaKvCache,
layer_idx: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
bufs: &mut CudaFullLayerBuffers,
) -> Result<(), CudaGraphError> {
let nq_u32 = nq as u32;
let nkv_u32 = nkv as u32;
let hd_u32 = head_dim as u32;
let heads_per_group_u32 = heads_per_group as u32;
let max_seq_u32 = bufs.max_seq as u32;
let inv_sqrt_hd = 1.0f32 / (head_dim as f32).sqrt();
let layer_offset = kv.layer_offset_elements(layer_idx);
let k_offset = nq * head_dim;
let k_in_view = bufs.d_qkv.slice(k_offset..);
launch_fused_qk_norm_rope(
graph,
mods,
&bufs.d_qkv,
&k_in_view,
&mut bufs.d_q_rope,
&mut bufs.d_k_rope,
d_q_norm_weight,
d_k_norm_weight,
&bufs.d_cos,
&bufs.d_sin,
nq_u32,
nkv_u32,
hd_u32,
norm_eps,
)?;
let v_offset = (nq + nkv) * head_dim;
let v_view = bufs.d_qkv.slice(v_offset..);
launch_fused_kv_store(
graph,
mods,
&bufs.d_k_rope,
&v_view,
&mut kv.k_cache,
&mut kv.v_cache,
hd_u32,
nkv_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
launch_batched_attn_scores_v2(
graph,
mods,
&bufs.d_q_rope,
&kv.k_cache,
&mut bufs.d_scores,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
inv_sqrt_hd,
layer_offset,
)?;
launch_batched_softmax(
graph,
mods,
&mut bufs.d_scores,
nq_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
)?;
launch_batched_attn_weighted_sum(
graph,
mods,
&bufs.d_scores,
&kv.v_cache,
&mut bufs.d_attn_out,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)
}
pub use encode_q1::{
encode_full_forward, encode_full_layer, try_cuda_full_forward,
try_cuda_full_forward_with_gpu_lm_head, try_cuda_full_layer,
};
pub use encode_ternary::{
encode_full_forward_ternary, encode_layer_into_ternary, encode_lm_head_gemv_ternary,
try_cuda_full_forward_ternary, try_cuda_full_forward_ternary_with_gpu_lm_head,
CudaFullForwardLayerParamsTernary,
};
#[cfg(test)]
mod tests;