use crate::error::{HiveGpuError, Result};
pub struct WgslShaderManager {
shader_sources: std::collections::HashMap<String, String>,
}
impl WgslShaderManager {
pub fn new() -> Self {
Self {
shader_sources: Self::load_wgsl_shaders(),
}
}
pub fn get_shader(&self, name: &str) -> Option<&str> {
self.shader_sources.get(name).map(|s| s.as_str())
}
pub fn get_shader_names(&self) -> Vec<String> {
self.shader_sources.keys().cloned().collect()
}
fn load_wgsl_shaders() -> std::collections::HashMap<String, String> {
let mut shaders = std::collections::HashMap::new();
shaders.insert(
"similarity".to_string(),
include_str!("../shaders/similarity.wgsl").to_string(),
);
shaders.insert(
"distance".to_string(),
include_str!("../shaders/distance.wgsl").to_string(),
);
shaders.insert(
"dot_product".to_string(),
include_str!("../shaders/dot_product.wgsl").to_string(),
);
shaders.insert(
"hnsw_construction".to_string(),
include_str!("../shaders/hnsw_construction.wgsl").to_string(),
);
shaders.insert(
"hnsw_navigation".to_string(),
include_str!("../shaders/hnsw_navigation.wgsl").to_string(),
);
shaders.insert(
"batch_construction".to_string(),
include_str!("../shaders/batch_construction.wgsl").to_string(),
);
shaders
}
pub fn validate_all_shaders(&self) -> Result<()> {
for (name, source) in &self.shader_sources {
if source.is_empty() {
return Err(HiveGpuError::ShaderCompilationFailed(format!(
"Empty shader source for: {}",
name
)));
}
if !source.contains("@compute")
&& !source.contains("@vertex")
&& !source.contains("@fragment")
{
return Err(HiveGpuError::ShaderCompilationFailed(format!(
"Invalid WGSL shader: {} (missing entry point)",
name
)));
}
}
Ok(())
}
pub fn get_compilation_info(&self) -> ShaderCompilationInfo {
ShaderCompilationInfo {
total_shaders: self.shader_sources.len(),
shader_names: self.get_shader_names(),
total_size: self.shader_sources.values().map(|s| s.len()).sum(),
}
}
}
impl Default for WgslShaderManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ShaderCompilationInfo {
pub total_shaders: usize,
pub shader_names: Vec<String>,
pub total_size: usize,
}