hive_gpu/shaders/
wgsl_shaders.rs1use crate::error::{Result, HiveGpuError};
6
7pub struct WgslShaderManager {
9 shader_sources: std::collections::HashMap<String, String>,
11}
12
13impl WgslShaderManager {
14 pub fn new() -> Self {
16 Self {
17 shader_sources: Self::load_wgsl_shaders(),
18 }
19 }
20
21 pub fn get_shader(&self, name: &str) -> Option<&str> {
23 self.shader_sources.get(name).map(|s| s.as_str())
24 }
25
26 pub fn get_shader_names(&self) -> Vec<String> {
28 self.shader_sources.keys().cloned().collect()
29 }
30
31 fn load_wgsl_shaders() -> std::collections::HashMap<String, String> {
33 let mut shaders = std::collections::HashMap::new();
34
35 shaders.insert("similarity".to_string(), include_str!("../shaders/similarity.wgsl").to_string());
38 shaders.insert("distance".to_string(), include_str!("../shaders/distance.wgsl").to_string());
39 shaders.insert("dot_product".to_string(), include_str!("../shaders/dot_product.wgsl").to_string());
40 shaders.insert("hnsw_construction".to_string(), include_str!("../shaders/hnsw_construction.wgsl").to_string());
41 shaders.insert("hnsw_navigation".to_string(), include_str!("../shaders/hnsw_navigation.wgsl").to_string());
42 shaders.insert("batch_construction".to_string(), include_str!("../shaders/batch_construction.wgsl").to_string());
43
44 shaders
45 }
46
47 pub fn validate_all_shaders(&self) -> Result<()> {
49 for (name, source) in &self.shader_sources {
50 if source.is_empty() {
51 return Err(HiveGpuError::ShaderCompilationFailed(
52 format!("Empty shader source for: {}", name)
53 ));
54 }
55
56 if !source.contains("@compute") && !source.contains("@vertex") && !source.contains("@fragment") {
58 return Err(HiveGpuError::ShaderCompilationFailed(
59 format!("Invalid WGSL shader: {} (missing entry point)", name)
60 ));
61 }
62 }
63
64 Ok(())
65 }
66
67 pub fn get_compilation_info(&self) -> ShaderCompilationInfo {
69 ShaderCompilationInfo {
70 total_shaders: self.shader_sources.len(),
71 shader_names: self.get_shader_names(),
72 total_size: self.shader_sources.values().map(|s| s.len()).sum(),
73 }
74 }
75}
76
77impl Default for WgslShaderManager {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ShaderCompilationInfo {
86 pub total_shaders: usize,
88 pub shader_names: Vec<String>,
90 pub total_size: usize,
92}