1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//! Frozen IR data-type tags shared by signatures, validators, and wire metadata.
use core::fmt;
/// Canonical data types supported by the vyre IR frozen data contract.
///
/// Integer-first by design. GPU floating-point is nondeterministic across
/// vendors through different rounding, fused multiply-add, and subnormal
/// handling. Integer arithmetic is deterministic everywhere. F32 is supported
/// for primitives that require it, with conformance validated per-backend.
/// `vyre::ir::DataType` re-exports this same type; conformance metadata should
/// use this canonical contract path. Example: `DataType::Vec4U32` records a
/// four-word lane value and has a minimum byte width of 16.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
pub enum DataType {
/// Unsigned 32-bit integer. The fundamental GPU word.
U32,
/// Signed 32-bit integer.
I32,
/// Unsigned 64-bit integer, emulated as `vec2<u32>` with low and high words.
U64,
/// Two-component `u32` vector.
Vec2U32,
/// Four-component `u32` vector.
Vec4U32,
/// Boolean value stored as a GPU word.
Bool,
/// Variable-length byte buffer.
Bytes,
/// Fixed-element-size array.
///
/// Each element is `element_size` bytes. The total byte count is
/// `N * element_size` where N is encoded by the value.
Array {
/// Byte size of each element.
element_size: usize,
},
/// Strict IEEE 754 binary16 floating-point.
F16,
/// Strict bfloat16 floating-point.
BF16,
/// IEEE 754 binary32 floating-point.
F32,
/// Strict IEEE 754 binary64 floating-point.
F64,
/// Multi-dimensional tensor value.
Tensor,
}
impl DataType {
/// Minimum byte count to represent one value of this type.
#[must_use]
pub const fn min_bytes(&self) -> usize {
match self {
Self::Bool | Self::U32 | Self::I32 | Self::F32 => 4,
Self::U64 | Self::Vec2U32 => 8,
Self::Vec4U32 => 16,
Self::F16 | Self::BF16 => 2,
Self::F64 => 8,
Self::Bytes | Self::Array { .. } | Self::Tensor => 0,
}
}
/// Maximum byte count for one value of this type.
///
/// Returns `None` for truly unbounded types; currently all variants
/// have a hard ceiling. Fixed-width types return `Some(min_bytes())`.
#[must_use]
pub const fn max_bytes(&self) -> Option<usize> {
match self {
Self::U32 | Self::I32 | Self::Bool => Some(4),
Self::U64 | Self::Vec2U32 => Some(8),
Self::Vec4U32 => Some(16),
Self::F16 | Self::BF16 => Some(2),
Self::F32 => Some(4),
Self::F64 => Some(8),
Self::Bytes => Some(64 * 1024 * 1024),
Self::Array { .. } | Self::Tensor => Some(256 * 1024 * 1024),
}
}
/// Element size for array-typed outputs, or `None` for scalar types.
#[must_use]
pub const fn element_size(&self) -> Option<usize> {
match self {
Self::Array { element_size } => Some(*element_size),
_ => None,
}
}
/// Whether this type belongs to the strict floating-point conformance family.
#[must_use]
pub const fn is_float_family(&self) -> bool {
matches!(self, Self::F16 | Self::BF16 | Self::F32 | Self::F64)
}
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::U32 => f.write_str("u32"),
Self::I32 => f.write_str("i32"),
Self::U64 => f.write_str("u64"),
Self::Vec2U32 => f.write_str("vec2<u32>"),
Self::Vec4U32 => f.write_str("vec4<u32>"),
Self::Bool => f.write_str("bool"),
Self::Bytes => f.write_str("bytes"),
Self::Array { element_size } => write!(f, "array<{element_size}B>"),
Self::F16 => f.write_str("f16"),
Self::BF16 => f.write_str("bf16"),
Self::F32 => f.write_str("f32"),
Self::F64 => f.write_str("f64"),
Self::Tensor => f.write_str("tensor"),
}
}
}