Skip to main content

vyre_libs/
descriptor.rs

1//! `ProgramDescriptor`  -  lightweight structural introspection of a
2//! Cat-A composition's Program without running the full IR builder.
3//!
4//! The descriptor answers questions like "how many buffers does this
5//! op declare?" and "what's its canonical workgroup size?" without
6//! forcing the caller to pay the full Program construction cost.
7//! P2.9 ships descriptors derived from a built Program; a future
8//! optimization may lazy-construct descriptors without materializing
9//! the full IR, but the surface here is the contract external
10//! tooling pins against.
11
12use vyre::ir::{BufferAccess, DataType, Program};
13
14/// Structural description of a Cat-A Program.
15#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct ProgramDescriptor {
18    /// Number of declared buffers.
19    pub buffer_count: usize,
20    /// Canonical workgroup dispatch size.
21    pub workgroup_size: [u32; 3],
22    /// Buffer summaries, one per declared buffer.
23    pub buffers: Vec<BufferDescriptor>,
24    /// Total element-bytes declared across ReadWrite buffers. Useful
25    /// for rough memory-footprint estimates. Missing counts (runtime-
26    /// determined buffer sizes) contribute zero.
27    pub rw_bytes_lower_bound: usize,
28    /// Number of top-level nodes in the entry body.
29    pub entry_node_count: usize,
30}
31
32/// One buffer summary inside a [`ProgramDescriptor`].
33#[derive(Debug, Clone)]
34#[non_exhaustive]
35pub struct BufferDescriptor {
36    /// Declared name (matches `TensorRef::name`).
37    pub name: String,
38    /// Storage-class access mode.
39    pub access: BufferAccess,
40    /// Element dtype.
41    pub dtype: DataType,
42    /// Element count, or 0 when the size is runtime-determined.
43    pub count: u32,
44}
45
46impl BufferDescriptor {
47    /// Construct a `BufferDescriptor` from explicit fields. External
48    /// tooling that synthesizes buffer summaries uses this constructor
49    /// (V7-EXT-022).
50    #[must_use]
51    pub fn new(name: String, access: BufferAccess, dtype: DataType, count: u32) -> Self {
52        Self {
53            name,
54            access,
55            dtype,
56            count,
57        }
58    }
59}
60
61impl ProgramDescriptor {
62    /// Construct a `ProgramDescriptor` directly from explicit fields.
63    /// External tooling that synthesizes descriptors without going
64    /// through `from_program` uses this constructor (V7-EXT-023).
65    #[must_use]
66    pub fn new(
67        buffer_count: usize,
68        workgroup_size: [u32; 3],
69        buffers: Vec<BufferDescriptor>,
70        rw_bytes_lower_bound: usize,
71        entry_node_count: usize,
72    ) -> Self {
73        Self {
74            buffer_count,
75            workgroup_size,
76            buffers,
77            rw_bytes_lower_bound,
78            entry_node_count,
79        }
80    }
81
82    /// Derive a descriptor from an already-built Program. Zero-allocation
83    /// aside from the owned buffer-name strings (one per declared
84    /// buffer); consumers that need every dispatch to stay cheap
85    /// should cache the descriptor once and reuse it.
86    #[must_use]
87    pub fn from_program(program: &Program) -> Self {
88        let buffers: Vec<BufferDescriptor> = program
89            .buffers()
90            .iter()
91            .map(|b| BufferDescriptor {
92                name: b.name().to_string(),
93                access: b.access(),
94                dtype: b.element(),
95                count: b.count(),
96            })
97            .collect();
98
99        let rw_bytes_lower_bound: usize = buffers
100            .iter()
101            .filter(|b| matches!(b.access, BufferAccess::ReadWrite))
102            .map(|b| {
103                let elem_bytes = b.dtype.size_bytes().unwrap_or(0);
104                (b.count as usize).saturating_mul(elem_bytes)
105            })
106            .sum();
107
108        Self {
109            buffer_count: buffers.len(),
110            workgroup_size: program.workgroup_size(),
111            rw_bytes_lower_bound,
112            entry_node_count: program.entry().len(),
113            buffers,
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    #[allow(unused_imports)]
121    use super::*;
122
123    #[cfg(feature = "nn-attention")]
124    #[test]
125    fn descriptor_summarizes_softmax() {
126        use crate::nn::softmax;
127        let program = softmax("in", "out", 64);
128        let desc = ProgramDescriptor::from_program(&program);
129        // Tiled softmax uses workgroup scratch buffers: input, softmax_scratch,
130        // softmax_max, output = 4 buffers total.
131        assert_eq!(desc.buffer_count, 4);
132        assert_eq!(desc.workgroup_size, [256, 1, 1]);
133        // Only the output buffer is ReadWrite; 64 F32 elements = 256 bytes.
134        assert_eq!(desc.rw_bytes_lower_bound, 64 * 4);
135        assert_eq!(desc.entry_node_count, 1); // one Region wrapper at top
136        assert_eq!(desc.buffers[0].name, "in");
137        assert_eq!(desc.buffers[1].name, "softmax_scratch");
138        assert_eq!(desc.buffers[2].name, "softmax_max");
139        assert_eq!(desc.buffers[3].name, "out");
140    }
141
142    #[cfg(feature = "math-linalg")]
143    #[test]
144    fn descriptor_summarizes_matmul() {
145        use crate::math::matmul;
146        let program = matmul("a", "b", "out", 4, 8, 16);
147        let desc = ProgramDescriptor::from_program(&program);
148        assert_eq!(desc.buffer_count, 3);
149        // Only `out` (4*16 = 64 u32 = 256 bytes) is ReadWrite.
150        assert_eq!(desc.rw_bytes_lower_bound, 4 * 16 * 4);
151    }
152}