1use super::{Brick, BrickAssertion, BrickBudget, BrickVerification};
37use std::time::Duration;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum TensorType {
42 F32,
44 F16,
46 I32,
48 U32,
50}
51
52impl TensorType {
53 #[must_use]
55 pub fn to_wgsl(&self) -> &'static str {
56 match self {
57 Self::F32 => "f32",
58 Self::F16 => "f16",
59 Self::I32 => "i32",
60 Self::U32 => "u32",
61 }
62 }
63
64 #[must_use]
66 pub fn to_rust(&self) -> &'static str {
67 match self {
68 Self::F32 => "f32",
69 Self::F16 => "half::f16",
70 Self::I32 => "i32",
71 Self::U32 => "u32",
72 }
73 }
74
75 #[must_use]
77 pub const fn byte_size(&self) -> usize {
78 match self {
79 Self::F32 | Self::I32 | Self::U32 => 4,
80 Self::F16 => 2,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct TensorBinding {
88 pub name: String,
90 pub dtype: TensorType,
92 pub shape: Vec<u32>,
94 pub group: u32,
96 pub binding: u32,
98 pub read_only: bool,
100}
101
102impl TensorBinding {
103 #[must_use]
105 pub fn new(name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
106 Self {
107 name: name.into(),
108 dtype,
109 shape: shape.to_vec(),
110 group: 0,
111 binding: 0,
112 read_only: true,
113 }
114 }
115
116 #[must_use]
118 pub fn at(mut self, group: u32, binding: u32) -> Self {
119 self.group = group;
120 self.binding = binding;
121 self
122 }
123
124 #[must_use]
126 pub fn writable(mut self) -> Self {
127 self.read_only = false;
128 self
129 }
130
131 #[must_use]
133 pub fn element_count(&self) -> u32 {
134 self.shape.iter().product()
135 }
136
137 #[must_use]
139 pub fn byte_size(&self) -> usize {
140 self.element_count() as usize * self.dtype.byte_size()
141 }
142
143 #[must_use]
145 pub fn to_wgsl_binding(&self) -> String {
146 let access = if self.read_only { "read" } else { "read_write" };
147 format!(
148 "@group({}) @binding({}) var<storage, {}> {}: array<{}>;",
149 self.group,
150 self.binding,
151 access,
152 self.name,
153 self.dtype.to_wgsl()
154 )
155 }
156}
157
158#[derive(Debug, Clone)]
160pub enum TileStrategy {
161 Simple2D {
163 tile_x: u32,
165 tile_y: u32,
167 },
168 Cooperative {
170 m: u32,
172 n: u32,
174 k: u32,
176 },
177 Streaming {
179 window: u32,
181 },
182 None,
184}
185
186impl TileStrategy {
187 #[must_use]
189 pub fn optimal_workgroup_size(&self) -> (u32, u32, u32) {
190 match self {
191 Self::Simple2D { tile_x, tile_y } => (*tile_x, *tile_y, 1),
192 Self::Cooperative { m, n, .. } => (*m, *n, 1),
193 Self::Streaming { window } => (*window, 1, 1),
194 Self::None => (64, 1, 1),
195 }
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ElementwiseOp {
202 Log,
204 Exp,
206 Sqrt,
208 Abs,
210 Relu,
212 Sigmoid,
214 Tanh,
216 AddScalar(i32),
218 MulScalar(i32),
220 Clamp,
222}
223
224impl ElementwiseOp {
225 #[must_use]
227 pub fn to_wgsl_expr(&self, operand: &str) -> String {
228 match self {
229 Self::Log => format!("log({})", operand),
230 Self::Exp => format!("exp({})", operand),
231 Self::Sqrt => format!("sqrt({})", operand),
232 Self::Abs => format!("abs({})", operand),
233 Self::Relu => format!("max({}, 0.0)", operand),
234 Self::Sigmoid => format!("1.0 / (1.0 + exp(-{}))", operand),
235 Self::Tanh => format!("tanh({})", operand),
236 Self::AddScalar(s) => format!("({} + {}.0)", operand, s),
237 Self::MulScalar(s) => format!("({} * {}.0)", operand, s),
238 Self::Clamp => format!("clamp({}, 0.0, 1.0)", operand),
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
245pub enum TileOp {
246 LoadShared {
248 src: String,
250 tile_size: (u32, u32),
252 },
253 Mma {
255 a: String,
257 b: String,
259 c: String,
261 },
262 Elementwise {
264 op: ElementwiseOp,
266 operands: Vec<String>,
268 output: Option<String>,
270 },
271 StoreShared {
273 dst: String,
275 },
276 Barrier,
278 Reduce {
280 kind: ReduceKind,
282 input: String,
284 output: String,
286 },
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum ReduceKind {
292 Sum,
294 Max,
296 Min,
298 Mean,
300}
301
302impl ReduceKind {
303 #[must_use]
305 pub fn identity(&self) -> &'static str {
306 match self {
307 Self::Sum | Self::Mean => "0.0",
308 Self::Max => "-3.402823e+38", Self::Min => "3.402823e+38", }
311 }
312
313 #[must_use]
315 pub fn combine_op(&self) -> &'static str {
316 match self {
317 Self::Sum | Self::Mean => "+",
318 Self::Max => "max",
319 Self::Min => "min",
320 }
321 }
322}
323
324#[derive(Debug, Clone)]
326pub struct ComputeBrick {
327 name: String,
329 workgroup_size: (u32, u32, u32),
331 inputs: Vec<TensorBinding>,
333 outputs: Vec<TensorBinding>,
335 tile_strategy: TileStrategy,
337 operations: Vec<TileOp>,
339 shared_memory: Vec<(String, TensorType, u32)>,
341}
342
343impl ComputeBrick {
344 #[must_use]
346 pub fn new(name: impl Into<String>) -> Self {
347 Self {
348 name: name.into(),
349 workgroup_size: (64, 1, 1),
350 inputs: Vec::new(),
351 outputs: Vec::new(),
352 tile_strategy: TileStrategy::None,
353 operations: Vec::new(),
354 shared_memory: Vec::new(),
355 }
356 }
357
358 #[must_use]
360 pub fn workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
361 self.workgroup_size = (x, y, z);
362 self
363 }
364
365 #[must_use]
367 pub fn input(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
368 let binding_idx = self.inputs.len() as u32;
369 self.inputs
370 .push(TensorBinding::new(name, dtype, shape).at(0, binding_idx));
371 self
372 }
373
374 #[must_use]
376 pub fn output(mut self, name: impl Into<String>, dtype: TensorType, shape: &[u32]) -> Self {
377 let binding_idx = self.outputs.len() as u32;
378 self.outputs.push(
379 TensorBinding::new(name, dtype, shape)
380 .at(1, binding_idx)
381 .writable(),
382 );
383 self
384 }
385
386 #[must_use]
388 pub fn tile_strategy(mut self, strategy: TileStrategy) -> Self {
389 self.tile_strategy = strategy;
390 self
391 }
392
393 #[must_use]
395 pub fn op(mut self, operation: TileOp) -> Self {
396 self.operations.push(operation);
397 self
398 }
399
400 #[must_use]
402 pub fn shared(mut self, name: impl Into<String>, dtype: TensorType, size: u32) -> Self {
403 self.shared_memory.push((name.into(), dtype, size));
404 self
405 }
406
407 #[must_use]
409 pub fn to_wgsl(&self) -> String {
410 let mut wgsl = String::new();
411
412 wgsl.push_str(&format!(
414 "// {} Compute Shader\n",
415 to_pascal_case(&self.name)
416 ));
417 wgsl.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
418
419 for input in &self.inputs {
421 wgsl.push_str(&input.to_wgsl_binding());
422 wgsl.push('\n');
423 }
424
425 for output in &self.outputs {
427 wgsl.push_str(&output.to_wgsl_binding());
428 wgsl.push('\n');
429 }
430
431 wgsl.push('\n');
432
433 for (name, dtype, size) in &self.shared_memory {
435 wgsl.push_str(&format!(
436 "var<workgroup> {}: array<{}, {}>;\n",
437 name,
438 dtype.to_wgsl(),
439 size
440 ));
441 }
442
443 if !self.shared_memory.is_empty() {
444 wgsl.push('\n');
445 }
446
447 let (wg_x, wg_y, wg_z) = self.workgroup_size;
449 wgsl.push_str(&format!(
450 "@compute @workgroup_size({}, {}, {})\n",
451 wg_x, wg_y, wg_z
452 ));
453 wgsl.push_str("fn main(\n");
454 wgsl.push_str(" @builtin(global_invocation_id) global_id: vec3<u32>,\n");
455 wgsl.push_str(" @builtin(local_invocation_id) local_id: vec3<u32>,\n");
456 wgsl.push_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n");
457 wgsl.push_str(") {\n");
458
459 wgsl.push_str(" let gid = global_id.x + global_id.y * ");
461 wgsl.push_str(&format!("{}u;\n", wg_x));
462 wgsl.push_str(" let lid = local_id.x + local_id.y * ");
463 wgsl.push_str(&format!("{}u;\n\n", wg_x));
464
465 for op in &self.operations {
467 match op {
468 TileOp::LoadShared { src, tile_size: _ } => {
469 wgsl.push_str(&format!(" // Load from {} to shared memory\n", src));
470 wgsl.push_str(&format!(" let val_{} = {}[gid];\n", src, src));
471 }
472 TileOp::Elementwise {
473 op: elem_op,
474 operands,
475 output,
476 } => {
477 let input = &operands[0];
478 let out_name = output.as_ref().unwrap_or(input);
479 let input_val = format!("val_{}", input);
480 let expr = elem_op.to_wgsl_expr(&input_val);
481 wgsl.push_str(&format!(" let val_{} = {};\n", out_name, expr));
482 }
483 TileOp::StoreShared { dst } => {
484 wgsl.push_str(&format!(" // Store to {}\n", dst));
485 let val_name = if self.operations.iter().any(
487 |o| matches!(o, TileOp::Elementwise { output: Some(n), .. } if n == dst),
488 ) {
489 format!("val_{}", dst)
490 } else if let Some(input) = self.inputs.first() {
491 format!("val_{}", input.name)
492 } else {
493 "0.0".to_string()
494 };
495 wgsl.push_str(&format!(" {}[gid] = {};\n", dst, val_name));
496 }
497 TileOp::Barrier => {
498 wgsl.push_str(" workgroupBarrier();\n");
499 }
500 TileOp::Mma { a, b, c } => {
501 wgsl.push_str(&format!(" // Matrix multiply: {} = {} @ {}\n", c, a, b));
502 wgsl.push_str(" // TODO: Implement cooperative matrix\n");
503 }
504 TileOp::Reduce {
505 kind,
506 input,
507 output,
508 } => {
509 wgsl.push_str(&format!(
510 " // Reduce {} -> {} ({:?})\n",
511 input, output, kind
512 ));
513 }
514 }
515 }
516
517 wgsl.push_str("}\n");
518
519 wgsl
520 }
521
522 #[must_use]
524 pub fn to_rust_bindings(&self) -> String {
525 let mut rust = String::new();
526
527 rust.push_str(&format!(
529 "//! {} Compute Bindings\n",
530 to_pascal_case(&self.name)
531 ));
532 rust.push_str("//! Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
533 rust.push_str(
534 "use wgpu::{BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry};\n",
535 );
536 rust.push_str("use wgpu::{ShaderStages, BufferBindingType, BindingType};\n\n");
537
538 let struct_name = to_pascal_case(&self.name);
539
540 rust.push_str(&format!("pub struct {}Compute {{\n", struct_name));
542 rust.push_str(" pub pipeline: wgpu::ComputePipeline,\n");
543 rust.push_str(" pub bind_group_layout: wgpu::BindGroupLayout,\n");
544 rust.push_str("}\n\n");
545
546 rust.push_str(&format!("impl {}Compute {{\n", struct_name));
548 rust.push_str(" pub const WORKGROUP_SIZE: (u32, u32, u32) = ");
549 rust.push_str(&format!("{:?};\n\n", self.workgroup_size));
550
551 rust.push_str(" pub const SHADER_SOURCE: &'static str = r#\"\n");
553 rust.push_str(&self.to_wgsl());
554 rust.push_str("\"#;\n\n");
555
556 rust.push_str(
558 " pub fn create_bind_group_layout(device: &wgpu::Device) -> BindGroupLayout {\n",
559 );
560 rust.push_str(" device.create_bind_group_layout(&BindGroupLayoutDescriptor {\n");
561 rust.push_str(&format!(
562 " label: Some(\"{} bind group layout\"),\n",
563 self.name
564 ));
565 rust.push_str(" entries: &[\n");
566
567 for input in &self.inputs {
568 rust.push_str(&format!(" // Input: {}\n", input.name));
569 rust.push_str(&format!(
570 " BindGroupLayoutEntry {{\n binding: {},\n visibility: ShaderStages::COMPUTE,\n ty: BindingType::Buffer {{\n ty: BufferBindingType::Storage {{ read_only: true }},\n has_dynamic_offset: false,\n min_binding_size: None,\n }},\n count: None,\n }},\n",
571 input.binding
572 ));
573 }
574
575 for output in &self.outputs {
576 rust.push_str(&format!(" // Output: {}\n", output.name));
577 rust.push_str(&format!(
578 " BindGroupLayoutEntry {{\n binding: {},\n visibility: ShaderStages::COMPUTE,\n ty: BindingType::Buffer {{\n ty: BufferBindingType::Storage {{ read_only: false }},\n has_dynamic_offset: false,\n min_binding_size: None,\n }},\n count: None,\n }},\n",
579 output.binding
580 ));
581 }
582
583 rust.push_str(" ],\n");
584 rust.push_str(" })\n");
585 rust.push_str(" }\n");
586 rust.push_str("}\n");
587
588 rust
589 }
590
591 #[must_use]
593 pub fn to_dispatch_js(&self) -> String {
594 let mut js = String::new();
595
596 js.push_str(&format!(
597 "// {} Compute Dispatch\n",
598 to_pascal_case(&self.name)
599 ));
600 js.push_str("// Generated by probar ComputeBrick - DO NOT EDIT MANUALLY\n\n");
601
602 let (wg_x, wg_y, wg_z) = self.workgroup_size;
603 js.push_str(&format!(
604 "const WORKGROUP_SIZE = [{}, {}, {}];\n\n",
605 wg_x, wg_y, wg_z
606 ));
607
608 js.push_str(&format!(
609 "async function dispatch{}(device, inputs, outputs) {{\n",
610 to_pascal_case(&self.name)
611 ));
612
613 js.push_str(" // Create shader module\n");
614 js.push_str(" const shaderModule = device.createShaderModule({\n");
615 js.push_str(&format!(" label: '{} shader',\n", self.name));
616 js.push_str(" code: SHADER_SOURCE,\n");
617 js.push_str(" });\n\n");
618
619 js.push_str(" // Calculate dispatch size\n");
620 if let Some(output) = self.outputs.first() {
621 let total_size = output.element_count();
622 js.push_str(&format!(" const totalElements = {};\n", total_size));
623 js.push_str(&format!(
624 " const numWorkgroups = Math.ceil(totalElements / {});\n\n",
625 wg_x * wg_y * wg_z
626 ));
627 }
628
629 js.push_str(" // Dispatch\n");
630 js.push_str(" const commandEncoder = device.createCommandEncoder();\n");
631 js.push_str(" const passEncoder = commandEncoder.beginComputePass();\n");
632 js.push_str(" passEncoder.setPipeline(pipeline);\n");
633 js.push_str(" passEncoder.setBindGroup(0, bindGroup);\n");
634 js.push_str(" passEncoder.dispatchWorkgroups(numWorkgroups, 1, 1);\n");
635 js.push_str(" passEncoder.end();\n");
636 js.push_str(" device.queue.submit([commandEncoder.finish()]);\n");
637 js.push_str("}\n");
638
639 js
640 }
641
642 #[must_use]
644 pub fn name(&self) -> &str {
645 &self.name
646 }
647
648 #[must_use]
650 pub fn get_workgroup_size(&self) -> (u32, u32, u32) {
651 self.workgroup_size
652 }
653
654 #[must_use]
656 pub fn inputs(&self) -> &[TensorBinding] {
657 &self.inputs
658 }
659
660 #[must_use]
662 pub fn outputs(&self) -> &[TensorBinding] {
663 &self.outputs
664 }
665}
666
667impl Brick for ComputeBrick {
668 fn brick_name(&self) -> &'static str {
669 "ComputeBrick"
670 }
671
672 fn assertions(&self) -> &[BrickAssertion] {
673 &[]
674 }
675
676 fn budget(&self) -> BrickBudget {
677 BrickBudget::uniform(100)
679 }
680
681 fn verify(&self) -> BrickVerification {
682 let mut passed = Vec::new();
683 let mut failed = Vec::new();
684
685 let (x, y, z) = self.workgroup_size;
687 if x * y * z > 1024 {
688 failed.push((
689 BrickAssertion::Custom {
690 name: "workgroup_size_valid".into(),
691 validator_id: 1,
692 },
693 format!(
694 "Workgroup size {}x{}x{}={} exceeds maximum 1024",
695 x,
696 y,
697 z,
698 x * y * z
699 ),
700 ));
701 } else {
702 passed.push(BrickAssertion::Custom {
703 name: "workgroup_size_valid".into(),
704 validator_id: 1,
705 });
706 }
707
708 if self.inputs.is_empty() {
710 failed.push((
711 BrickAssertion::Custom {
712 name: "has_inputs".into(),
713 validator_id: 2,
714 },
715 "ComputeBrick has no input tensors".into(),
716 ));
717 } else {
718 passed.push(BrickAssertion::Custom {
719 name: "has_inputs".into(),
720 validator_id: 2,
721 });
722 }
723
724 if self.outputs.is_empty() {
725 failed.push((
726 BrickAssertion::Custom {
727 name: "has_outputs".into(),
728 validator_id: 3,
729 },
730 "ComputeBrick has no output tensors".into(),
731 ));
732 } else {
733 passed.push(BrickAssertion::Custom {
734 name: "has_outputs".into(),
735 validator_id: 3,
736 });
737 }
738
739 let tensor_names: Vec<_> = self
741 .inputs
742 .iter()
743 .chain(self.outputs.iter())
744 .map(|t| t.name.as_str())
745 .collect();
746
747 for op in &self.operations {
748 match op {
749 TileOp::LoadShared { src, .. } => {
750 if !tensor_names.contains(&src.as_str()) {
751 failed.push((
752 BrickAssertion::Custom {
753 name: "tensor_exists".into(),
754 validator_id: 4,
755 },
756 format!("LoadShared references unknown tensor: {}", src),
757 ));
758 }
759 }
760 TileOp::StoreShared { dst } => {
761 if !tensor_names.contains(&dst.as_str()) {
762 failed.push((
763 BrickAssertion::Custom {
764 name: "tensor_exists".into(),
765 validator_id: 4,
766 },
767 format!("StoreShared references unknown tensor: {}", dst),
768 ));
769 }
770 }
771 _ => {}
772 }
773 }
774
775 if failed.is_empty() {
776 passed.push(BrickAssertion::Custom {
777 name: "compute_brick_valid".into(),
778 validator_id: 5,
779 });
780 }
781
782 BrickVerification {
783 passed,
784 failed,
785 verification_time: Duration::from_micros(100),
786 }
787 }
788
789 fn to_html(&self) -> String {
790 String::new()
792 }
793
794 fn to_css(&self) -> String {
795 String::new()
797 }
798}
799
800fn to_pascal_case(s: &str) -> String {
802 let mut result = String::new();
803 let mut capitalize_next = true;
804
805 for c in s.chars() {
806 if c == '_' || c == '-' || c == ' ' {
807 capitalize_next = true;
808 } else if capitalize_next {
809 result.push(c.to_ascii_uppercase());
810 capitalize_next = false;
811 } else {
812 result.push(c);
813 }
814 }
815
816 result
817}
818
819#[cfg(test)]
820mod tests {
821 use super::*;
822
823 #[test]
824 fn test_compute_brick_basic() {
825 let brick = ComputeBrick::new("test")
826 .workgroup_size(256, 1, 1)
827 .input("audio", TensorType::F32, &[1024])
828 .output("mel", TensorType::F32, &[80, 100]);
829
830 assert_eq!(brick.name(), "test");
831 assert_eq!(brick.get_workgroup_size(), (256, 1, 1));
832 assert_eq!(brick.inputs().len(), 1);
833 assert_eq!(brick.outputs().len(), 1);
834 }
835
836 #[test]
837 fn test_compute_brick_wgsl_generation() {
838 let brick = ComputeBrick::new("log-transform")
839 .workgroup_size(64, 1, 1)
840 .input("input", TensorType::F32, &[1024])
841 .output("output", TensorType::F32, &[1024])
842 .op(TileOp::LoadShared {
843 src: "input".into(),
844 tile_size: (64, 1),
845 })
846 .op(TileOp::Elementwise {
847 op: ElementwiseOp::Log,
848 operands: vec!["input".into()],
849 output: Some("output".into()),
850 })
851 .op(TileOp::StoreShared {
852 dst: "output".into(),
853 });
854
855 let wgsl = brick.to_wgsl();
856
857 assert!(wgsl.contains("@compute @workgroup_size(64, 1, 1)"));
858 assert!(wgsl.contains("fn main("));
859 assert!(wgsl.contains("log("));
860 assert!(wgsl.contains("Generated by probar"));
861 }
862
863 #[test]
864 fn test_compute_brick_verification() {
865 let brick = ComputeBrick::new("test")
866 .workgroup_size(256, 1, 1)
867 .input("input", TensorType::F32, &[1024])
868 .output("output", TensorType::F32, &[1024]);
869
870 let result = brick.verify();
871 assert!(result.is_valid());
872 }
873
874 #[test]
875 fn test_compute_brick_verification_fails_no_inputs() {
876 let brick = ComputeBrick::new("test").workgroup_size(256, 1, 1).output(
877 "output",
878 TensorType::F32,
879 &[1024],
880 );
881
882 let result = brick.verify();
883 assert!(!result.is_valid());
884 }
885
886 #[test]
887 fn test_compute_brick_verification_fails_large_workgroup() {
888 let brick = ComputeBrick::new("test")
889 .workgroup_size(1024, 2, 1) .input("input", TensorType::F32, &[1024])
891 .output("output", TensorType::F32, &[1024]);
892
893 let result = brick.verify();
894 assert!(!result.is_valid());
895 }
896
897 #[test]
898 fn test_tensor_binding() {
899 let binding = TensorBinding::new("audio", TensorType::F32, &[1024, 80])
900 .at(0, 1)
901 .writable();
902
903 assert_eq!(binding.name, "audio");
904 assert_eq!(binding.element_count(), 1024 * 80);
905 assert_eq!(binding.byte_size(), 1024 * 80 * 4);
906 assert!(!binding.read_only);
907 }
908
909 #[test]
910 fn test_tensor_type_wgsl() {
911 assert_eq!(TensorType::F32.to_wgsl(), "f32");
912 assert_eq!(TensorType::F16.to_wgsl(), "f16");
913 assert_eq!(TensorType::I32.to_wgsl(), "i32");
914 assert_eq!(TensorType::U32.to_wgsl(), "u32");
915 }
916
917 #[test]
918 fn test_elementwise_ops() {
919 assert_eq!(ElementwiseOp::Log.to_wgsl_expr("x"), "log(x)");
920 assert_eq!(ElementwiseOp::Exp.to_wgsl_expr("x"), "exp(x)");
921 assert_eq!(ElementwiseOp::Relu.to_wgsl_expr("x"), "max(x, 0.0)");
922 assert_eq!(ElementwiseOp::AddScalar(5).to_wgsl_expr("x"), "(x + 5.0)");
923 }
924
925 #[test]
926 fn test_rust_bindings_generation() {
927 let brick = ComputeBrick::new("mel-transform")
928 .workgroup_size(256, 1, 1)
929 .input("audio", TensorType::F32, &[1024])
930 .output("mel", TensorType::F32, &[80]);
931
932 let rust = brick.to_rust_bindings();
933
934 assert!(rust.contains("pub struct MelTransformCompute"));
935 assert!(rust.contains("WORKGROUP_SIZE"));
936 assert!(rust.contains("SHADER_SOURCE"));
937 assert!(rust.contains("create_bind_group_layout"));
938 }
939
940 #[test]
941 fn test_js_dispatch_generation() {
942 let brick = ComputeBrick::new("fft")
943 .workgroup_size(64, 1, 1)
944 .input("signal", TensorType::F32, &[512])
945 .output("spectrum", TensorType::F32, &[512]);
946
947 let js = brick.to_dispatch_js();
948
949 assert!(js.contains("async function dispatchFft"));
950 assert!(js.contains("WORKGROUP_SIZE"));
951 assert!(js.contains("dispatchWorkgroups"));
952 }
953
954 #[test]
955 fn test_tile_strategy_workgroup_size() {
956 let simple = TileStrategy::Simple2D {
957 tile_x: 16,
958 tile_y: 16,
959 };
960 assert_eq!(simple.optimal_workgroup_size(), (16, 16, 1));
961
962 let coop = TileStrategy::Cooperative { m: 8, n: 8, k: 4 };
963 assert_eq!(coop.optimal_workgroup_size(), (8, 8, 1));
964
965 let streaming = TileStrategy::Streaming { window: 32 };
966 assert_eq!(streaming.optimal_workgroup_size(), (32, 1, 1));
967 }
968}