#[cfg(feature = "cuda")]
#[allow(unused_imports)]
use crate::cuda::CudaExecutor;
#[cfg(feature = "cuda")]
#[allow(unused_imports)]
use crate::error::{RealizarError, Result};
#[cfg(feature = "cuda")]
pub struct CudaScheduler {
executor: crate::cuda::CudaExecutor,
}
#[cfg(feature = "cuda")]
impl CudaScheduler {
pub fn new() -> Result<Self> {
let executor = crate::cuda::CudaExecutor::new(0).map_err(|e| RealizarError::GpuError {
reason: format!("Failed to create CudaExecutor: {}", e),
})?;
Ok(Self { executor })
}
#[must_use]
pub fn has_cuda(&self) -> bool {
true }
#[must_use]
#[allow(clippy::unused_self)]
pub fn uses_cuda_for(&self, _m: usize, _k: usize, _n: usize) -> bool {
true }
#[allow(clippy::many_single_char_names)]
pub fn matmul(
&mut self,
a: &[f32],
b: &[f32],
m: usize,
k: usize,
n: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; m * n];
self.executor
.gemm(a, b, &mut output, m as u32, n as u32, k as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("CUDA GEMM failed: {}", e),
})?;
Ok(output)
}
pub fn device_name(&self) -> Result<String> {
self.executor
.device_name()
.map_err(|e| RealizarError::GpuError {
reason: format!("Failed to get device name: {}", e),
})
}
pub fn cache_weight(&mut self, name: &str, weight: &[f32]) -> Result<()> {
self.executor
.load_weights(name, weight)
.map(|_| ())
.map_err(|e| RealizarError::GpuError {
reason: format!("Failed to cache weight '{}': {}", name, e),
})
}
#[must_use]
pub fn has_cached_weight(&self, name: &str) -> bool {
self.executor.has_weights(name)
}
#[must_use]
pub fn cached_weight_count(&self) -> usize {
self.executor.cached_weight_count()
}
pub fn matmul_cached(
&mut self,
weight_name: &str,
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; n];
self.executor
.gemv_cached(weight_name, x, &mut output, k as u32, n as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("CUDA cached GEMV failed: {}", e),
})?;
Ok(output)
}
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
#[test]
fn test_cuda_scheduler_new() {
let scheduler = CudaScheduler::new();
assert!(
scheduler.is_ok(),
"CudaScheduler::new() failed: {:?}",
scheduler.err()
);
}
#[test]
fn test_cuda_scheduler_has_cuda() {
let scheduler = CudaScheduler::new().expect("scheduler");
assert!(scheduler.has_cuda());
}
#[test]
fn test_cuda_scheduler_uses_cuda_for_all_dims() {
let scheduler = CudaScheduler::new().expect("scheduler");
assert!(scheduler.uses_cuda_for(1, 64, 64)); assert!(scheduler.uses_cuda_for(8, 256, 256)); assert!(scheduler.uses_cuda_for(1, 1, 1)); assert!(scheduler.uses_cuda_for(1024, 4096, 4096)); }
#[test]
fn test_cuda_scheduler_device_name() {
let scheduler = CudaScheduler::new().expect("scheduler");
let name = scheduler.device_name();
assert!(name.is_ok());
let name = name.expect("name");
assert!(!name.is_empty());
}
#[test]
fn test_cuda_scheduler_matmul_basic() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let result = scheduler.matmul(&a, &b, 2, 3, 2);
assert!(result.is_ok());
let output = result.expect("output");
assert_eq!(output.len(), 4);
}
#[test]
fn test_cuda_scheduler_matmul_single_element() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
let a = vec![3.0];
let b = vec![4.0];
let result = scheduler.matmul(&a, &b, 1, 1, 1);
assert!(result.is_ok());
let output = result.expect("output");
assert_eq!(output.len(), 1);
assert!((output[0] - 12.0).abs() < 0.1);
}
#[test]
fn test_cuda_scheduler_matmul_larger() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
let a: Vec<f32> = (0..256).map(|i| (i as f32) * 0.01).collect();
let b: Vec<f32> = (0..2048).map(|i| (i as f32) * 0.001).collect();
let result = scheduler.matmul(&a, &b, 4, 64, 32);
assert!(result.is_ok());
let output = result.expect("output");
assert_eq!(output.len(), 128);
}
#[test]
fn test_cuda_scheduler_cache_weight() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
let weight = vec![1.0f32; 256 * 128];
let result = scheduler.cache_weight("test_weight", &weight);
assert!(result.is_ok());
assert!(scheduler.has_cached_weight("test_weight"));
assert!(!scheduler.has_cached_weight("nonexistent"));
}
#[test]
fn test_cuda_scheduler_cached_weight_count() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
let initial_count = scheduler.cached_weight_count();
let weight = vec![1.0f32; 64 * 64];
scheduler
.cache_weight("weight_1", &weight)
.expect("cache_weight");
assert_eq!(scheduler.cached_weight_count(), initial_count + 1);
scheduler
.cache_weight("weight_2", &weight)
.expect("cache_weight");
assert_eq!(scheduler.cached_weight_count(), initial_count + 2);
}
#[test]
fn test_cuda_scheduler_matmul_cached() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
let weight: Vec<f32> = (0..2048).map(|i| (i as f32) * 0.001).collect();
scheduler
.cache_weight("cached_test", &weight)
.expect("cache_weight");
let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
let result = scheduler.matmul_cached("cached_test", &input, 64, 32);
assert!(result.is_ok());
let output = result.expect("output");
assert_eq!(output.len(), 32);
}
#[test]
fn test_cuda_scheduler_matmul_identity() {
let mut scheduler = CudaScheduler::new().expect("scheduler");
#[rustfmt::skip]
let identity = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
];
let v = vec![1.0, 2.0, 3.0, 4.0];
let result = scheduler.matmul(&identity, &v, 4, 4, 1);
assert!(result.is_ok());
let output = result.expect("output");
for (i, &expected) in v.iter().enumerate() {
assert!(
(output[i] - expected).abs() < 0.01,
"idx={} got={} expected={}",
i,
output[i],
expected
);
}
}
}