#[cfg(feature = "cuda")]
pub(crate) fn expert_swiglu_cuda(
executor: &mut crate::cuda::CudaExecutor,
gate_bytes: &[u8],
up_bytes: &[u8],
down_bytes: &[u8],
hidden: &[f32],
hidden_dim: usize,
intermediate: usize,
) -> Result<Vec<f32>> {
if hidden.len() != hidden_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"expert_swiglu_cuda: hidden.len() = {} but hidden_dim = {}",
hidden.len(),
hidden_dim
),
});
}
if hidden_dim == 0 || intermediate == 0 {
return Err(RealizarError::InvalidShape {
reason: format!(
"expert_swiglu_cuda: hidden_dim ({hidden_dim}) and intermediate \
({intermediate}) must both be > 0"
),
});
}
let mut gate_out = vec![0.0f32; intermediate];
executor
.q4k_matvec(
gate_bytes,
hidden,
&mut gate_out,
intermediate as u32,
hidden_dim as u32,
)
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "expert_swiglu_cuda::gate_q4k_matvec".to_string(),
reason: format!("{e}"),
})?;
let mut up_out = vec![0.0f32; intermediate];
executor
.q4k_matvec(
up_bytes,
hidden,
&mut up_out,
intermediate as u32,
hidden_dim as u32,
)
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "expert_swiglu_cuda::up_q4k_matvec".to_string(),
reason: format!("{e}"),
})?;
let mut ffn_inner = vec![0.0f32; intermediate];
for i in 0..intermediate {
let g = gate_out[i];
let silu_g = g / (1.0 + (-g).exp());
ffn_inner[i] = silu_g * up_out[i];
}
let mut expert_out = vec![0.0f32; hidden_dim];
executor
.q6k_gemv(
down_bytes,
&ffn_inner,
&mut expert_out,
hidden_dim as u32,
intermediate as u32,
)
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "expert_swiglu_cuda::down_q6k_gemv".to_string(),
reason: format!("{e}"),
})?;
Ok(expert_out)
}
#[cfg(test)]
mod expert_swiglu_cuda_tests {
use super::*;
#[test]
fn expert_swiglu_cuda_signature_drift_gate() {}
#[cfg(feature = "cuda")]
#[test]
fn expert_swiglu_cuda_rejects_mismatched_hidden_len() {
if let Ok(mut executor) = crate::cuda::CudaExecutor::new(0) {
let dummy_bytes = vec![0u8; 144];
let hidden = vec![1.0f32; 5];
let result = expert_swiglu_cuda(
&mut executor,
&dummy_bytes,
&dummy_bytes,
&dummy_bytes,
&hidden,
10,
4,
);
assert!(matches!(result, Err(RealizarError::InvalidShape { .. })));
}
}
}