use super::context::CudaContext;
use crate::error::{HiveGpuError, Result};
#[cfg(feature = "cuda")]
pub struct CudaHelpers;
#[cfg(feature = "cuda")]
impl CudaHelpers {
pub fn calculate_block_size(
context: &CudaContext,
grid_size: (u32, u32, u32),
) -> Result<(u32, u32, u32)> {
let (major, _minor) = context.compute_capability();
let max_threads_per_block = if major >= 6 { 1024 } else { 512 };
let x = grid_size.0.min(max_threads_per_block);
let y = grid_size.1.min(max_threads_per_block / x);
let z = grid_size.2.min(max_threads_per_block / (x * y));
Ok((x, y, z))
}
pub fn validate_device_capabilities(context: &CudaContext) -> Result<()> {
if !context.supports_required_features() {
return Err(HiveGpuError::Other(
"CUDA device does not support required features".to_string(),
));
}
Ok(())
}
}