use super::types::{BufferCache, BufferId};
use crate::errors::{Result, TrustformersError};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::driver::{CudaContext, CudaStream};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use std::sync::Arc;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub struct CudaBackend {
pub(crate) context: Arc<CudaContext>,
pub(crate) stream: Arc<CudaStream>,
pub(crate) buffer_cache: Arc<std::sync::Mutex<BufferCache>>,
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl CudaBackend {
pub fn new(device_id: usize) -> Result<Self> {
let context = CudaContext::new(device_id).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create CUDA context: {}", e),
"CudaBackend::new",
)
})?;
let stream = context.default_stream();
println!("✓ CUDA backend initialized on device {}", device_id);
Ok(Self {
context,
stream,
buffer_cache: Arc::new(std::sync::Mutex::new(BufferCache::new())),
})
}
pub fn device_info(&self) -> String {
format!("CUDA Device {}", self.context.ordinal())
}
}
pub use BufferId;