Skip to main content

ringkernel_ir/
types.rs

1//! IR type system.
2//!
3//! Defines types that can be used in GPU kernels across all backends.
4
5use std::fmt;
6
7/// Scalar types supported in IR.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum ScalarType {
10    /// Boolean.
11    Bool,
12    /// 8-bit signed integer.
13    I8,
14    /// 16-bit signed integer.
15    I16,
16    /// 32-bit signed integer.
17    I32,
18    /// 64-bit signed integer.
19    I64,
20    /// 8-bit unsigned integer.
21    U8,
22    /// 16-bit unsigned integer.
23    U16,
24    /// 32-bit unsigned integer.
25    U32,
26    /// 64-bit unsigned integer.
27    U64,
28    /// 16-bit floating point.
29    F16,
30    /// 32-bit floating point.
31    F32,
32    /// 64-bit floating point (not supported on all backends).
33    F64,
34}
35
36impl ScalarType {
37    /// Get the size in bytes.
38    pub fn size_bytes(&self) -> usize {
39        match self {
40            ScalarType::Bool | ScalarType::I8 | ScalarType::U8 => 1,
41            ScalarType::I16 | ScalarType::U16 | ScalarType::F16 => 2,
42            ScalarType::I32 | ScalarType::U32 | ScalarType::F32 => 4,
43            ScalarType::I64 | ScalarType::U64 | ScalarType::F64 => 8,
44        }
45    }
46
47    /// Check if this is a floating point type.
48    pub fn is_float(&self) -> bool {
49        matches!(self, ScalarType::F16 | ScalarType::F32 | ScalarType::F64)
50    }
51
52    /// Check if this is a signed integer type.
53    pub fn is_signed_int(&self) -> bool {
54        matches!(
55            self,
56            ScalarType::I8 | ScalarType::I16 | ScalarType::I32 | ScalarType::I64
57        )
58    }
59
60    /// Check if this is an unsigned integer type.
61    pub fn is_unsigned_int(&self) -> bool {
62        matches!(
63            self,
64            ScalarType::U8 | ScalarType::U16 | ScalarType::U32 | ScalarType::U64
65        )
66    }
67
68    /// Check if this is any integer type.
69    pub fn is_int(&self) -> bool {
70        self.is_signed_int() || self.is_unsigned_int()
71    }
72
73    /// Check if this requires special capability (f64).
74    pub fn requires_f64(&self) -> bool {
75        matches!(self, ScalarType::F64)
76    }
77
78    /// Check if this requires 64-bit integer capability.
79    pub fn requires_i64(&self) -> bool {
80        matches!(self, ScalarType::I64 | ScalarType::U64)
81    }
82}
83
84impl fmt::Display for ScalarType {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        match self {
87            ScalarType::Bool => write!(f, "bool"),
88            ScalarType::I8 => write!(f, "i8"),
89            ScalarType::I16 => write!(f, "i16"),
90            ScalarType::I32 => write!(f, "i32"),
91            ScalarType::I64 => write!(f, "i64"),
92            ScalarType::U8 => write!(f, "u8"),
93            ScalarType::U16 => write!(f, "u16"),
94            ScalarType::U32 => write!(f, "u32"),
95            ScalarType::U64 => write!(f, "u64"),
96            ScalarType::F16 => write!(f, "f16"),
97            ScalarType::F32 => write!(f, "f32"),
98            ScalarType::F64 => write!(f, "f64"),
99        }
100    }
101}
102
103/// Vector types.
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
105pub struct VectorType {
106    /// Element type.
107    pub element: ScalarType,
108    /// Number of elements (2, 3, or 4).
109    pub count: u8,
110}
111
112impl VectorType {
113    /// Create a new vector type.
114    pub fn new(element: ScalarType, count: u8) -> Self {
115        debug_assert!((2..=4).contains(&count), "Vector count must be 2, 3, or 4");
116        Self { element, count }
117    }
118
119    /// Get size in bytes.
120    pub fn size_bytes(&self) -> usize {
121        self.element.size_bytes() * self.count as usize
122    }
123}
124
125impl fmt::Display for VectorType {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "vec{}<{}>", self.count, self.element)
128    }
129}
130
131/// IR type.
132#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133pub enum IrType {
134    /// Void type (for functions with no return).
135    Void,
136    /// Scalar type.
137    Scalar(ScalarType),
138    /// Vector type.
139    Vector(VectorType),
140    /// Pointer type.
141    Ptr(Box<IrType>),
142    /// Array type with static size.
143    Array(Box<IrType>, usize),
144    /// Slice type (runtime-sized array).
145    Slice(Box<IrType>),
146    /// Struct type with named fields.
147    Struct(StructType),
148    /// Function type.
149    Function(FunctionType),
150}
151
152impl IrType {
153    // Convenience constructors for common types
154
155    /// Boolean type.
156    pub const BOOL: IrType = IrType::Scalar(ScalarType::Bool);
157    /// 32-bit signed integer.
158    pub const I32: IrType = IrType::Scalar(ScalarType::I32);
159    /// 64-bit signed integer.
160    pub const I64: IrType = IrType::Scalar(ScalarType::I64);
161    /// 32-bit unsigned integer.
162    pub const U32: IrType = IrType::Scalar(ScalarType::U32);
163    /// 64-bit unsigned integer.
164    pub const U64: IrType = IrType::Scalar(ScalarType::U64);
165    /// 32-bit float.
166    pub const F32: IrType = IrType::Scalar(ScalarType::F32);
167    /// 64-bit float.
168    pub const F64: IrType = IrType::Scalar(ScalarType::F64);
169
170    /// Create a pointer type.
171    pub fn ptr(inner: IrType) -> Self {
172        IrType::Ptr(Box::new(inner))
173    }
174
175    /// Create an array type.
176    pub fn array(inner: IrType, size: usize) -> Self {
177        IrType::Array(Box::new(inner), size)
178    }
179
180    /// Create a slice type.
181    pub fn slice(inner: IrType) -> Self {
182        IrType::Slice(Box::new(inner))
183    }
184
185    /// Get size in bytes (None for unsized types).
186    pub fn size_bytes(&self) -> Option<usize> {
187        match self {
188            IrType::Void => Some(0),
189            IrType::Scalar(s) => Some(s.size_bytes()),
190            IrType::Vector(v) => Some(v.size_bytes()),
191            IrType::Ptr(_) => Some(8), // 64-bit pointers
192            IrType::Array(inner, count) => inner.size_bytes().map(|s| s * count),
193            IrType::Slice(_) => None, // Unsized
194            IrType::Struct(s) => s.size_bytes(),
195            IrType::Function(_) => None,
196        }
197    }
198
199    /// Check if this is a pointer type.
200    pub fn is_ptr(&self) -> bool {
201        matches!(self, IrType::Ptr(_))
202    }
203
204    /// Check if this is a scalar type.
205    pub fn is_scalar(&self) -> bool {
206        matches!(self, IrType::Scalar(_))
207    }
208
209    /// Check if this is a numeric type.
210    pub fn is_numeric(&self) -> bool {
211        match self {
212            IrType::Scalar(s) => s.is_float() || s.is_int(),
213            _ => false,
214        }
215    }
216
217    /// Get the element type for pointers, arrays, and slices.
218    pub fn element_type(&self) -> Option<&IrType> {
219        match self {
220            IrType::Ptr(inner) | IrType::Array(inner, _) | IrType::Slice(inner) => Some(inner),
221            _ => None,
222        }
223    }
224}
225
226impl fmt::Display for IrType {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        match self {
229            IrType::Void => write!(f, "void"),
230            IrType::Scalar(s) => write!(f, "{}", s),
231            IrType::Vector(v) => write!(f, "{}", v),
232            IrType::Ptr(inner) => write!(f, "*{}", inner),
233            IrType::Array(inner, size) => write!(f, "[{}; {}]", inner, size),
234            IrType::Slice(inner) => write!(f, "[{}]", inner),
235            IrType::Struct(s) => write!(f, "struct {}", s.name),
236            IrType::Function(ft) => write!(f, "{}", ft),
237        }
238    }
239}
240
241/// A struct type definition.
242#[derive(Debug, Clone, PartialEq, Eq, Hash)]
243pub struct StructType {
244    /// Struct name.
245    pub name: String,
246    /// Fields with names and types.
247    pub fields: Vec<(String, IrType)>,
248}
249
250impl StructType {
251    /// Create a new struct type.
252    pub fn new(name: impl Into<String>, fields: Vec<(String, IrType)>) -> Self {
253        Self {
254            name: name.into(),
255            fields,
256        }
257    }
258
259    /// Get size in bytes.
260    pub fn size_bytes(&self) -> Option<usize> {
261        let mut size = 0;
262        for (_, ty) in &self.fields {
263            size += ty.size_bytes()?;
264        }
265        Some(size)
266    }
267
268    /// Get field type by name.
269    pub fn get_field(&self, name: &str) -> Option<&IrType> {
270        self.fields
271            .iter()
272            .find(|(n, _)| n == name)
273            .map(|(_, ty)| ty)
274    }
275
276    /// Get field index by name.
277    pub fn get_field_index(&self, name: &str) -> Option<usize> {
278        self.fields.iter().position(|(n, _)| n == name)
279    }
280}
281
282/// A function type.
283#[derive(Debug, Clone, PartialEq, Eq, Hash)]
284pub struct FunctionType {
285    /// Parameter types.
286    pub params: Vec<IrType>,
287    /// Return type.
288    pub return_type: Box<IrType>,
289}
290
291impl FunctionType {
292    /// Create a new function type.
293    pub fn new(params: Vec<IrType>, return_type: IrType) -> Self {
294        Self {
295            params,
296            return_type: Box::new(return_type),
297        }
298    }
299}
300
301impl fmt::Display for FunctionType {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        write!(f, "fn(")?;
304        for (i, param) in self.params.iter().enumerate() {
305            if i > 0 {
306                write!(f, ", ")?;
307            }
308            write!(f, "{}", param)?;
309        }
310        write!(f, ") -> {}", self.return_type)
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_scalar_size() {
320        assert_eq!(ScalarType::Bool.size_bytes(), 1);
321        assert_eq!(ScalarType::I32.size_bytes(), 4);
322        assert_eq!(ScalarType::F64.size_bytes(), 8);
323    }
324
325    #[test]
326    fn test_scalar_classification() {
327        assert!(ScalarType::F32.is_float());
328        assert!(!ScalarType::I32.is_float());
329
330        assert!(ScalarType::I32.is_signed_int());
331        assert!(!ScalarType::U32.is_signed_int());
332
333        assert!(ScalarType::U32.is_unsigned_int());
334        assert!(!ScalarType::I32.is_unsigned_int());
335    }
336
337    #[test]
338    fn test_vector_type() {
339        let v = VectorType::new(ScalarType::F32, 4);
340        assert_eq!(v.size_bytes(), 16);
341        assert_eq!(format!("{}", v), "vec4<f32>");
342    }
343
344    #[test]
345    fn test_ir_type_display() {
346        assert_eq!(format!("{}", IrType::I32), "i32");
347        assert_eq!(format!("{}", IrType::ptr(IrType::F32)), "*f32");
348        assert_eq!(format!("{}", IrType::array(IrType::I32, 16)), "[i32; 16]");
349    }
350
351    #[test]
352    fn test_struct_type() {
353        let s = StructType::new(
354            "Point",
355            vec![
356                ("x".to_string(), IrType::F32),
357                ("y".to_string(), IrType::F32),
358            ],
359        );
360        assert_eq!(s.size_bytes(), Some(8));
361        assert_eq!(s.get_field("x"), Some(&IrType::F32));
362        assert_eq!(s.get_field_index("y"), Some(1));
363    }
364
365    #[test]
366    fn test_function_type() {
367        let ft = FunctionType::new(vec![IrType::I32, IrType::F32], IrType::F32);
368        assert_eq!(format!("{}", ft), "fn(i32, f32) -> f32");
369    }
370}