Skip to main content

vyre_libs/
builder.rs

1//! Shared helpers used by the per-op Cat-A builders.
2//!
3//! Each op in `vyre-libs` ships a chainable builder that:
4//!
5//! 1. Accepts [`TensorRef`]s instead of bare `&str` buffer names, so
6//!    dtype + shape mismatches fail at `build()` time.
7//! 2. Checks every pair of buffer names is unique.
8//! 3. Verifies every [`TensorRef`]'s dtype against the op's expected dtype.
9//! 4. Verifies element-count overflow.
10//! 5. Allows chained overrides (workgroup size, region generator,
11//!    tenant id) without churning the function signature  -  extension
12//!    fields live inside a `#[non_exhaustive]` options struct so new
13//!    knobs never break existing call sites.
14//!
15//! `BuildOptions` is intentionally small at launch; fields are added
16//! rather than removed (the `#[non_exhaustive]` attribute enforces
17//! this). Every Cat-A op exposes its builder as `<Op>Builder::new(...)`
18//! and delegates defaults through `BuildOptions::default()`.
19
20use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
21use vyre_foundation::ir::model::expr::GeneratorRef;
22
23use crate::tensor_ref::{TensorRef, TensorRefError};
24
25/// Shared child region for one-output indexed maps.
26///
27/// This is the kernel skeleton behind embedding lookup, byte shuffles,
28/// quant pack/unpack, and similar data-layout transforms:
29/// `for i in 0..n { out[dst(i)] = value(i) }`.
30pub(crate) const INDEXED_MAP_OP_ID: &str = "vyre-libs::substrate::indexed_map";
31/// Shared child region for strided per-lane workgroup accumulators.
32pub(crate) const STRIDED_ACCUMULATE_OP_ID: &str = "vyre-libs::substrate::strided_accumulate";
33/// Shared child region for strided writeback after a tiled row reduction.
34pub(crate) const STRIDED_WRITEBACK_OP_ID: &str = "anonymous::vyre-libs::substrate::strided_writeback";
35
36/// Shared options every Cat-A builder threads through. Lives here so
37/// every op agrees on the same surface.
38#[derive(Debug, Clone, Default)]
39#[non_exhaustive]
40pub struct BuildOptions {
41    /// Workgroup size override. `None` = op's canonical default.
42    pub workgroup_size: Option<[u32; 3]>,
43    /// Region generator override. `None` = op's canonical `"vyre-libs::…"`
44    /// identifier. Used when a downstream crate wraps a Cat-A op and
45    /// wants its own generator id in conformance certificates.
46    pub region_generator: Option<&'static str>,
47    /// Tenant id baked into the region metadata for multi-tenant
48    /// deployments. Routed through the megakernel's tenant-mask table
49    /// when the Program runs inside `vyre-runtime`.
50    pub tenant_id: Option<u32>,
51}
52
53impl BuildOptions {
54    /// Fluent constructor  -  start with defaults and chain overrides.
55    #[must_use]
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Override the workgroup size.
61    #[must_use]
62    pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
63        self.workgroup_size = Some(size);
64        self
65    }
66
67    /// Override the region generator name (must be `&'static str`).
68    #[must_use]
69    pub fn with_region_generator(mut self, name: &'static str) -> Self {
70        self.region_generator = Some(name);
71        self
72    }
73
74    /// Stamp a tenant id into the Cat-A op's region metadata.
75    #[must_use]
76    pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
77        self.tenant_id = Some(tenant_id);
78        self
79    }
80}
81
82macro_rules! impl_cat_a_builder_options {
83    ($builder:ident) => {
84        impl $builder {
85            /// Override the generated Program workgroup size.
86            #[must_use]
87            pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
88                self.options = self.options.with_workgroup_size(size);
89                self
90            }
91
92            /// Override the Region generator id.
93            #[must_use]
94            pub fn with_region_generator(mut self, name: &'static str) -> Self {
95                self.options = self.options.with_region_generator(name);
96                self
97            }
98
99            /// Stamp the Region metadata with a tenant id.
100            #[must_use]
101            pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
102                self.options = self.options.with_tenant_id(tenant_id);
103                self
104            }
105        }
106    };
107}
108
109pub(crate) use impl_cat_a_builder_options;
110
111/// Validate a slice of `TensorRef`s against an expected `DataType`
112/// for each position, plus name-uniqueness across the whole slice.
113/// Used by every op's `build()` to consolidate the fanout of checks.
114pub fn check_tensors(
115    op: &'static str,
116    tensors: &[(&TensorRef, DataType)],
117) -> Result<(), TensorRefError> {
118    // Dtype check per tensor.
119    for (r, expected) in tensors {
120        crate::tensor_ref::check_dtype(r, expected.clone(), op)?;
121        if r.element_count().is_none() {
122            return Err(TensorRefError::ElementCountOverflow {
123                name: r.name.as_str().to_string(),
124                shape: r.shape.to_vec(),
125            });
126        }
127    }
128    for (idx, (left, _)) in tensors.iter().enumerate() {
129        for (right, _) in &tensors[idx + 1..] {
130            if left.name_str() == right.name_str() {
131                return Err(TensorRefError::NameCollision {
132                    name: left.name.as_str().to_string(),
133                    op,
134                });
135            }
136        }
137    }
138    Ok(())
139}
140
141#[cfg(test)]
142mod cat_a_builder_option_macro_tests {
143    #![allow(unreachable_pub)]
144
145    use super::BuildOptions;
146
147    #[derive(Debug, Clone)]
148    struct DemoBuilder {
149        options: BuildOptions,
150    }
151
152    impl DemoBuilder {
153        fn new() -> Self {
154            Self {
155                options: BuildOptions::default(),
156            }
157        }
158    }
159
160    super::impl_cat_a_builder_options!(DemoBuilder);
161
162    #[test]
163    fn generated_option_surface_threads_every_shared_knob() {
164        let builder = DemoBuilder::new()
165            .with_workgroup_size([8, 4, 2])
166            .with_region_generator("custom::generator")
167            .with_tenant_id(17);
168
169        assert_eq!(builder.options.workgroup_size, Some([8, 4, 2]));
170        assert_eq!(builder.options.region_generator, Some("custom::generator"));
171        assert_eq!(builder.options.tenant_id, Some(17));
172    }
173}
174
175/// Build the canonical one-output indexed-map skeleton.
176///
177/// Callers provide buffer declarations plus the semantic mapping from logical
178/// element `i` to `(dst_index, value)`. The loop, bounds guard, invocation id,
179/// workgroup default, and composition region stay centralized.
180pub(crate) fn build_indexed_map<F>(
181    op_id: &'static str,
182    buffers: Vec<BufferDecl>,
183    output: &str,
184    count: u32,
185    workgroup_size: [u32; 3],
186    f: F,
187) -> Program
188where
189    F: FnOnce(Expr) -> (Expr, Expr),
190{
191    let i = Expr::var("i");
192    let (dst_index, value) = f(i.clone());
193    let child_body = vec![
194        Node::let_bind("i", Expr::InvocationId { axis: 0 }),
195        Node::if_then(
196            Expr::lt(i, Expr::u32(count)),
197            vec![Node::store(output, dst_index, value)],
198        ),
199    ];
200    let parent = GeneratorRef {
201        name: op_id.to_string(),
202    };
203
204    Program::wrapped(
205        buffers,
206        workgroup_size,
207        vec![crate::region::wrap_anonymous(
208            op_id,
209            vec![crate::region::wrap_child(
210                INDEXED_MAP_OP_ID,
211                parent,
212                child_body,
213            )],
214        )],
215    )
216}
217
218/// Build a shared strided single-accumulator child region.
219///
220/// The parent must bind `local = LocalId(0)` before this child. The child
221/// accumulates `i = chunk * tile + local` for `chunk in 0..chunks`, guards
222/// `i < n`, and stores the lane-local accumulator into `scratch[local]`.
223pub(crate) fn strided_accumulate_child<F>(
224    parent_op_id: &'static str,
225    tile: u32,
226    chunks: u32,
227    n: u32,
228    acc_name: &'static str,
229    initial: Expr,
230    scratch: &'static str,
231    step: F,
232) -> Node
233where
234    F: Fn(Expr, Expr) -> Expr,
235{
236    let local = Expr::var("local");
237    let idx = Expr::var("idx");
238    let acc = Expr::var(acc_name);
239    let child_body = vec![Node::if_then(
240        Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
241        vec![
242            Node::let_bind(acc_name, initial),
243            strided_loop(
244                tile,
245                chunks,
246                n,
247                vec![Node::assign(acc_name, step(idx, acc))],
248            ),
249            Node::store(scratch, local, Expr::var(acc_name)),
250        ],
251    )];
252
253    child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
254}
255
256/// Build a shared strided dual-accumulator child region.
257///
258/// This keeps paired reductions such as `(sum, sum_sq)` in one memory pass
259/// instead of forcing two separate scans over the input.
260#[allow(dead_code)]
261pub(crate) fn strided_accumulate2_child<F1, F2>(
262    parent_op_id: &'static str,
263    tile: u32,
264    chunks: u32,
265    n: u32,
266    first: (&'static str, Expr, &'static str, F1),
267    second: (&'static str, Expr, &'static str, F2),
268) -> Node
269where
270    F1: Fn(Expr, Expr) -> Expr,
271    F2: Fn(Expr, Expr) -> Expr,
272{
273    let (first_name, first_initial, first_scratch, first_step) = first;
274    let (second_name, second_initial, second_scratch, second_step) = second;
275    let local = Expr::var("local");
276    let idx = Expr::var("idx");
277    let child_body = vec![Node::if_then(
278        Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
279        vec![
280            Node::let_bind(first_name, first_initial),
281            Node::let_bind(second_name, second_initial),
282            strided_loop(
283                tile,
284                chunks,
285                n,
286                vec![
287                    Node::assign(first_name, first_step(idx.clone(), Expr::var(first_name))),
288                    Node::assign(second_name, second_step(idx, Expr::var(second_name))),
289                ],
290            ),
291            Node::store(first_scratch, local.clone(), Expr::var(first_name)),
292            Node::store(second_scratch, local, Expr::var(second_name)),
293        ],
294    )];
295
296    child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
297}
298
299/// Build a shared strided writeback child region.
300///
301/// The parent must bind `local = LocalId(0)` before this child. Optional
302/// `prelude` nodes run once in workgroup zero before the strided write loop,
303/// which lets row reductions load reduced scalars exactly once per lane.
304pub(crate) fn strided_writeback_child<F>(
305    parent_op_id: &'static str,
306    tile: u32,
307    chunks: u32,
308    n: u32,
309    output: &str,
310    prelude: Vec<Node>,
311    value: F,
312) -> Node
313where
314    F: Fn(Expr) -> Expr,
315{
316    let idx = Expr::var("idx");
317    let mut guarded = prelude;
318    guarded.push(strided_loop(
319        tile,
320        chunks,
321        n,
322        vec![Node::store(output, idx.clone(), value(idx))],
323    ));
324    child_region(
325        parent_op_id,
326        STRIDED_WRITEBACK_OP_ID,
327        vec![Node::if_then(
328            Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
329            guarded,
330        )],
331    )
332}
333
334fn strided_loop(tile: u32, chunks: u32, n: u32, guarded_body: Vec<Node>) -> Node {
335    Node::loop_for(
336        "chunk",
337        Expr::u32(0),
338        Expr::u32(chunks),
339        vec![
340            Node::let_bind(
341                "idx",
342                Expr::add(
343                    Expr::mul(Expr::var("chunk"), Expr::u32(tile)),
344                    Expr::var("local"),
345                ),
346            ),
347            Node::if_then(Expr::lt(Expr::var("idx"), Expr::u32(n)), guarded_body),
348        ],
349    )
350}
351
352fn child_region(parent_op_id: &'static str, child_op_id: &'static str, body: Vec<Node>) -> Node {
353    crate::region::wrap_child(
354        child_op_id,
355        GeneratorRef {
356            name: parent_op_id.to_string(),
357        },
358        body,
359    )
360}
361
362/// Build a scalar-output trap program for invalid Cat-A builder inputs.
363///
364/// This keeps public compatibility wrappers infallible without panicking on
365/// user-controlled names or shapes. Typed builders should still return
366/// `Result`; this helper is for legacy `fn foo(...) -> Program` surfaces.
367#[allow(dead_code)]
368pub(crate) fn invalid_output_program(
369    op_id: &'static str,
370    output: &str,
371    data_type: DataType,
372    message: String,
373) -> Program {
374    Program::wrapped(
375        vec![BufferDecl::output(output, 0, data_type).with_count(1)],
376        [1, 1, 1],
377        vec![crate::region::wrap_anonymous(
378            op_id,
379            vec![Node::trap(Expr::u32(0), message)],
380        )],
381    )
382}
383
384/// Tensor-ref elementwise binary builder, used by `math::avg_floor`,
385/// `math::algebra`, and other binary-arithmetic primitives.
386#[allow(dead_code)]
387pub(crate) fn build_elementwise_binary<F>(
388    op_id: &'static str,
389    a: crate::tensor_ref::TensorRef,
390    b: crate::tensor_ref::TensorRef,
391    out: crate::tensor_ref::TensorRef,
392    options: BuildOptions,
393    f: F,
394) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
395where
396    F: Fn(vyre::ir::Expr, vyre::ir::Expr) -> vyre::ir::Expr,
397{
398    check_tensors(
399        op_id,
400        &[
401            (&a, vyre::ir::DataType::U32),
402            (&b, vyre::ir::DataType::U32),
403            (&out, vyre::ir::DataType::U32),
404        ],
405    )?;
406
407    if a.shape != b.shape || a.shape != out.shape {
408        return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
409            name: "elementwise_binary".into(),
410            found: vec![],
411            expected: vec![],
412            op: op_id,
413        });
414    }
415
416    let a_count = a.element_count().ok_or_else(|| {
417        crate::tensor_ref::TensorRefError::ElementCountOverflow {
418            name: a.name_str().to_string(),
419            shape: a.shape.to_vec(),
420        }
421    })?;
422    let out_count = out.element_count().ok_or_else(|| {
423        crate::tensor_ref::TensorRefError::ElementCountOverflow {
424            name: out.name_str().to_string(),
425            shape: out.shape.to_vec(),
426        }
427    })?;
428    if out_count < a_count {
429        return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
430            name: out.name_str().to_string(),
431            found: out.shape.to_vec(),
432            expected: a.shape.to_vec(),
433            op: op_id,
434        });
435    }
436
437    let n = a_count;
438    let body = vec![
439        vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
440        vyre::ir::Node::if_then(
441            vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
442            vec![vyre::ir::Node::store(
443                out.name_str(),
444                vyre::ir::Expr::var("idx"),
445                f(
446                    vyre::ir::Expr::load(a.name_str(), vyre::ir::Expr::var("idx")),
447                    vyre::ir::Expr::load(b.name_str(), vyre::ir::Expr::var("idx")),
448                ),
449            )],
450        ),
451    ];
452
453    let group = options.workgroup_size.unwrap_or([64, 1, 1]);
454
455    Ok(vyre::ir::Program::wrapped(
456        vec![
457            vyre::ir::BufferDecl::storage(
458                a.name_str(),
459                0,
460                vyre::ir::BufferAccess::ReadOnly,
461                vyre::ir::DataType::U32,
462            )
463            .with_count(n),
464            vyre::ir::BufferDecl::storage(
465                b.name_str(),
466                1,
467                vyre::ir::BufferAccess::ReadOnly,
468                vyre::ir::DataType::U32,
469            )
470            .with_count(n),
471            vyre::ir::BufferDecl::output(out.name_str(), 2, vyre::ir::DataType::U32).with_count(n),
472        ],
473        group,
474        vec![crate::region::wrap_anonymous(op_id, body)],
475    ))
476}
477
478#[allow(dead_code)]
479pub(crate) fn build_elementwise_unary<F>(
480    op_id: &'static str,
481    a: crate::tensor_ref::TensorRef,
482    out: crate::tensor_ref::TensorRef,
483    options: BuildOptions,
484    f: F,
485) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
486where
487    F: Fn(vyre::ir::Expr) -> vyre::ir::Expr,
488{
489    check_tensors(
490        op_id,
491        &[
492            (&a, vyre::ir::DataType::U32),
493            (&out, vyre::ir::DataType::U32),
494        ],
495    )?;
496
497    if a.shape != out.shape {
498        return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
499            name: "elementwise_unary".into(),
500            found: vec![],
501            expected: vec![],
502            op: op_id,
503        });
504    }
505
506    let n = a.element_count().ok_or_else(|| {
507        crate::tensor_ref::TensorRefError::ElementCountOverflow {
508            name: a.name_str().to_string(),
509            shape: a.shape.to_vec(),
510        }
511    })?;
512    let body = vec![
513        vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
514        vyre::ir::Node::if_then(
515            vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
516            vec![vyre::ir::Node::store(
517                out.name_str(),
518                vyre::ir::Expr::var("idx"),
519                f(vyre::ir::Expr::load(
520                    a.name_str(),
521                    vyre::ir::Expr::var("idx"),
522                )),
523            )],
524        ),
525    ];
526
527    let group = options.workgroup_size.unwrap_or([64, 1, 1]);
528
529    Ok(vyre::ir::Program::wrapped(
530        vec![
531            vyre::ir::BufferDecl::storage(
532                a.name_str(),
533                0,
534                vyre::ir::BufferAccess::ReadOnly,
535                vyre::ir::DataType::U32,
536            )
537            .with_count(n),
538            vyre::ir::BufferDecl::output(out.name_str(), 1, vyre::ir::DataType::U32).with_count(n),
539        ],
540        group,
541        vec![crate::region::wrap_anonymous(op_id, body)],
542    ))
543}
544
545#[cfg(test)]
546
547mod tests {
548    use super::*;
549
550    #[test]
551    fn build_options_defaults_are_all_none() {
552        let o = BuildOptions::default();
553        assert!(o.workgroup_size.is_none());
554        assert!(o.region_generator.is_none());
555        assert!(o.tenant_id.is_none());
556    }
557
558    #[test]
559    fn build_options_chain_preserves_earlier_setters() {
560        let o = BuildOptions::new()
561            .with_workgroup_size([128, 1, 1])
562            .with_region_generator("test::op")
563            .with_tenant_id(7);
564        assert_eq!(o.workgroup_size, Some([128, 1, 1]));
565        assert_eq!(o.region_generator, Some("test::op"));
566        assert_eq!(o.tenant_id, Some(7));
567    }
568
569    #[test]
570    fn check_tensors_passes_on_clean_inputs() {
571        let a = TensorRef::u32_1d("a", 4);
572        let b = TensorRef::u32_1d("b", 4);
573        assert!(matches!(
574            check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]),
575            Ok(())
576        ));
577    }
578
579    #[test]
580    fn check_tensors_catches_dtype_mismatch() {
581        let a = TensorRef::u32_1d("a", 4);
582        let err = check_tensors("op", &[(&a, DataType::F32)]).unwrap_err();
583        assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
584    }
585
586    #[test]
587    fn check_tensors_catches_overflow() {
588        let a = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
589        let err = check_tensors("op", &[(&a, DataType::U32)]).unwrap_err();
590        assert!(matches!(err, TensorRefError::ElementCountOverflow { .. }));
591    }
592
593    #[test]
594    fn check_tensors_catches_name_collision() {
595        let a = TensorRef::u32_1d("x", 4);
596        let b = TensorRef::u32_1d("x", 4);
597        let err = check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]).unwrap_err();
598        assert!(matches!(err, TensorRefError::NameCollision { .. }));
599    }
600
601    #[test]
602    fn indexed_map_builder_emits_shared_child_region() {
603        let program = build_indexed_map(
604            "vyre-libs::test::indexed_map_user",
605            vec![
606                BufferDecl::storage("input", 0, vyre::ir::BufferAccess::ReadOnly, DataType::U32)
607                    .with_count(4),
608                BufferDecl::output("output", 1, DataType::U32).with_count(4),
609            ],
610            "output",
611            4,
612            [64, 1, 1],
613            |i| (i.clone(), Expr::load("input", i)),
614        );
615        let rendered = format!("{:?}", program.entry());
616        assert!(
617            rendered.contains(INDEXED_MAP_OP_ID),
618            "Fix: indexed-map users must share the same child region instead of copying loop skeletons: {rendered}"
619        );
620    }
621
622    #[test]
623    fn strided_writeback_builder_emits_shared_child_region() {
624        let program = Program::wrapped(
625            vec![BufferDecl::output("out", 0, DataType::F32).with_count(4)],
626            [4, 1, 1],
627            vec![crate::region::wrap_anonymous(
628                "vyre-libs::test::row_reduction_user",
629                vec![
630                    Node::let_bind("local", Expr::LocalId { axis: 0 }),
631                    strided_writeback_child(
632                        "vyre-libs::test::row_reduction_user",
633                        4,
634                        1,
635                        4,
636                        "out",
637                        vec![Node::let_bind("scale", Expr::f32(0.5))],
638                        |_idx| Expr::var("scale"),
639                    ),
640                ],
641            )],
642        );
643        let rendered = format!("{:?}", program.entry());
644        assert!(
645            rendered.contains(STRIDED_WRITEBACK_OP_ID),
646            "Fix: row-reduction writeback users must share the same child region instead of copying loop skeletons: {rendered}"
647        );
648    }
649}
650