1use std::fmt;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum ScalarType {
10 Bool,
12 I8,
14 I16,
16 I32,
18 I64,
20 U8,
22 U16,
24 U32,
26 U64,
28 F16,
30 F32,
32 F64,
34}
35
36impl ScalarType {
37 pub fn size_bytes(&self) -> usize {
39 match self {
40 ScalarType::Bool | ScalarType::I8 | ScalarType::U8 => 1,
41 ScalarType::I16 | ScalarType::U16 | ScalarType::F16 => 2,
42 ScalarType::I32 | ScalarType::U32 | ScalarType::F32 => 4,
43 ScalarType::I64 | ScalarType::U64 | ScalarType::F64 => 8,
44 }
45 }
46
47 pub fn is_float(&self) -> bool {
49 matches!(self, ScalarType::F16 | ScalarType::F32 | ScalarType::F64)
50 }
51
52 pub fn is_signed_int(&self) -> bool {
54 matches!(
55 self,
56 ScalarType::I8 | ScalarType::I16 | ScalarType::I32 | ScalarType::I64
57 )
58 }
59
60 pub fn is_unsigned_int(&self) -> bool {
62 matches!(
63 self,
64 ScalarType::U8 | ScalarType::U16 | ScalarType::U32 | ScalarType::U64
65 )
66 }
67
68 pub fn is_int(&self) -> bool {
70 self.is_signed_int() || self.is_unsigned_int()
71 }
72
73 pub fn requires_f64(&self) -> bool {
75 matches!(self, ScalarType::F64)
76 }
77
78 pub fn requires_i64(&self) -> bool {
80 matches!(self, ScalarType::I64 | ScalarType::U64)
81 }
82}
83
84impl fmt::Display for ScalarType {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 match self {
87 ScalarType::Bool => write!(f, "bool"),
88 ScalarType::I8 => write!(f, "i8"),
89 ScalarType::I16 => write!(f, "i16"),
90 ScalarType::I32 => write!(f, "i32"),
91 ScalarType::I64 => write!(f, "i64"),
92 ScalarType::U8 => write!(f, "u8"),
93 ScalarType::U16 => write!(f, "u16"),
94 ScalarType::U32 => write!(f, "u32"),
95 ScalarType::U64 => write!(f, "u64"),
96 ScalarType::F16 => write!(f, "f16"),
97 ScalarType::F32 => write!(f, "f32"),
98 ScalarType::F64 => write!(f, "f64"),
99 }
100 }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
105pub struct VectorType {
106 pub element: ScalarType,
108 pub count: u8,
110}
111
112impl VectorType {
113 pub fn new(element: ScalarType, count: u8) -> Self {
115 debug_assert!((2..=4).contains(&count), "Vector count must be 2, 3, or 4");
116 Self { element, count }
117 }
118
119 pub fn size_bytes(&self) -> usize {
121 self.element.size_bytes() * self.count as usize
122 }
123}
124
125impl fmt::Display for VectorType {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 write!(f, "vec{}<{}>", self.count, self.element)
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133pub enum IrType {
134 Void,
136 Scalar(ScalarType),
138 Vector(VectorType),
140 Ptr(Box<IrType>),
142 Array(Box<IrType>, usize),
144 Slice(Box<IrType>),
146 Struct(StructType),
148 Function(FunctionType),
150}
151
152impl IrType {
153 pub const BOOL: IrType = IrType::Scalar(ScalarType::Bool);
157 pub const I32: IrType = IrType::Scalar(ScalarType::I32);
159 pub const I64: IrType = IrType::Scalar(ScalarType::I64);
161 pub const U32: IrType = IrType::Scalar(ScalarType::U32);
163 pub const U64: IrType = IrType::Scalar(ScalarType::U64);
165 pub const F32: IrType = IrType::Scalar(ScalarType::F32);
167 pub const F64: IrType = IrType::Scalar(ScalarType::F64);
169
170 pub fn ptr(inner: IrType) -> Self {
172 IrType::Ptr(Box::new(inner))
173 }
174
175 pub fn array(inner: IrType, size: usize) -> Self {
177 IrType::Array(Box::new(inner), size)
178 }
179
180 pub fn slice(inner: IrType) -> Self {
182 IrType::Slice(Box::new(inner))
183 }
184
185 pub fn size_bytes(&self) -> Option<usize> {
187 match self {
188 IrType::Void => Some(0),
189 IrType::Scalar(s) => Some(s.size_bytes()),
190 IrType::Vector(v) => Some(v.size_bytes()),
191 IrType::Ptr(_) => Some(8), IrType::Array(inner, count) => inner.size_bytes().map(|s| s * count),
193 IrType::Slice(_) => None, IrType::Struct(s) => s.size_bytes(),
195 IrType::Function(_) => None,
196 }
197 }
198
199 pub fn is_ptr(&self) -> bool {
201 matches!(self, IrType::Ptr(_))
202 }
203
204 pub fn is_scalar(&self) -> bool {
206 matches!(self, IrType::Scalar(_))
207 }
208
209 pub fn is_numeric(&self) -> bool {
211 match self {
212 IrType::Scalar(s) => s.is_float() || s.is_int(),
213 _ => false,
214 }
215 }
216
217 pub fn element_type(&self) -> Option<&IrType> {
219 match self {
220 IrType::Ptr(inner) | IrType::Array(inner, _) | IrType::Slice(inner) => Some(inner),
221 _ => None,
222 }
223 }
224}
225
226impl fmt::Display for IrType {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 match self {
229 IrType::Void => write!(f, "void"),
230 IrType::Scalar(s) => write!(f, "{}", s),
231 IrType::Vector(v) => write!(f, "{}", v),
232 IrType::Ptr(inner) => write!(f, "*{}", inner),
233 IrType::Array(inner, size) => write!(f, "[{}; {}]", inner, size),
234 IrType::Slice(inner) => write!(f, "[{}]", inner),
235 IrType::Struct(s) => write!(f, "struct {}", s.name),
236 IrType::Function(ft) => write!(f, "{}", ft),
237 }
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Hash)]
243pub struct StructType {
244 pub name: String,
246 pub fields: Vec<(String, IrType)>,
248}
249
250impl StructType {
251 pub fn new(name: impl Into<String>, fields: Vec<(String, IrType)>) -> Self {
253 Self {
254 name: name.into(),
255 fields,
256 }
257 }
258
259 pub fn size_bytes(&self) -> Option<usize> {
261 let mut size = 0;
262 for (_, ty) in &self.fields {
263 size += ty.size_bytes()?;
264 }
265 Some(size)
266 }
267
268 pub fn get_field(&self, name: &str) -> Option<&IrType> {
270 self.fields
271 .iter()
272 .find(|(n, _)| n == name)
273 .map(|(_, ty)| ty)
274 }
275
276 pub fn get_field_index(&self, name: &str) -> Option<usize> {
278 self.fields.iter().position(|(n, _)| n == name)
279 }
280}
281
282#[derive(Debug, Clone, PartialEq, Eq, Hash)]
284pub struct FunctionType {
285 pub params: Vec<IrType>,
287 pub return_type: Box<IrType>,
289}
290
291impl FunctionType {
292 pub fn new(params: Vec<IrType>, return_type: IrType) -> Self {
294 Self {
295 params,
296 return_type: Box::new(return_type),
297 }
298 }
299}
300
301impl fmt::Display for FunctionType {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 write!(f, "fn(")?;
304 for (i, param) in self.params.iter().enumerate() {
305 if i > 0 {
306 write!(f, ", ")?;
307 }
308 write!(f, "{}", param)?;
309 }
310 write!(f, ") -> {}", self.return_type)
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_scalar_size() {
320 assert_eq!(ScalarType::Bool.size_bytes(), 1);
321 assert_eq!(ScalarType::I32.size_bytes(), 4);
322 assert_eq!(ScalarType::F64.size_bytes(), 8);
323 }
324
325 #[test]
326 fn test_scalar_classification() {
327 assert!(ScalarType::F32.is_float());
328 assert!(!ScalarType::I32.is_float());
329
330 assert!(ScalarType::I32.is_signed_int());
331 assert!(!ScalarType::U32.is_signed_int());
332
333 assert!(ScalarType::U32.is_unsigned_int());
334 assert!(!ScalarType::I32.is_unsigned_int());
335 }
336
337 #[test]
338 fn test_vector_type() {
339 let v = VectorType::new(ScalarType::F32, 4);
340 assert_eq!(v.size_bytes(), 16);
341 assert_eq!(format!("{}", v), "vec4<f32>");
342 }
343
344 #[test]
345 fn test_ir_type_display() {
346 assert_eq!(format!("{}", IrType::I32), "i32");
347 assert_eq!(format!("{}", IrType::ptr(IrType::F32)), "*f32");
348 assert_eq!(format!("{}", IrType::array(IrType::I32, 16)), "[i32; 16]");
349 }
350
351 #[test]
352 fn test_struct_type() {
353 let s = StructType::new(
354 "Point",
355 vec![
356 ("x".to_string(), IrType::F32),
357 ("y".to_string(), IrType::F32),
358 ],
359 );
360 assert_eq!(s.size_bytes(), Some(8));
361 assert_eq!(s.get_field("x"), Some(&IrType::F32));
362 assert_eq!(s.get_field_index("y"), Some(1));
363 }
364
365 #[test]
366 fn test_function_type() {
367 let ft = FunctionType::new(vec![IrType::I32, IrType::F32], IrType::F32);
368 assert_eq!(format!("{}", ft), "fn(i32, f32) -> f32");
369 }
370}