1use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
21use vyre_foundation::ir::model::expr::GeneratorRef;
22
23use crate::tensor_ref::{TensorRef, TensorRefError};
24
25pub(crate) const INDEXED_MAP_OP_ID: &str = "vyre-libs::substrate::indexed_map";
31pub(crate) const STRIDED_ACCUMULATE_OP_ID: &str = "vyre-libs::substrate::strided_accumulate";
33pub(crate) const STRIDED_WRITEBACK_OP_ID: &str = "anonymous::vyre-libs::substrate::strided_writeback";
35
36#[derive(Debug, Clone, Default)]
39#[non_exhaustive]
40pub struct BuildOptions {
41 pub workgroup_size: Option<[u32; 3]>,
43 pub region_generator: Option<&'static str>,
47 pub tenant_id: Option<u32>,
51}
52
53impl BuildOptions {
54 #[must_use]
56 pub fn new() -> Self {
57 Self::default()
58 }
59
60 #[must_use]
62 pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
63 self.workgroup_size = Some(size);
64 self
65 }
66
67 #[must_use]
69 pub fn with_region_generator(mut self, name: &'static str) -> Self {
70 self.region_generator = Some(name);
71 self
72 }
73
74 #[must_use]
76 pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
77 self.tenant_id = Some(tenant_id);
78 self
79 }
80}
81
82macro_rules! impl_cat_a_builder_options {
83 ($builder:ident) => {
84 impl $builder {
85 #[must_use]
87 pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
88 self.options = self.options.with_workgroup_size(size);
89 self
90 }
91
92 #[must_use]
94 pub fn with_region_generator(mut self, name: &'static str) -> Self {
95 self.options = self.options.with_region_generator(name);
96 self
97 }
98
99 #[must_use]
101 pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
102 self.options = self.options.with_tenant_id(tenant_id);
103 self
104 }
105 }
106 };
107}
108
109pub(crate) use impl_cat_a_builder_options;
110
111pub fn check_tensors(
115 op: &'static str,
116 tensors: &[(&TensorRef, DataType)],
117) -> Result<(), TensorRefError> {
118 for (r, expected) in tensors {
120 crate::tensor_ref::check_dtype(r, expected.clone(), op)?;
121 if r.element_count().is_none() {
122 return Err(TensorRefError::ElementCountOverflow {
123 name: r.name.as_str().to_string(),
124 shape: r.shape.to_vec(),
125 });
126 }
127 }
128 for (idx, (left, _)) in tensors.iter().enumerate() {
129 for (right, _) in &tensors[idx + 1..] {
130 if left.name_str() == right.name_str() {
131 return Err(TensorRefError::NameCollision {
132 name: left.name.as_str().to_string(),
133 op,
134 });
135 }
136 }
137 }
138 Ok(())
139}
140
141#[cfg(test)]
142mod cat_a_builder_option_macro_tests {
143 #![allow(unreachable_pub)]
144
145 use super::BuildOptions;
146
147 #[derive(Debug, Clone)]
148 struct DemoBuilder {
149 options: BuildOptions,
150 }
151
152 impl DemoBuilder {
153 fn new() -> Self {
154 Self {
155 options: BuildOptions::default(),
156 }
157 }
158 }
159
160 super::impl_cat_a_builder_options!(DemoBuilder);
161
162 #[test]
163 fn generated_option_surface_threads_every_shared_knob() {
164 let builder = DemoBuilder::new()
165 .with_workgroup_size([8, 4, 2])
166 .with_region_generator("custom::generator")
167 .with_tenant_id(17);
168
169 assert_eq!(builder.options.workgroup_size, Some([8, 4, 2]));
170 assert_eq!(builder.options.region_generator, Some("custom::generator"));
171 assert_eq!(builder.options.tenant_id, Some(17));
172 }
173}
174
175pub(crate) fn build_indexed_map<F>(
181 op_id: &'static str,
182 buffers: Vec<BufferDecl>,
183 output: &str,
184 count: u32,
185 workgroup_size: [u32; 3],
186 f: F,
187) -> Program
188where
189 F: FnOnce(Expr) -> (Expr, Expr),
190{
191 let i = Expr::var("i");
192 let (dst_index, value) = f(i.clone());
193 let child_body = vec![
194 Node::let_bind("i", Expr::InvocationId { axis: 0 }),
195 Node::if_then(
196 Expr::lt(i, Expr::u32(count)),
197 vec![Node::store(output, dst_index, value)],
198 ),
199 ];
200 let parent = GeneratorRef {
201 name: op_id.to_string(),
202 };
203
204 Program::wrapped(
205 buffers,
206 workgroup_size,
207 vec![crate::region::wrap_anonymous(
208 op_id,
209 vec![crate::region::wrap_child(
210 INDEXED_MAP_OP_ID,
211 parent,
212 child_body,
213 )],
214 )],
215 )
216}
217
218pub(crate) fn strided_accumulate_child<F>(
224 parent_op_id: &'static str,
225 tile: u32,
226 chunks: u32,
227 n: u32,
228 acc_name: &'static str,
229 initial: Expr,
230 scratch: &'static str,
231 step: F,
232) -> Node
233where
234 F: Fn(Expr, Expr) -> Expr,
235{
236 let local = Expr::var("local");
237 let idx = Expr::var("idx");
238 let acc = Expr::var(acc_name);
239 let child_body = vec![Node::if_then(
240 Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
241 vec![
242 Node::let_bind(acc_name, initial),
243 strided_loop(
244 tile,
245 chunks,
246 n,
247 vec![Node::assign(acc_name, step(idx, acc))],
248 ),
249 Node::store(scratch, local, Expr::var(acc_name)),
250 ],
251 )];
252
253 child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
254}
255
256#[allow(dead_code)]
261pub(crate) fn strided_accumulate2_child<F1, F2>(
262 parent_op_id: &'static str,
263 tile: u32,
264 chunks: u32,
265 n: u32,
266 first: (&'static str, Expr, &'static str, F1),
267 second: (&'static str, Expr, &'static str, F2),
268) -> Node
269where
270 F1: Fn(Expr, Expr) -> Expr,
271 F2: Fn(Expr, Expr) -> Expr,
272{
273 let (first_name, first_initial, first_scratch, first_step) = first;
274 let (second_name, second_initial, second_scratch, second_step) = second;
275 let local = Expr::var("local");
276 let idx = Expr::var("idx");
277 let child_body = vec![Node::if_then(
278 Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
279 vec![
280 Node::let_bind(first_name, first_initial),
281 Node::let_bind(second_name, second_initial),
282 strided_loop(
283 tile,
284 chunks,
285 n,
286 vec![
287 Node::assign(first_name, first_step(idx.clone(), Expr::var(first_name))),
288 Node::assign(second_name, second_step(idx, Expr::var(second_name))),
289 ],
290 ),
291 Node::store(first_scratch, local.clone(), Expr::var(first_name)),
292 Node::store(second_scratch, local, Expr::var(second_name)),
293 ],
294 )];
295
296 child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
297}
298
299pub(crate) fn strided_writeback_child<F>(
305 parent_op_id: &'static str,
306 tile: u32,
307 chunks: u32,
308 n: u32,
309 output: &str,
310 prelude: Vec<Node>,
311 value: F,
312) -> Node
313where
314 F: Fn(Expr) -> Expr,
315{
316 let idx = Expr::var("idx");
317 let mut guarded = prelude;
318 guarded.push(strided_loop(
319 tile,
320 chunks,
321 n,
322 vec![Node::store(output, idx.clone(), value(idx))],
323 ));
324 child_region(
325 parent_op_id,
326 STRIDED_WRITEBACK_OP_ID,
327 vec![Node::if_then(
328 Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
329 guarded,
330 )],
331 )
332}
333
334fn strided_loop(tile: u32, chunks: u32, n: u32, guarded_body: Vec<Node>) -> Node {
335 Node::loop_for(
336 "chunk",
337 Expr::u32(0),
338 Expr::u32(chunks),
339 vec![
340 Node::let_bind(
341 "idx",
342 Expr::add(
343 Expr::mul(Expr::var("chunk"), Expr::u32(tile)),
344 Expr::var("local"),
345 ),
346 ),
347 Node::if_then(Expr::lt(Expr::var("idx"), Expr::u32(n)), guarded_body),
348 ],
349 )
350}
351
352fn child_region(parent_op_id: &'static str, child_op_id: &'static str, body: Vec<Node>) -> Node {
353 crate::region::wrap_child(
354 child_op_id,
355 GeneratorRef {
356 name: parent_op_id.to_string(),
357 },
358 body,
359 )
360}
361
362#[allow(dead_code)]
368pub(crate) fn invalid_output_program(
369 op_id: &'static str,
370 output: &str,
371 data_type: DataType,
372 message: String,
373) -> Program {
374 Program::wrapped(
375 vec![BufferDecl::output(output, 0, data_type).with_count(1)],
376 [1, 1, 1],
377 vec![crate::region::wrap_anonymous(
378 op_id,
379 vec![Node::trap(Expr::u32(0), message)],
380 )],
381 )
382}
383
384#[allow(dead_code)]
387pub(crate) fn build_elementwise_binary<F>(
388 op_id: &'static str,
389 a: crate::tensor_ref::TensorRef,
390 b: crate::tensor_ref::TensorRef,
391 out: crate::tensor_ref::TensorRef,
392 options: BuildOptions,
393 f: F,
394) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
395where
396 F: Fn(vyre::ir::Expr, vyre::ir::Expr) -> vyre::ir::Expr,
397{
398 check_tensors(
399 op_id,
400 &[
401 (&a, vyre::ir::DataType::U32),
402 (&b, vyre::ir::DataType::U32),
403 (&out, vyre::ir::DataType::U32),
404 ],
405 )?;
406
407 if a.shape != b.shape || a.shape != out.shape {
408 return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
409 name: "elementwise_binary".into(),
410 found: vec![],
411 expected: vec![],
412 op: op_id,
413 });
414 }
415
416 let a_count = a.element_count().ok_or_else(|| {
417 crate::tensor_ref::TensorRefError::ElementCountOverflow {
418 name: a.name_str().to_string(),
419 shape: a.shape.to_vec(),
420 }
421 })?;
422 let out_count = out.element_count().ok_or_else(|| {
423 crate::tensor_ref::TensorRefError::ElementCountOverflow {
424 name: out.name_str().to_string(),
425 shape: out.shape.to_vec(),
426 }
427 })?;
428 if out_count < a_count {
429 return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
430 name: out.name_str().to_string(),
431 found: out.shape.to_vec(),
432 expected: a.shape.to_vec(),
433 op: op_id,
434 });
435 }
436
437 let n = a_count;
438 let body = vec![
439 vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
440 vyre::ir::Node::if_then(
441 vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
442 vec![vyre::ir::Node::store(
443 out.name_str(),
444 vyre::ir::Expr::var("idx"),
445 f(
446 vyre::ir::Expr::load(a.name_str(), vyre::ir::Expr::var("idx")),
447 vyre::ir::Expr::load(b.name_str(), vyre::ir::Expr::var("idx")),
448 ),
449 )],
450 ),
451 ];
452
453 let group = options.workgroup_size.unwrap_or([64, 1, 1]);
454
455 Ok(vyre::ir::Program::wrapped(
456 vec![
457 vyre::ir::BufferDecl::storage(
458 a.name_str(),
459 0,
460 vyre::ir::BufferAccess::ReadOnly,
461 vyre::ir::DataType::U32,
462 )
463 .with_count(n),
464 vyre::ir::BufferDecl::storage(
465 b.name_str(),
466 1,
467 vyre::ir::BufferAccess::ReadOnly,
468 vyre::ir::DataType::U32,
469 )
470 .with_count(n),
471 vyre::ir::BufferDecl::output(out.name_str(), 2, vyre::ir::DataType::U32).with_count(n),
472 ],
473 group,
474 vec![crate::region::wrap_anonymous(op_id, body)],
475 ))
476}
477
478#[allow(dead_code)]
479pub(crate) fn build_elementwise_unary<F>(
480 op_id: &'static str,
481 a: crate::tensor_ref::TensorRef,
482 out: crate::tensor_ref::TensorRef,
483 options: BuildOptions,
484 f: F,
485) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
486where
487 F: Fn(vyre::ir::Expr) -> vyre::ir::Expr,
488{
489 check_tensors(
490 op_id,
491 &[
492 (&a, vyre::ir::DataType::U32),
493 (&out, vyre::ir::DataType::U32),
494 ],
495 )?;
496
497 if a.shape != out.shape {
498 return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
499 name: "elementwise_unary".into(),
500 found: vec![],
501 expected: vec![],
502 op: op_id,
503 });
504 }
505
506 let n = a.element_count().ok_or_else(|| {
507 crate::tensor_ref::TensorRefError::ElementCountOverflow {
508 name: a.name_str().to_string(),
509 shape: a.shape.to_vec(),
510 }
511 })?;
512 let body = vec![
513 vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
514 vyre::ir::Node::if_then(
515 vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
516 vec![vyre::ir::Node::store(
517 out.name_str(),
518 vyre::ir::Expr::var("idx"),
519 f(vyre::ir::Expr::load(
520 a.name_str(),
521 vyre::ir::Expr::var("idx"),
522 )),
523 )],
524 ),
525 ];
526
527 let group = options.workgroup_size.unwrap_or([64, 1, 1]);
528
529 Ok(vyre::ir::Program::wrapped(
530 vec![
531 vyre::ir::BufferDecl::storage(
532 a.name_str(),
533 0,
534 vyre::ir::BufferAccess::ReadOnly,
535 vyre::ir::DataType::U32,
536 )
537 .with_count(n),
538 vyre::ir::BufferDecl::output(out.name_str(), 1, vyre::ir::DataType::U32).with_count(n),
539 ],
540 group,
541 vec![crate::region::wrap_anonymous(op_id, body)],
542 ))
543}
544
545#[cfg(test)]
546
547mod tests {
548 use super::*;
549
550 #[test]
551 fn build_options_defaults_are_all_none() {
552 let o = BuildOptions::default();
553 assert!(o.workgroup_size.is_none());
554 assert!(o.region_generator.is_none());
555 assert!(o.tenant_id.is_none());
556 }
557
558 #[test]
559 fn build_options_chain_preserves_earlier_setters() {
560 let o = BuildOptions::new()
561 .with_workgroup_size([128, 1, 1])
562 .with_region_generator("test::op")
563 .with_tenant_id(7);
564 assert_eq!(o.workgroup_size, Some([128, 1, 1]));
565 assert_eq!(o.region_generator, Some("test::op"));
566 assert_eq!(o.tenant_id, Some(7));
567 }
568
569 #[test]
570 fn check_tensors_passes_on_clean_inputs() {
571 let a = TensorRef::u32_1d("a", 4);
572 let b = TensorRef::u32_1d("b", 4);
573 assert!(matches!(
574 check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]),
575 Ok(())
576 ));
577 }
578
579 #[test]
580 fn check_tensors_catches_dtype_mismatch() {
581 let a = TensorRef::u32_1d("a", 4);
582 let err = check_tensors("op", &[(&a, DataType::F32)]).unwrap_err();
583 assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
584 }
585
586 #[test]
587 fn check_tensors_catches_overflow() {
588 let a = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
589 let err = check_tensors("op", &[(&a, DataType::U32)]).unwrap_err();
590 assert!(matches!(err, TensorRefError::ElementCountOverflow { .. }));
591 }
592
593 #[test]
594 fn check_tensors_catches_name_collision() {
595 let a = TensorRef::u32_1d("x", 4);
596 let b = TensorRef::u32_1d("x", 4);
597 let err = check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]).unwrap_err();
598 assert!(matches!(err, TensorRefError::NameCollision { .. }));
599 }
600
601 #[test]
602 fn indexed_map_builder_emits_shared_child_region() {
603 let program = build_indexed_map(
604 "vyre-libs::test::indexed_map_user",
605 vec![
606 BufferDecl::storage("input", 0, vyre::ir::BufferAccess::ReadOnly, DataType::U32)
607 .with_count(4),
608 BufferDecl::output("output", 1, DataType::U32).with_count(4),
609 ],
610 "output",
611 4,
612 [64, 1, 1],
613 |i| (i.clone(), Expr::load("input", i)),
614 );
615 let rendered = format!("{:?}", program.entry());
616 assert!(
617 rendered.contains(INDEXED_MAP_OP_ID),
618 "Fix: indexed-map users must share the same child region instead of copying loop skeletons: {rendered}"
619 );
620 }
621
622 #[test]
623 fn strided_writeback_builder_emits_shared_child_region() {
624 let program = Program::wrapped(
625 vec![BufferDecl::output("out", 0, DataType::F32).with_count(4)],
626 [4, 1, 1],
627 vec![crate::region::wrap_anonymous(
628 "vyre-libs::test::row_reduction_user",
629 vec![
630 Node::let_bind("local", Expr::LocalId { axis: 0 }),
631 strided_writeback_child(
632 "vyre-libs::test::row_reduction_user",
633 4,
634 1,
635 4,
636 "out",
637 vec![Node::let_bind("scale", Expr::f32(0.5))],
638 |_idx| Expr::var("scale"),
639 ),
640 ],
641 )],
642 );
643 let rendered = format!("{:?}", program.entry());
644 assert!(
645 rendered.contains(STRIDED_WRITEBACK_OP_ID),
646 "Fix: row-reduction writeback users must share the same child region instead of copying loop skeletons: {rendered}"
647 );
648 }
649}
650