1use super::{Brick, BrickAssertion, BrickBudget, BrickVerification};
37use std::time::Duration;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum TensorType {
42 F32,
44 F16,
46 I32,
48 U32,
50}
51
52impl TensorType {
53 #[must_use]
55 pub fn to_wgsl(&self) -> &'static str {
56 match self {
57 Self::F32 => "f32",
58 Self::F16 => "f16",
59 Self::I32 => "i32",
60 Self::U32 => "u32",
61 }
62 }
63
64 #[must_use]
66 pub fn to_rust(&self) -> &'static str {
67 match self {
68 Self::F32 => "f32",
69 Self::F16 => "half::f16",
70 Self::I32 => "i32",
71 Self::U32 => "u32",
72 }
73 }
74
75 #[must_use]
77 pub const fn byte_size(&self) -> usize {
78 match self {
79 Self::F32 | Self::I32 | Self::U32 => 4,
80 Self::F16 => 2,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct TensorBinding {
88 pub name: String,
90 pub dtype: TensorType,
92 pub shape: Vec<u32>,
94 pub group: u32,
96 pub binding: u32,
98 pub read_only: bool,
100}
101
102impl TensorBinding {
103 #[must_use]
105 pub fn new(name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
106 Self {
107 name: name.into(),
108 dtype,
109 shape: shape.to_vec(),
110 group: 0,
111 binding: 0,
112 read_only: true,
113 }
114 }
115
116 #[must_use]
118 pub fn at(mut self, group: u32, binding: u32) -> Self {
119 self.group = group;
120 self.binding = binding;
121 self
122 }
123
124 #[must_use]
126 pub fn writable(mut self) -> Self {
127 self.read_only = false;
128 self
129 }
130
131 #[must_use]
133 pub fn element_count(&self) -> u32 {
134 self.shape.iter().product()
135 }
136
137 #[must_use]
139 pub fn byte_size(&self) -> usize {
140 self.element_count() as usize * self.dtype.byte_size()
141 }
142
143 #[must_use]
145 pub fn to_wgsl_binding(&self) -> String {
146 let access = if self.read_only { "read" } else { "read_write" };
147 format!(
148 "@group({}) @binding({}) var<storage, {}> {}: array<{}>;",
149 self.group,
150 self.binding,
151 access,
152 self.name,
153 self.dtype.to_wgsl()
154 )
155 }
156}
157
158#[derive(Debug, Clone)]
160pub enum TileStrategy {
161 Simple2D {
163 tile_x: u32,
165 tile_y: u32,
167 },
168 Cooperative {
170 m: u32,
172 n: u32,
174 k: u32,
176 },
177 Streaming {
179 window: u32,
181 },
182 None,
184}
185
186impl TileStrategy {
187 #[must_use]
189 pub fn optimal_workgroup_size(&self) -> (u32, u32, u32) {
190 match self {
191 Self::Simple2D { tile_x, tile_y } => (*tile_x, *tile_y, 1),
192 Self::Cooperative { m, n, .. } => (*m, *n, 1),
193 Self::Streaming { window } => (*window, 1, 1),
194 Self::None => (64, 1, 1),
195 }
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ElementwiseOp {
202 Log,
204 Exp,
206 Sqrt,
208 Abs,
210 Relu,
212 Sigmoid,
214 Tanh,
216 AddScalar(i32),
218 MulScalar(i32),
220 Clamp,
222}
223
224impl ElementwiseOp {
225 #[must_use]
227 pub fn to_wgsl_expr(&self, operand: &str) -> String {
228 match self {
229 Self::Log => format!("log({})", operand),
230 Self::Exp => format!("exp({})", operand),
231 Self::Sqrt => format!("sqrt({})", operand),
232 Self::Abs => format!("abs({})", operand),
233 Self::Relu => format!("max({}, 0.0)", operand),
234 Self::Sigmoid => format!("1.0 / (1.0 + exp(-{}))", operand),
235 Self::Tanh => format!("tanh({})", operand),
236 Self::AddScalar(s) => format!("({} + {}.0)", operand, s),
237 Self::MulScalar(s) => format!("({} * {}.0)", operand, s),
238 Self::Clamp => format!("clamp({}, 0.0, 1.0)", operand),
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
245pub enum TileOp {
246 LoadShared {
248 src: String,
250 tile_size: (u32, u32),
252 },
253 Mma {
255 a: String,
257 b: String,
259 c: String,
261 },
262 Elementwise {
264 op: ElementwiseOp,
266 operands: Vec<String>,
268 output: Option<String>,
270 },
271 StoreShared {
273 dst: String,
275 },
276 Barrier,
278 Reduce {
280 kind: ReduceKind,
282 input: String,
284 output: String,
286 },
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum ReduceKind {
292 Sum,
294 Max,
296 Min,
298 Mean,
300}
301
302impl ReduceKind {
303 #[must_use]
305 pub fn identity(&self) -> &'static str {
306 match self {
307 Self::Sum | Self::Mean => "0.0",
308 Self::Max => "-3.402823e+38", Self::Min => "3.402823e+38", }
311 }
312
313 #[must_use]
315 pub fn combine_op(&self) -> &'static str {
316 match self {
317 Self::Sum | Self::Mean => "+",
318 Self::Max => "max",
319 Self::Min => "min",
320 }
321 }
322}
323
324#[derive(Debug, Clone)]
326pub struct ComputeBrick {
327 name: String,
329 workgroup_size: (u32, u32, u32),
331 inputs: Vec<TensorBinding>,
333 outputs: Vec<TensorBinding>,
335 tile_strategy: TileStrategy,
337 operations: Vec<TileOp>,
339 shared_memory: Vec<(String, TensorType, u32)>,
341}
342
343impl ComputeBrick {
344 #[must_use]
346 pub fn new(name: impl Into<String>) -> Self {
347 Self {
348 name: name.into(),
349 workgroup_size: (64, 1, 1),
350 inputs: Vec::new(),
351 outputs: Vec::new(),
352 tile_strategy: TileStrategy::None,
353 operations: Vec::new(),
354 shared_memory: Vec::new(),
355 }
356 }
357
358 #[must_use]
360 pub fn workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
361 self.workgroup_size = (x, y, z);
362 self
363 }
364
365 #[must_use]
367 pub fn input(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
368 let binding_idx = self.inputs.len() as u32;
369 self.inputs
370 .push(TensorBinding::new(name, dtype, shape).at(0, binding_idx));
371 self
372 }
373
374 #[must_use]
376 pub fn output(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
377 let binding_idx = self.outputs.len() as u32;
378 self.outputs.push(
379 TensorBinding::new(name, dtype, shape)
380 .at(1, binding_idx)
381 .writable(),
382 );
383 self
384 }
385
386 #[must_use]
388 pub fn tile_strategy(mut self, strategy: TileStrategy) -> Self {
389 self.tile_strategy = strategy;
390 self
391 }
392
393 #[must_use]
395 pub fn op(mut self, operation: TileOp) -> Self {
396 self.operations.push(operation);
397 self
398 }
399
400 #[must_use]
402 pub fn shared(mut self, name: impl Into<String>, dtype: TensorType, size: u32) -> Self {
403 self.shared_memory.push((name.into(), dtype, size));
404 self
405 }
406
407 #[must_use]
409 pub fn to_wgsl(&self) -> String {
410 let mut wgsl = String::new();
411
412 wgsl.push_str(&format!(
414 "// {} Compute Shader\n",
415 to_pascal_case(&self.name)
416 ));
417 wgsl.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
418
419 for input in &self.inputs {
421 wgsl.push_str(&input.to_wgsl_binding());
422 wgsl.push('\n');
423 }
424
425 for output in &self.outputs {
427 wgsl.push_str(&output.to_wgsl_binding());
428 wgsl.push('\n');
429 }
430
431 wgsl.push('\n');
432
433 for (name, dtype, size) in &self.shared_memory {
435 wgsl.push_str(&format!(
436 "var<workgroup> {}: array<{}, {}>;\n",
437 name,
438 dtype.to_wgsl(),
439 size
440 ));
441 }
442
443 if !self.shared_memory.is_empty() {
444 wgsl.push('\n');
445 }
446
447 let (wg_x, wg_y, wg_z) = self.workgroup_size;
449 wgsl.push_str(&format!(
450 "@compute @workgroup_size({}, {}, {})\n",
451 wg_x, wg_y, wg_z
452 ));
453 wgsl.push_str("fn main(\n");
454 wgsl.push_str(" @builtin(global_invocation_id) global_id: vec3<u32>,\n");
455 wgsl.push_str(" @builtin(local_invocation_id) local_id: vec3<u32>,\n");
456 wgsl.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
457 wgsl.push_str(") {\n");
458
459 wgsl.push_str(" let gid = global_id.x + global_id.y * ");
461 wgsl.push_str(&format!("{}u;\n", wg_x));
462 wgsl.push_str(" let lid = local_id.x + local_id.y * ");
463 wgsl.push_str(&format!("{}u;\n\n", wg_x));
464
465 for op in &self.operations {
467 match op {
468 TileOp::LoadShared { src, tile_size: _ } => {
469 wgsl.push_str(&format!(" // Load from {} to shared memory\n", src));
470 wgsl.push_str(&format!(" let val_{} = {}[gid];\n", src, src));
471 }
472 TileOp::Elementwise {
473 op: elem_op,
474 operands,
475 output,
476 } => {
477 let input = &operands[0];
478 let out_name = output.as_ref().unwrap_or(input);
479 let input_val = format!("val_{}", input);
480 let expr = elem_op.to_wgsl_expr(&input_val);
481 wgsl.push_str(&format!(" let val_{} = {};\n", out_name, expr));
482 }
483 TileOp::StoreShared { dst } => {
484 wgsl.push_str(&format!(" // Store to {}\n", dst));
485 let val_name = if self.operations.iter().any(
487 |o| matches!(o, TileOp::Elementwise { output: Some(n), .. } if n == dst),
488 ) {
489 format!("val_{}", dst)
490 } else if let Some(input) = self.inputs.first() {
491 format!("val_{}", input.name)
492 } else {
493 "0.0".to_string()
494 };
495 wgsl.push_str(&format!(" {}[gid] = {};\n", dst, val_name));
496 }
497 TileOp::Barrier => {
498 wgsl.push_str(" workgroupBarrier();\n");
499 }
500 TileOp::Mma { a, b, c } => {
501 wgsl.push_str(&format!(" // Matrix multiply: {} = {} @ {}\n", c, a, b));
502 wgsl.push_str(" // TODO: Implement cooperative matrix\n");
503 }
504 TileOp::Reduce {
505 kind,
506 input,
507 output,
508 } => {
509 wgsl.push_str(&format!(
510 " // Reduce {} -> {} ({:?})\n",
511 input, output, kind
512 ));
513 }
514 }
515 }
516
517 wgsl.push_str("}\n");
518
519 wgsl
520 }
521
522 #[must_use]
524 pub fn to_rust_bindings(&self) -> String {
525 let mut rust = String::new();
526
527 rust.push_str(&format!(
529 "//! {} Compute Bindings\n",
530 to_pascal_case(&self.name)
531 ));
532 rust.push_str("//! Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
533 rust.push_str(
534 "use wgpu::{BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry};\n",
535 );
536 rust.push_str("use wgpu::{ShaderStages, BufferBindingType, BindingType};\n\n");
537
538 let struct_name = to_pascal_case(&self.name);
539
540 rust.push_str(&format!("pub struct {}Compute {{\n", struct_name));
542 rust.push_str(" pub pipeline: wgpu::ComputePipeline,\n");
543 rust.push_str(" pub bind_group_layout: wgpu::BindGroupLayout,\n");
544 rust.push_str("}\n\n");
545
546 rust.push_str(&format!("impl {}Compute {{\n", struct_name));
548 rust.push_str(" pub const WORKGROUP_SIZE: (u32, u32, u32) = ");
549 rust.push_str(&format!("{:?};\n\n", self.workgroup_size));
550
551 rust.push_str(" pub const SHADER_SOURCE: &'static str = r#\"\n");
553 rust.push_str(&self.to_wgsl());
554 rust.push_str("\"#;\n\n");
555
556 rust.push_str(
558 " pub fn create_bind_group_layout(device: &wgpu::Device) -> BindGroupLayout {\n",
559 );
560 rust.push_str(" device.create_bind_group_layout(&BindGroupLayoutDescriptor {\n");
561 rust.push_str(&format!(
562 " label: Some(\"{} bind group layout\"),\n",
563 self.name
564 ));
565 rust.push_str(" entries: &[\n");
566
567 for input in &self.inputs {
568 rust.push_str(&format!(" // Input: {}\n", input.name));
569 rust.push_str(&format!(
570 " BindGroupLayoutEntry {{\n binding: {},\n visibility: ShaderStages::COMPUTE,\n ty: BindingType::Buffer {{\n ty: BufferBindingType::Storage {{ read_only: true }},\n has_dynamic_offset: false,\n min_binding_size: None,\n }},\n count: None,\n }},\n",
571 input.binding
572 ));
573 }
574
575 for output in &self.outputs {
576 rust.push_str(&format!(" // Output: {}\n", output.name));
577 rust.push_str(&format!(
578 " BindGroupLayoutEntry {{\n binding: {},\n visibility: ShaderStages::COMPUTE,\n ty: BindingType::Buffer {{\n ty: BufferBindingType::Storage {{ read_only: false }},\n has_dynamic_offset: false,\n min_binding_size: None,\n }},\n count: None,\n }},\n",
579 output.binding
580 ));
581 }
582
583 rust.push_str(" ],\n");
584 rust.push_str(" })\n");
585 rust.push_str(" }\n");
586 rust.push_str("}\n");
587
588 rust
589 }
590
591 #[must_use]
593 pub fn to_dispatch_js(&self) -> String {
594 let mut js = String::new();
595
596 js.push_str(&format!(
597 "// {} Compute Dispatch\n",
598 to_pascal_case(&self.name)
599 ));
600 js.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
601
602 let (wg_x, wg_y, wg_z) = self.workgroup_size;
603 js.push_str(&format!(
604 "const WORKGROUP_SIZE = [{}, {}, {}];\n\n",
605 wg_x, wg_y, wg_z
606 ));
607
608 js.push_str(&format!(
609 "async function dispatch{}(device, inputs, outputs) {{\n",
610 to_pascal_case(&self.name)
611 ));
612
613 js.push_str(" // Create shader module\n");
614 js.push_str(" const shaderModule = device.createShaderModule({\n");
615 js.push_str(&format!(" label: '{} shader',\n", self.name));
616 js.push_str(" code: SHADER_SOURCE,\n");
617 js.push_str(" });\n\n");
618
619 js.push_str(" // Calculate dispatch size\n");
620 if let Some(output) = self.outputs.first() {
621 let total_size = output.element_count();
622 js.push_str(&format!(" const totalElements = {};\n", total_size));
623 js.push_str(&format!(
624 " const numWorkgroups = Math.ceil(totalElements / {});\n\n",
625 wg_x * wg_y * wg_z
626 ));
627 }
628
629 js.push_str(" // Dispatch\n");
630 js.push_str(" const commandEncoder = device.createCommandEncoder();\n");
631 js.push_str(" const passEncoder = commandEncoder.beginComputePass();\n");
632 js.push_str(" passEncoder.setPipeline(pipeline);\n");
633 js.push_str(" passEncoder.setBindGroup(0, bindGroup);\n");
634 js.push_str(" passEncoder.dispatchWorkgroups(numWorkgroups, 1, 1);\n");
635 js.push_str(" passEncoder.end();\n");
636 js.push_str(" device.queue.submit([commandEncoder.finish()]);\n");
637 js.push_str("}\n");
638
639 js
640 }
641
642 #[must_use]
644 pub fn name(&self) -> &str {
645 &self.name
646 }
647
648 #[must_use]
650 pub fn get_workgroup_size(&self) -> (u32, u32, u32) {
651 self.workgroup_size
652 }
653
654 #[must_use]
656 pub fn inputs(&self) -> &[TensorBinding] {
657 &self.inputs
658 }
659
660 #[must_use]
662 pub fn outputs(&self) -> &[TensorBinding] {
663 &self.outputs
664 }
665}
666
667impl Brick for ComputeBrick {
668 fn brick_name(&self) -> &'static str {
669 "ComputeBrick"
670 }
671
672 fn assertions(&self) -> &[BrickAssertion] {
673 &[]
674 }
675
676 fn budget(&self) -> BrickBudget {
677 BrickBudget::uniform(100)
679 }
680
681 fn verify(&self) -> BrickVerification {
682 let mut passed = Vec::new();
683 let mut failed = Vec::new();
684
685 let (x, y, z) = self.workgroup_size;
687 if x * y * z > 1024 {
688 failed.push((
689 BrickAssertion::Custom {
690 name: "workgroup_size_valid".into(),
691 validator_id: 1,
692 },
693 format!(
694 "Workgroup size {}x{}x{}={} exceeds maximum 1024",
695 x,
696 y,
697 z,
698 x * y * z
699 ),
700 ));
701 } else {
702 passed.push(BrickAssertion::Custom {
703 name: "workgroup_size_valid".into(),
704 validator_id: 1,
705 });
706 }
707
708 if self.inputs.is_empty() {
710 failed.push((
711 BrickAssertion::Custom {
712 name: "has_inputs".into(),
713 validator_id: 2,
714 },
715 "ComputeBrick has no input tensors".into(),
716 ));
717 } else {
718 passed.push(BrickAssertion::Custom {
719 name: "has_inputs".into(),
720 validator_id: 2,
721 });
722 }
723
724 if self.outputs.is_empty() {
725 failed.push((
726 BrickAssertion::Custom {
727 name: "has_outputs".into(),
728 validator_id: 3,
729 },
730 "ComputeBrick has no output tensors".into(),
731 ));
732 } else {
733 passed.push(BrickAssertion::Custom {
734 name: "has_outputs".into(),
735 validator_id: 3,
736 });
737 }
738
739 let tensor_names: Vec<_> = self
741 .inputs
742 .iter()
743 .chain(self.outputs.iter())
744 .map(|t| t.name.as_str())
745 .collect();
746
747 for op in &self.operations {
748 match op {
749 TileOp::LoadShared { src, .. } => {
750 if !tensor_names.contains(&src.as_str()) {
751 failed.push((
752 BrickAssertion::Custom {
753 name: "tensor_exists".into(),
754 validator_id: 4,
755 },
756 format!("LoadShared references unknown tensor: {}", src),
757 ));
758 }
759 }
760 TileOp::StoreShared { dst } => {
761 if !tensor_names.contains(&dst.as_str()) {
762 failed.push((
763 BrickAssertion::Custom {
764 name: "tensor_exists".into(),
765 validator_id: 4,
766 },
767 format!("StoreShared references unknown tensor: {}", dst),
768 ));
769 }
770 }
771 _ => {}
772 }
773 }
774
775 if failed.is_empty() {
776 passed.push(BrickAssertion::Custom {
777 name: "compute_brick_valid".into(),
778 validator_id: 5,
779 });
780 }
781
782 BrickVerification {
783 passed,
784 failed,
785 verification_time: Duration::from_micros(100),
786 }
787 }
788
789 fn to_html(&self) -> String {
790 String::new()
792 }
793
794 fn to_css(&self) -> String {
795 String::new()
797 }
798}
799
800fn to_pascal_case(s: &str) -> String {
802 let mut result = String::new();
803 let mut capitalize_next = true;
804
805 for c in s.chars() {
806 if c == '_' || c == '-' || c == ' ' {
807 capitalize_next = true;
808 } else if capitalize_next {
809 result.push(c.to_ascii_uppercase());
810 capitalize_next = false;
811 } else {
812 result.push(c);
813 }
814 }
815
816 result
817}
818
819#[cfg(test)]
820#[allow(clippy::unwrap_used, clippy::expect_used)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_compute_brick_basic() {
826 let brick = ComputeBrick::new("test")
827 .workgroup_size(256, 1, 1)
828 .input("audio", TensorType::F32, &[1024])
829 .output("mel", TensorType::F32, &[80, 100]);
830
831 assert_eq!(brick.name(), "test");
832 assert_eq!(brick.get_workgroup_size(), (256, 1, 1));
833 assert_eq!(brick.inputs().len(), 1);
834 assert_eq!(brick.outputs().len(), 1);
835 }
836
837 #[test]
838 fn test_compute_brick_wgsl_generation() {
839 let brick = ComputeBrick::new("log-transform")
840 .workgroup_size(64, 1, 1)
841 .input("input", TensorType::F32, &[1024])
842 .output("output", TensorType::F32, &[1024])
843 .op(TileOp::LoadShared {
844 src: "input".into(),
845 tile_size: (64, 1),
846 })
847 .op(TileOp::Elementwise {
848 op: ElementwiseOp::Log,
849 operands: vec!["input".into()],
850 output: Some("output".into()),
851 })
852 .op(TileOp::StoreShared {
853 dst: "output".into(),
854 });
855
856 let wgsl = brick.to_wgsl();
857
858 assert!(wgsl.contains("@compute @workgroup_size(64, 1, 1)"));
859 assert!(wgsl.contains("fn main("));
860 assert!(wgsl.contains("log("));
861 assert!(wgsl.contains("Generated by probar"));
862 }
863
864 #[test]
865 fn test_compute_brick_verification() {
866 let brick = ComputeBrick::new("test")
867 .workgroup_size(256, 1, 1)
868 .input("input", TensorType::F32, &[1024])
869 .output("output", TensorType::F32, &[1024]);
870
871 let result = brick.verify();
872 assert!(result.is_valid());
873 }
874
875 #[test]
876 fn test_compute_brick_verification_fails_no_inputs() {
877 let brick = ComputeBrick::new("test").workgroup_size(256, 1, 1).output(
878 "output",
879 TensorType::F32,
880 &[1024],
881 );
882
883 let result = brick.verify();
884 assert!(!result.is_valid());
885 }
886
887 #[test]
888 fn test_compute_brick_verification_fails_large_workgroup() {
889 let brick = ComputeBrick::new("test")
890 .workgroup_size(1024, 2, 1) .input("input", TensorType::F32, &[1024])
892 .output("output", TensorType::F32, &[1024]);
893
894 let result = brick.verify();
895 assert!(!result.is_valid());
896 }
897
898 #[test]
899 fn test_tensor_binding() {
900 let binding = TensorBinding::new("audio", TensorType::F32, &[1024, 80])
901 .at(0, 1)
902 .writable();
903
904 assert_eq!(binding.name, "audio");
905 assert_eq!(binding.element_count(), 1024 * 80);
906 assert_eq!(binding.byte_size(), 1024 * 80 * 4);
907 assert!(!binding.read_only);
908 }
909
910 #[test]
911 fn test_tensor_type_wgsl() {
912 assert_eq!(TensorType::F32.to_wgsl(), "f32");
913 assert_eq!(TensorType::F16.to_wgsl(), "f16");
914 assert_eq!(TensorType::I32.to_wgsl(), "i32");
915 assert_eq!(TensorType::U32.to_wgsl(), "u32");
916 }
917
918 #[test]
919 fn test_elementwise_ops() {
920 assert_eq!(ElementwiseOp::Log.to_wgsl_expr("x"), "log(x)");
921 assert_eq!(ElementwiseOp::Exp.to_wgsl_expr("x"), "exp(x)");
922 assert_eq!(ElementwiseOp::Relu.to_wgsl_expr("x"), "max(x, 0.0)");
923 assert_eq!(ElementwiseOp::AddScalar(5).to_wgsl_expr("x"), "(x + 5.0)");
924 }
925
926 #[test]
927 fn test_rust_bindings_generation() {
928 let brick = ComputeBrick::new("mel-transform")
929 .workgroup_size(256, 1, 1)
930 .input("audio", TensorType::F32, &[1024])
931 .output("mel", TensorType::F32, &[80]);
932
933 let rust = brick.to_rust_bindings();
934
935 assert!(rust.contains("pub struct MelTransformCompute"));
936 assert!(rust.contains("WORKGROUP_SIZE"));
937 assert!(rust.contains("SHADER_SOURCE"));
938 assert!(rust.contains("create_bind_group_layout"));
939 }
940
941 #[test]
942 fn test_js_dispatch_generation() {
943 let brick = ComputeBrick::new("fft")
944 .workgroup_size(64, 1, 1)
945 .input("signal", TensorType::F32, &[512])
946 .output("spectrum", TensorType::F32, &[512]);
947
948 let js = brick.to_dispatch_js();
949
950 assert!(js.contains("async function dispatchFft"));
951 assert!(js.contains("WORKGROUP_SIZE"));
952 assert!(js.contains("dispatchWorkgroups"));
953 }
954
955 #[test]
956 fn test_tile_strategy_workgroup_size() {
957 let simple = TileStrategy::Simple2D {
958 tile_x: 16,
959 tile_y: 16,
960 };
961 assert_eq!(simple.optimal_workgroup_size(), (16, 16, 1));
962
963 let coop = TileStrategy::Cooperative { m: 8, n: 8, k: 4 };
964 assert_eq!(coop.optimal_workgroup_size(), (8, 8, 1));
965
966 let streaming = TileStrategy::Streaming { window: 32 };
967 assert_eq!(streaming.optimal_workgroup_size(), (32, 1, 1));
968 }
969
970 #[test]
975 fn test_tensor_type_rust() {
976 assert_eq!(TensorType::F32.to_rust(), "f32");
977 assert_eq!(TensorType::F16.to_rust(), "half::f16");
978 assert_eq!(TensorType::I32.to_rust(), "i32");
979 assert_eq!(TensorType::U32.to_rust(), "u32");
980 }
981
982 #[test]
983 fn test_tensor_type_byte_size() {
984 assert_eq!(TensorType::F32.byte_size(), 4);
985 assert_eq!(TensorType::F16.byte_size(), 2);
986 assert_eq!(TensorType::I32.byte_size(), 4);
987 assert_eq!(TensorType::U32.byte_size(), 4);
988 }
989
990 #[test]
991 fn test_tensor_type_clone() {
992 let t = TensorType::F32;
993 let cloned = t;
994 assert_eq!(t, cloned);
995 }
996
997 #[test]
998 fn test_tensor_binding_default_values() {
999 let binding = TensorBinding::new("test", TensorType::I32, &[10, 20]);
1000 assert_eq!(binding.group, 0);
1001 assert_eq!(binding.binding, 0);
1002 assert!(binding.read_only);
1003 }
1004
1005 #[test]
1006 fn test_tensor_binding_to_wgsl_binding_read_only() {
1007 let binding = TensorBinding::new("data", TensorType::F32, &[100]).at(1, 2);
1008 let wgsl = binding.to_wgsl_binding();
1009 assert!(wgsl.contains("@group(1) @binding(2)"));
1010 assert!(wgsl.contains("var<storage, read>"));
1011 assert!(wgsl.contains("data"));
1012 assert!(wgsl.contains("f32"));
1013 }
1014
1015 #[test]
1016 fn test_tensor_binding_to_wgsl_binding_read_write() {
1017 let binding = TensorBinding::new("output", TensorType::U32, &[50])
1018 .at(0, 0)
1019 .writable();
1020 let wgsl = binding.to_wgsl_binding();
1021 assert!(wgsl.contains("var<storage, read_write>"));
1022 }
1023
1024 #[test]
1025 fn test_tensor_binding_clone() {
1026 let binding = TensorBinding::new("test", TensorType::F32, &[1, 2, 3])
1027 .at(1, 2)
1028 .writable();
1029 let cloned = binding.clone();
1030 assert_eq!(binding.name, cloned.name);
1031 assert_eq!(binding.shape, cloned.shape);
1032 assert_eq!(binding.read_only, cloned.read_only);
1033 }
1034
1035 #[test]
1036 fn test_tile_strategy_none() {
1037 let strategy = TileStrategy::None;
1038 assert_eq!(strategy.optimal_workgroup_size(), (64, 1, 1));
1039 }
1040
1041 #[test]
1042 fn test_tile_strategy_clone() {
1043 let strategy = TileStrategy::Simple2D {
1044 tile_x: 8,
1045 tile_y: 8,
1046 };
1047 let cloned = strategy;
1048 assert!(matches!(
1049 cloned,
1050 TileStrategy::Simple2D {
1051 tile_x: 8,
1052 tile_y: 8
1053 }
1054 ));
1055 }
1056
1057 #[test]
1058 fn test_elementwise_op_sqrt() {
1059 assert_eq!(ElementwiseOp::Sqrt.to_wgsl_expr("val"), "sqrt(val)");
1060 }
1061
1062 #[test]
1063 fn test_elementwise_op_abs() {
1064 assert_eq!(ElementwiseOp::Abs.to_wgsl_expr("v"), "abs(v)");
1065 }
1066
1067 #[test]
1068 fn test_elementwise_op_sigmoid() {
1069 assert_eq!(
1070 ElementwiseOp::Sigmoid.to_wgsl_expr("x"),
1071 "1.0 / (1.0 + exp(-x))"
1072 );
1073 }
1074
1075 #[test]
1076 fn test_elementwise_op_tanh() {
1077 assert_eq!(ElementwiseOp::Tanh.to_wgsl_expr("x"), "tanh(x)");
1078 }
1079
1080 #[test]
1081 fn test_elementwise_op_mul_scalar() {
1082 assert_eq!(ElementwiseOp::MulScalar(3).to_wgsl_expr("y"), "(y * 3.0)");
1083 assert_eq!(ElementwiseOp::MulScalar(-2).to_wgsl_expr("x"), "(x * -2.0)");
1084 }
1085
1086 #[test]
1087 fn test_elementwise_op_clamp() {
1088 assert_eq!(ElementwiseOp::Clamp.to_wgsl_expr("x"), "clamp(x, 0.0, 1.0)");
1089 }
1090
1091 #[test]
1092 fn test_elementwise_op_eq() {
1093 assert_eq!(ElementwiseOp::Log, ElementwiseOp::Log);
1094 assert_ne!(ElementwiseOp::Log, ElementwiseOp::Exp);
1095 assert_eq!(ElementwiseOp::AddScalar(5), ElementwiseOp::AddScalar(5));
1096 assert_ne!(ElementwiseOp::AddScalar(5), ElementwiseOp::AddScalar(6));
1097 }
1098
1099 #[test]
1100 fn test_reduce_kind_identity() {
1101 assert_eq!(ReduceKind::Sum.identity(), "0.0");
1102 assert_eq!(ReduceKind::Mean.identity(), "0.0");
1103 assert_eq!(ReduceKind::Max.identity(), "-3.402823e+38");
1104 assert_eq!(ReduceKind::Min.identity(), "3.402823e+38");
1105 }
1106
1107 #[test]
1108 fn test_reduce_kind_combine_op() {
1109 assert_eq!(ReduceKind::Sum.combine_op(), "+");
1110 assert_eq!(ReduceKind::Mean.combine_op(), "+");
1111 assert_eq!(ReduceKind::Max.combine_op(), "max");
1112 assert_eq!(ReduceKind::Min.combine_op(), "min");
1113 }
1114
1115 #[test]
1116 fn test_reduce_kind_eq() {
1117 assert_eq!(ReduceKind::Sum, ReduceKind::Sum);
1118 assert_ne!(ReduceKind::Sum, ReduceKind::Max);
1119 }
1120
1121 #[test]
1122 fn test_tile_op_load_shared() {
1123 let op = TileOp::LoadShared {
1124 src: "audio".into(),
1125 tile_size: (32, 32),
1126 };
1127 match op {
1128 TileOp::LoadShared { src, tile_size } => {
1129 assert_eq!(src, "audio");
1130 assert_eq!(tile_size, (32, 32));
1131 }
1132 _ => panic!("Expected LoadShared"),
1133 }
1134 }
1135
1136 #[test]
1137 fn test_tile_op_mma() {
1138 let op = TileOp::Mma {
1139 a: "A".into(),
1140 b: "B".into(),
1141 c: "C".into(),
1142 };
1143 match op {
1144 TileOp::Mma { a, b, c } => {
1145 assert_eq!(a, "A");
1146 assert_eq!(b, "B");
1147 assert_eq!(c, "C");
1148 }
1149 _ => panic!("Expected Mma"),
1150 }
1151 }
1152
1153 #[test]
1154 fn test_tile_op_reduce() {
1155 let op = TileOp::Reduce {
1156 kind: ReduceKind::Max,
1157 input: "values".into(),
1158 output: "max_val".into(),
1159 };
1160 match op {
1161 TileOp::Reduce {
1162 kind,
1163 input,
1164 output,
1165 } => {
1166 assert_eq!(kind, ReduceKind::Max);
1167 assert_eq!(input, "values");
1168 assert_eq!(output, "max_val");
1169 }
1170 _ => panic!("Expected Reduce"),
1171 }
1172 }
1173
1174 #[test]
1175 fn test_tile_op_barrier() {
1176 let op = TileOp::Barrier;
1177 assert!(matches!(op, TileOp::Barrier));
1178 }
1179
1180 #[test]
1181 fn test_tile_op_clone() {
1182 let op = TileOp::Elementwise {
1183 op: ElementwiseOp::Relu,
1184 operands: vec!["x".into(), "y".into()],
1185 output: Some("z".into()),
1186 };
1187 let cloned = op;
1188 assert!(matches!(cloned, TileOp::Elementwise { .. }));
1189 }
1190
1191 #[test]
1192 fn test_compute_brick_tile_strategy() {
1193 let brick = ComputeBrick::new("test").tile_strategy(TileStrategy::Cooperative {
1194 m: 16,
1195 n: 16,
1196 k: 8,
1197 });
1198
1199 assert_eq!(brick.name(), "test");
1201 }
1202
1203 #[test]
1204 fn test_compute_brick_shared_memory() {
1205 let brick = ComputeBrick::new("test")
1206 .shared("tile_a", TensorType::F32, 256)
1207 .shared("tile_b", TensorType::F32, 128);
1208
1209 let wgsl = brick.to_wgsl();
1210 assert!(wgsl.contains("var<workgroup> tile_a"));
1211 assert!(wgsl.contains("var<workgroup> tile_b"));
1212 }
1213
1214 #[test]
1215 fn test_compute_brick_verification_no_outputs() {
1216 let brick = ComputeBrick::new("test").input("input", TensorType::F32, &[1024]);
1217
1218 let result = brick.verify();
1219 assert!(!result.is_valid());
1220 }
1221
1222 #[test]
1223 fn test_compute_brick_verification_invalid_load_tensor() {
1224 let brick = ComputeBrick::new("test")
1225 .input("input", TensorType::F32, &[1024])
1226 .output("output", TensorType::F32, &[1024])
1227 .op(TileOp::LoadShared {
1228 src: "nonexistent".into(),
1229 tile_size: (64, 1),
1230 });
1231
1232 let result = brick.verify();
1233 assert!(!result.is_valid());
1234 }
1235
1236 #[test]
1237 fn test_compute_brick_verification_invalid_store_tensor() {
1238 let brick = ComputeBrick::new("test")
1239 .input("input", TensorType::F32, &[1024])
1240 .output("output", TensorType::F32, &[1024])
1241 .op(TileOp::StoreShared {
1242 dst: "nonexistent".into(),
1243 });
1244
1245 let result = brick.verify();
1246 assert!(!result.is_valid());
1247 }
1248
1249 #[test]
1250 fn test_compute_brick_wgsl_barrier() {
1251 let brick = ComputeBrick::new("test")
1252 .input("input", TensorType::F32, &[64])
1253 .output("output", TensorType::F32, &[64])
1254 .op(TileOp::Barrier);
1255
1256 let wgsl = brick.to_wgsl();
1257 assert!(wgsl.contains("workgroupBarrier()"));
1258 }
1259
1260 #[test]
1261 fn test_compute_brick_wgsl_mma() {
1262 let brick = ComputeBrick::new("matmul")
1263 .input("A", TensorType::F32, &[64, 64])
1264 .input("B", TensorType::F32, &[64, 64])
1265 .output("C", TensorType::F32, &[64, 64])
1266 .op(TileOp::Mma {
1267 a: "A".into(),
1268 b: "B".into(),
1269 c: "C".into(),
1270 });
1271
1272 let wgsl = brick.to_wgsl();
1273 assert!(wgsl.contains("Matrix multiply"));
1274 }
1275
1276 #[test]
1277 fn test_compute_brick_wgsl_reduce() {
1278 let brick = ComputeBrick::new("reduce")
1279 .input("values", TensorType::F32, &[1024])
1280 .output("result", TensorType::F32, &[1])
1281 .op(TileOp::Reduce {
1282 kind: ReduceKind::Sum,
1283 input: "values".into(),
1284 output: "result".into(),
1285 });
1286
1287 let wgsl = brick.to_wgsl();
1288 assert!(wgsl.contains("Reduce"));
1289 }
1290
1291 #[test]
1292 fn test_compute_brick_wgsl_elementwise_no_output() {
1293 let brick = ComputeBrick::new("test")
1294 .input("x", TensorType::F32, &[64])
1295 .output("y", TensorType::F32, &[64])
1296 .op(TileOp::LoadShared {
1297 src: "x".into(),
1298 tile_size: (64, 1),
1299 })
1300 .op(TileOp::Elementwise {
1301 op: ElementwiseOp::Log,
1302 operands: vec!["x".into()],
1303 output: None, });
1305
1306 let wgsl = brick.to_wgsl();
1307 assert!(wgsl.contains("log(val_x)"));
1308 }
1309
1310 #[test]
1311 fn test_compute_brick_wgsl_store_fallback() {
1312 let brick = ComputeBrick::new("test")
1313 .input("input", TensorType::F32, &[64])
1314 .output("output", TensorType::F32, &[64])
1315 .op(TileOp::LoadShared {
1316 src: "input".into(),
1317 tile_size: (64, 1),
1318 })
1319 .op(TileOp::StoreShared {
1320 dst: "output".into(),
1321 });
1322
1323 let wgsl = brick.to_wgsl();
1324 assert!(wgsl.contains("output[gid]"));
1325 }
1326
1327 #[test]
1328 fn test_compute_brick_implements_brick() {
1329 let brick = ComputeBrick::new("test")
1330 .input("in", TensorType::F32, &[32])
1331 .output("out", TensorType::F32, &[32]);
1332
1333 assert_eq!(brick.brick_name(), "ComputeBrick");
1334 assert!(brick.assertions().is_empty());
1335 assert_eq!(brick.budget().total_ms, 100);
1336 assert!(brick.to_html().is_empty());
1337 assert!(brick.to_css().is_empty());
1338 }
1339
1340 #[test]
1341 fn test_to_pascal_case_variants() {
1342 assert_eq!(to_pascal_case("simple"), "Simple");
1343 assert_eq!(to_pascal_case("two_words"), "TwoWords");
1344 assert_eq!(to_pascal_case("three-part-name"), "ThreePartName");
1345 assert_eq!(to_pascal_case("mixed_style-here"), "MixedStyleHere");
1346 assert_eq!(to_pascal_case("with space"), "WithSpace");
1347 }
1348
1349 #[test]
1350 fn test_compute_brick_multiple_inputs() {
1351 let brick = ComputeBrick::new("multi")
1352 .input("a", TensorType::F32, &[100])
1353 .input("b", TensorType::I32, &[100])
1354 .input("c", TensorType::U32, &[100])
1355 .output("result", TensorType::F32, &[100]);
1356
1357 assert_eq!(brick.inputs().len(), 3);
1358 assert_eq!(brick.inputs()[0].binding, 0);
1359 assert_eq!(brick.inputs()[1].binding, 1);
1360 assert_eq!(brick.inputs()[2].binding, 2);
1361 }
1362
1363 #[test]
1364 fn test_compute_brick_multiple_outputs() {
1365 let brick = ComputeBrick::new("multi_out")
1366 .input("in", TensorType::F32, &[50])
1367 .output("out1", TensorType::F32, &[50])
1368 .output("out2", TensorType::F32, &[25]);
1369
1370 assert_eq!(brick.outputs().len(), 2);
1371 assert_eq!(brick.outputs()[0].binding, 0);
1372 assert_eq!(brick.outputs()[1].binding, 1);
1373 assert_eq!(brick.outputs()[0].group, 1);
1374 assert_eq!(brick.outputs()[1].group, 1);
1375 }
1376
1377 #[test]
1378 fn test_compute_brick_clone() {
1379 let brick = ComputeBrick::new("test")
1380 .workgroup_size(128, 4, 1)
1381 .input("in", TensorType::F16, &[256])
1382 .output("out", TensorType::F16, &[256])
1383 .shared("cache", TensorType::F16, 512);
1384
1385 let cloned = brick.clone();
1386 assert_eq!(brick.name(), cloned.name());
1387 assert_eq!(brick.get_workgroup_size(), cloned.get_workgroup_size());
1388 }
1389
1390 #[test]
1391 fn test_js_dispatch_no_outputs() {
1392 let brick = ComputeBrick::new("no_out").input("in", TensorType::F32, &[10]);
1393
1394 let js = brick.to_dispatch_js();
1395 assert!(js.contains("dispatchNoOut"));
1397 }
1398
1399 #[test]
1400 fn test_rust_bindings_multiple_io() {
1401 let brick = ComputeBrick::new("complex")
1402 .input("in1", TensorType::F32, &[100])
1403 .input("in2", TensorType::I32, &[50])
1404 .output("out1", TensorType::F32, &[100])
1405 .output("out2", TensorType::U32, &[25]);
1406
1407 let rust = brick.to_rust_bindings();
1408 assert!(rust.contains("Input: in1"));
1409 assert!(rust.contains("Input: in2"));
1410 assert!(rust.contains("Output: out1"));
1411 assert!(rust.contains("Output: out2"));
1412 }
1413
1414 #[test]
1415 fn test_tensor_binding_empty_shape() {
1416 let binding = TensorBinding::new("scalar", TensorType::F32, &[]);
1417 assert_eq!(binding.element_count(), 1); assert_eq!(binding.byte_size(), 4);
1419 }
1420}