jugar_probar/brick/
compute.rs

1//! ComputeBrick: WebGPU shader generation from brick definitions (PROBAR-SPEC-009-P8)
2//!
3//! Generates WGSL shaders and wgpu bindings from a single brick definition.
4//! Zero hand-written shaders - all code derived from Rust types.
5//!
6//! # Design Philosophy
7//!
8//! ComputeBrick applies the same zero-artifact principle to GPU compute:
9//! - Define tensor shapes and operations in Rust
10//! - Generate WGSL shader code
11//! - Generate Rust wgpu bindings
12//!
13//! # Inspiration
14//!
15//! NVIDIA CUDA Tile IR provides the model for declarative GPU programming.
16//! ComputeBrick adapts these patterns for WebGPU.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use probar::brick::compute::{ComputeBrick, TensorBinding, TensorType, TileStrategy, TileOp};
22//!
23//! let mel_brick = ComputeBrick::new("mel-filterbank")
24//!     .workgroup_size(256, 1, 1)
25//!     .input("audio", TensorType::F32, &[CHUNK_SIZE])
26//!     .output("mel", TensorType::F32, &[N_MELS, N_FRAMES])
27//!     .tile_strategy(TileStrategy::Simple2D { tile_x: 16, tile_y: 16 })
28//!     .op(TileOp::LoadShared { src: "audio".into(), tile_size: (256, 1) })
29//!     .op(TileOp::Elementwise { op: ElementwiseOp::Log, operands: vec!["audio".into()] })
30//!     .op(TileOp::StoreShared { dst: "mel".into() });
31//!
32//! // Generate WGSL
33//! let wgsl = mel_brick.to_wgsl();
34//! ```
35
36use super::{Brick, BrickAssertion, BrickBudget, BrickVerification};
37use std::time::Duration;
38
39/// Tensor element type for GPU compute
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum TensorType {
42    /// 32-bit float
43    F32,
44    /// 16-bit float (half precision)
45    F16,
46    /// 32-bit signed integer
47    I32,
48    /// 32-bit unsigned integer
49    U32,
50}
51
52impl TensorType {
53    /// Get WGSL type name
54    #[must_use]
55    pub fn to_wgsl(&self) -> &'static str {
56        match self {
57            Self::F32 => "f32",
58            Self::F16 => "f16",
59            Self::I32 => "i32",
60            Self::U32 => "u32",
61        }
62    }
63
64    /// Get Rust type name
65    #[must_use]
66    pub fn to_rust(&self) -> &'static str {
67        match self {
68            Self::F32 => "f32",
69            Self::F16 => "half::f16",
70            Self::I32 => "i32",
71            Self::U32 => "u32",
72        }
73    }
74
75    /// Get byte size
76    #[must_use]
77    pub const fn byte_size(&self) -> usize {
78        match self {
79            Self::F32 | Self::I32 | Self::U32 => 4,
80            Self::F16 => 2,
81        }
82    }
83}
84
85/// A tensor binding for compute shader
86#[derive(Debug, Clone)]
87pub struct TensorBinding {
88    /// Binding name
89    pub name: String,
90    /// Element type
91    pub dtype: TensorType,
92    /// Shape dimensions
93    pub shape: Vec<u32>,
94    /// Binding group
95    pub group: u32,
96    /// Binding index within group
97    pub binding: u32,
98    /// Read-only flag
99    pub read_only: bool,
100}
101
102impl TensorBinding {
103    /// Create a new tensor binding
104    #[must_use]
105    pub fn new(name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
106        Self {
107            name: name.into(),
108            dtype,
109            shape: shape.to_vec(),
110            group: 0,
111            binding: 0,
112            read_only: true,
113        }
114    }
115
116    /// Set binding group and index
117    #[must_use]
118    pub fn at(mut self, group: u32, binding: u32) -> Self {
119        self.group = group;
120        self.binding = binding;
121        self
122    }
123
124    /// Mark as writable
125    #[must_use]
126    pub fn writable(mut self) -> Self {
127        self.read_only = false;
128        self
129    }
130
131    /// Get total element count
132    #[must_use]
133    pub fn element_count(&self) -> u32 {
134        self.shape.iter().product()
135    }
136
137    /// Get total byte size
138    #[must_use]
139    pub fn byte_size(&self) -> usize {
140        self.element_count() as usize * self.dtype.byte_size()
141    }
142
143    /// Generate WGSL binding declaration
144    #[must_use]
145    pub fn to_wgsl_binding(&self) -> String {
146        let access = if self.read_only { "read" } else { "read_write" };
147        format!(
148            "@group({}) @binding({}) var<storage, {}> {}: array<{}>;",
149            self.group,
150            self.binding,
151            access,
152            self.name,
153            self.dtype.to_wgsl()
154        )
155    }
156}
157
158/// Tiling strategy for GPU compute
159#[derive(Debug, Clone)]
160pub enum TileStrategy {
161    /// Simple 2D tiling
162    Simple2D {
163        /// Tile width
164        tile_x: u32,
165        /// Tile height
166        tile_y: u32,
167    },
168    /// Cooperative matrix (tensor core style)
169    Cooperative {
170        /// Matrix M dimension
171        m: u32,
172        /// Matrix N dimension
173        n: u32,
174        /// Matrix K dimension
175        k: u32,
176    },
177    /// Streaming (for convolutions)
178    Streaming {
179        /// Window size
180        window: u32,
181    },
182    /// No tiling (direct compute)
183    None,
184}
185
186impl TileStrategy {
187    /// Get optimal workgroup size for this strategy
188    #[must_use]
189    pub fn optimal_workgroup_size(&self) -> (u32, u32, u32) {
190        match self {
191            Self::Simple2D { tile_x, tile_y } => (*tile_x, *tile_y, 1),
192            Self::Cooperative { m, n, .. } => (*m, *n, 1),
193            Self::Streaming { window } => (*window, 1, 1),
194            Self::None => (64, 1, 1),
195        }
196    }
197}
198
199/// Element-wise operation type
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ElementwiseOp {
202    /// Natural logarithm
203    Log,
204    /// Exponential
205    Exp,
206    /// Square root
207    Sqrt,
208    /// Absolute value
209    Abs,
210    /// Rectified linear unit
211    Relu,
212    /// Sigmoid
213    Sigmoid,
214    /// Hyperbolic tangent
215    Tanh,
216    /// Add constant
217    AddScalar(i32),
218    /// Multiply by constant
219    MulScalar(i32),
220    /// Clamp to range
221    Clamp,
222}
223
224impl ElementwiseOp {
225    /// Generate WGSL expression for this operation
226    #[must_use]
227    pub fn to_wgsl_expr(&self, operand: &str) -> String {
228        match self {
229            Self::Log => format!("log({})", operand),
230            Self::Exp => format!("exp({})", operand),
231            Self::Sqrt => format!("sqrt({})", operand),
232            Self::Abs => format!("abs({})", operand),
233            Self::Relu => format!("max({}, 0.0)", operand),
234            Self::Sigmoid => format!("1.0 / (1.0 + exp(-{}))", operand),
235            Self::Tanh => format!("tanh({})", operand),
236            Self::AddScalar(s) => format!("({} + {}.0)", operand, s),
237            Self::MulScalar(s) => format!("({} * {}.0)", operand, s),
238            Self::Clamp => format!("clamp({}, 0.0, 1.0)", operand),
239        }
240    }
241}
242
243/// Tile operation in compute shader
244#[derive(Debug, Clone)]
245pub enum TileOp {
246    /// Load tile from global to shared memory
247    LoadShared {
248        /// Source tensor name
249        src: String,
250        /// Tile dimensions
251        tile_size: (u32, u32),
252    },
253    /// Matrix multiply accumulate (tensor core pattern)
254    Mma {
255        /// Input A tensor
256        a: String,
257        /// Input B tensor
258        b: String,
259        /// Output C tensor
260        c: String,
261    },
262    /// Element-wise operation
263    Elementwise {
264        /// Operation type
265        op: ElementwiseOp,
266        /// Input operand names
267        operands: Vec<String>,
268        /// Output name (defaults to first operand if None)
269        output: Option<String>,
270    },
271    /// Store tile from shared to global memory
272    StoreShared {
273        /// Destination tensor name
274        dst: String,
275    },
276    /// Synchronization barrier
277    Barrier,
278    /// Reduction operation (sum, max, min)
279    Reduce {
280        /// Reduction type
281        kind: ReduceKind,
282        /// Input tensor
283        input: String,
284        /// Output scalar or reduced tensor
285        output: String,
286    },
287}
288
289/// Reduction operation type
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum ReduceKind {
292    /// Sum all elements
293    Sum,
294    /// Maximum element
295    Max,
296    /// Minimum element
297    Min,
298    /// Mean of elements
299    Mean,
300}
301
302impl ReduceKind {
303    /// Get WGSL identity value
304    #[must_use]
305    pub fn identity(&self) -> &'static str {
306        match self {
307            Self::Sum | Self::Mean => "0.0",
308            Self::Max => "-3.402823e+38", // f32::MIN
309            Self::Min => "3.402823e+38",  // f32::MAX
310        }
311    }
312
313    /// Get WGSL combine operation
314    #[must_use]
315    pub fn combine_op(&self) -> &'static str {
316        match self {
317            Self::Sum | Self::Mean => "+",
318            Self::Max => "max",
319            Self::Min => "min",
320        }
321    }
322}
323
324/// ComputeBrick: Generates WebGPU shaders from brick definition
325#[derive(Debug, Clone)]
326pub struct ComputeBrick {
327    /// Shader name
328    name: String,
329    /// Workgroup size
330    workgroup_size: (u32, u32, u32),
331    /// Input tensor bindings
332    inputs: Vec<TensorBinding>,
333    /// Output tensor bindings
334    outputs: Vec<TensorBinding>,
335    /// Tiling strategy
336    tile_strategy: TileStrategy,
337    /// Operations to perform
338    operations: Vec<TileOp>,
339    /// Shared memory allocations
340    shared_memory: Vec<(String, TensorType, u32)>,
341}
342
343impl ComputeBrick {
344    /// Create a new compute brick
345    #[must_use]
346    pub fn new(name: impl Into<String>) -> Self {
347        Self {
348            name: name.into(),
349            workgroup_size: (64, 1, 1),
350            inputs: Vec::new(),
351            outputs: Vec::new(),
352            tile_strategy: TileStrategy::None,
353            operations: Vec::new(),
354            shared_memory: Vec::new(),
355        }
356    }
357
358    /// Set workgroup size
359    #[must_use]
360    pub fn workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
361        self.workgroup_size = (x, y, z);
362        self
363    }
364
365    /// Add an input tensor
366    #[must_use]
367    pub fn input(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
368        let binding_idx = self.inputs.len() as u32;
369        self.inputs
370            .push(TensorBinding::new(name, dtype, shape).at(0, binding_idx));
371        self
372    }
373
374    /// Add an output tensor
375    #[must_use]
376    pub fn output(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
377        let binding_idx = self.outputs.len() as u32;
378        self.outputs.push(
379            TensorBinding::new(name, dtype, shape)
380                .at(1, binding_idx)
381                .writable(),
382        );
383        self
384    }
385
386    /// Set the tiling strategy
387    #[must_use]
388    pub fn tile_strategy(mut self, strategy: TileStrategy) -> Self {
389        self.tile_strategy = strategy;
390        self
391    }
392
393    /// Add an operation
394    #[must_use]
395    pub fn op(mut self, operation: TileOp) -> Self {
396        self.operations.push(operation);
397        self
398    }
399
400    /// Allocate shared memory
401    #[must_use]
402    pub fn shared(mut self, name: impl Into<String>, dtype: TensorType, size: u32) -> Self {
403        self.shared_memory.push((name.into(), dtype, size));
404        self
405    }
406
407    /// Generate WGSL shader code
408    #[must_use]
409    pub fn to_wgsl(&self) -> String {
410        let mut wgsl = String::new();
411
412        // Header comment
413        wgsl.push_str(&format!(
414            "// {} Compute Shader\n",
415            to_pascal_case(&self.name)
416        ));
417        wgsl.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
418
419        // Input bindings
420        for input in &self.inputs {
421            wgsl.push_str(&input.to_wgsl_binding());
422            wgsl.push('\n');
423        }
424
425        // Output bindings
426        for output in &self.outputs {
427            wgsl.push_str(&output.to_wgsl_binding());
428            wgsl.push('\n');
429        }
430
431        wgsl.push('\n');
432
433        // Shared memory declarations
434        for (name, dtype, size) in &self.shared_memory {
435            wgsl.push_str(&format!(
436                "var<workgroup> {}: array<{}, {}>;\n",
437                name,
438                dtype.to_wgsl(),
439                size
440            ));
441        }
442
443        if !self.shared_memory.is_empty() {
444            wgsl.push('\n');
445        }
446
447        // Main compute function
448        let (wg_x, wg_y, wg_z) = self.workgroup_size;
449        wgsl.push_str(&format!(
450            "@compute @workgroup_size({}, {}, {})\n",
451            wg_x, wg_y, wg_z
452        ));
453        wgsl.push_str("fn main(\n");
454        wgsl.push_str("    @builtin(global_invocation_id) global_id: vec3<u32>,\n");
455        wgsl.push_str("    @builtin(local_invocation_id) local_id: vec3<u32>,\n");
456        wgsl.push_str("    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
457        wgsl.push_str(") {\n");
458
459        // Index calculations
460        wgsl.push_str("    let gid = global_id.x + global_id.y * ");
461        wgsl.push_str(&format!("{}u;\n", wg_x));
462        wgsl.push_str("    let lid = local_id.x + local_id.y * ");
463        wgsl.push_str(&format!("{}u;\n\n", wg_x));
464
465        // Generate operations
466        for op in &self.operations {
467            match op {
468                TileOp::LoadShared { src, tile_size: _ } => {
469                    wgsl.push_str(&format!("    // Load from {} to shared memory\n", src));
470                    wgsl.push_str(&format!("    let val_{} = {}[gid];\n", src, src));
471                }
472                TileOp::Elementwise {
473                    op: elem_op,
474                    operands,
475                    output,
476                } => {
477                    let input = &operands[0];
478                    let out_name = output.as_ref().unwrap_or(input);
479                    let input_val = format!("val_{}", input);
480                    let expr = elem_op.to_wgsl_expr(&input_val);
481                    wgsl.push_str(&format!("    let val_{} = {};\n", out_name, expr));
482                }
483                TileOp::StoreShared { dst } => {
484                    wgsl.push_str(&format!("    // Store to {}\n", dst));
485                    // Find what value to store
486                    let val_name = if self.operations.iter().any(
487                        |o| matches!(o, TileOp::Elementwise { output: Some(n), .. } if n == dst),
488                    ) {
489                        format!("val_{}", dst)
490                    } else if let Some(input) = self.inputs.first() {
491                        format!("val_{}", input.name)
492                    } else {
493                        "0.0".to_string()
494                    };
495                    wgsl.push_str(&format!("    {}[gid] = {};\n", dst, val_name));
496                }
497                TileOp::Barrier => {
498                    wgsl.push_str("    workgroupBarrier();\n");
499                }
500                TileOp::Mma { a, b, c } => {
501                    wgsl.push_str(&format!("    // Matrix multiply: {} = {} @ {}\n", c, a, b));
502                    wgsl.push_str("    // TODO: Implement cooperative matrix\n");
503                }
504                TileOp::Reduce {
505                    kind,
506                    input,
507                    output,
508                } => {
509                    wgsl.push_str(&format!(
510                        "    // Reduce {} -> {} ({:?})\n",
511                        input, output, kind
512                    ));
513                }
514            }
515        }
516
517        wgsl.push_str("}\n");
518
519        wgsl
520    }
521
522    /// Generate Rust wgpu bindings
523    #[must_use]
524    pub fn to_rust_bindings(&self) -> String {
525        let mut rust = String::new();
526
527        // Header
528        rust.push_str(&format!(
529            "//! {} Compute Bindings\n",
530            to_pascal_case(&self.name)
531        ));
532        rust.push_str("//! Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
533        rust.push_str(
534            "use wgpu::{BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry};\n",
535        );
536        rust.push_str("use wgpu::{ShaderStages, BufferBindingType, BindingType};\n\n");
537
538        let struct_name = to_pascal_case(&self.name);
539
540        // Struct definition
541        rust.push_str(&format!("pub struct {}Compute {{\n", struct_name));
542        rust.push_str("    pub pipeline: wgpu::ComputePipeline,\n");
543        rust.push_str("    pub bind_group_layout: wgpu::BindGroupLayout,\n");
544        rust.push_str("}\n\n");
545
546        // Implementation
547        rust.push_str(&format!("impl {}Compute {{\n", struct_name));
548        rust.push_str("    pub const WORKGROUP_SIZE: (u32, u32, u32) = ");
549        rust.push_str(&format!("{:?};\n\n", self.workgroup_size));
550
551        // WGSL source as const
552        rust.push_str("    pub const SHADER_SOURCE: &'static str = r#\"\n");
553        rust.push_str(&self.to_wgsl());
554        rust.push_str("\"#;\n\n");
555
556        // Create bind group layout
557        rust.push_str(
558            "    pub fn create_bind_group_layout(device: &wgpu::Device) -> BindGroupLayout {\n",
559        );
560        rust.push_str("        device.create_bind_group_layout(&BindGroupLayoutDescriptor {\n");
561        rust.push_str(&format!(
562            "            label: Some(\"{} bind group layout\"),\n",
563            self.name
564        ));
565        rust.push_str("            entries: &[\n");
566
567        for input in &self.inputs {
568            rust.push_str(&format!("                // Input: {}\n", input.name));
569            rust.push_str(&format!(
570                "                BindGroupLayoutEntry {{\n                    binding: {},\n                    visibility: ShaderStages::COMPUTE,\n                    ty: BindingType::Buffer {{\n                        ty: BufferBindingType::Storage {{ read_only: true }},\n                        has_dynamic_offset: false,\n                        min_binding_size: None,\n                    }},\n                    count: None,\n                }},\n",
571                input.binding
572            ));
573        }
574
575        for output in &self.outputs {
576            rust.push_str(&format!("                // Output: {}\n", output.name));
577            rust.push_str(&format!(
578                "                BindGroupLayoutEntry {{\n                    binding: {},\n                    visibility: ShaderStages::COMPUTE,\n                    ty: BindingType::Buffer {{\n                        ty: BufferBindingType::Storage {{ read_only: false }},\n                        has_dynamic_offset: false,\n                        min_binding_size: None,\n                    }},\n                    count: None,\n                }},\n",
579                output.binding
580            ));
581        }
582
583        rust.push_str("            ],\n");
584        rust.push_str("        })\n");
585        rust.push_str("    }\n");
586        rust.push_str("}\n");
587
588        rust
589    }
590
591    /// Generate JavaScript dispatch code for WebGPU
592    #[must_use]
593    pub fn to_dispatch_js(&self) -> String {
594        let mut js = String::new();
595
596        js.push_str(&format!(
597            "// {} Compute Dispatch\n",
598            to_pascal_case(&self.name)
599        ));
600        js.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
601
602        let (wg_x, wg_y, wg_z) = self.workgroup_size;
603        js.push_str(&format!(
604            "const WORKGROUP_SIZE = [{}, {}, {}];\n\n",
605            wg_x, wg_y, wg_z
606        ));
607
608        js.push_str(&format!(
609            "async function dispatch{}(device, inputs, outputs) {{\n",
610            to_pascal_case(&self.name)
611        ));
612
613        js.push_str("    // Create shader module\n");
614        js.push_str("    const shaderModule = device.createShaderModule({\n");
615        js.push_str(&format!("        label: '{} shader',\n", self.name));
616        js.push_str("        code: SHADER_SOURCE,\n");
617        js.push_str("    });\n\n");
618
619        js.push_str("    // Calculate dispatch size\n");
620        if let Some(output) = self.outputs.first() {
621            let total_size = output.element_count();
622            js.push_str(&format!("    const totalElements = {};\n", total_size));
623            js.push_str(&format!(
624                "    const numWorkgroups = Math.ceil(totalElements / {});\n\n",
625                wg_x * wg_y * wg_z
626            ));
627        }
628
629        js.push_str("    // Dispatch\n");
630        js.push_str("    const commandEncoder = device.createCommandEncoder();\n");
631        js.push_str("    const passEncoder = commandEncoder.beginComputePass();\n");
632        js.push_str("    passEncoder.setPipeline(pipeline);\n");
633        js.push_str("    passEncoder.setBindGroup(0, bindGroup);\n");
634        js.push_str("    passEncoder.dispatchWorkgroups(numWorkgroups, 1, 1);\n");
635        js.push_str("    passEncoder.end();\n");
636        js.push_str("    device.queue.submit([commandEncoder.finish()]);\n");
637        js.push_str("}\n");
638
639        js
640    }
641
642    /// Get the brick name
643    #[must_use]
644    pub fn name(&self) -> &str {
645        &self.name
646    }
647
648    /// Get workgroup size
649    #[must_use]
650    pub fn get_workgroup_size(&self) -> (u32, u32, u32) {
651        self.workgroup_size
652    }
653
654    /// Get input bindings
655    #[must_use]
656    pub fn inputs(&self) -> &[TensorBinding] {
657        &self.inputs
658    }
659
660    /// Get output bindings
661    #[must_use]
662    pub fn outputs(&self) -> &[TensorBinding] {
663        &self.outputs
664    }
665}
666
667impl Brick for ComputeBrick {
668    fn brick_name(&self) -> &'static str {
669        "ComputeBrick"
670    }
671
672    fn assertions(&self) -> &[BrickAssertion] {
673        &[]
674    }
675
676    fn budget(&self) -> BrickBudget {
677        // Compute shaders have longer budgets
678        BrickBudget::uniform(100)
679    }
680
681    fn verify(&self) -> BrickVerification {
682        let mut passed = Vec::new();
683        let mut failed = Vec::new();
684
685        // Verify workgroup size is valid
686        let (x, y, z) = self.workgroup_size;
687        if x * y * z > 1024 {
688            failed.push((
689                BrickAssertion::Custom {
690                    name: "workgroup_size_valid".into(),
691                    validator_id: 1,
692                },
693                format!(
694                    "Workgroup size {}x{}x{}={} exceeds maximum 1024",
695                    x,
696                    y,
697                    z,
698                    x * y * z
699                ),
700            ));
701        } else {
702            passed.push(BrickAssertion::Custom {
703                name: "workgroup_size_valid".into(),
704                validator_id: 1,
705            });
706        }
707
708        // Verify inputs and outputs are defined
709        if self.inputs.is_empty() {
710            failed.push((
711                BrickAssertion::Custom {
712                    name: "has_inputs".into(),
713                    validator_id: 2,
714                },
715                "ComputeBrick has no input tensors".into(),
716            ));
717        } else {
718            passed.push(BrickAssertion::Custom {
719                name: "has_inputs".into(),
720                validator_id: 2,
721            });
722        }
723
724        if self.outputs.is_empty() {
725            failed.push((
726                BrickAssertion::Custom {
727                    name: "has_outputs".into(),
728                    validator_id: 3,
729                },
730                "ComputeBrick has no output tensors".into(),
731            ));
732        } else {
733            passed.push(BrickAssertion::Custom {
734                name: "has_outputs".into(),
735                validator_id: 3,
736            });
737        }
738
739        // Verify operations reference valid tensors
740        let tensor_names: Vec<_> = self
741            .inputs
742            .iter()
743            .chain(self.outputs.iter())
744            .map(|t| t.name.as_str())
745            .collect();
746
747        for op in &self.operations {
748            match op {
749                TileOp::LoadShared { src, .. } => {
750                    if !tensor_names.contains(&src.as_str()) {
751                        failed.push((
752                            BrickAssertion::Custom {
753                                name: "tensor_exists".into(),
754                                validator_id: 4,
755                            },
756                            format!("LoadShared references unknown tensor: {}", src),
757                        ));
758                    }
759                }
760                TileOp::StoreShared { dst } => {
761                    if !tensor_names.contains(&dst.as_str()) {
762                        failed.push((
763                            BrickAssertion::Custom {
764                                name: "tensor_exists".into(),
765                                validator_id: 4,
766                            },
767                            format!("StoreShared references unknown tensor: {}", dst),
768                        ));
769                    }
770                }
771                _ => {}
772            }
773        }
774
775        if failed.is_empty() {
776            passed.push(BrickAssertion::Custom {
777                name: "compute_brick_valid".into(),
778                validator_id: 5,
779            });
780        }
781
782        BrickVerification {
783            passed,
784            failed,
785            verification_time: Duration::from_micros(100),
786        }
787    }
788
789    fn to_html(&self) -> String {
790        // ComputeBrick doesn't generate HTML
791        String::new()
792    }
793
794    fn to_css(&self) -> String {
795        // ComputeBrick doesn't generate CSS
796        String::new()
797    }
798}
799
800/// Convert string to PascalCase
801fn to_pascal_case(s: &str) -> String {
802    let mut result = String::new();
803    let mut capitalize_next = true;
804
805    for c in s.chars() {
806        if c == '_' || c == '-' || c == ' ' {
807            capitalize_next = true;
808        } else if capitalize_next {
809            result.push(c.to_ascii_uppercase());
810            capitalize_next = false;
811        } else {
812            result.push(c);
813        }
814    }
815
816    result
817}
818
819#[cfg(test)]
820#[allow(clippy::unwrap_used, clippy::expect_used)]
821mod tests {
822    use super::*;
823
824    #[test]
825    fn test_compute_brick_basic() {
826        let brick = ComputeBrick::new("test")
827            .workgroup_size(256, 1, 1)
828            .input("audio", TensorType::F32, &[1024])
829            .output("mel", TensorType::F32, &[80, 100]);
830
831        assert_eq!(brick.name(), "test");
832        assert_eq!(brick.get_workgroup_size(), (256, 1, 1));
833        assert_eq!(brick.inputs().len(), 1);
834        assert_eq!(brick.outputs().len(), 1);
835    }
836
837    #[test]
838    fn test_compute_brick_wgsl_generation() {
839        let brick = ComputeBrick::new("log-transform")
840            .workgroup_size(64, 1, 1)
841            .input("input", TensorType::F32, &[1024])
842            .output("output", TensorType::F32, &[1024])
843            .op(TileOp::LoadShared {
844                src: "input".into(),
845                tile_size: (64, 1),
846            })
847            .op(TileOp::Elementwise {
848                op: ElementwiseOp::Log,
849                operands: vec!["input".into()],
850                output: Some("output".into()),
851            })
852            .op(TileOp::StoreShared {
853                dst: "output".into(),
854            });
855
856        let wgsl = brick.to_wgsl();
857
858        assert!(wgsl.contains("@compute @workgroup_size(64, 1, 1)"));
859        assert!(wgsl.contains("fn main("));
860        assert!(wgsl.contains("log("));
861        assert!(wgsl.contains("Generated by probar"));
862    }
863
864    #[test]
865    fn test_compute_brick_verification() {
866        let brick = ComputeBrick::new("test")
867            .workgroup_size(256, 1, 1)
868            .input("input", TensorType::F32, &[1024])
869            .output("output", TensorType::F32, &[1024]);
870
871        let result = brick.verify();
872        assert!(result.is_valid());
873    }
874
875    #[test]
876    fn test_compute_brick_verification_fails_no_inputs() {
877        let brick = ComputeBrick::new("test").workgroup_size(256, 1, 1).output(
878            "output",
879            TensorType::F32,
880            &[1024],
881        );
882
883        let result = brick.verify();
884        assert!(!result.is_valid());
885    }
886
887    #[test]
888    fn test_compute_brick_verification_fails_large_workgroup() {
889        let brick = ComputeBrick::new("test")
890            .workgroup_size(1024, 2, 1) // 2048 > 1024 max
891            .input("input", TensorType::F32, &[1024])
892            .output("output", TensorType::F32, &[1024]);
893
894        let result = brick.verify();
895        assert!(!result.is_valid());
896    }
897
898    #[test]
899    fn test_tensor_binding() {
900        let binding = TensorBinding::new("audio", TensorType::F32, &[1024, 80])
901            .at(0, 1)
902            .writable();
903
904        assert_eq!(binding.name, "audio");
905        assert_eq!(binding.element_count(), 1024 * 80);
906        assert_eq!(binding.byte_size(), 1024 * 80 * 4);
907        assert!(!binding.read_only);
908    }
909
910    #[test]
911    fn test_tensor_type_wgsl() {
912        assert_eq!(TensorType::F32.to_wgsl(), "f32");
913        assert_eq!(TensorType::F16.to_wgsl(), "f16");
914        assert_eq!(TensorType::I32.to_wgsl(), "i32");
915        assert_eq!(TensorType::U32.to_wgsl(), "u32");
916    }
917
918    #[test]
919    fn test_elementwise_ops() {
920        assert_eq!(ElementwiseOp::Log.to_wgsl_expr("x"), "log(x)");
921        assert_eq!(ElementwiseOp::Exp.to_wgsl_expr("x"), "exp(x)");
922        assert_eq!(ElementwiseOp::Relu.to_wgsl_expr("x"), "max(x, 0.0)");
923        assert_eq!(ElementwiseOp::AddScalar(5).to_wgsl_expr("x"), "(x + 5.0)");
924    }
925
926    #[test]
927    fn test_rust_bindings_generation() {
928        let brick = ComputeBrick::new("mel-transform")
929            .workgroup_size(256, 1, 1)
930            .input("audio", TensorType::F32, &[1024])
931            .output("mel", TensorType::F32, &[80]);
932
933        let rust = brick.to_rust_bindings();
934
935        assert!(rust.contains("pub struct MelTransformCompute"));
936        assert!(rust.contains("WORKGROUP_SIZE"));
937        assert!(rust.contains("SHADER_SOURCE"));
938        assert!(rust.contains("create_bind_group_layout"));
939    }
940
941    #[test]
942    fn test_js_dispatch_generation() {
943        let brick = ComputeBrick::new("fft")
944            .workgroup_size(64, 1, 1)
945            .input("signal", TensorType::F32, &[512])
946            .output("spectrum", TensorType::F32, &[512]);
947
948        let js = brick.to_dispatch_js();
949
950        assert!(js.contains("async function dispatchFft"));
951        assert!(js.contains("WORKGROUP_SIZE"));
952        assert!(js.contains("dispatchWorkgroups"));
953    }
954
955    #[test]
956    fn test_tile_strategy_workgroup_size() {
957        let simple = TileStrategy::Simple2D {
958            tile_x: 16,
959            tile_y: 16,
960        };
961        assert_eq!(simple.optimal_workgroup_size(), (16, 16, 1));
962
963        let coop = TileStrategy::Cooperative { m: 8, n: 8, k: 4 };
964        assert_eq!(coop.optimal_workgroup_size(), (8, 8, 1));
965
966        let streaming = TileStrategy::Streaming { window: 32 };
967        assert_eq!(streaming.optimal_workgroup_size(), (32, 1, 1));
968    }
969
970    // ========================================================================
971    // Additional comprehensive tests for 95%+ coverage
972    // ========================================================================
973
974    #[test]
975    fn test_tensor_type_rust() {
976        assert_eq!(TensorType::F32.to_rust(), "f32");
977        assert_eq!(TensorType::F16.to_rust(), "half::f16");
978        assert_eq!(TensorType::I32.to_rust(), "i32");
979        assert_eq!(TensorType::U32.to_rust(), "u32");
980    }
981
982    #[test]
983    fn test_tensor_type_byte_size() {
984        assert_eq!(TensorType::F32.byte_size(), 4);
985        assert_eq!(TensorType::F16.byte_size(), 2);
986        assert_eq!(TensorType::I32.byte_size(), 4);
987        assert_eq!(TensorType::U32.byte_size(), 4);
988    }
989
990    #[test]
991    fn test_tensor_type_clone() {
992        let t = TensorType::F32;
993        let cloned = t;
994        assert_eq!(t, cloned);
995    }
996
997    #[test]
998    fn test_tensor_binding_default_values() {
999        let binding = TensorBinding::new("test", TensorType::I32, &[10, 20]);
1000        assert_eq!(binding.group, 0);
1001        assert_eq!(binding.binding, 0);
1002        assert!(binding.read_only);
1003    }
1004
1005    #[test]
1006    fn test_tensor_binding_to_wgsl_binding_read_only() {
1007        let binding = TensorBinding::new("data", TensorType::F32, &[100]).at(1, 2);
1008        let wgsl = binding.to_wgsl_binding();
1009        assert!(wgsl.contains("@group(1) @binding(2)"));
1010        assert!(wgsl.contains("var<storage, read>"));
1011        assert!(wgsl.contains("data"));
1012        assert!(wgsl.contains("f32"));
1013    }
1014
1015    #[test]
1016    fn test_tensor_binding_to_wgsl_binding_read_write() {
1017        let binding = TensorBinding::new("output", TensorType::U32, &[50])
1018            .at(0, 0)
1019            .writable();
1020        let wgsl = binding.to_wgsl_binding();
1021        assert!(wgsl.contains("var<storage, read_write>"));
1022    }
1023
1024    #[test]
1025    fn test_tensor_binding_clone() {
1026        let binding = TensorBinding::new("test", TensorType::F32, &[1, 2, 3])
1027            .at(1, 2)
1028            .writable();
1029        let cloned = binding.clone();
1030        assert_eq!(binding.name, cloned.name);
1031        assert_eq!(binding.shape, cloned.shape);
1032        assert_eq!(binding.read_only, cloned.read_only);
1033    }
1034
1035    #[test]
1036    fn test_tile_strategy_none() {
1037        let strategy = TileStrategy::None;
1038        assert_eq!(strategy.optimal_workgroup_size(), (64, 1, 1));
1039    }
1040
1041    #[test]
1042    fn test_tile_strategy_clone() {
1043        let strategy = TileStrategy::Simple2D {
1044            tile_x: 8,
1045            tile_y: 8,
1046        };
1047        let cloned = strategy;
1048        assert!(matches!(
1049            cloned,
1050            TileStrategy::Simple2D {
1051                tile_x: 8,
1052                tile_y: 8
1053            }
1054        ));
1055    }
1056
1057    #[test]
1058    fn test_elementwise_op_sqrt() {
1059        assert_eq!(ElementwiseOp::Sqrt.to_wgsl_expr("val"), "sqrt(val)");
1060    }
1061
1062    #[test]
1063    fn test_elementwise_op_abs() {
1064        assert_eq!(ElementwiseOp::Abs.to_wgsl_expr("v"), "abs(v)");
1065    }
1066
1067    #[test]
1068    fn test_elementwise_op_sigmoid() {
1069        assert_eq!(
1070            ElementwiseOp::Sigmoid.to_wgsl_expr("x"),
1071            "1.0 / (1.0 + exp(-x))"
1072        );
1073    }
1074
1075    #[test]
1076    fn test_elementwise_op_tanh() {
1077        assert_eq!(ElementwiseOp::Tanh.to_wgsl_expr("x"), "tanh(x)");
1078    }
1079
1080    #[test]
1081    fn test_elementwise_op_mul_scalar() {
1082        assert_eq!(ElementwiseOp::MulScalar(3).to_wgsl_expr("y"), "(y * 3.0)");
1083        assert_eq!(ElementwiseOp::MulScalar(-2).to_wgsl_expr("x"), "(x * -2.0)");
1084    }
1085
1086    #[test]
1087    fn test_elementwise_op_clamp() {
1088        assert_eq!(ElementwiseOp::Clamp.to_wgsl_expr("x"), "clamp(x, 0.0, 1.0)");
1089    }
1090
1091    #[test]
1092    fn test_elementwise_op_eq() {
1093        assert_eq!(ElementwiseOp::Log, ElementwiseOp::Log);
1094        assert_ne!(ElementwiseOp::Log, ElementwiseOp::Exp);
1095        assert_eq!(ElementwiseOp::AddScalar(5), ElementwiseOp::AddScalar(5));
1096        assert_ne!(ElementwiseOp::AddScalar(5), ElementwiseOp::AddScalar(6));
1097    }
1098
1099    #[test]
1100    fn test_reduce_kind_identity() {
1101        assert_eq!(ReduceKind::Sum.identity(), "0.0");
1102        assert_eq!(ReduceKind::Mean.identity(), "0.0");
1103        assert_eq!(ReduceKind::Max.identity(), "-3.402823e+38");
1104        assert_eq!(ReduceKind::Min.identity(), "3.402823e+38");
1105    }
1106
1107    #[test]
1108    fn test_reduce_kind_combine_op() {
1109        assert_eq!(ReduceKind::Sum.combine_op(), "+");
1110        assert_eq!(ReduceKind::Mean.combine_op(), "+");
1111        assert_eq!(ReduceKind::Max.combine_op(), "max");
1112        assert_eq!(ReduceKind::Min.combine_op(), "min");
1113    }
1114
1115    #[test]
1116    fn test_reduce_kind_eq() {
1117        assert_eq!(ReduceKind::Sum, ReduceKind::Sum);
1118        assert_ne!(ReduceKind::Sum, ReduceKind::Max);
1119    }
1120
1121    #[test]
1122    fn test_tile_op_load_shared() {
1123        let op = TileOp::LoadShared {
1124            src: "audio".into(),
1125            tile_size: (32, 32),
1126        };
1127        match op {
1128            TileOp::LoadShared { src, tile_size } => {
1129                assert_eq!(src, "audio");
1130                assert_eq!(tile_size, (32, 32));
1131            }
1132            _ => panic!("Expected LoadShared"),
1133        }
1134    }
1135
1136    #[test]
1137    fn test_tile_op_mma() {
1138        let op = TileOp::Mma {
1139            a: "A".into(),
1140            b: "B".into(),
1141            c: "C".into(),
1142        };
1143        match op {
1144            TileOp::Mma { a, b, c } => {
1145                assert_eq!(a, "A");
1146                assert_eq!(b, "B");
1147                assert_eq!(c, "C");
1148            }
1149            _ => panic!("Expected Mma"),
1150        }
1151    }
1152
1153    #[test]
1154    fn test_tile_op_reduce() {
1155        let op = TileOp::Reduce {
1156            kind: ReduceKind::Max,
1157            input: "values".into(),
1158            output: "max_val".into(),
1159        };
1160        match op {
1161            TileOp::Reduce {
1162                kind,
1163                input,
1164                output,
1165            } => {
1166                assert_eq!(kind, ReduceKind::Max);
1167                assert_eq!(input, "values");
1168                assert_eq!(output, "max_val");
1169            }
1170            _ => panic!("Expected Reduce"),
1171        }
1172    }
1173
1174    #[test]
1175    fn test_tile_op_barrier() {
1176        let op = TileOp::Barrier;
1177        assert!(matches!(op, TileOp::Barrier));
1178    }
1179
1180    #[test]
1181    fn test_tile_op_clone() {
1182        let op = TileOp::Elementwise {
1183            op: ElementwiseOp::Relu,
1184            operands: vec!["x".into(), "y".into()],
1185            output: Some("z".into()),
1186        };
1187        let cloned = op;
1188        assert!(matches!(cloned, TileOp::Elementwise { .. }));
1189    }
1190
1191    #[test]
1192    fn test_compute_brick_tile_strategy() {
1193        let brick = ComputeBrick::new("test").tile_strategy(TileStrategy::Cooperative {
1194            m: 16,
1195            n: 16,
1196            k: 8,
1197        });
1198
1199        // The tile_strategy is stored internally
1200        assert_eq!(brick.name(), "test");
1201    }
1202
1203    #[test]
1204    fn test_compute_brick_shared_memory() {
1205        let brick = ComputeBrick::new("test")
1206            .shared("tile_a", TensorType::F32, 256)
1207            .shared("tile_b", TensorType::F32, 128);
1208
1209        let wgsl = brick.to_wgsl();
1210        assert!(wgsl.contains("var<workgroup> tile_a"));
1211        assert!(wgsl.contains("var<workgroup> tile_b"));
1212    }
1213
1214    #[test]
1215    fn test_compute_brick_verification_no_outputs() {
1216        let brick = ComputeBrick::new("test").input("input", TensorType::F32, &[1024]);
1217
1218        let result = brick.verify();
1219        assert!(!result.is_valid());
1220    }
1221
1222    #[test]
1223    fn test_compute_brick_verification_invalid_load_tensor() {
1224        let brick = ComputeBrick::new("test")
1225            .input("input", TensorType::F32, &[1024])
1226            .output("output", TensorType::F32, &[1024])
1227            .op(TileOp::LoadShared {
1228                src: "nonexistent".into(),
1229                tile_size: (64, 1),
1230            });
1231
1232        let result = brick.verify();
1233        assert!(!result.is_valid());
1234    }
1235
1236    #[test]
1237    fn test_compute_brick_verification_invalid_store_tensor() {
1238        let brick = ComputeBrick::new("test")
1239            .input("input", TensorType::F32, &[1024])
1240            .output("output", TensorType::F32, &[1024])
1241            .op(TileOp::StoreShared {
1242                dst: "nonexistent".into(),
1243            });
1244
1245        let result = brick.verify();
1246        assert!(!result.is_valid());
1247    }
1248
1249    #[test]
1250    fn test_compute_brick_wgsl_barrier() {
1251        let brick = ComputeBrick::new("test")
1252            .input("input", TensorType::F32, &[64])
1253            .output("output", TensorType::F32, &[64])
1254            .op(TileOp::Barrier);
1255
1256        let wgsl = brick.to_wgsl();
1257        assert!(wgsl.contains("workgroupBarrier()"));
1258    }
1259
1260    #[test]
1261    fn test_compute_brick_wgsl_mma() {
1262        let brick = ComputeBrick::new("matmul")
1263            .input("A", TensorType::F32, &[64, 64])
1264            .input("B", TensorType::F32, &[64, 64])
1265            .output("C", TensorType::F32, &[64, 64])
1266            .op(TileOp::Mma {
1267                a: "A".into(),
1268                b: "B".into(),
1269                c: "C".into(),
1270            });
1271
1272        let wgsl = brick.to_wgsl();
1273        assert!(wgsl.contains("Matrix multiply"));
1274    }
1275
1276    #[test]
1277    fn test_compute_brick_wgsl_reduce() {
1278        let brick = ComputeBrick::new("reduce")
1279            .input("values", TensorType::F32, &[1024])
1280            .output("result", TensorType::F32, &[1])
1281            .op(TileOp::Reduce {
1282                kind: ReduceKind::Sum,
1283                input: "values".into(),
1284                output: "result".into(),
1285            });
1286
1287        let wgsl = brick.to_wgsl();
1288        assert!(wgsl.contains("Reduce"));
1289    }
1290
1291    #[test]
1292    fn test_compute_brick_wgsl_elementwise_no_output() {
1293        let brick = ComputeBrick::new("test")
1294            .input("x", TensorType::F32, &[64])
1295            .output("y", TensorType::F32, &[64])
1296            .op(TileOp::LoadShared {
1297                src: "x".into(),
1298                tile_size: (64, 1),
1299            })
1300            .op(TileOp::Elementwise {
1301                op: ElementwiseOp::Log,
1302                operands: vec!["x".into()],
1303                output: None, // Output defaults to first operand
1304            });
1305
1306        let wgsl = brick.to_wgsl();
1307        assert!(wgsl.contains("log(val_x)"));
1308    }
1309
1310    #[test]
1311    fn test_compute_brick_wgsl_store_fallback() {
1312        let brick = ComputeBrick::new("test")
1313            .input("input", TensorType::F32, &[64])
1314            .output("output", TensorType::F32, &[64])
1315            .op(TileOp::LoadShared {
1316                src: "input".into(),
1317                tile_size: (64, 1),
1318            })
1319            .op(TileOp::StoreShared {
1320                dst: "output".into(),
1321            });
1322
1323        let wgsl = brick.to_wgsl();
1324        assert!(wgsl.contains("output[gid]"));
1325    }
1326
1327    #[test]
1328    fn test_compute_brick_implements_brick() {
1329        let brick = ComputeBrick::new("test")
1330            .input("in", TensorType::F32, &[32])
1331            .output("out", TensorType::F32, &[32]);
1332
1333        assert_eq!(brick.brick_name(), "ComputeBrick");
1334        assert!(brick.assertions().is_empty());
1335        assert_eq!(brick.budget().total_ms, 100);
1336        assert!(brick.to_html().is_empty());
1337        assert!(brick.to_css().is_empty());
1338    }
1339
1340    #[test]
1341    fn test_to_pascal_case_variants() {
1342        assert_eq!(to_pascal_case("simple"), "Simple");
1343        assert_eq!(to_pascal_case("two_words"), "TwoWords");
1344        assert_eq!(to_pascal_case("three-part-name"), "ThreePartName");
1345        assert_eq!(to_pascal_case("mixed_style-here"), "MixedStyleHere");
1346        assert_eq!(to_pascal_case("with space"), "WithSpace");
1347    }
1348
1349    #[test]
1350    fn test_compute_brick_multiple_inputs() {
1351        let brick = ComputeBrick::new("multi")
1352            .input("a", TensorType::F32, &[100])
1353            .input("b", TensorType::I32, &[100])
1354            .input("c", TensorType::U32, &[100])
1355            .output("result", TensorType::F32, &[100]);
1356
1357        assert_eq!(brick.inputs().len(), 3);
1358        assert_eq!(brick.inputs()[0].binding, 0);
1359        assert_eq!(brick.inputs()[1].binding, 1);
1360        assert_eq!(brick.inputs()[2].binding, 2);
1361    }
1362
1363    #[test]
1364    fn test_compute_brick_multiple_outputs() {
1365        let brick = ComputeBrick::new("multi_out")
1366            .input("in", TensorType::F32, &[50])
1367            .output("out1", TensorType::F32, &[50])
1368            .output("out2", TensorType::F32, &[25]);
1369
1370        assert_eq!(brick.outputs().len(), 2);
1371        assert_eq!(brick.outputs()[0].binding, 0);
1372        assert_eq!(brick.outputs()[1].binding, 1);
1373        assert_eq!(brick.outputs()[0].group, 1);
1374        assert_eq!(brick.outputs()[1].group, 1);
1375    }
1376
1377    #[test]
1378    fn test_compute_brick_clone() {
1379        let brick = ComputeBrick::new("test")
1380            .workgroup_size(128, 4, 1)
1381            .input("in", TensorType::F16, &[256])
1382            .output("out", TensorType::F16, &[256])
1383            .shared("cache", TensorType::F16, 512);
1384
1385        let cloned = brick.clone();
1386        assert_eq!(brick.name(), cloned.name());
1387        assert_eq!(brick.get_workgroup_size(), cloned.get_workgroup_size());
1388    }
1389
1390    #[test]
1391    fn test_js_dispatch_no_outputs() {
1392        let brick = ComputeBrick::new("no_out").input("in", TensorType::F32, &[10]);
1393
1394        let js = brick.to_dispatch_js();
1395        // Should still generate dispatch function but no numWorkgroups calculation
1396        assert!(js.contains("dispatchNoOut"));
1397    }
1398
1399    #[test]
1400    fn test_rust_bindings_multiple_io() {
1401        let brick = ComputeBrick::new("complex")
1402            .input("in1", TensorType::F32, &[100])
1403            .input("in2", TensorType::I32, &[50])
1404            .output("out1", TensorType::F32, &[100])
1405            .output("out2", TensorType::U32, &[25]);
1406
1407        let rust = brick.to_rust_bindings();
1408        assert!(rust.contains("Input: in1"));
1409        assert!(rust.contains("Input: in2"));
1410        assert!(rust.contains("Output: out1"));
1411        assert!(rust.contains("Output: out2"));
1412    }
1413
1414    #[test]
1415    fn test_tensor_binding_empty_shape() {
1416        let binding = TensorBinding::new("scalar", TensorType::F32, &[]);
1417        assert_eq!(binding.element_count(), 1); // Product of empty vec is 1
1418        assert_eq!(binding.byte_size(), 4);
1419    }
1420}