Skip to main content

airl_ir/
types.rs

1use crate::effects::Effect;
2use crate::ids::{Symbol, TypeId};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::fmt;
5
6/// A variant in an enum type definition.
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
8pub struct Variant {
9    /// Variant name (e.g. `Ok`, `Err`, `None`, `Some`).
10    pub name: Symbol,
11    /// Named fields of this variant, with their types.
12    pub fields: Vec<(Symbol, Type)>,
13}
14
15/// The core type system for AIRL IR.
16#[derive(Clone, Debug, PartialEq)]
17#[allow(missing_docs)] // variant names are self-documenting
18pub enum Type {
19    /// The unit type `()`. Also used as a wildcard for generic builtins.
20    Unit,
21    /// Boolean type: `true` or `false`.
22    Bool,
23    I8,
24    I16,
25    I32,
26    /// Signed 64-bit integer (the default integer type in AIRL).
27    I64,
28    U8,
29    U16,
30    U32,
31    U64,
32    F32,
33    /// 64-bit floating-point number.
34    F64,
35    /// UTF-8 string.
36    String,
37    /// Raw byte array.
38    Bytes,
39
40    // Composite
41    Array {
42        element: Box<Type>,
43    },
44    Tuple {
45        elements: Vec<Type>,
46    },
47    Struct {
48        name: Symbol,
49        fields: Vec<(Symbol, Type)>,
50    },
51    Enum {
52        name: Symbol,
53        variants: Vec<Variant>,
54    },
55    Function {
56        params: Vec<Type>,
57        returns: Box<Type>,
58        effects: Vec<Effect>,
59    },
60    Reference {
61        inner: Box<Type>,
62        mutable: bool,
63    },
64    Optional {
65        inner: Box<Type>,
66    },
67    Result {
68        ok: Box<Type>,
69        err: Box<Type>,
70    },
71    TypeParam {
72        name: Symbol,
73        bounds: Vec<std::string::String>,
74    },
75    Generic {
76        base: Box<Type>,
77        args: Vec<Type>,
78    },
79    Named(TypeId),
80}
81
82impl Type {
83    /// Parse a type string from the JSON IR format into a Type.
84    ///
85    /// Handles simple types like `I64`, `Bool`, `String`, `Unit`,
86    /// composite types like `Array<I64>`, `Optional<String>`,
87    /// `Result<I64,String>`, and falls back to Named for unknown types.
88    pub fn from_type_str(s: &str) -> Type {
89        let s = s.trim();
90        match s {
91            "Unit" | "()" => Type::Unit,
92            "Bool" => Type::Bool,
93            "I8" => Type::I8,
94            "I16" => Type::I16,
95            "I32" => Type::I32,
96            "I64" => Type::I64,
97            "U8" => Type::U8,
98            "U16" => Type::U16,
99            "U32" => Type::U32,
100            "U64" => Type::U64,
101            "F32" => Type::F32,
102            "F64" => Type::F64,
103            "String" => Type::String,
104            "Bytes" => Type::Bytes,
105            _ => {
106                // Try to parse generic types like Array<I64>, Optional<String>, etc.
107                if let Some(inner_str) = strip_generic(s, "Array") {
108                    Type::Array {
109                        element: Box::new(Type::from_type_str(inner_str)),
110                    }
111                } else if let Some(inner_str) = strip_generic(s, "Optional") {
112                    Type::Optional {
113                        inner: Box::new(Type::from_type_str(inner_str)),
114                    }
115                } else if let Some(inner_str) = strip_generic(s, "Result") {
116                    // Result<Ok, Err>
117                    let parts = split_type_args(inner_str);
118                    if parts.len() == 2 {
119                        Type::Result {
120                            ok: Box::new(Type::from_type_str(&parts[0])),
121                            err: Box::new(Type::from_type_str(&parts[1])),
122                        }
123                    } else {
124                        Type::Named(TypeId::new(s))
125                    }
126                } else if let Some(inner_str) = strip_generic(s, "Tuple") {
127                    let parts = split_type_args(inner_str);
128                    Type::Tuple {
129                        elements: parts.iter().map(|p| Type::from_type_str(p)).collect(),
130                    }
131                } else if let Some(inner_str) = strip_generic(s, "Ref") {
132                    Type::Reference {
133                        inner: Box::new(Type::from_type_str(inner_str)),
134                        mutable: false,
135                    }
136                } else if let Some(inner_str) = strip_generic(s, "MutRef") {
137                    Type::Reference {
138                        inner: Box::new(Type::from_type_str(inner_str)),
139                        mutable: true,
140                    }
141                } else {
142                    Type::Named(TypeId::new(s))
143                }
144            }
145        }
146    }
147
148    /// Convert a Type to its string representation for JSON serialization.
149    pub fn to_type_str(&self) -> std::string::String {
150        match self {
151            Type::Unit => "Unit".into(),
152            Type::Bool => "Bool".into(),
153            Type::I8 => "I8".into(),
154            Type::I16 => "I16".into(),
155            Type::I32 => "I32".into(),
156            Type::I64 => "I64".into(),
157            Type::U8 => "U8".into(),
158            Type::U16 => "U16".into(),
159            Type::U32 => "U32".into(),
160            Type::U64 => "U64".into(),
161            Type::F32 => "F32".into(),
162            Type::F64 => "F64".into(),
163            Type::String => "String".into(),
164            Type::Bytes => "Bytes".into(),
165            Type::Array { element } => format!("Array<{}>", element.to_type_str()),
166            Type::Tuple { elements } => {
167                let inner: Vec<_> = elements.iter().map(|e| e.to_type_str()).collect();
168                format!("Tuple<{}>", inner.join(", "))
169            }
170            Type::Optional { inner } => format!("Optional<{}>", inner.to_type_str()),
171            Type::Result { ok, err } => {
172                format!("Result<{}, {}>", ok.to_type_str(), err.to_type_str())
173            }
174            Type::Reference { inner, mutable } => {
175                if *mutable {
176                    format!("MutRef<{}>", inner.to_type_str())
177                } else {
178                    format!("Ref<{}>", inner.to_type_str())
179                }
180            }
181            Type::Named(id) => id.0.clone(),
182            Type::Struct { name, .. } => name.0.clone(),
183            Type::Enum { name, .. } => name.0.clone(),
184            Type::Function {
185                params, returns, ..
186            } => {
187                let p: Vec<_> = params.iter().map(|t| t.to_type_str()).collect();
188                format!("Fn({}) -> {}", p.join(", "), returns.to_type_str())
189            }
190            Type::TypeParam { name, .. } => name.0.clone(),
191            Type::Generic { base, args } => {
192                let a: Vec<_> = args.iter().map(|t| t.to_type_str()).collect();
193                format!("{}<{}>", base.to_type_str(), a.join(", "))
194            }
195        }
196    }
197}
198
199/// Strip a generic wrapper like "Array<I64>" -> "I64" given prefix "Array".
200fn strip_generic<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
201    let s = s.trim();
202    if let Some(rest) = s.strip_prefix(prefix) {
203        let rest = rest.trim_start();
204        if rest.starts_with('<') && rest.ends_with('>') {
205            Some(&rest[1..rest.len() - 1])
206        } else {
207            None
208        }
209    } else {
210        None
211    }
212}
213
214/// Split type arguments at the top level, respecting nested angle brackets.
215/// "I64, String" -> ["I64", "String"]
216/// "Array<I64>, String" -> ["Array<I64>", "String"]
217fn split_type_args(s: &str) -> Vec<std::string::String> {
218    let mut result = Vec::new();
219    let mut depth = 0;
220    let mut current = std::string::String::new();
221
222    for ch in s.chars() {
223        match ch {
224            '<' => {
225                depth += 1;
226                current.push(ch);
227            }
228            '>' => {
229                depth -= 1;
230                current.push(ch);
231            }
232            ',' if depth == 0 => {
233                result.push(current.trim().to_string());
234                current = std::string::String::new();
235            }
236            _ => {
237                current.push(ch);
238            }
239        }
240    }
241
242    let trimmed = current.trim().to_string();
243    if !trimmed.is_empty() {
244        result.push(trimmed);
245    }
246
247    result
248}
249
250impl Serialize for Type {
251    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
252    where
253        S: Serializer,
254    {
255        serializer.serialize_str(&self.to_type_str())
256    }
257}
258
259impl<'de> Deserialize<'de> for Type {
260    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
261    where
262        D: Deserializer<'de>,
263    {
264        let s = std::string::String::deserialize(deserializer)?;
265        Ok(Type::from_type_str(&s))
266    }
267}
268
269impl fmt::Display for Type {
270    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271        write!(f, "{}", self.to_type_str())
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_primitive_roundtrip() {
281        let types = vec![
282            Type::Unit,
283            Type::Bool,
284            Type::I32,
285            Type::I64,
286            Type::F64,
287            Type::String,
288        ];
289        for t in types {
290            let s = t.to_type_str();
291            let parsed = Type::from_type_str(&s);
292            assert_eq!(t, parsed, "roundtrip failed for {s}");
293        }
294    }
295
296    #[test]
297    fn test_array_roundtrip() {
298        let t = Type::Array {
299            element: Box::new(Type::I64),
300        };
301        assert_eq!(t.to_type_str(), "Array<I64>");
302        assert_eq!(Type::from_type_str("Array<I64>"), t);
303    }
304
305    #[test]
306    fn test_result_roundtrip() {
307        let t = Type::Result {
308            ok: Box::new(Type::I64),
309            err: Box::new(Type::String),
310        };
311        assert_eq!(t.to_type_str(), "Result<I64, String>");
312        assert_eq!(Type::from_type_str("Result<I64, String>"), t);
313    }
314
315    #[test]
316    fn test_named_fallback() {
317        let t = Type::from_type_str("MyCustomType");
318        assert_eq!(t, Type::Named(TypeId::new("MyCustomType")));
319    }
320
321    #[test]
322    fn test_serde_roundtrip() {
323        let t = Type::I64;
324        let json = serde_json::to_string(&t).unwrap();
325        assert_eq!(json, "\"I64\"");
326        let parsed: Type = serde_json::from_str(&json).unwrap();
327        assert_eq!(parsed, t);
328    }
329}