1use crate::effects::Effect;
2use crate::ids::{Symbol, TypeId};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::fmt;
5
6#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
8pub struct Variant {
9 pub name: Symbol,
11 pub fields: Vec<(Symbol, Type)>,
13}
14
15#[derive(Clone, Debug, PartialEq)]
17#[allow(missing_docs)] pub enum Type {
19 Unit,
21 Bool,
23 I8,
24 I16,
25 I32,
26 I64,
28 U8,
29 U16,
30 U32,
31 U64,
32 F32,
33 F64,
35 String,
37 Bytes,
39
40 Array {
42 element: Box<Type>,
43 },
44 Tuple {
45 elements: Vec<Type>,
46 },
47 Struct {
48 name: Symbol,
49 fields: Vec<(Symbol, Type)>,
50 },
51 Enum {
52 name: Symbol,
53 variants: Vec<Variant>,
54 },
55 Function {
56 params: Vec<Type>,
57 returns: Box<Type>,
58 effects: Vec<Effect>,
59 },
60 Reference {
61 inner: Box<Type>,
62 mutable: bool,
63 },
64 Optional {
65 inner: Box<Type>,
66 },
67 Result {
68 ok: Box<Type>,
69 err: Box<Type>,
70 },
71 TypeParam {
72 name: Symbol,
73 bounds: Vec<std::string::String>,
74 },
75 Generic {
76 base: Box<Type>,
77 args: Vec<Type>,
78 },
79 Named(TypeId),
80}
81
82impl Type {
83 pub fn from_type_str(s: &str) -> Type {
89 let s = s.trim();
90 match s {
91 "Unit" | "()" => Type::Unit,
92 "Bool" => Type::Bool,
93 "I8" => Type::I8,
94 "I16" => Type::I16,
95 "I32" => Type::I32,
96 "I64" => Type::I64,
97 "U8" => Type::U8,
98 "U16" => Type::U16,
99 "U32" => Type::U32,
100 "U64" => Type::U64,
101 "F32" => Type::F32,
102 "F64" => Type::F64,
103 "String" => Type::String,
104 "Bytes" => Type::Bytes,
105 _ => {
106 if let Some(inner_str) = strip_generic(s, "Array") {
108 Type::Array {
109 element: Box::new(Type::from_type_str(inner_str)),
110 }
111 } else if let Some(inner_str) = strip_generic(s, "Optional") {
112 Type::Optional {
113 inner: Box::new(Type::from_type_str(inner_str)),
114 }
115 } else if let Some(inner_str) = strip_generic(s, "Result") {
116 let parts = split_type_args(inner_str);
118 if parts.len() == 2 {
119 Type::Result {
120 ok: Box::new(Type::from_type_str(&parts[0])),
121 err: Box::new(Type::from_type_str(&parts[1])),
122 }
123 } else {
124 Type::Named(TypeId::new(s))
125 }
126 } else if let Some(inner_str) = strip_generic(s, "Tuple") {
127 let parts = split_type_args(inner_str);
128 Type::Tuple {
129 elements: parts.iter().map(|p| Type::from_type_str(p)).collect(),
130 }
131 } else if let Some(inner_str) = strip_generic(s, "Ref") {
132 Type::Reference {
133 inner: Box::new(Type::from_type_str(inner_str)),
134 mutable: false,
135 }
136 } else if let Some(inner_str) = strip_generic(s, "MutRef") {
137 Type::Reference {
138 inner: Box::new(Type::from_type_str(inner_str)),
139 mutable: true,
140 }
141 } else {
142 Type::Named(TypeId::new(s))
143 }
144 }
145 }
146 }
147
148 pub fn to_type_str(&self) -> std::string::String {
150 match self {
151 Type::Unit => "Unit".into(),
152 Type::Bool => "Bool".into(),
153 Type::I8 => "I8".into(),
154 Type::I16 => "I16".into(),
155 Type::I32 => "I32".into(),
156 Type::I64 => "I64".into(),
157 Type::U8 => "U8".into(),
158 Type::U16 => "U16".into(),
159 Type::U32 => "U32".into(),
160 Type::U64 => "U64".into(),
161 Type::F32 => "F32".into(),
162 Type::F64 => "F64".into(),
163 Type::String => "String".into(),
164 Type::Bytes => "Bytes".into(),
165 Type::Array { element } => format!("Array<{}>", element.to_type_str()),
166 Type::Tuple { elements } => {
167 let inner: Vec<_> = elements.iter().map(|e| e.to_type_str()).collect();
168 format!("Tuple<{}>", inner.join(", "))
169 }
170 Type::Optional { inner } => format!("Optional<{}>", inner.to_type_str()),
171 Type::Result { ok, err } => {
172 format!("Result<{}, {}>", ok.to_type_str(), err.to_type_str())
173 }
174 Type::Reference { inner, mutable } => {
175 if *mutable {
176 format!("MutRef<{}>", inner.to_type_str())
177 } else {
178 format!("Ref<{}>", inner.to_type_str())
179 }
180 }
181 Type::Named(id) => id.0.clone(),
182 Type::Struct { name, .. } => name.0.clone(),
183 Type::Enum { name, .. } => name.0.clone(),
184 Type::Function {
185 params, returns, ..
186 } => {
187 let p: Vec<_> = params.iter().map(|t| t.to_type_str()).collect();
188 format!("Fn({}) -> {}", p.join(", "), returns.to_type_str())
189 }
190 Type::TypeParam { name, .. } => name.0.clone(),
191 Type::Generic { base, args } => {
192 let a: Vec<_> = args.iter().map(|t| t.to_type_str()).collect();
193 format!("{}<{}>", base.to_type_str(), a.join(", "))
194 }
195 }
196 }
197}
198
199fn strip_generic<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
201 let s = s.trim();
202 if let Some(rest) = s.strip_prefix(prefix) {
203 let rest = rest.trim_start();
204 if rest.starts_with('<') && rest.ends_with('>') {
205 Some(&rest[1..rest.len() - 1])
206 } else {
207 None
208 }
209 } else {
210 None
211 }
212}
213
214fn split_type_args(s: &str) -> Vec<std::string::String> {
218 let mut result = Vec::new();
219 let mut depth = 0;
220 let mut current = std::string::String::new();
221
222 for ch in s.chars() {
223 match ch {
224 '<' => {
225 depth += 1;
226 current.push(ch);
227 }
228 '>' => {
229 depth -= 1;
230 current.push(ch);
231 }
232 ',' if depth == 0 => {
233 result.push(current.trim().to_string());
234 current = std::string::String::new();
235 }
236 _ => {
237 current.push(ch);
238 }
239 }
240 }
241
242 let trimmed = current.trim().to_string();
243 if !trimmed.is_empty() {
244 result.push(trimmed);
245 }
246
247 result
248}
249
250impl Serialize for Type {
251 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
252 where
253 S: Serializer,
254 {
255 serializer.serialize_str(&self.to_type_str())
256 }
257}
258
259impl<'de> Deserialize<'de> for Type {
260 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
261 where
262 D: Deserializer<'de>,
263 {
264 let s = std::string::String::deserialize(deserializer)?;
265 Ok(Type::from_type_str(&s))
266 }
267}
268
269impl fmt::Display for Type {
270 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271 write!(f, "{}", self.to_type_str())
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_primitive_roundtrip() {
281 let types = vec![
282 Type::Unit,
283 Type::Bool,
284 Type::I32,
285 Type::I64,
286 Type::F64,
287 Type::String,
288 ];
289 for t in types {
290 let s = t.to_type_str();
291 let parsed = Type::from_type_str(&s);
292 assert_eq!(t, parsed, "roundtrip failed for {s}");
293 }
294 }
295
296 #[test]
297 fn test_array_roundtrip() {
298 let t = Type::Array {
299 element: Box::new(Type::I64),
300 };
301 assert_eq!(t.to_type_str(), "Array<I64>");
302 assert_eq!(Type::from_type_str("Array<I64>"), t);
303 }
304
305 #[test]
306 fn test_result_roundtrip() {
307 let t = Type::Result {
308 ok: Box::new(Type::I64),
309 err: Box::new(Type::String),
310 };
311 assert_eq!(t.to_type_str(), "Result<I64, String>");
312 assert_eq!(Type::from_type_str("Result<I64, String>"), t);
313 }
314
315 #[test]
316 fn test_named_fallback() {
317 let t = Type::from_type_str("MyCustomType");
318 assert_eq!(t, Type::Named(TypeId::new("MyCustomType")));
319 }
320
321 #[test]
322 fn test_serde_roundtrip() {
323 let t = Type::I64;
324 let json = serde_json::to_string(&t).unwrap();
325 assert_eq!(json, "\"I64\"");
326 let parsed: Type = serde_json::from_str(&json).unwrap();
327 assert_eq!(parsed, t);
328 }
329}