vyre_libs/descriptor.rs
1//! `ProgramDescriptor` - lightweight structural introspection of a
2//! Cat-A composition's Program without running the full IR builder.
3//!
4//! The descriptor answers questions like "how many buffers does this
5//! op declare?" and "what's its canonical workgroup size?" without
6//! forcing the caller to pay the full Program construction cost.
7//! P2.9 ships descriptors derived from a built Program; a future
8//! optimization may lazy-construct descriptors without materializing
9//! the full IR, but the surface here is the contract external
10//! tooling pins against.
11
12use vyre::ir::{BufferAccess, DataType, Program};
13
14/// Structural description of a Cat-A Program.
15#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct ProgramDescriptor {
18 /// Number of declared buffers.
19 pub buffer_count: usize,
20 /// Canonical workgroup dispatch size.
21 pub workgroup_size: [u32; 3],
22 /// Buffer summaries, one per declared buffer.
23 pub buffers: Vec<BufferDescriptor>,
24 /// Total element-bytes declared across ReadWrite buffers. Useful
25 /// for rough memory-footprint estimates. Missing counts (runtime-
26 /// determined buffer sizes) contribute zero.
27 pub rw_bytes_lower_bound: usize,
28 /// Number of top-level nodes in the entry body.
29 pub entry_node_count: usize,
30}
31
32/// One buffer summary inside a [`ProgramDescriptor`].
33#[derive(Debug, Clone)]
34#[non_exhaustive]
35pub struct BufferDescriptor {
36 /// Declared name (matches `TensorRef::name`).
37 pub name: String,
38 /// Storage-class access mode.
39 pub access: BufferAccess,
40 /// Element dtype.
41 pub dtype: DataType,
42 /// Element count, or 0 when the size is runtime-determined.
43 pub count: u32,
44}
45
46impl BufferDescriptor {
47 /// Construct a `BufferDescriptor` from explicit fields. External
48 /// tooling that synthesizes buffer summaries uses this constructor
49 /// (V7-EXT-022).
50 #[must_use]
51 pub fn new(name: String, access: BufferAccess, dtype: DataType, count: u32) -> Self {
52 Self {
53 name,
54 access,
55 dtype,
56 count,
57 }
58 }
59}
60
61impl ProgramDescriptor {
62 /// Construct a `ProgramDescriptor` directly from explicit fields.
63 /// External tooling that synthesizes descriptors without going
64 /// through `from_program` uses this constructor (V7-EXT-023).
65 #[must_use]
66 pub fn new(
67 buffer_count: usize,
68 workgroup_size: [u32; 3],
69 buffers: Vec<BufferDescriptor>,
70 rw_bytes_lower_bound: usize,
71 entry_node_count: usize,
72 ) -> Self {
73 Self {
74 buffer_count,
75 workgroup_size,
76 buffers,
77 rw_bytes_lower_bound,
78 entry_node_count,
79 }
80 }
81
82 /// Derive a descriptor from an already-built Program. Zero-allocation
83 /// aside from the owned buffer-name strings (one per declared
84 /// buffer); consumers that need every dispatch to stay cheap
85 /// should cache the descriptor once and reuse it.
86 #[must_use]
87 pub fn from_program(program: &Program) -> Self {
88 let buffers: Vec<BufferDescriptor> = program
89 .buffers()
90 .iter()
91 .map(|b| BufferDescriptor {
92 name: b.name().to_string(),
93 access: b.access(),
94 dtype: b.element(),
95 count: b.count(),
96 })
97 .collect();
98
99 let rw_bytes_lower_bound: usize = buffers
100 .iter()
101 .filter(|b| matches!(b.access, BufferAccess::ReadWrite))
102 .map(|b| {
103 let elem_bytes = b.dtype.size_bytes().unwrap_or(0);
104 (b.count as usize).saturating_mul(elem_bytes)
105 })
106 .sum();
107
108 Self {
109 buffer_count: buffers.len(),
110 workgroup_size: program.workgroup_size(),
111 rw_bytes_lower_bound,
112 entry_node_count: program.entry().len(),
113 buffers,
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 #[allow(unused_imports)]
121 use super::*;
122
123 #[cfg(feature = "nn-attention")]
124 #[test]
125 fn descriptor_summarizes_softmax() {
126 use crate::nn::softmax;
127 let program = softmax("in", "out", 64);
128 let desc = ProgramDescriptor::from_program(&program);
129 // Tiled softmax uses workgroup scratch buffers: input, softmax_scratch,
130 // softmax_max, output = 4 buffers total.
131 assert_eq!(desc.buffer_count, 4);
132 assert_eq!(desc.workgroup_size, [256, 1, 1]);
133 // Only the output buffer is ReadWrite; 64 F32 elements = 256 bytes.
134 assert_eq!(desc.rw_bytes_lower_bound, 64 * 4);
135 assert_eq!(desc.entry_node_count, 1); // one Region wrapper at top
136 assert_eq!(desc.buffers[0].name, "in");
137 assert_eq!(desc.buffers[1].name, "softmax_scratch");
138 assert_eq!(desc.buffers[2].name, "softmax_max");
139 assert_eq!(desc.buffers[3].name, "out");
140 }
141
142 #[cfg(feature = "math-linalg")]
143 #[test]
144 fn descriptor_summarizes_matmul() {
145 use crate::math::matmul;
146 let program = matmul("a", "b", "out", 4, 8, 16);
147 let desc = ProgramDescriptor::from_program(&program);
148 assert_eq!(desc.buffer_count, 3);
149 // Only `out` (4*16 = 64 u32 = 256 bytes) is ReadWrite.
150 assert_eq!(desc.rw_bytes_lower_bound, 4 * 16 * 4);
151 }
152}