vyre-foundation 0.4.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
// Expression nodes — produce values.
//
// Every expression evaluates to a typed value. Expressions are pure:
// they read state but do not modify it.

use crate::ir_inner::model::types::DataType;
use rustc_hash::FxHasher;
use std::borrow::Borrow;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;

/// Reference to the generator/macro that produced an AST region.
/// Used for source-mapping and DWARF-like debugging context.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GeneratorRef {
    /// The name of the generator (e.g., `vyre-nn::flash_attention`).
    pub name: String,
}

/// Interned identifier used by expression nodes.
///
/// `Ident` is cheap to clone and keeps expression trees from repeatedly
/// allocating owned `String` values for the same variable or buffer names.
#[derive(Clone, Eq, PartialEq)]
pub struct Ident {
    text: Arc<str>,
    hash: u64,
}

impl Ident {
    #[inline]
    fn prehash(text: &str) -> u64 {
        let mut hasher = FxHasher::default();
        text.hash(&mut hasher);
        hasher.finish()
    }

    #[must_use]
    #[inline]
    /// Construct an identifier from shared text while caching its hash once.
    pub fn new(text: Arc<str>) -> Self {
        let hash = Self::prehash(&text);
        Self { text, hash }
    }

    /// Clone the underlying interned string handle without copying UTF-8 bytes.
    #[must_use]
    #[inline]
    pub fn shared_text(&self) -> Arc<str> {
        Arc::clone(&self.text)
    }

    /// Return the identifier text.
    #[must_use]
    #[inline]
    pub fn as_str(&self) -> &str {
        &self.text
    }

    /// Return the cached hash used by hash-map/set lookups.
    #[must_use]
    #[inline]
    pub fn cached_hash(&self) -> u64 {
        self.hash
    }
}

impl From<&str> for Ident {
    #[inline]
    fn from(value: &str) -> Self {
        Self::new(Arc::from(value))
    }
}

impl From<String> for Ident {
    #[inline]
    fn from(value: String) -> Self {
        Self::new(Arc::from(value))
    }
}

impl From<Arc<str>> for Ident {
    #[inline]
    fn from(value: Arc<str>) -> Self {
        Self::new(value)
    }
}

impl From<&String> for Ident {
    #[inline]
    fn from(value: &String) -> Self {
        Self::from(value.as_str())
    }
}

impl From<&Ident> for Ident {
    #[inline]
    fn from(value: &Ident) -> Self {
        value.clone()
    }
}

impl fmt::Debug for Ident {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("Ident").field(&self.as_str()).finish()
    }
}

impl Hash for Ident {
    /// Audit P-IDENT-BORROW (2026-04-29): hash via the underlying str so the
    /// `Hash` impl matches the `Borrow<str>` impl, preserving the
    /// `HashMap::get<Q: Borrow<K> + Hash + Eq>` invariant. The
    /// pre-fix `state.write_u64(self.hash)` produced a different u64 than
    /// `<str as Hash>::hash` for the same hasher (which writes bytes + a
    /// length terminator), so any `FxHashMap<Ident, V>::get(&str)` lookup
    /// silently missed the inserted entry. Callers that want the cached
    /// FxHash for a fast equality-check key call [`Ident::cached_hash`]
    /// directly.
    #[inline]
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.text.hash(state);
    }
}

impl Deref for Ident {
    type Target = str;

    #[inline]
    fn deref(&self) -> &Self::Target {
        self.as_str()
    }
}

impl AsRef<str> for Ident {
    #[inline]
    fn as_ref(&self) -> &str {
        self.as_str()
    }
}

impl Borrow<str> for Ident {
    #[inline]
    fn borrow(&self) -> &str {
        self.as_str()
    }
}

impl fmt::Display for Ident {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(self.as_str())
    }
}

impl PartialEq<str> for Ident {
    #[inline]
    fn eq(&self, other: &str) -> bool {
        self.as_str() == other
    }
}

impl PartialEq<&str> for Ident {
    #[inline]
    fn eq(&self, other: &&str) -> bool {
        self.as_str() == *other
    }
}

impl PartialOrd for Ident {
    #[inline]
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for Ident {
    #[inline]
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        self.as_str().cmp(other.as_str())
    }
}

/// An expression that produces a value.
///
/// # Examples
///
/// ```
/// use vyre::ir::Expr;
///
/// let lit = Expr::u32(42);
/// let var = Expr::var("x");
/// let add = Expr::add(lit, var);
/// ```
pub use crate::ir_inner::model::generated::Expr;

/// Public contract for downstream expression extension nodes.
///
/// Extension nodes are intentionally opaque to core. A downstream crate owns
/// the semantic payload and provides the stable metadata core needs for
/// validation, debug output, equality, and CSE identity. Backends that
/// understand the extension can downcast through their own wrapper type before
/// constructing target code; backends that do not understand it must reject it
/// with an actionable error.
pub trait ExprNode: fmt::Debug + Send + Sync + 'static {
    /// Stable extension namespace, for example `my_backend.tensor.shuffle`.
    fn extension_kind(&self) -> &'static str;

    /// Human-readable identity used in diagnostics and debug logs.
    fn debug_identity(&self) -> &str;

    /// Static result type produced by this expression.
    fn result_type(&self) -> Option<DataType>;

    /// Whether CSE may treat this extension as a pure, repeatable expression.
    fn cse_safe(&self) -> bool;

    /// Stable, content-addressed identity for equality and optimizer keys.
    fn stable_fingerprint(&self) -> [u8; 32];

    /// Validate extension-local invariants.
    ///
    /// # Errors
    ///
    /// The returned error must explain the bad invariant and include `Fix:`.
    fn validate_extension(&self) -> Result<(), String>;

    /// Downcast to Any to allow backend-specific dispatch from opaque payloads.
    fn as_any(&self) -> &dyn std::any::Any;

    /// Serialize the extension payload into stable bytes used by the wire
    /// encoder's `Expr::Opaque` path (tag `0x80`). Default: empty payload —
    /// suitable for extensions that carry no state beyond their type
    /// identity. Extensions with state must override this to emit the exact
    /// bytes `wire_payload`'s matching `OpaqueExprResolver` will consume.
    ///
    /// The payload contract is endian-fixed: any numeric field wider than
    /// one byte MUST be written with `to_le_bytes`, and the matching decoder
    /// MUST reconstruct it with `from_le_bytes`. Host-endian encodings such as
    /// `to_ne_bytes` are forbidden because the wire format must stay
    /// byte-identical across architectures.
    ///
    /// Extension authors are recommended (but not required, for API
    /// compatibility) to use [`crate::opaque_payload::LeBytesWriter`] when
    /// building payloads — it makes the right endianness the only choice at
    /// the type level.
    ///
    /// Literal extensions that encode regex payloads must also canonicalize
    /// inline flag prefixes before emitting bytes. For example, `(?mi)` and
    /// `(?im)` are the same semantic payload and MUST serialize to the same
    /// flag ordering.
    fn wire_payload(&self) -> Vec<u8> {
        Vec::new()
    }
}

impl Expr {
    /// Load from buffer at index.
    ///
    /// # Examples
    ///
    /// ```
    /// use vyre::ir::Expr;
    /// let _ = Expr::load("a", Expr::u32(0));
    /// ```
    #[must_use]
    #[inline]
    pub fn load(buffer: impl Into<Ident>, index: Self) -> Self {
        Self::Load {
            buffer: buffer.into(),
            index: Box::new(index),
        }
    }

    /// Buffer element count.
    ///
    /// # Examples
    ///
    /// ```
    /// use vyre::ir::Expr;
    /// let _ = Expr::buf_len("a");
    /// ```
    #[must_use]
    #[inline]
    pub fn buf_len(buffer: impl Into<Ident>) -> Self {
        Self::BufLen {
            buffer: buffer.into(),
        }
    }

    /// `global_invocation_id.x`
    #[must_use]
    #[inline]
    pub fn gid_x() -> Self {
        Self::InvocationId { axis: 0 }
    }

    /// `global_invocation_id.y`
    #[must_use]
    #[inline]
    pub fn gid_y() -> Self {
        Self::InvocationId { axis: 1 }
    }

    /// `global_invocation_id.z`
    #[must_use]
    #[inline]
    pub fn gid_z() -> Self {
        Self::InvocationId { axis: 2 }
    }

    /// `workgroup_id.x`
    #[must_use]
    #[inline]
    pub fn workgroup_x() -> Self {
        Self::WorkgroupId { axis: 0 }
    }

    /// `workgroup_id.y`
    #[must_use]
    #[inline]
    pub fn workgroup_y() -> Self {
        Self::WorkgroupId { axis: 1 }
    }

    /// `workgroup_id.z`
    #[must_use]
    #[inline]
    pub fn workgroup_z() -> Self {
        Self::WorkgroupId { axis: 2 }
    }

    /// `local_invocation_id.x`
    #[must_use]
    #[inline]
    pub fn local_x() -> Self {
        Self::LocalId { axis: 0 }
    }

    /// `subgroup_invocation_id` (lane index within subgroup).
    #[must_use]
    #[inline]
    pub fn subgroup_local_id() -> Self {
        Self::SubgroupLocalId
    }

    /// `subgroup_size` (number of lanes per subgroup).
    #[must_use]
    #[inline]
    pub fn subgroup_size() -> Self {
        Self::SubgroupSize
    }

    /// `local_invocation_id.y`
    #[must_use]
    #[inline]
    pub fn local_y() -> Self {
        Self::LocalId { axis: 1 }
    }

    /// `local_invocation_id.z`
    #[must_use]
    #[inline]
    pub fn local_z() -> Self {
        Self::LocalId { axis: 2 }
    }

    /// Substrate-neutral alias for [`workgroup_x`](Self::workgroup_x).
    ///
    /// "Parallel region" is the vocabulary used in vyre-core's public
    /// surface. Concrete drivers translate this concept into their own
    /// target vocabulary at the boundary.
    #[must_use]
    #[inline]
    pub fn parallel_region_x() -> Self {
        Self::WorkgroupId { axis: 0 }
    }

    /// Substrate-neutral alias for [`workgroup_y`](Self::workgroup_y).
    #[must_use]
    #[inline]
    pub fn parallel_region_y() -> Self {
        Self::WorkgroupId { axis: 1 }
    }

    /// Substrate-neutral alias for [`workgroup_z`](Self::workgroup_z).
    #[must_use]
    #[inline]
    pub fn parallel_region_z() -> Self {
        Self::WorkgroupId { axis: 2 }
    }

    /// Substrate-neutral alias for [`local_x`](Self::local_x).
    #[must_use]
    #[inline]
    pub fn invocation_local_x() -> Self {
        Self::LocalId { axis: 0 }
    }

    /// Substrate-neutral alias for [`local_y`](Self::local_y).
    #[must_use]
    #[inline]
    pub fn invocation_local_y() -> Self {
        Self::LocalId { axis: 1 }
    }

    /// Substrate-neutral alias for [`local_z`](Self::local_z).
    #[must_use]
    #[inline]
    pub fn invocation_local_z() -> Self {
        Self::LocalId { axis: 2 }
    }

    /// Conditional select.
    #[must_use]
    #[inline]
    pub fn select(cond: Self, true_val: Self, false_val: Self) -> Self {
        Self::Select {
            cond: Box::new(cond),
            true_val: Box::new(true_val),
            false_val: Box::new(false_val),
        }
    }

    /// Subgroup inclusive-add reduction across the active subgroup.
    #[must_use]
    #[inline]
    pub fn subgroup_add(value: Self) -> Self {
        Self::SubgroupAdd {
            value: Box::new(value),
        }
    }

    /// Subgroup shuffle: broadcast `value` from the given lane id to
    /// every active lane in the subgroup.
    #[must_use]
    #[inline]
    pub fn subgroup_shuffle(value: Self, lane: Self) -> Self {
        Self::SubgroupShuffle {
            value: Box::new(value),
            lane: Box::new(lane),
        }
    }

    /// Subgroup ballot: gather the boolean predicate `cond` across
    /// the active subgroup into a single bitmask.
    #[must_use]
    #[inline]
    pub fn subgroup_ballot(cond: Self) -> Self {
        Self::SubgroupBallot {
            cond: Box::new(cond),
        }
    }

    /// Named variable reference.
    #[must_use]
    #[inline]
    pub fn var(name: impl Into<Ident>) -> Self {
        Self::Var(name.into())
    }

    /// Unsigned 32-bit literal.
    #[must_use]
    #[inline]
    pub fn u32(value: u32) -> Self {
        Self::LitU32(value)
    }

    /// Signed 32-bit literal.
    #[must_use]
    #[inline]
    pub fn i32(value: i32) -> Self {
        Self::LitI32(value)
    }

    /// 32-bit floating-point literal.
    #[must_use]
    #[inline]
    pub fn f32(value: f32) -> Self {
        Self::LitF32(value)
    }

    /// Boolean literal.
    #[must_use]
    #[inline]
    pub fn bool(value: bool) -> Self {
        Self::LitBool(value)
    }

    /// Operation call by stable operation ID.
    #[must_use]
    #[inline]
    pub fn call(op_id: impl Into<Ident>, args: Vec<Self>) -> Self {
        Self::Call {
            op_id: op_id.into(),
            args,
        }
    }

    /// Fused multiply-add `a * b + c` (f32).
    #[must_use]
    #[inline]
    pub fn fma(a: Self, b: Self, c: Self) -> Self {
        Self::Fma {
            a: Box::new(a),
            b: Box::new(b),
            c: Box::new(c),
        }
    }

    /// Cast a value to `target`.
    #[must_use]
    #[inline]
    pub fn cast(target: DataType, value: Self) -> Self {
        Self::Cast {
            target,
            value: Box::new(value),
        }
    }

    /// Wrap a downstream extension expression node.
    #[must_use]
    #[inline]
    pub fn opaque(node: impl ExprNode) -> Self {
        Self::Opaque(Arc::new(node))
    }

    /// Wrap a shared downstream extension expression node.
    #[must_use]
    #[inline]
    pub fn opaque_arc(node: Arc<dyn ExprNode>) -> Self {
        Self::Opaque(node)
    }
}
mod atomics;
mod builders;

#[cfg(test)]
mod tests {
    use super::Expr;

    #[test]
    fn expr_size_is_bounded() {
        let size = std::mem::size_of::<Expr>();
        eprintln!("Expr size: {size}");
        assert!(
            size <= 128,
            "Expr grew to {size} bytes. Fix: box the largest variant before adding more fields."
        );
    }
}