Skip to main content

vyre_libs/
builder.rs

1//! Shared helpers used by the per-op Cat-A builders.
2//!
3//! Each op in `vyre-libs` ships a chainable builder that:
4//!
5//! 1. Accepts [`TensorRef`]s instead of bare `&str` buffer names, so
6//!    dtype + shape mismatches fail at `build()` time.
7//! 2. Checks every pair of buffer names is unique.
8//! 3. Verifies every [`TensorRef`]'s dtype against the op's expected dtype.
9//! 4. Verifies element-count overflow.
10//! 5. Allows chained overrides (workgroup size, region generator,
11//!    tenant id) without churning the function signature  -  extension
12//!    fields live inside a `#[non_exhaustive]` options struct so new
13//!    knobs never break existing call sites.
14//!
15//! `BuildOptions` is intentionally small at launch; fields are added
16//! rather than removed (the `#[non_exhaustive]` attribute enforces
17//! this). Every Cat-A op exposes its builder as `<Op>Builder::new(...)`
18//! and delegates defaults through `BuildOptions::default()`.
19
20use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
21use vyre_foundation::ir::model::expr::GeneratorRef;
22
23use crate::tensor_ref::{TensorRef, TensorRefError};
24
25/// Shared child region for one-output indexed maps.
26///
27/// This is the kernel skeleton behind embedding lookup, byte shuffles,
28/// quant pack/unpack, and similar data-layout transforms:
29/// `for i in 0..n { out[dst(i)] = value(i) }`.
30pub(crate) const INDEXED_MAP_OP_ID: &str = "vyre-libs::substrate::indexed_map";
31/// Shared child region for strided per-lane workgroup accumulators.
32pub(crate) const STRIDED_ACCUMULATE_OP_ID: &str = "vyre-libs::substrate::strided_accumulate";
33/// Shared child region for strided writeback after a tiled row reduction.
34pub(crate) const STRIDED_WRITEBACK_OP_ID: &str =
35    "anonymous::vyre-libs::substrate::strided_writeback";
36
37/// Shared options every Cat-A builder threads through. Lives here so
38/// every op agrees on the same surface.
39#[derive(Debug, Clone, Default)]
40#[non_exhaustive]
41pub struct BuildOptions {
42    /// Workgroup size override. `None` = op's canonical default.
43    pub workgroup_size: Option<[u32; 3]>,
44    /// Region generator override. `None` = op's canonical `"vyre-libs::…"`
45    /// identifier. Used when a downstream crate wraps a Cat-A op and
46    /// wants its own generator id in conformance certificates.
47    pub region_generator: Option<&'static str>,
48    /// Tenant id baked into the region metadata for multi-tenant
49    /// deployments. Routed through the megakernel's tenant-mask table
50    /// when the Program runs inside `vyre-runtime`.
51    pub tenant_id: Option<u32>,
52}
53
54impl BuildOptions {
55    /// Fluent constructor  -  start with defaults and chain overrides.
56    #[must_use]
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Override the workgroup size.
62    #[must_use]
63    pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
64        self.workgroup_size = Some(size);
65        self
66    }
67
68    /// Override the region generator name (must be `&'static str`).
69    #[must_use]
70    pub fn with_region_generator(mut self, name: &'static str) -> Self {
71        self.region_generator = Some(name);
72        self
73    }
74
75    /// Stamp a tenant id into the Cat-A op's region metadata.
76    #[must_use]
77    pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
78        self.tenant_id = Some(tenant_id);
79        self
80    }
81}
82
83macro_rules! impl_cat_a_builder_options {
84    ($builder:ident) => {
85        impl $builder {
86            /// Override the generated Program workgroup size.
87            #[must_use]
88            pub fn with_workgroup_size(mut self, size: [u32; 3]) -> Self {
89                self.options = self.options.with_workgroup_size(size);
90                self
91            }
92
93            /// Override the Region generator id.
94            #[must_use]
95            pub fn with_region_generator(mut self, name: &'static str) -> Self {
96                self.options = self.options.with_region_generator(name);
97                self
98            }
99
100            /// Stamp the Region metadata with a tenant id.
101            #[must_use]
102            pub fn with_tenant_id(mut self, tenant_id: u32) -> Self {
103                self.options = self.options.with_tenant_id(tenant_id);
104                self
105            }
106        }
107    };
108}
109
110pub(crate) use impl_cat_a_builder_options;
111
112/// Validate a slice of `TensorRef`s against an expected `DataType`
113/// for each position, plus name-uniqueness across the whole slice.
114/// Used by every op's `build()` to consolidate the fanout of checks.
115pub fn check_tensors(
116    op: &'static str,
117    tensors: &[(&TensorRef, DataType)],
118) -> Result<(), TensorRefError> {
119    // Dtype check per tensor.
120    for (r, expected) in tensors {
121        crate::tensor_ref::check_dtype(r, expected.clone(), op)?;
122        if r.element_count().is_none() {
123            return Err(TensorRefError::ElementCountOverflow {
124                name: r.name.as_str().to_string(),
125                shape: r.shape.to_vec(),
126            });
127        }
128    }
129    for (idx, (left, _)) in tensors.iter().enumerate() {
130        for (right, _) in &tensors[idx + 1..] {
131            if left.name_str() == right.name_str() {
132                return Err(TensorRefError::NameCollision {
133                    name: left.name.as_str().to_string(),
134                    op,
135                });
136            }
137        }
138    }
139    Ok(())
140}
141
142#[cfg(test)]
143mod cat_a_builder_option_macro_tests {
144    #![allow(unreachable_pub)]
145
146    use super::BuildOptions;
147
148    #[derive(Debug, Clone)]
149    struct DemoBuilder {
150        options: BuildOptions,
151    }
152
153    impl DemoBuilder {
154        fn new() -> Self {
155            Self {
156                options: BuildOptions::default(),
157            }
158        }
159    }
160
161    super::impl_cat_a_builder_options!(DemoBuilder);
162
163    #[test]
164    fn generated_option_surface_threads_every_shared_knob() {
165        let builder = DemoBuilder::new()
166            .with_workgroup_size([8, 4, 2])
167            .with_region_generator("custom::generator")
168            .with_tenant_id(17);
169
170        assert_eq!(builder.options.workgroup_size, Some([8, 4, 2]));
171        assert_eq!(builder.options.region_generator, Some("custom::generator"));
172        assert_eq!(builder.options.tenant_id, Some(17));
173    }
174}
175
176/// Build the canonical one-output indexed-map skeleton.
177///
178/// Callers provide buffer declarations plus the semantic mapping from logical
179/// element `i` to `(dst_index, value)`. The loop, bounds guard, invocation id,
180/// workgroup default, and composition region stay centralized.
181pub(crate) fn build_indexed_map<F>(
182    op_id: &'static str,
183    buffers: Vec<BufferDecl>,
184    output: &str,
185    count: u32,
186    workgroup_size: [u32; 3],
187    f: F,
188) -> Program
189where
190    F: FnOnce(Expr) -> (Expr, Expr),
191{
192    let i = Expr::var("i");
193    let (dst_index, value) = f(i.clone());
194    let child_body = vec![
195        Node::let_bind("i", Expr::InvocationId { axis: 0 }),
196        Node::if_then(
197            Expr::lt(i, Expr::u32(count)),
198            vec![Node::store(output, dst_index, value)],
199        ),
200    ];
201    let parent = GeneratorRef {
202        name: op_id.to_string(),
203    };
204
205    Program::wrapped(
206        buffers,
207        workgroup_size,
208        vec![crate::region::wrap_anonymous(
209            op_id,
210            vec![crate::region::wrap_child(
211                INDEXED_MAP_OP_ID,
212                parent,
213                child_body,
214            )],
215        )],
216    )
217}
218
219/// Build a shared strided single-accumulator child region.
220///
221/// The parent must bind `local = LocalId(0)` before this child. The child
222/// accumulates `i = chunk * tile + local` for `chunk in 0..chunks`, guards
223/// `i < n`, and stores the lane-local accumulator into `scratch[local]`.
224pub(crate) fn strided_accumulate_child<F>(
225    parent_op_id: &'static str,
226    tile: u32,
227    chunks: u32,
228    n: u32,
229    acc_name: &'static str,
230    initial: Expr,
231    scratch: &'static str,
232    step: F,
233) -> Node
234where
235    F: Fn(Expr, Expr) -> Expr,
236{
237    let local = Expr::var("local");
238    let idx = Expr::var("idx");
239    let acc = Expr::var(acc_name);
240    let child_body = vec![Node::if_then(
241        Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
242        vec![
243            Node::let_bind(acc_name, initial),
244            strided_loop(
245                tile,
246                chunks,
247                n,
248                vec![Node::assign(acc_name, step(idx, acc))],
249            ),
250            Node::store(scratch, local, Expr::var(acc_name)),
251        ],
252    )];
253
254    child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
255}
256
257/// Build a shared strided dual-accumulator child region.
258///
259/// This keeps paired reductions such as `(sum, sum_sq)` in one memory pass
260/// instead of forcing two separate scans over the input.
261#[allow(dead_code)]
262pub(crate) fn strided_accumulate2_child<F1, F2>(
263    parent_op_id: &'static str,
264    tile: u32,
265    chunks: u32,
266    n: u32,
267    first: (&'static str, Expr, &'static str, F1),
268    second: (&'static str, Expr, &'static str, F2),
269) -> Node
270where
271    F1: Fn(Expr, Expr) -> Expr,
272    F2: Fn(Expr, Expr) -> Expr,
273{
274    let (first_name, first_initial, first_scratch, first_step) = first;
275    let (second_name, second_initial, second_scratch, second_step) = second;
276    let local = Expr::var("local");
277    let idx = Expr::var("idx");
278    let child_body = vec![Node::if_then(
279        Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
280        vec![
281            Node::let_bind(first_name, first_initial),
282            Node::let_bind(second_name, second_initial),
283            strided_loop(
284                tile,
285                chunks,
286                n,
287                vec![
288                    Node::assign(first_name, first_step(idx.clone(), Expr::var(first_name))),
289                    Node::assign(second_name, second_step(idx, Expr::var(second_name))),
290                ],
291            ),
292            Node::store(first_scratch, local.clone(), Expr::var(first_name)),
293            Node::store(second_scratch, local, Expr::var(second_name)),
294        ],
295    )];
296
297    child_region(parent_op_id, STRIDED_ACCUMULATE_OP_ID, child_body)
298}
299
300/// Build a shared strided writeback child region.
301///
302/// The parent must bind `local = LocalId(0)` before this child. Optional
303/// `prelude` nodes run once in workgroup zero before the strided write loop,
304/// which lets row reductions load reduced scalars exactly once per lane.
305pub(crate) fn strided_writeback_child<F>(
306    parent_op_id: &'static str,
307    tile: u32,
308    chunks: u32,
309    n: u32,
310    output: &str,
311    prelude: Vec<Node>,
312    value: F,
313) -> Node
314where
315    F: Fn(Expr) -> Expr,
316{
317    let idx = Expr::var("idx");
318    let mut guarded = prelude;
319    guarded.push(strided_loop(
320        tile,
321        chunks,
322        n,
323        vec![Node::store(output, idx.clone(), value(idx))],
324    ));
325    child_region(
326        parent_op_id,
327        STRIDED_WRITEBACK_OP_ID,
328        vec![Node::if_then(
329            Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
330            guarded,
331        )],
332    )
333}
334
335fn strided_loop(tile: u32, chunks: u32, n: u32, guarded_body: Vec<Node>) -> Node {
336    Node::loop_for(
337        "chunk",
338        Expr::u32(0),
339        Expr::u32(chunks),
340        vec![
341            Node::let_bind(
342                "idx",
343                Expr::add(
344                    Expr::mul(Expr::var("chunk"), Expr::u32(tile)),
345                    Expr::var("local"),
346                ),
347            ),
348            Node::if_then(Expr::lt(Expr::var("idx"), Expr::u32(n)), guarded_body),
349        ],
350    )
351}
352
353fn child_region(parent_op_id: &'static str, child_op_id: &'static str, body: Vec<Node>) -> Node {
354    crate::region::wrap_child(
355        child_op_id,
356        GeneratorRef {
357            name: parent_op_id.to_string(),
358        },
359        body,
360    )
361}
362
363/// Build a scalar-output trap program for invalid Cat-A builder inputs.
364///
365/// This keeps public compatibility wrappers infallible without panicking on
366/// user-controlled names or shapes. Typed builders should still return
367/// `Result`; this helper is for legacy `fn foo(...) -> Program` surfaces.
368#[allow(dead_code)]
369pub(crate) fn invalid_output_program(
370    op_id: &'static str,
371    output: &str,
372    data_type: DataType,
373    message: String,
374) -> Program {
375    Program::wrapped(
376        vec![BufferDecl::output(output, 0, data_type).with_count(1)],
377        [1, 1, 1],
378        vec![crate::region::wrap_anonymous(
379            op_id,
380            vec![Node::trap(Expr::u32(0), message)],
381        )],
382    )
383}
384
385/// Tensor-ref elementwise binary builder, used by `math::avg_floor`,
386/// `math::algebra`, and other binary-arithmetic primitives.
387#[allow(dead_code)]
388pub(crate) fn build_elementwise_binary<F>(
389    op_id: &'static str,
390    a: crate::tensor_ref::TensorRef,
391    b: crate::tensor_ref::TensorRef,
392    out: crate::tensor_ref::TensorRef,
393    options: BuildOptions,
394    f: F,
395) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
396where
397    F: Fn(vyre::ir::Expr, vyre::ir::Expr) -> vyre::ir::Expr,
398{
399    check_tensors(
400        op_id,
401        &[
402            (&a, vyre::ir::DataType::U32),
403            (&b, vyre::ir::DataType::U32),
404            (&out, vyre::ir::DataType::U32),
405        ],
406    )?;
407
408    if a.shape != b.shape || a.shape != out.shape {
409        return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
410            name: "elementwise_binary".into(),
411            found: vec![],
412            expected: vec![],
413            op: op_id,
414        });
415    }
416
417    let a_count = a.element_count().ok_or_else(|| {
418        crate::tensor_ref::TensorRefError::ElementCountOverflow {
419            name: a.name_str().to_string(),
420            shape: a.shape.to_vec(),
421        }
422    })?;
423    let out_count = out.element_count().ok_or_else(|| {
424        crate::tensor_ref::TensorRefError::ElementCountOverflow {
425            name: out.name_str().to_string(),
426            shape: out.shape.to_vec(),
427        }
428    })?;
429    if out_count < a_count {
430        return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
431            name: out.name_str().to_string(),
432            found: out.shape.to_vec(),
433            expected: a.shape.to_vec(),
434            op: op_id,
435        });
436    }
437
438    let n = a_count;
439    let body = vec![
440        vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
441        vyre::ir::Node::if_then(
442            vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
443            vec![vyre::ir::Node::store(
444                out.name_str(),
445                vyre::ir::Expr::var("idx"),
446                f(
447                    vyre::ir::Expr::load(a.name_str(), vyre::ir::Expr::var("idx")),
448                    vyre::ir::Expr::load(b.name_str(), vyre::ir::Expr::var("idx")),
449                ),
450            )],
451        ),
452    ];
453
454    let group = options.workgroup_size.unwrap_or([64, 1, 1]);
455
456    Ok(vyre::ir::Program::wrapped(
457        vec![
458            vyre::ir::BufferDecl::storage(
459                a.name_str(),
460                0,
461                vyre::ir::BufferAccess::ReadOnly,
462                vyre::ir::DataType::U32,
463            )
464            .with_count(n),
465            vyre::ir::BufferDecl::storage(
466                b.name_str(),
467                1,
468                vyre::ir::BufferAccess::ReadOnly,
469                vyre::ir::DataType::U32,
470            )
471            .with_count(n),
472            vyre::ir::BufferDecl::output(out.name_str(), 2, vyre::ir::DataType::U32).with_count(n),
473        ],
474        group,
475        vec![crate::region::wrap_anonymous(op_id, body)],
476    ))
477}
478
479#[allow(dead_code)]
480pub(crate) fn build_elementwise_unary<F>(
481    op_id: &'static str,
482    a: crate::tensor_ref::TensorRef,
483    out: crate::tensor_ref::TensorRef,
484    options: BuildOptions,
485    f: F,
486) -> Result<vyre::ir::Program, crate::tensor_ref::TensorRefError>
487where
488    F: Fn(vyre::ir::Expr) -> vyre::ir::Expr,
489{
490    check_tensors(
491        op_id,
492        &[
493            (&a, vyre::ir::DataType::U32),
494            (&out, vyre::ir::DataType::U32),
495        ],
496    )?;
497
498    if a.shape != out.shape {
499        return Err(crate::tensor_ref::TensorRefError::ShapeMismatch {
500            name: "elementwise_unary".into(),
501            found: vec![],
502            expected: vec![],
503            op: op_id,
504        });
505    }
506
507    let n = a.element_count().ok_or_else(|| {
508        crate::tensor_ref::TensorRefError::ElementCountOverflow {
509            name: a.name_str().to_string(),
510            shape: a.shape.to_vec(),
511        }
512    })?;
513    let body = vec![
514        vyre::ir::Node::let_bind("idx", vyre::ir::Expr::InvocationId { axis: 0 }),
515        vyre::ir::Node::if_then(
516            vyre::ir::Expr::lt(vyre::ir::Expr::var("idx"), vyre::ir::Expr::u32(n)),
517            vec![vyre::ir::Node::store(
518                out.name_str(),
519                vyre::ir::Expr::var("idx"),
520                f(vyre::ir::Expr::load(
521                    a.name_str(),
522                    vyre::ir::Expr::var("idx"),
523                )),
524            )],
525        ),
526    ];
527
528    let group = options.workgroup_size.unwrap_or([64, 1, 1]);
529
530    Ok(vyre::ir::Program::wrapped(
531        vec![
532            vyre::ir::BufferDecl::storage(
533                a.name_str(),
534                0,
535                vyre::ir::BufferAccess::ReadOnly,
536                vyre::ir::DataType::U32,
537            )
538            .with_count(n),
539            vyre::ir::BufferDecl::output(out.name_str(), 1, vyre::ir::DataType::U32).with_count(n),
540        ],
541        group,
542        vec![crate::region::wrap_anonymous(op_id, body)],
543    ))
544}
545
546#[cfg(test)]
547
548mod tests {
549    use super::*;
550
551    #[test]
552    fn build_options_defaults_are_all_none() {
553        let o = BuildOptions::default();
554        assert!(o.workgroup_size.is_none());
555        assert!(o.region_generator.is_none());
556        assert!(o.tenant_id.is_none());
557    }
558
559    #[test]
560    fn build_options_chain_preserves_earlier_setters() {
561        let o = BuildOptions::new()
562            .with_workgroup_size([128, 1, 1])
563            .with_region_generator("test::op")
564            .with_tenant_id(7);
565        assert_eq!(o.workgroup_size, Some([128, 1, 1]));
566        assert_eq!(o.region_generator, Some("test::op"));
567        assert_eq!(o.tenant_id, Some(7));
568    }
569
570    #[test]
571    fn check_tensors_passes_on_clean_inputs() {
572        let a = TensorRef::u32_1d("a", 4);
573        let b = TensorRef::u32_1d("b", 4);
574        assert!(matches!(
575            check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]),
576            Ok(())
577        ));
578    }
579
580    #[test]
581    fn check_tensors_catches_dtype_mismatch() {
582        let a = TensorRef::u32_1d("a", 4);
583        let err = check_tensors("op", &[(&a, DataType::F32)]).unwrap_err();
584        assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
585    }
586
587    #[test]
588    fn check_tensors_catches_overflow() {
589        let a = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
590        let err = check_tensors("op", &[(&a, DataType::U32)]).unwrap_err();
591        assert!(matches!(err, TensorRefError::ElementCountOverflow { .. }));
592    }
593
594    #[test]
595    fn check_tensors_catches_name_collision() {
596        let a = TensorRef::u32_1d("x", 4);
597        let b = TensorRef::u32_1d("x", 4);
598        let err = check_tensors("op", &[(&a, DataType::U32), (&b, DataType::U32)]).unwrap_err();
599        assert!(matches!(err, TensorRefError::NameCollision { .. }));
600    }
601
602    #[test]
603    fn indexed_map_builder_emits_shared_child_region() {
604        let program = build_indexed_map(
605            "vyre-libs::test::indexed_map_user",
606            vec![
607                BufferDecl::storage("input", 0, vyre::ir::BufferAccess::ReadOnly, DataType::U32)
608                    .with_count(4),
609                BufferDecl::output("output", 1, DataType::U32).with_count(4),
610            ],
611            "output",
612            4,
613            [64, 1, 1],
614            |i| (i.clone(), Expr::load("input", i)),
615        );
616        let rendered = format!("{:?}", program.entry());
617        assert!(
618            rendered.contains(INDEXED_MAP_OP_ID),
619            "Fix: indexed-map users must share the same child region instead of copying loop skeletons: {rendered}"
620        );
621    }
622
623    #[test]
624    fn strided_writeback_builder_emits_shared_child_region() {
625        let program = Program::wrapped(
626            vec![BufferDecl::output("out", 0, DataType::F32).with_count(4)],
627            [4, 1, 1],
628            vec![crate::region::wrap_anonymous(
629                "vyre-libs::test::row_reduction_user",
630                vec![
631                    Node::let_bind("local", Expr::LocalId { axis: 0 }),
632                    strided_writeback_child(
633                        "vyre-libs::test::row_reduction_user",
634                        4,
635                        1,
636                        4,
637                        "out",
638                        vec![Node::let_bind("scale", Expr::f32(0.5))],
639                        |_idx| Expr::var("scale"),
640                    ),
641                ],
642            )],
643        );
644        let rendered = format!("{:?}", program.entry());
645        assert!(
646            rendered.contains(STRIDED_WRITEBACK_OP_ID),
647            "Fix: row-reduction writeback users must share the same child region instead of copying loop skeletons: {rendered}"
648        );
649    }
650}