jugar_probar/brick/
compute.rs

1//! ComputeBrick: WebGPU shader generation from brick definitions (PROBAR-SPEC-009-P8)
2//!
3//! Generates WGSL shaders and wgpu bindings from a single brick definition.
4//! Zero hand-written shaders - all code derived from Rust types.
5//!
6//! # Design Philosophy
7//!
8//! ComputeBrick applies the same zero-artifact principle to GPU compute:
9//! - Define tensor shapes and operations in Rust
10//! - Generate WGSL shader code
11//! - Generate Rust wgpu bindings
12//!
13//! # Inspiration
14//!
15//! NVIDIA CUDA Tile IR provides the model for declarative GPU programming.
16//! ComputeBrick adapts these patterns for WebGPU.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use probar::brick::compute::{ComputeBrick, TensorBinding, TensorType, TileStrategy, TileOp};
22//!
23//! let mel_brick = ComputeBrick::new("mel-filterbank")
24//!     .workgroup_size(256, 1, 1)
25//!     .input("audio", TensorType::F32, &[CHUNK_SIZE])
26//!     .output("mel", TensorType::F32, &[N_MELS, N_FRAMES])
27//!     .tile_strategy(TileStrategy::Simple2D { tile_x: 16, tile_y: 16 })
28//!     .op(TileOp::LoadShared { src: "audio".into(), tile_size: (256, 1) })
29//!     .op(TileOp::Elementwise { op: ElementwiseOp::Log, operands: vec!["audio".into()] })
30//!     .op(TileOp::StoreShared { dst: "mel".into() });
31//!
32//! // Generate WGSL
33//! let wgsl = mel_brick.to_wgsl();
34//! ```
35
36use super::{Brick, BrickAssertion, BrickBudget, BrickVerification};
37use std::time::Duration;
38
39/// Tensor element type for GPU compute
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum TensorType {
42    /// 32-bit float
43    F32,
44    /// 16-bit float (half precision)
45    F16,
46    /// 32-bit signed integer
47    I32,
48    /// 32-bit unsigned integer
49    U32,
50}
51
52impl TensorType {
53    /// Get WGSL type name
54    #[must_use]
55    pub fn to_wgsl(&self) -> &'static str {
56        match self {
57            Self::F32 => "f32",
58            Self::F16 => "f16",
59            Self::I32 => "i32",
60            Self::U32 => "u32",
61        }
62    }
63
64    /// Get Rust type name
65    #[must_use]
66    pub fn to_rust(&self) -> &'static str {
67        match self {
68            Self::F32 => "f32",
69            Self::F16 => "half::f16",
70            Self::I32 => "i32",
71            Self::U32 => "u32",
72        }
73    }
74
75    /// Get byte size
76    #[must_use]
77    pub const fn byte_size(&self) -> usize {
78        match self {
79            Self::F32 | Self::I32 | Self::U32 => 4,
80            Self::F16 => 2,
81        }
82    }
83}
84
85/// A tensor binding for compute shader
86#[derive(Debug, Clone)]
87pub struct TensorBinding {
88    /// Binding name
89    pub name: String,
90    /// Element type
91    pub dtype: TensorType,
92    /// Shape dimensions
93    pub shape: Vec<u32>,
94    /// Binding group
95    pub group: u32,
96    /// Binding index within group
97    pub binding: u32,
98    /// Read-only flag
99    pub read_only: bool,
100}
101
102impl TensorBinding {
103    /// Create a new tensor binding
104    #[must_use]
105    pub fn new(name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
106        Self {
107            name: name.into(),
108            dtype,
109            shape: shape.to_vec(),
110            group: 0,
111            binding: 0,
112            read_only: true,
113        }
114    }
115
116    /// Set binding group and index
117    #[must_use]
118    pub fn at(mut self, group: u32, binding: u32) -> Self {
119        self.group = group;
120        self.binding = binding;
121        self
122    }
123
124    /// Mark as writable
125    #[must_use]
126    pub fn writable(mut self) -> Self {
127        self.read_only = false;
128        self
129    }
130
131    /// Get total element count
132    #[must_use]
133    pub fn element_count(&self) -> u32 {
134        self.shape.iter().product()
135    }
136
137    /// Get total byte size
138    #[must_use]
139    pub fn byte_size(&self) -> usize {
140        self.element_count() as usize * self.dtype.byte_size()
141    }
142
143    /// Generate WGSL binding declaration
144    #[must_use]
145    pub fn to_wgsl_binding(&self) -> String {
146        let access = if self.read_only { "read" } else { "read_write" };
147        format!(
148            "@group({}) @binding({}) var<storage, {}> {}: array<{}>;",
149            self.group,
150            self.binding,
151            access,
152            self.name,
153            self.dtype.to_wgsl()
154        )
155    }
156}
157
158/// Tiling strategy for GPU compute
159#[derive(Debug, Clone)]
160pub enum TileStrategy {
161    /// Simple 2D tiling
162    Simple2D {
163        /// Tile width
164        tile_x: u32,
165        /// Tile height
166        tile_y: u32,
167    },
168    /// Cooperative matrix (tensor core style)
169    Cooperative {
170        /// Matrix M dimension
171        m: u32,
172        /// Matrix N dimension
173        n: u32,
174        /// Matrix K dimension
175        k: u32,
176    },
177    /// Streaming (for convolutions)
178    Streaming {
179        /// Window size
180        window: u32,
181    },
182    /// No tiling (direct compute)
183    None,
184}
185
186impl TileStrategy {
187    /// Get optimal workgroup size for this strategy
188    #[must_use]
189    pub fn optimal_workgroup_size(&self) -> (u32, u32, u32) {
190        match self {
191            Self::Simple2D { tile_x, tile_y } => (*tile_x, *tile_y, 1),
192            Self::Cooperative { m, n, .. } => (*m, *n, 1),
193            Self::Streaming { window } => (*window, 1, 1),
194            Self::None => (64, 1, 1),
195        }
196    }
197}
198
199/// Element-wise operation type
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ElementwiseOp {
202    /// Natural logarithm
203    Log,
204    /// Exponential
205    Exp,
206    /// Square root
207    Sqrt,
208    /// Absolute value
209    Abs,
210    /// Rectified linear unit
211    Relu,
212    /// Sigmoid
213    Sigmoid,
214    /// Hyperbolic tangent
215    Tanh,
216    /// Add constant
217    AddScalar(i32),
218    /// Multiply by constant
219    MulScalar(i32),
220    /// Clamp to range
221    Clamp,
222}
223
224impl ElementwiseOp {
225    /// Generate WGSL expression for this operation
226    #[must_use]
227    pub fn to_wgsl_expr(&self, operand: &str) -> String {
228        match self {
229            Self::Log => format!("log({})", operand),
230            Self::Exp => format!("exp({})", operand),
231            Self::Sqrt => format!("sqrt({})", operand),
232            Self::Abs => format!("abs({})", operand),
233            Self::Relu => format!("max({}, 0.0)", operand),
234            Self::Sigmoid => format!("1.0 / (1.0 + exp(-{}))", operand),
235            Self::Tanh => format!("tanh({})", operand),
236            Self::AddScalar(s) => format!("({} + {}.0)", operand, s),
237            Self::MulScalar(s) => format!("({} * {}.0)", operand, s),
238            Self::Clamp => format!("clamp({}, 0.0, 1.0)", operand),
239        }
240    }
241}
242
243/// Tile operation in compute shader
244#[derive(Debug, Clone)]
245pub enum TileOp {
246    /// Load tile from global to shared memory
247    LoadShared {
248        /// Source tensor name
249        src: String,
250        /// Tile dimensions
251        tile_size: (u32, u32),
252    },
253    /// Matrix multiply accumulate (tensor core pattern)
254    Mma {
255        /// Input A tensor
256        a: String,
257        /// Input B tensor
258        b: String,
259        /// Output C tensor
260        c: String,
261    },
262    /// Element-wise operation
263    Elementwise {
264        /// Operation type
265        op: ElementwiseOp,
266        /// Input operand names
267        operands: Vec<String>,
268        /// Output name (defaults to first operand if None)
269        output: Option<String>,
270    },
271    /// Store tile from shared to global memory
272    StoreShared {
273        /// Destination tensor name
274        dst: String,
275    },
276    /// Synchronization barrier
277    Barrier,
278    /// Reduction operation (sum, max, min)
279    Reduce {
280        /// Reduction type
281        kind: ReduceKind,
282        /// Input tensor
283        input: String,
284        /// Output scalar or reduced tensor
285        output: String,
286    },
287}
288
289/// Reduction operation type
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum ReduceKind {
292    /// Sum all elements
293    Sum,
294    /// Maximum element
295    Max,
296    /// Minimum element
297    Min,
298    /// Mean of elements
299    Mean,
300}
301
302impl ReduceKind {
303    /// Get WGSL identity value
304    #[must_use]
305    pub fn identity(&self) -> &'static str {
306        match self {
307            Self::Sum | Self::Mean => "0.0",
308            Self::Max => "-3.402823e+38", // f32::MIN
309            Self::Min => "3.402823e+38",  // f32::MAX
310        }
311    }
312
313    /// Get WGSL combine operation
314    #[must_use]
315    pub fn combine_op(&self) -> &'static str {
316        match self {
317            Self::Sum | Self::Mean => "+",
318            Self::Max => "max",
319            Self::Min => "min",
320        }
321    }
322}
323
324/// ComputeBrick: Generates WebGPU shaders from brick definition
325#[derive(Debug, Clone)]
326pub struct ComputeBrick {
327    /// Shader name
328    name: String,
329    /// Workgroup size
330    workgroup_size: (u32, u32, u32),
331    /// Input tensor bindings
332    inputs: Vec<TensorBinding>,
333    /// Output tensor bindings
334    outputs: Vec<TensorBinding>,
335    /// Tiling strategy
336    tile_strategy: TileStrategy,
337    /// Operations to perform
338    operations: Vec<TileOp>,
339    /// Shared memory allocations
340    shared_memory: Vec<(String, TensorType, u32)>,
341}
342
343impl ComputeBrick {
344    /// Create a new compute brick
345    #[must_use]
346    pub fn new(name: impl Into<String>) -> Self {
347        Self {
348            name: name.into(),
349            workgroup_size: (64, 1, 1),
350            inputs: Vec::new(),
351            outputs: Vec::new(),
352            tile_strategy: TileStrategy::None,
353            operations: Vec::new(),
354            shared_memory: Vec::new(),
355        }
356    }
357
358    /// Set workgroup size
359    #[must_use]
360    pub fn workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
361        self.workgroup_size = (x, y, z);
362        self
363    }
364
365    /// Add an input tensor
366    #[must_use]
367    pub fn input(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
368        let binding_idx = self.inputs.len() as u32;
369        self.inputs
370            .push(TensorBinding::new(name, dtype, shape).at(0, binding_idx));
371        self
372    }
373
374    /// Add an output tensor
375    #[must_use]
376    pub fn output(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
377        let binding_idx = self.outputs.len() as u32;
378        self.outputs.push(
379            TensorBinding::new(name, dtype, shape)
380                .at(1, binding_idx)
381                .writable(),
382        );
383        self
384    }
385
386    /// Set the tiling strategy
387    #[must_use]
388    pub fn tile_strategy(mut self, strategy: TileStrategy) -> Self {
389        self.tile_strategy = strategy;
390        self
391    }
392
393    /// Add an operation
394    #[must_use]
395    pub fn op(mut self, operation: TileOp) -> Self {
396        self.operations.push(operation);
397        self
398    }
399
400    /// Allocate shared memory
401    #[must_use]
402    pub fn shared(mut self, name: impl Into<String>, dtype: TensorType, size: u32) -> Self {
403        self.shared_memory.push((name.into(), dtype, size));
404        self
405    }
406
407    /// Generate WGSL shader code
408    #[must_use]
409    pub fn to_wgsl(&self) -> String {
410        let mut wgsl = String::new();
411
412        // Header comment
413        wgsl.push_str(&format!(
414            "// {} Compute Shader\n",
415            to_pascal_case(&self.name)
416        ));
417        wgsl.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
418
419        // Input bindings
420        for input in &self.inputs {
421            wgsl.push_str(&input.to_wgsl_binding());
422            wgsl.push('\n');
423        }
424
425        // Output bindings
426        for output in &self.outputs {
427            wgsl.push_str(&output.to_wgsl_binding());
428            wgsl.push('\n');
429        }
430
431        wgsl.push('\n');
432
433        // Shared memory declarations
434        for (name, dtype, size) in &self.shared_memory {
435            wgsl.push_str(&format!(
436                "var<workgroup> {}: array<{}, {}>;\n",
437                name,
438                dtype.to_wgsl(),
439                size
440            ));
441        }
442
443        if !self.shared_memory.is_empty() {
444            wgsl.push('\n');
445        }
446
447        // Main compute function
448        let (wg_x, wg_y, wg_z) = self.workgroup_size;
449        wgsl.push_str(&format!(
450            "@compute @workgroup_size({}, {}, {})\n",
451            wg_x, wg_y, wg_z
452        ));
453        wgsl.push_str("fn main(\n");
454        wgsl.push_str("    @builtin(global_invocation_id) global_id: vec3<u32>,\n");
455        wgsl.push_str("    @builtin(local_invocation_id) local_id: vec3<u32>,\n");
456        wgsl.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
457        wgsl.push_str(") {\n");
458
459        // Index calculations
460        wgsl.push_str("    let gid = global_id.x + global_id.y * ");
461        wgsl.push_str(&format!("{}u;\n", wg_x));
462        wgsl.push_str("    let lid = local_id.x + local_id.y * ");
463        wgsl.push_str(&format!("{}u;\n\n", wg_x));
464
465        // Generate operations
466        for op in &self.operations {
467            match op {
468                TileOp::LoadShared { src, tile_size: _ } => {
469                    wgsl.push_str(&format!("    // Load from {} to shared memory\n", src));
470                    wgsl.push_str(&format!("    let val_{} = {}[gid];\n", src, src));
471                }
472                TileOp::Elementwise {
473                    op: elem_op,
474                    operands,
475                    output,
476                } => {
477                    let input = &operands[0];
478                    let out_name = output.as_ref().unwrap_or(input);
479                    let input_val = format!("val_{}", input);
480                    let expr = elem_op.to_wgsl_expr(&input_val);
481                    wgsl.push_str(&format!("    let val_{} = {};\n", out_name, expr));
482                }
483                TileOp::StoreShared { dst } => {
484                    wgsl.push_str(&format!("    // Store to {}\n", dst));
485                    // Find what value to store
486                    let val_name = if self.operations.iter().any(
487                        |o| matches!(o, TileOp::Elementwise { output: Some(n), .. } if n == dst),
488                    ) {
489                        format!("val_{}", dst)
490                    } else if let Some(input) = self.inputs.first() {
491                        format!("val_{}", input.name)
492                    } else {
493                        "0.0".to_string()
494                    };
495                    wgsl.push_str(&format!("    {}[gid] = {};\n", dst, val_name));
496                }
497                TileOp::Barrier => {
498                    wgsl.push_str("    workgroupBarrier();\n");
499                }
500                TileOp::Mma { a, b, c } => {
501                    wgsl.push_str(&format!("    // Matrix multiply: {} = {} @ {}\n", c, a, b));
502                    wgsl.push_str("    // TODO: Implement cooperative matrix\n");
503                }
504                TileOp::Reduce {
505                    kind,
506                    input,
507                    output,
508                } => {
509                    wgsl.push_str(&format!(
510                        "    // Reduce {} -> {} ({:?})\n",
511                        input, output, kind
512                    ));
513                }
514            }
515        }
516
517        wgsl.push_str("}\n");
518
519        wgsl
520    }
521
522    /// Generate Rust wgpu bindings
523    #[must_use]
524    pub fn to_rust_bindings(&self) -> String {
525        let mut rust = String::new();
526
527        // Header
528        rust.push_str(&format!(
529            "//! {} Compute Bindings\n",
530            to_pascal_case(&self.name)
531        ));
532        rust.push_str("//! Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
533        rust.push_str(
534            "use wgpu::{BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry};\n",
535        );
536        rust.push_str("use wgpu::{ShaderStages, BufferBindingType, BindingType};\n\n");
537
538        let struct_name = to_pascal_case(&self.name);
539
540        // Struct definition
541        rust.push_str(&format!("pub struct {}Compute {{\n", struct_name));
542        rust.push_str("    pub pipeline: wgpu::ComputePipeline,\n");
543        rust.push_str("    pub bind_group_layout: wgpu::BindGroupLayout,\n");
544        rust.push_str("}\n\n");
545
546        // Implementation
547        rust.push_str(&format!("impl {}Compute {{\n", struct_name));
548        rust.push_str("    pub const WORKGROUP_SIZE: (u32, u32, u32) = ");
549        rust.push_str(&format!("{:?};\n\n", self.workgroup_size));
550
551        // WGSL source as const
552        rust.push_str("    pub const SHADER_SOURCE: &'static str = r#\"\n");
553        rust.push_str(&self.to_wgsl());
554        rust.push_str("\"#;\n\n");
555
556        // Create bind group layout
557        rust.push_str(
558            "    pub fn create_bind_group_layout(device: &wgpu::Device) -> BindGroupLayout {\n",
559        );
560        rust.push_str("        device.create_bind_group_layout(&BindGroupLayoutDescriptor {\n");
561        rust.push_str(&format!(
562            "            label: Some(\"{} bind group layout\"),\n",
563            self.name
564        ));
565        rust.push_str("            entries: &[\n");
566
567        for input in &self.inputs {
568            rust.push_str(&format!("                // Input: {}\n", input.name));
569            rust.push_str(&format!(
570                "                BindGroupLayoutEntry {{\n                    binding: {},\n                    visibility: ShaderStages::COMPUTE,\n                    ty: BindingType::Buffer {{\n                        ty: BufferBindingType::Storage {{ read_only: true }},\n                        has_dynamic_offset: false,\n                        min_binding_size: None,\n                    }},\n                    count: None,\n                }},\n",
571                input.binding
572            ));
573        }
574
575        for output in &self.outputs {
576            rust.push_str(&format!("                // Output: {}\n", output.name));
577            rust.push_str(&format!(
578                "                BindGroupLayoutEntry {{\n                    binding: {},\n                    visibility: ShaderStages::COMPUTE,\n                    ty: BindingType::Buffer {{\n                        ty: BufferBindingType::Storage {{ read_only: false }},\n                        has_dynamic_offset: false,\n                        min_binding_size: None,\n                    }},\n                    count: None,\n                }},\n",
579                output.binding
580            ));
581        }
582
583        rust.push_str("            ],\n");
584        rust.push_str("        })\n");
585        rust.push_str("    }\n");
586        rust.push_str("}\n");
587
588        rust
589    }
590
591    /// Generate JavaScript dispatch code for WebGPU
592    #[must_use]
593    pub fn to_dispatch_js(&self) -> String {
594        let mut js = String::new();
595
596        js.push_str(&format!(
597            "// {} Compute Dispatch\n",
598            to_pascal_case(&self.name)
599        ));
600        js.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
601
602        let (wg_x, wg_y, wg_z) = self.workgroup_size;
603        js.push_str(&format!(
604            "const WORKGROUP_SIZE = [{}, {}, {}];\n\n",
605            wg_x, wg_y, wg_z
606        ));
607
608        js.push_str(&format!(
609            "async function dispatch{}(device, inputs, outputs) {{\n",
610            to_pascal_case(&self.name)
611        ));
612
613        js.push_str("    // Create shader module\n");
614        js.push_str("    const shaderModule = device.createShaderModule({\n");
615        js.push_str(&format!("        label: '{} shader',\n", self.name));
616        js.push_str("        code: SHADER_SOURCE,\n");
617        js.push_str("    });\n\n");
618
619        js.push_str("    // Calculate dispatch size\n");
620        if let Some(output) = self.outputs.first() {
621            let total_size = output.element_count();
622            js.push_str(&format!("    const totalElements = {};\n", total_size));
623            js.push_str(&format!(
624                "    const numWorkgroups = Math.ceil(totalElements / {});\n\n",
625                wg_x * wg_y * wg_z
626            ));
627        }
628
629        js.push_str("    // Dispatch\n");
630        js.push_str("    const commandEncoder = device.createCommandEncoder();\n");
631        js.push_str("    const passEncoder = commandEncoder.beginComputePass();\n");
632        js.push_str("    passEncoder.setPipeline(pipeline);\n");
633        js.push_str("    passEncoder.setBindGroup(0, bindGroup);\n");
634        js.push_str("    passEncoder.dispatchWorkgroups(numWorkgroups, 1, 1);\n");
635        js.push_str("    passEncoder.end();\n");
636        js.push_str("    device.queue.submit([commandEncoder.finish()]);\n");
637        js.push_str("}\n");
638
639        js
640    }
641
642    /// Get the brick name
643    #[must_use]
644    pub fn name(&self) -> &str {
645        &self.name
646    }
647
648    /// Get workgroup size
649    #[must_use]
650    pub fn get_workgroup_size(&self) -> (u32, u32, u32) {
651        self.workgroup_size
652    }
653
654    /// Get input bindings
655    #[must_use]
656    pub fn inputs(&self) -> &[TensorBinding] {
657        &self.inputs
658    }
659
660    /// Get output bindings
661    #[must_use]
662    pub fn outputs(&self) -> &[TensorBinding] {
663        &self.outputs
664    }
665}
666
667impl Brick for ComputeBrick {
668    fn brick_name(&self) -> &'static str {
669        "ComputeBrick"
670    }
671
672    fn assertions(&self) -> &[BrickAssertion] {
673        &[]
674    }
675
676    fn budget(&self) -> BrickBudget {
677        // Compute shaders have longer budgets
678        BrickBudget::uniform(100)
679    }
680
681    fn verify(&self) -> BrickVerification {
682        let mut passed = Vec::new();
683        let mut failed = Vec::new();
684
685        // Verify workgroup size is valid
686        let (x, y, z) = self.workgroup_size;
687        if x * y * z > 1024 {
688            failed.push((
689                BrickAssertion::Custom {
690                    name: "workgroup_size_valid".into(),
691                    validator_id: 1,
692                },
693                format!(
694                    "Workgroup size {}x{}x{}={} exceeds maximum 1024",
695                    x,
696                    y,
697                    z,
698                    x * y * z
699                ),
700            ));
701        } else {
702            passed.push(BrickAssertion::Custom {
703                name: "workgroup_size_valid".into(),
704                validator_id: 1,
705            });
706        }
707
708        // Verify inputs and outputs are defined
709        if self.inputs.is_empty() {
710            failed.push((
711                BrickAssertion::Custom {
712                    name: "has_inputs".into(),
713                    validator_id: 2,
714                },
715                "ComputeBrick has no input tensors".into(),
716            ));
717        } else {
718            passed.push(BrickAssertion::Custom {
719                name: "has_inputs".into(),
720                validator_id: 2,
721            });
722        }
723
724        if self.outputs.is_empty() {
725            failed.push((
726                BrickAssertion::Custom {
727                    name: "has_outputs".into(),
728                    validator_id: 3,
729                },
730                "ComputeBrick has no output tensors".into(),
731            ));
732        } else {
733            passed.push(BrickAssertion::Custom {
734                name: "has_outputs".into(),
735                validator_id: 3,
736            });
737        }
738
739        // Verify operations reference valid tensors
740        let tensor_names: Vec<_> = self
741            .inputs
742            .iter()
743            .chain(self.outputs.iter())
744            .map(|t| t.name.as_str())
745            .collect();
746
747        for op in &self.operations {
748            match op {
749                TileOp::LoadShared { src, .. } => {
750                    if !tensor_names.contains(&src.as_str()) {
751                        failed.push((
752                            BrickAssertion::Custom {
753                                name: "tensor_exists".into(),
754                                validator_id: 4,
755                            },
756                            format!("LoadShared references unknown tensor: {}", src),
757                        ));
758                    }
759                }
760                TileOp::StoreShared { dst } => {
761                    if !tensor_names.contains(&dst.as_str()) {
762                        failed.push((
763                            BrickAssertion::Custom {
764                                name: "tensor_exists".into(),
765                                validator_id: 4,
766                            },
767                            format!("StoreShared references unknown tensor: {}", dst),
768                        ));
769                    }
770                }
771                _ => {}
772            }
773        }
774
775        if failed.is_empty() {
776            passed.push(BrickAssertion::Custom {
777                name: "compute_brick_valid".into(),
778                validator_id: 5,
779            });
780        }
781
782        BrickVerification {
783            passed,
784            failed,
785            verification_time: Duration::from_micros(100),
786        }
787    }
788
789    fn to_html(&self) -> String {
790        // ComputeBrick doesn't generate HTML
791        String::new()
792    }
793
794    fn to_css(&self) -> String {
795        // ComputeBrick doesn't generate CSS
796        String::new()
797    }
798}
799
800/// Convert string to PascalCase
801fn to_pascal_case(s: &str) -> String {
802    let mut result = String::new();
803    let mut capitalize_next = true;
804
805    for c in s.chars() {
806        if c == '_' || c == '-' || c == ' ' {
807            capitalize_next = true;
808        } else if capitalize_next {
809            result.push(c.to_ascii_uppercase());
810            capitalize_next = false;
811        } else {
812            result.push(c);
813        }
814    }
815
816    result
817}
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822
823    #[test]
824    fn test_compute_brick_basic() {
825        let brick = ComputeBrick::new("test")
826            .workgroup_size(256, 1, 1)
827            .input("audio", TensorType::F32, &[1024])
828            .output("mel", TensorType::F32, &[80, 100]);
829
830        assert_eq!(brick.name(), "test");
831        assert_eq!(brick.get_workgroup_size(), (256, 1, 1));
832        assert_eq!(brick.inputs().len(), 1);
833        assert_eq!(brick.outputs().len(), 1);
834    }
835
836    #[test]
837    fn test_compute_brick_wgsl_generation() {
838        let brick = ComputeBrick::new("log-transform")
839            .workgroup_size(64, 1, 1)
840            .input("input", TensorType::F32, &[1024])
841            .output("output", TensorType::F32, &[1024])
842            .op(TileOp::LoadShared {
843                src: "input".into(),
844                tile_size: (64, 1),
845            })
846            .op(TileOp::Elementwise {
847                op: ElementwiseOp::Log,
848                operands: vec!["input".into()],
849                output: Some("output".into()),
850            })
851            .op(TileOp::StoreShared {
852                dst: "output".into(),
853            });
854
855        let wgsl = brick.to_wgsl();
856
857        assert!(wgsl.contains("@compute @workgroup_size(64, 1, 1)"));
858        assert!(wgsl.contains("fn main("));
859        assert!(wgsl.contains("log("));
860        assert!(wgsl.contains("Generated by probar"));
861    }
862
863    #[test]
864    fn test_compute_brick_verification() {
865        let brick = ComputeBrick::new("test")
866            .workgroup_size(256, 1, 1)
867            .input("input", TensorType::F32, &[1024])
868            .output("output", TensorType::F32, &[1024]);
869
870        let result = brick.verify();
871        assert!(result.is_valid());
872    }
873
874    #[test]
875    fn test_compute_brick_verification_fails_no_inputs() {
876        let brick = ComputeBrick::new("test").workgroup_size(256, 1, 1).output(
877            "output",
878            TensorType::F32,
879            &[1024],
880        );
881
882        let result = brick.verify();
883        assert!(!result.is_valid());
884    }
885
886    #[test]
887    fn test_compute_brick_verification_fails_large_workgroup() {
888        let brick = ComputeBrick::new("test")
889            .workgroup_size(1024, 2, 1) // 2048 > 1024 max
890            .input("input", TensorType::F32, &[1024])
891            .output("output", TensorType::F32, &[1024]);
892
893        let result = brick.verify();
894        assert!(!result.is_valid());
895    }
896
897    #[test]
898    fn test_tensor_binding() {
899        let binding = TensorBinding::new("audio", TensorType::F32, &[1024, 80])
900            .at(0, 1)
901            .writable();
902
903        assert_eq!(binding.name, "audio");
904        assert_eq!(binding.element_count(), 1024 * 80);
905        assert_eq!(binding.byte_size(), 1024 * 80 * 4);
906        assert!(!binding.read_only);
907    }
908
909    #[test]
910    fn test_tensor_type_wgsl() {
911        assert_eq!(TensorType::F32.to_wgsl(), "f32");
912        assert_eq!(TensorType::F16.to_wgsl(), "f16");
913        assert_eq!(TensorType::I32.to_wgsl(), "i32");
914        assert_eq!(TensorType::U32.to_wgsl(), "u32");
915    }
916
917    #[test]
918    fn test_elementwise_ops() {
919        assert_eq!(ElementwiseOp::Log.to_wgsl_expr("x"), "log(x)");
920        assert_eq!(ElementwiseOp::Exp.to_wgsl_expr("x"), "exp(x)");
921        assert_eq!(ElementwiseOp::Relu.to_wgsl_expr("x"), "max(x, 0.0)");
922        assert_eq!(ElementwiseOp::AddScalar(5).to_wgsl_expr("x"), "(x + 5.0)");
923    }
924
925    #[test]
926    fn test_rust_bindings_generation() {
927        let brick = ComputeBrick::new("mel-transform")
928            .workgroup_size(256, 1, 1)
929            .input("audio", TensorType::F32, &[1024])
930            .output("mel", TensorType::F32, &[80]);
931
932        let rust = brick.to_rust_bindings();
933
934        assert!(rust.contains("pub struct MelTransformCompute"));
935        assert!(rust.contains("WORKGROUP_SIZE"));
936        assert!(rust.contains("SHADER_SOURCE"));
937        assert!(rust.contains("create_bind_group_layout"));
938    }
939
940    #[test]
941    fn test_js_dispatch_generation() {
942        let brick = ComputeBrick::new("fft")
943            .workgroup_size(64, 1, 1)
944            .input("signal", TensorType::F32, &[512])
945            .output("spectrum", TensorType::F32, &[512]);
946
947        let js = brick.to_dispatch_js();
948
949        assert!(js.contains("async function dispatchFft"));
950        assert!(js.contains("WORKGROUP_SIZE"));
951        assert!(js.contains("dispatchWorkgroups"));
952    }
953
954    #[test]
955    fn test_tile_strategy_workgroup_size() {
956        let simple = TileStrategy::Simple2D {
957            tile_x: 16,
958            tile_y: 16,
959        };
960        assert_eq!(simple.optimal_workgroup_size(), (16, 16, 1));
961
962        let coop = TileStrategy::Cooperative { m: 8, n: 8, k: 4 };
963        assert_eq!(coop.optimal_workgroup_size(), (8, 8, 1));
964
965        let streaming = TileStrategy::Streaming { window: 32 };
966        assert_eq!(streaming.optimal_workgroup_size(), (32, 1, 1));
967    }
968}