hive_gpu/shaders/
wgsl_shaders.rs

1//! WGSL Shader Management
2//!
3//! This module provides loading and management of WGSL shaders for cross-platform GPU operations.
4
5use crate::error::{Result, HiveGpuError};
6
7/// WGSL shader manager for loading and compiling WGSL shaders
8pub struct WgslShaderManager {
9    /// Loaded shader sources
10    shader_sources: std::collections::HashMap<String, String>,
11}
12
13impl WgslShaderManager {
14    /// Create a new WGSL shader manager
15    pub fn new() -> Self {
16        Self {
17            shader_sources: Self::load_wgsl_shaders(),
18        }
19    }
20
21    /// Get a specific shader source
22    pub fn get_shader(&self, name: &str) -> Option<&str> {
23        self.shader_sources.get(name).map(|s| s.as_str())
24    }
25
26    /// Get all available shader names
27    pub fn get_shader_names(&self) -> Vec<String> {
28        self.shader_sources.keys().cloned().collect()
29    }
30
31    /// Load WGSL shaders from embedded sources
32    fn load_wgsl_shaders() -> std::collections::HashMap<String, String> {
33        let mut shaders = std::collections::HashMap::new();
34        
35        // Load individual WGSL shaders
36        // These will be populated when we migrate the shaders from vectorizer
37        shaders.insert("similarity".to_string(), include_str!("../shaders/similarity.wgsl").to_string());
38        shaders.insert("distance".to_string(), include_str!("../shaders/distance.wgsl").to_string());
39        shaders.insert("dot_product".to_string(), include_str!("../shaders/dot_product.wgsl").to_string());
40        shaders.insert("hnsw_construction".to_string(), include_str!("../shaders/hnsw_construction.wgsl").to_string());
41        shaders.insert("hnsw_navigation".to_string(), include_str!("../shaders/hnsw_navigation.wgsl").to_string());
42        shaders.insert("batch_construction".to_string(), include_str!("../shaders/batch_construction.wgsl").to_string());
43        
44        shaders
45    }
46
47    /// Validate all shaders
48    pub fn validate_all_shaders(&self) -> Result<()> {
49        for (name, source) in &self.shader_sources {
50            if source.is_empty() {
51                return Err(HiveGpuError::ShaderCompilationFailed(
52                    format!("Empty shader source for: {}", name)
53                ));
54            }
55
56            // Basic WGSL validation
57            if !source.contains("@compute") && !source.contains("@vertex") && !source.contains("@fragment") {
58                return Err(HiveGpuError::ShaderCompilationFailed(
59                    format!("Invalid WGSL shader: {} (missing entry point)", name)
60                ));
61            }
62        }
63
64        Ok(())
65    }
66
67    /// Get shader compilation info
68    pub fn get_compilation_info(&self) -> ShaderCompilationInfo {
69        ShaderCompilationInfo {
70            total_shaders: self.shader_sources.len(),
71            shader_names: self.get_shader_names(),
72            total_size: self.shader_sources.values().map(|s| s.len()).sum(),
73        }
74    }
75}
76
77impl Default for WgslShaderManager {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83/// Shader compilation information
84#[derive(Debug, Clone)]
85pub struct ShaderCompilationInfo {
86    /// Total number of shaders
87    pub total_shaders: usize,
88    /// Names of all shaders
89    pub shader_names: Vec<String>,
90    /// Total size of all shaders in bytes
91    pub total_size: usize,
92}