1use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
21use vyre_foundation::ir::model::expr::GeneratorRef;
22
23use crate::tensor_ref::{TensorRef, TensorRefError};
24
25pub(crate) const INDEXED_MAP_OP_ID: &str = "vyre-libs::substrate::indexed_map";
31pub(crate) const STRIDED_ACCUMULATE_OP_ID: &str = "vyre-libs::substrate::strided_accumulate";
33pub(crate) const STRIDED_WRITEBACK_OP_ID: &str =
35 "anonymous::vyre-libs::substrate::strided_writeback";
36
37#[derive(Debug, Clone, Default)]
40#[non_exhaustive]
41pub struct BuildOptions {
42 pub workgroup_size: Option<[u32; 3]>,
44 pub region_generator: Option<&'static str>,
48 pub tenant_id: Option<u32>,
52}
53
54impl BuildOptions {
55 #[must_use]
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 #[must_use]
63 pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
64 self.workgroup_size = Some(size);
65 self
66 }
67
68 #[must_use]
70 pub fn with_region_generator(mut self, name: &'static str) -> Self {
71 self.region_generator = Some(name);
72 self
73 }
74
75 #[must_use]
77 pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
78 self.tenant_id = Some(tenant_id);
79 self
80 }
81}
82
83macro_rules! impl_cat_a_builder_options {
84 ($builder:ident) => {
85 impl $builder {
86 #[must_use]
88 pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
89 self.options = self.options.with_workgroup_size(size);
90 self
91 }
92
93 #[must_use]
95 pub fn with_region_generator(mut self, name: &'static str) -> Self {
96 self.options = self.options.with_region_generator(name);
97 self
98 }
99
100 #[must_use]
102 pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
103 self.options = self.options.with_tenant_id(tenant_id);
104 self
105 }
106 }
107 };
108}
109
110pub(crate) use impl_cat_a_builder_options;
111
112pub fn check_tensors(
116 op: &'static str,
117 tensors: &[(&TensorRef, DataType)],
118) -> Result<(), TensorRefError> {
119 for (r, expected) in tensors {
121 crate::tensor_ref::check_dtype(r, expected.clone(), op)?;
122 if r.element_count().is_none() {
123 return Err(TensorRefError::ElementCountOverflow {
124 name: r.name.as_str().to_string(),
125 shape: r.shape.to_vec(),
126 });
127 }
128 }
129 for (idx, (left, _)) in tensors.iter().enumerate() {
130 for (right, _) in &tensors[idx + 1..] {
131 if left.name_str() == right.name_str() {
132 return Err(TensorRefError::NameCollision {
133 name: left.name.as_str().to_string(),
134 op,
135 });
136 }
137 }
138 }
139 Ok(())
140}
141
142#[cfg(test)]
143mod cat_a_builder_option_macro_tests {
144 #![allow(unreachable_pub)]
145
146 use super::BuildOptions;
147
148 #[derive(Debug, Clone)]
149 struct DemoBuilder {
150 options: BuildOptions,
151 }
152
153 impl DemoBuilder {
154 fn new() -> Self {
155 Self {
156 options: BuildOptions::default(),
157 }
158 }
159 }
160
161 super::impl_cat_a_builder_options!(DemoBuilder);
162
163 #[test]
164 fn generated_option_surface_threads_every_shared_knob() {
165 let builder = DemoBuilder::new()
166 .with_workgroup_size([8, 4, 2])
167 .with_region_generator("custom::generator")
168 .with_tenant_id(17);
169
170 assert_eq!(builder.options.workgroup_size, Some([8, 4, 2]));
171 assert_eq!(builder.options.region_generator, Some("custom::generator"));
172 assert_eq!(builder.options.tenant_id, Some(17));
173 }
174}
175
176pub(crate) fn build_indexed_map<F>(
182 op_id: &'static str,
183 buffers: Vec<BufferDecl>,
184 output: &str,
185 count: u32,
186 workgroup_size: [u32; 3],
187 f: F,
188) -> Program
189where
190 F: FnOnce(Expr) -> (Expr, Expr),
191{
192 let i = Expr::var("i");
193 let (dst_index, value) = f(i.clone());
194 let child_body = vec![
195 Node::let_bind("i", Expr::InvocationId { axis: 0 }),
196 Node::if_then(
197 Expr::lt(i, Expr::u32(count)),
198 vec![Node::store(output, dst_index, value)],
199 ),
200 ];
201 let parent = GeneratorRef {
202 name: op_id.to_string(),
203 };
204
205 Program::wrapped(
206 buffers,
207 workgroup_size,
208 vec![crate::region::wrap_anonymous(
209 op_id,
210 vec![crate::region::wrap_child(
211 INDEXED_MAP_OP_ID,
212 parent,
213 child_body,
214 )],
215 )],
216 )
217}
218
219pub(crate) fn strided_accumulate_child<F>(
225 parent_op_id: &'static str,
226 tile: u32,
227 chunks: u32,
228 n: u32,
229 acc_name: &'static str,
230 initial: Expr,
231 scratch: &'static str,
232 step: F,
233) -> Node
234where
235 F: Fn(Expr, Expr) -> Expr,
236{
237 let local = Expr::var("local");
238 let idx = Expr::var("idx");
239 let acc = Expr::var(acc_name);
240 let child_body = vec![Node::if_then(
241 Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
242 vec![
243 Node::let_bind(acc_name, initial),
244 strided_loop(
245 tile,
246 chunks,
247 n,
248 vec![Node::assign(acc_name, step(idx, acc))],
249 ),
250 Node::store(scratch, local, Expr::var(acc_name)),
251 ],
252 )];
253
254 child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
255}
256
257#[allow(dead_code)]
262pub(crate) fn strided_accumulate2_child<F1, F2>(
263 parent_op_id: &'static str,
264 tile: u32,
265 chunks: u32,
266 n: u32,
267 first: (&'static str, Expr, &'static str, F1),
268 second: (&'static str, Expr, &'static str, F2),
269) -> Node
270where
271 F1: Fn(Expr, Expr) -> Expr,
272 F2: Fn(Expr, Expr) -> Expr,
273{
274 let (first_name, first_initial, first_scratch, first_step) = first;
275 let (second_name, second_initial, second_scratch, second_step) = second;
276 let local = Expr::var("local");
277 let idx = Expr::var("idx");
278 let child_body = vec![Node::if_then(
279 Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
280 vec![
281 Node::let_bind(first_name, first_initial),
282 Node::let_bind(second_name, second_initial),
283 strided_loop(
284 tile,
285 chunks,
286 n,
287 vec![
288 Node::assign(first_name, first_step(idx.clone(), Expr::var(first_name))),
289 Node::assign(second_name, second_step(idx, Expr::var(second_name))),
290 ],
291 ),
292 Node::store(first_scratch, local.clone(), Expr::var(first_name)),
293 Node::store(second_scratch, local, Expr::var(second_name)),
294 ],
295 )];
296
297 child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
298}
299
300pub(crate) fn strided_writeback_child<F>(
306 parent_op_id: &'static str,
307 tile: u32,
308 chunks: u32,
309 n: u32,
310 output: &str,
311 prelude: Vec<Node>,
312 value: F,
313) -> Node
314where
315 F: Fn(Expr) -> Expr,
316{
317 let idx = Expr::var("idx");
318 let mut guarded = prelude;
319 guarded.push(strided_loop(
320 tile,
321 chunks,
322 n,
323 vec![Node::store(output, idx.clone(), value(idx))],
324 ));
325 child_region(
326 parent_op_id,
327 STRIDED_WRITEBACK_OP_ID,
328 vec![Node::if_then(
329 Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
330 guarded,
331 )],
332 )
333}
334
335fn strided_loop(tile: u32, chunks: u32, n: u32, guarded_body: Vec<Node>) -> Node {
336 Node::loop_for(
337 "chunk",
338 Expr::u32(0),
339 Expr::u32(chunks),
340 vec![
341 Node::let_bind(
342 "idx",
343 Expr::add(
344 Expr::mul(Expr::var("chunk"), Expr::u32(tile)),
345 Expr::var("local"),
346 ),
347 ),
348 Node::if_then(Expr::lt(Expr::var("idx"), Expr::u32(n)), guarded_body),
349 ],
350 )
351}
352
353fn child_region(parent_op_id: &'static str, child_op_id: &'static str, body: Vec<Node>) -> Node {
354 crate::region::wrap_child(
355 child_op_id,
356 GeneratorRef {
357 name: parent_op_id.to_string(),
358 },
359 body,
360 )
361}
362
363#[allow(dead_code)]
369pub(crate) fn invalid_output_program(
370 op_id: &'static str,
371 output: &str,
372 data_type: DataType,
373 message: String,
374) -> Program {
375 Program::wrapped(
376 vec![BufferDecl::output(output, 0, data_type).with_count(1)],
377 [1, 1, 1],
378 vec![crate::region::wrap_anonymous(
379 op_id,
380 vec![Node::trap(Expr::u32(0), message)],
381 )],
382 )
383}
384
385#[allow(dead_code)]
388pub(crate) fn build_elementwise_binary<F>(
389 op_id: &'static str,
390 a: crate::tensor_ref::TensorRef,
391 b: crate::tensor_ref::TensorRef,
392 out: crate::tensor_ref::TensorRef,
393 options: BuildOptions,
394 f: F,
395) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
396where
397 F: Fn(vyre::ir::Expr, vyre::ir::Expr) -> vyre::ir::Expr,
398{
399 check_tensors(
400 op_id,
401 &[
402 (&a, vyre::ir::DataType::U32),
403 (&b, vyre::ir::DataType::U32),
404 (&out, vyre::ir::DataType::U32),
405 ],
406 )?;
407
408 if a.shape != b.shape || a.shape != out.shape {
409 return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
410 name: "elementwise_binary".into(),
411 found: vec![],
412 expected: vec![],
413 op: op_id,
414 });
415 }
416
417 let a_count = a.element_count().ok_or_else(|| {
418 crate::tensor_ref::TensorRefError::ElementCountOverflow {
419 name: a.name_str().to_string(),
420 shape: a.shape.to_vec(),
421 }
422 })?;
423 let out_count = out.element_count().ok_or_else(|| {
424 crate::tensor_ref::TensorRefError::ElementCountOverflow {
425 name: out.name_str().to_string(),
426 shape: out.shape.to_vec(),
427 }
428 })?;
429 if out_count < a_count {
430 return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
431 name: out.name_str().to_string(),
432 found: out.shape.to_vec(),
433 expected: a.shape.to_vec(),
434 op: op_id,
435 });
436 }
437
438 let n = a_count;
439 let body = vec![
440 vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
441 vyre::ir::Node::if_then(
442 vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
443 vec![vyre::ir::Node::store(
444 out.name_str(),
445 vyre::ir::Expr::var("idx"),
446 f(
447 vyre::ir::Expr::load(a.name_str(), vyre::ir::Expr::var("idx")),
448 vyre::ir::Expr::load(b.name_str(), vyre::ir::Expr::var("idx")),
449 ),
450 )],
451 ),
452 ];
453
454 let group = options.workgroup_size.unwrap_or([64, 1, 1]);
455
456 Ok(vyre::ir::Program::wrapped(
457 vec![
458 vyre::ir::BufferDecl::storage(
459 a.name_str(),
460 0,
461 vyre::ir::BufferAccess::ReadOnly,
462 vyre::ir::DataType::U32,
463 )
464 .with_count(n),
465 vyre::ir::BufferDecl::storage(
466 b.name_str(),
467 1,
468 vyre::ir::BufferAccess::ReadOnly,
469 vyre::ir::DataType::U32,
470 )
471 .with_count(n),
472 vyre::ir::BufferDecl::output(out.name_str(), 2, vyre::ir::DataType::U32).with_count(n),
473 ],
474 group,
475 vec![crate::region::wrap_anonymous(op_id, body)],
476 ))
477}
478
479#[allow(dead_code)]
480pub(crate) fn build_elementwise_unary<F>(
481 op_id: &'static str,
482 a: crate::tensor_ref::TensorRef,
483 out: crate::tensor_ref::TensorRef,
484 options: BuildOptions,
485 f: F,
486) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
487where
488 F: Fn(vyre::ir::Expr) -> vyre::ir::Expr,
489{
490 check_tensors(
491 op_id,
492 &[
493 (&a, vyre::ir::DataType::U32),
494 (&out, vyre::ir::DataType::U32),
495 ],
496 )?;
497
498 if a.shape != out.shape {
499 return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
500 name: "elementwise_unary".into(),
501 found: vec![],
502 expected: vec![],
503 op: op_id,
504 });
505 }
506
507 let n = a.element_count().ok_or_else(|| {
508 crate::tensor_ref::TensorRefError::ElementCountOverflow {
509 name: a.name_str().to_string(),
510 shape: a.shape.to_vec(),
511 }
512 })?;
513 let body = vec![
514 vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
515 vyre::ir::Node::if_then(
516 vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
517 vec![vyre::ir::Node::store(
518 out.name_str(),
519 vyre::ir::Expr::var("idx"),
520 f(vyre::ir::Expr::load(
521 a.name_str(),
522 vyre::ir::Expr::var("idx"),
523 )),
524 )],
525 ),
526 ];
527
528 let group = options.workgroup_size.unwrap_or([64, 1, 1]);
529
530 Ok(vyre::ir::Program::wrapped(
531 vec![
532 vyre::ir::BufferDecl::storage(
533 a.name_str(),
534 0,
535 vyre::ir::BufferAccess::ReadOnly,
536 vyre::ir::DataType::U32,
537 )
538 .with_count(n),
539 vyre::ir::BufferDecl::output(out.name_str(), 1, vyre::ir::DataType::U32).with_count(n),
540 ],
541 group,
542 vec![crate::region::wrap_anonymous(op_id, body)],
543 ))
544}
545
546#[cfg(test)]
547
548mod tests {
549 use super::*;
550
551 #[test]
552 fn build_options_defaults_are_all_none() {
553 let o = BuildOptions::default();
554 assert!(o.workgroup_size.is_none());
555 assert!(o.region_generator.is_none());
556 assert!(o.tenant_id.is_none());
557 }
558
559 #[test]
560 fn build_options_chain_preserves_earlier_setters() {
561 let o = BuildOptions::new()
562 .with_workgroup_size([128, 1, 1])
563 .with_region_generator("test::op")
564 .with_tenant_id(7);
565 assert_eq!(o.workgroup_size, Some([128, 1, 1]));
566 assert_eq!(o.region_generator, Some("test::op"));
567 assert_eq!(o.tenant_id, Some(7));
568 }
569
570 #[test]
571 fn check_tensors_passes_on_clean_inputs() {
572 let a = TensorRef::u32_1d("a", 4);
573 let b = TensorRef::u32_1d("b", 4);
574 assert!(matches!(
575 check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]),
576 Ok(())
577 ));
578 }
579
580 #[test]
581 fn check_tensors_catches_dtype_mismatch() {
582 let a = TensorRef::u32_1d("a", 4);
583 let err = check_tensors("op", &[(&a, DataType::F32)]).unwrap_err();
584 assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
585 }
586
587 #[test]
588 fn check_tensors_catches_overflow() {
589 let a = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
590 let err = check_tensors("op", &[(&a, DataType::U32)]).unwrap_err();
591 assert!(matches!(err, TensorRefError::ElementCountOverflow { .. }));
592 }
593
594 #[test]
595 fn check_tensors_catches_name_collision() {
596 let a = TensorRef::u32_1d("x", 4);
597 let b = TensorRef::u32_1d("x", 4);
598 let err = check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]).unwrap_err();
599 assert!(matches!(err, TensorRefError::NameCollision { .. }));
600 }
601
602 #[test]
603 fn indexed_map_builder_emits_shared_child_region() {
604 let program = build_indexed_map(
605 "vyre-libs::test::indexed_map_user",
606 vec![
607 BufferDecl::storage("input", 0, vyre::ir::BufferAccess::ReadOnly, DataType::U32)
608 .with_count(4),
609 BufferDecl::output("output", 1, DataType::U32).with_count(4),
610 ],
611 "output",
612 4,
613 [64, 1, 1],
614 |i| (i.clone(), Expr::load("input", i)),
615 );
616 let rendered = format!("{:?}", program.entry());
617 assert!(
618 rendered.contains(INDEXED_MAP_OP_ID),
619 "Fix: indexed-map users must share the same child region instead of copying loop skeletons: {rendered}"
620 );
621 }
622
623 #[test]
624 fn strided_writeback_builder_emits_shared_child_region() {
625 let program = Program::wrapped(
626 vec![BufferDecl::output("out", 0, DataType::F32).with_count(4)],
627 [4, 1, 1],
628 vec![crate::region::wrap_anonymous(
629 "vyre-libs::test::row_reduction_user",
630 vec![
631 Node::let_bind("local", Expr::LocalId { axis: 0 }),
632 strided_writeback_child(
633 "vyre-libs::test::row_reduction_user",
634 4,
635 1,
636 4,
637 "out",
638 vec![Node::let_bind("scale", Expr::f32(0.5))],
639 |_idx| Expr::var("scale"),
640 ),
641 ],
642 )],
643 );
644 let rendered = format!("{:?}", program.entry());
645 assert!(
646 rendered.contains(STRIDED_WRITEBACK_OP_ID),
647 "Fix: row-reduction writeback users must share the same child region instead of copying loop skeletons: {rendered}"
648 );
649 }
650}