Skip to main content

airl_ir/
node.rs

1use crate::ids::NodeId;
2use crate::types::Type;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::fmt;
5
6// ---------------------------------------------------------------------------
7// Supporting enums
8// ---------------------------------------------------------------------------
9
10/// Literal values in the IR.
11#[derive(Clone, Debug, PartialEq)]
12pub enum LiteralValue {
13    /// Signed 64-bit integer literal.
14    Integer(i64),
15    /// 64-bit floating-point literal.
16    Float(f64),
17    /// Boolean literal (`true` or `false`).
18    Boolean(bool),
19    /// String literal.
20    Str(String),
21    /// Unit literal `()`.
22    Unit,
23}
24
25impl Serialize for LiteralValue {
26    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
27    where
28        S: Serializer,
29    {
30        match self {
31            LiteralValue::Integer(v) => serializer.serialize_i64(*v),
32            LiteralValue::Float(v) => serializer.serialize_f64(*v),
33            LiteralValue::Boolean(v) => serializer.serialize_bool(*v),
34            LiteralValue::Str(v) => serializer.serialize_str(v),
35            LiteralValue::Unit => serializer.serialize_none(),
36        }
37    }
38}
39
40impl<'de> Deserialize<'de> for LiteralValue {
41    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
42    where
43        D: Deserializer<'de>,
44    {
45        let val = serde_json::Value::deserialize(deserializer)?;
46        Ok(literal_from_json_value(&val))
47    }
48}
49
50fn literal_from_json_value(val: &serde_json::Value) -> LiteralValue {
51    match val {
52        serde_json::Value::Bool(b) => LiteralValue::Boolean(*b),
53        serde_json::Value::Number(n) => {
54            if let Some(i) = n.as_i64() {
55                LiteralValue::Integer(i)
56            } else if let Some(f) = n.as_f64() {
57                LiteralValue::Float(f)
58            } else {
59                LiteralValue::Integer(0)
60            }
61        }
62        serde_json::Value::String(s) => LiteralValue::Str(s.clone()),
63        serde_json::Value::Null => LiteralValue::Unit,
64        _ => LiteralValue::Unit,
65    }
66}
67
68/// Binary operators for [`Node::BinOp`].
69#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
70pub enum BinOpKind {
71    /// `a + b` — integer/float addition or string concatenation
72    Add,
73    /// `a - b`
74    Sub,
75    /// `a * b`
76    Mul,
77    /// `a / b` (signed integer division or float division)
78    Div,
79    /// `a % b` (signed remainder)
80    Mod,
81    /// `a == b`
82    Eq,
83    /// `a != b`
84    Neq,
85    /// `a < b`
86    Lt,
87    /// `a <= b`
88    Lte,
89    /// `a > b`
90    Gt,
91    /// `a >= b`
92    Gte,
93    /// Logical `a && b` (short-circuits)
94    And,
95    /// Logical `a || b` (short-circuits)
96    Or,
97    /// Bitwise `a & b`
98    BitAnd,
99    /// Bitwise `a | b`
100    BitOr,
101    /// Bitwise `a ^ b`
102    BitXor,
103    /// Left shift `a << b`
104    Shl,
105    /// Arithmetic right shift `a >> b`
106    Shr,
107}
108
109/// Unary operators for [`Node::UnaryOp`].
110#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
111pub enum UnaryOpKind {
112    /// Arithmetic negation: `-x`
113    Neg,
114    /// Logical not: `!x`
115    Not,
116    /// Bitwise not: `~x`
117    BitNot,
118}
119
120/// A pattern in a [`Node::Match`] arm.
121#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
122#[serde(tag = "kind")]
123pub enum Pattern {
124    /// Match a specific literal value (integer, string, bool, unit).
125    Literal {
126        /// The literal value to match against.
127        value: LiteralValue,
128    },
129    /// Match anything (catch-all `_`).
130    Wildcard,
131    /// Match anything and bind to a local variable.
132    Variable {
133        /// The name to bind the scrutinee value to inside the arm body.
134        name: String,
135    },
136}
137
138/// A single arm of a [`Node::Match`] expression: a pattern and the body to evaluate.
139#[derive(Clone, Debug, PartialEq)]
140pub struct MatchArm {
141    /// The pattern to match against the scrutinee.
142    pub pattern: Pattern,
143    /// The body expression, evaluated if the pattern matches.
144    pub body: Node,
145}
146
147// ---------------------------------------------------------------------------
148// Node - the core IR node type
149// ---------------------------------------------------------------------------
150
151/// The core IR node enum. Each variant represents a different kind of
152/// computation in the AIRL intermediate representation.
153///
154/// Nodes are serialized to/from JSON with a `"kind"` discriminator field.
155///
156/// All variants share:
157/// - `id` — unique [`NodeId`] for patch targeting
158/// - `node_type` — the [`Type`] this node evaluates to (except for [`Node::Error`])
159#[derive(Clone, Debug, PartialEq)]
160#[allow(missing_docs)] // field names are self-documenting
161pub enum Node {
162    /// A literal value (integer, float, bool, string, unit).
163    Literal {
164        id: NodeId,
165        node_type: Type,
166        value: LiteralValue,
167    },
168    /// Reference to a function parameter or let-bound local variable.
169    Param {
170        id: NodeId,
171        name: String,
172        index: u32,
173        node_type: Type,
174    },
175    /// `let name = value in body` — introduces a local binding.
176    Let {
177        id: NodeId,
178        name: String,
179        node_type: Type,
180        value: Box<Node>,
181        body: Box<Node>,
182    },
183    /// `if cond then then_branch else else_branch`.
184    If {
185        id: NodeId,
186        node_type: Type,
187        cond: Box<Node>,
188        then_branch: Box<Node>,
189        else_branch: Box<Node>,
190    },
191    /// A function call. `target` is a name (user-defined or builtin like `std::io::println`).
192    Call {
193        id: NodeId,
194        node_type: Type,
195        target: String,
196        args: Vec<Node>,
197    },
198    /// `return value` — early exit from a function.
199    Return {
200        id: NodeId,
201        node_type: Type,
202        value: Box<Node>,
203    },
204    /// Binary operator application (`lhs op rhs`). See [`BinOpKind`].
205    BinOp {
206        id: NodeId,
207        op: BinOpKind,
208        node_type: Type,
209        lhs: Box<Node>,
210        rhs: Box<Node>,
211    },
212    /// Unary operator application (`op operand`). See [`UnaryOpKind`].
213    UnaryOp {
214        id: NodeId,
215        op: UnaryOpKind,
216        node_type: Type,
217        operand: Box<Node>,
218    },
219    /// A sequence of statements followed by a result expression.
220    Block {
221        id: NodeId,
222        node_type: Type,
223        statements: Vec<Node>,
224        result: Box<Node>,
225    },
226    /// Infinite loop — exit only via `Return` or internal `LoopBreak`.
227    Loop {
228        id: NodeId,
229        node_type: Type,
230        body: Box<Node>,
231    },
232    /// Pattern match. Evaluates `scrutinee`, then the first matching arm's body.
233    Match {
234        id: NodeId,
235        node_type: Type,
236        scrutinee: Box<Node>,
237        arms: Vec<MatchArm>,
238    },
239    /// Struct literal: `{ field1: v1, field2: v2 }`.
240    StructLiteral {
241        id: NodeId,
242        node_type: Type,
243        fields: Vec<(String, Node)>,
244    },
245    /// Access a struct field: `object.field`.
246    FieldAccess {
247        id: NodeId,
248        node_type: Type,
249        object: Box<Node>,
250        field: String,
251    },
252    /// Array literal: `[e1, e2, e3]`.
253    ArrayLiteral {
254        id: NodeId,
255        node_type: Type,
256        elements: Vec<Node>,
257    },
258    /// Array index access: `array[index]`.
259    IndexAccess {
260        id: NodeId,
261        node_type: Type,
262        array: Box<Node>,
263        index: Box<Node>,
264    },
265    /// An explicit error node, used as a placeholder when parsing fails.
266    Error { id: NodeId, message: String },
267}
268
269impl Node {
270    /// Get the NodeId of this node.
271    pub fn id(&self) -> &NodeId {
272        match self {
273            Node::Literal { id, .. }
274            | Node::Param { id, .. }
275            | Node::Let { id, .. }
276            | Node::If { id, .. }
277            | Node::Call { id, .. }
278            | Node::Return { id, .. }
279            | Node::BinOp { id, .. }
280            | Node::UnaryOp { id, .. }
281            | Node::Block { id, .. }
282            | Node::Loop { id, .. }
283            | Node::Match { id, .. }
284            | Node::StructLiteral { id, .. }
285            | Node::FieldAccess { id, .. }
286            | Node::ArrayLiteral { id, .. }
287            | Node::IndexAccess { id, .. }
288            | Node::Error { id, .. } => id,
289        }
290    }
291
292    /// Get the type of this node (if it has one).
293    pub fn node_type(&self) -> Option<&Type> {
294        match self {
295            Node::Literal { node_type, .. }
296            | Node::Param { node_type, .. }
297            | Node::Let { node_type, .. }
298            | Node::If { node_type, .. }
299            | Node::Call { node_type, .. }
300            | Node::Return { node_type, .. }
301            | Node::BinOp { node_type, .. }
302            | Node::UnaryOp { node_type, .. }
303            | Node::Block { node_type, .. }
304            | Node::Loop { node_type, .. }
305            | Node::Match { node_type, .. }
306            | Node::StructLiteral { node_type, .. }
307            | Node::FieldAccess { node_type, .. }
308            | Node::ArrayLiteral { node_type, .. }
309            | Node::IndexAccess { node_type, .. } => Some(node_type),
310            Node::Error { .. } => None,
311        }
312    }
313}
314
315// ---------------------------------------------------------------------------
316// Custom Serde for Node
317// ---------------------------------------------------------------------------
318
319// We use a flat JSON representation with a "kind" field as discriminator.
320// The "type" field is a string that gets parsed via Type::from_type_str.
321
322impl Serialize for Node {
323    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
324    where
325        S: Serializer,
326    {
327        use serde::ser::SerializeMap;
328
329        match self {
330            Node::Literal {
331                id,
332                node_type,
333                value,
334            } => {
335                let mut map = serializer.serialize_map(None)?;
336                map.serialize_entry("id", &id)?;
337                map.serialize_entry("kind", "Literal")?;
338                map.serialize_entry("type", &node_type.to_type_str())?;
339                map.serialize_entry("value", value)?;
340                map.end()
341            }
342            Node::Param {
343                id,
344                name,
345                index,
346                node_type,
347            } => {
348                let mut map = serializer.serialize_map(None)?;
349                map.serialize_entry("id", &id)?;
350                map.serialize_entry("kind", "Param")?;
351                map.serialize_entry("type", &node_type.to_type_str())?;
352                map.serialize_entry("name", name)?;
353                map.serialize_entry("index", index)?;
354                map.end()
355            }
356            Node::Let {
357                id,
358                name,
359                node_type,
360                value,
361                body,
362            } => {
363                let mut map = serializer.serialize_map(None)?;
364                map.serialize_entry("id", &id)?;
365                map.serialize_entry("kind", "Let")?;
366                map.serialize_entry("type", &node_type.to_type_str())?;
367                map.serialize_entry("name", name)?;
368                map.serialize_entry("value", &**value)?;
369                map.serialize_entry("body", &**body)?;
370                map.end()
371            }
372            Node::If {
373                id,
374                node_type,
375                cond,
376                then_branch,
377                else_branch,
378            } => {
379                let mut map = serializer.serialize_map(None)?;
380                map.serialize_entry("id", &id)?;
381                map.serialize_entry("kind", "If")?;
382                map.serialize_entry("type", &node_type.to_type_str())?;
383                map.serialize_entry("cond", &**cond)?;
384                map.serialize_entry("then_branch", &**then_branch)?;
385                map.serialize_entry("else_branch", &**else_branch)?;
386                map.end()
387            }
388            Node::Call {
389                id,
390                node_type,
391                target,
392                args,
393            } => {
394                let mut map = serializer.serialize_map(None)?;
395                map.serialize_entry("id", &id)?;
396                map.serialize_entry("kind", "Call")?;
397                map.serialize_entry("type", &node_type.to_type_str())?;
398                map.serialize_entry("target", target)?;
399                map.serialize_entry("args", args)?;
400                map.end()
401            }
402            Node::Return {
403                id,
404                node_type,
405                value,
406            } => {
407                let mut map = serializer.serialize_map(None)?;
408                map.serialize_entry("id", &id)?;
409                map.serialize_entry("kind", "Return")?;
410                map.serialize_entry("type", &node_type.to_type_str())?;
411                map.serialize_entry("value", &**value)?;
412                map.end()
413            }
414            Node::BinOp {
415                id,
416                op,
417                node_type,
418                lhs,
419                rhs,
420            } => {
421                let mut map = serializer.serialize_map(None)?;
422                map.serialize_entry("id", &id)?;
423                map.serialize_entry("kind", "BinOp")?;
424                map.serialize_entry("type", &node_type.to_type_str())?;
425                map.serialize_entry("op", op)?;
426                map.serialize_entry("lhs", &**lhs)?;
427                map.serialize_entry("rhs", &**rhs)?;
428                map.end()
429            }
430            Node::UnaryOp {
431                id,
432                op,
433                node_type,
434                operand,
435            } => {
436                let mut map = serializer.serialize_map(None)?;
437                map.serialize_entry("id", &id)?;
438                map.serialize_entry("kind", "UnaryOp")?;
439                map.serialize_entry("type", &node_type.to_type_str())?;
440                map.serialize_entry("op", op)?;
441                map.serialize_entry("operand", &**operand)?;
442                map.end()
443            }
444            Node::Block {
445                id,
446                node_type,
447                statements,
448                result,
449            } => {
450                let mut map = serializer.serialize_map(None)?;
451                map.serialize_entry("id", &id)?;
452                map.serialize_entry("kind", "Block")?;
453                map.serialize_entry("type", &node_type.to_type_str())?;
454                map.serialize_entry("statements", statements)?;
455                map.serialize_entry("result", &**result)?;
456                map.end()
457            }
458            Node::Loop {
459                id,
460                node_type,
461                body,
462            } => {
463                let mut map = serializer.serialize_map(None)?;
464                map.serialize_entry("id", &id)?;
465                map.serialize_entry("kind", "Loop")?;
466                map.serialize_entry("type", &node_type.to_type_str())?;
467                map.serialize_entry("body", &**body)?;
468                map.end()
469            }
470            Node::Match {
471                id,
472                node_type,
473                scrutinee,
474                arms,
475            } => {
476                let mut map = serializer.serialize_map(None)?;
477                map.serialize_entry("id", &id)?;
478                map.serialize_entry("kind", "Match")?;
479                map.serialize_entry("type", &node_type.to_type_str())?;
480                map.serialize_entry("scrutinee", &**scrutinee)?;
481                // Serialize arms as array of objects
482                let arm_values: Vec<serde_json::Value> = arms
483                    .iter()
484                    .map(|arm| {
485                        let body_val = serde_json::to_value(&arm.body).unwrap_or_default();
486                        let pattern_val = serde_json::to_value(&arm.pattern).unwrap_or_default();
487                        serde_json::json!({
488                            "pattern": pattern_val,
489                            "body": body_val,
490                        })
491                    })
492                    .collect();
493                map.serialize_entry("arms", &arm_values)?;
494                map.end()
495            }
496            Node::StructLiteral {
497                id,
498                node_type,
499                fields,
500            } => {
501                let mut map = serializer.serialize_map(None)?;
502                map.serialize_entry("id", &id)?;
503                map.serialize_entry("kind", "StructLiteral")?;
504                map.serialize_entry("type", &node_type.to_type_str())?;
505                // Serialize fields as array of {name, value} objects
506                let field_values: Vec<serde_json::Value> = fields
507                    .iter()
508                    .map(|(name, node)| {
509                        let node_val = serde_json::to_value(node).unwrap_or_default();
510                        serde_json::json!({
511                            "name": name,
512                            "value": node_val,
513                        })
514                    })
515                    .collect();
516                map.serialize_entry("fields", &field_values)?;
517                map.end()
518            }
519            Node::FieldAccess {
520                id,
521                node_type,
522                object,
523                field,
524            } => {
525                let mut map = serializer.serialize_map(None)?;
526                map.serialize_entry("id", &id)?;
527                map.serialize_entry("kind", "FieldAccess")?;
528                map.serialize_entry("type", &node_type.to_type_str())?;
529                map.serialize_entry("object", &**object)?;
530                map.serialize_entry("field", field)?;
531                map.end()
532            }
533            Node::ArrayLiteral {
534                id,
535                node_type,
536                elements,
537            } => {
538                let mut map = serializer.serialize_map(None)?;
539                map.serialize_entry("id", &id)?;
540                map.serialize_entry("kind", "ArrayLiteral")?;
541                map.serialize_entry("type", &node_type.to_type_str())?;
542                map.serialize_entry("elements", elements)?;
543                map.end()
544            }
545            Node::IndexAccess {
546                id,
547                node_type,
548                array,
549                index,
550            } => {
551                let mut map = serializer.serialize_map(None)?;
552                map.serialize_entry("id", &id)?;
553                map.serialize_entry("kind", "IndexAccess")?;
554                map.serialize_entry("type", &node_type.to_type_str())?;
555                map.serialize_entry("array", &**array)?;
556                map.serialize_entry("index", &**index)?;
557                map.end()
558            }
559            Node::Error { id, message } => {
560                let mut map = serializer.serialize_map(None)?;
561                map.serialize_entry("id", &id)?;
562                map.serialize_entry("kind", "Error")?;
563                map.serialize_entry("message", message)?;
564                map.end()
565            }
566        }
567    }
568}
569
570impl<'de> Deserialize<'de> for Node {
571    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
572    where
573        D: Deserializer<'de>,
574    {
575        let val = serde_json::Value::deserialize(deserializer)?;
576        node_from_value(&val).map_err(serde::de::Error::custom)
577    }
578}
579
580/// Deserialize a Node from a serde_json::Value.
581fn node_from_value(val: &serde_json::Value) -> Result<Node, String> {
582    let obj = val.as_object().ok_or("Node must be a JSON object")?;
583
584    let id = obj
585        .get("id")
586        .and_then(|v| v.as_str())
587        .map(NodeId::new)
588        .ok_or("Node missing 'id' field")?;
589
590    let kind = obj
591        .get("kind")
592        .and_then(|v| v.as_str())
593        .ok_or("Node missing 'kind' field")?;
594
595    let node_type = obj
596        .get("type")
597        .and_then(|v| v.as_str())
598        .map(Type::from_type_str)
599        .unwrap_or(Type::Unit);
600
601    match kind {
602        "Literal" => {
603            let value = obj
604                .get("value")
605                .map(literal_from_json_value)
606                .unwrap_or(LiteralValue::Unit);
607            Ok(Node::Literal {
608                id,
609                node_type,
610                value,
611            })
612        }
613        "Param" => {
614            let name = obj
615                .get("name")
616                .and_then(|v| v.as_str())
617                .unwrap_or("")
618                .to_string();
619            let index = obj.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
620            Ok(Node::Param {
621                id,
622                name,
623                index,
624                node_type,
625            })
626        }
627        "Let" => {
628            let name = obj
629                .get("name")
630                .and_then(|v| v.as_str())
631                .unwrap_or("")
632                .to_string();
633            let value = obj.get("value").ok_or("Let missing 'value'")?;
634            let body = obj.get("body").ok_or("Let missing 'body'")?;
635            Ok(Node::Let {
636                id,
637                name,
638                node_type,
639                value: Box::new(node_from_value(value)?),
640                body: Box::new(node_from_value(body)?),
641            })
642        }
643        "If" => {
644            let cond = obj.get("cond").ok_or("If missing 'cond'")?;
645            let then_branch = obj.get("then_branch").ok_or("If missing 'then_branch'")?;
646            let else_branch = obj.get("else_branch").ok_or("If missing 'else_branch'")?;
647            Ok(Node::If {
648                id,
649                node_type,
650                cond: Box::new(node_from_value(cond)?),
651                then_branch: Box::new(node_from_value(then_branch)?),
652                else_branch: Box::new(node_from_value(else_branch)?),
653            })
654        }
655        "Call" => {
656            let target = obj
657                .get("target")
658                .and_then(|v| v.as_str())
659                .unwrap_or("")
660                .to_string();
661            let args = obj
662                .get("args")
663                .and_then(|v| v.as_array())
664                .map(|arr| {
665                    arr.iter()
666                        .map(node_from_value)
667                        .collect::<Result<Vec<_>, _>>()
668                })
669                .transpose()?
670                .unwrap_or_default();
671            Ok(Node::Call {
672                id,
673                node_type,
674                target,
675                args,
676            })
677        }
678        "Return" => {
679            let value = obj.get("value").ok_or("Return missing 'value'")?;
680            Ok(Node::Return {
681                id,
682                node_type,
683                value: Box::new(node_from_value(value)?),
684            })
685        }
686        "BinOp" => {
687            let op_str = obj
688                .get("op")
689                .and_then(|v| v.as_str())
690                .ok_or("BinOp missing 'op'")?;
691            let op: BinOpKind = serde_json::from_value(serde_json::Value::String(op_str.into()))
692                .map_err(|e| format!("Invalid BinOp op: {e}"))?;
693            let lhs = obj.get("lhs").ok_or("BinOp missing 'lhs'")?;
694            let rhs = obj.get("rhs").ok_or("BinOp missing 'rhs'")?;
695            Ok(Node::BinOp {
696                id,
697                op,
698                node_type,
699                lhs: Box::new(node_from_value(lhs)?),
700                rhs: Box::new(node_from_value(rhs)?),
701            })
702        }
703        "UnaryOp" => {
704            let op_str = obj
705                .get("op")
706                .and_then(|v| v.as_str())
707                .ok_or("UnaryOp missing 'op'")?;
708            let op: UnaryOpKind = serde_json::from_value(serde_json::Value::String(op_str.into()))
709                .map_err(|e| format!("Invalid UnaryOp op: {e}"))?;
710            let operand = obj.get("operand").ok_or("UnaryOp missing 'operand'")?;
711            Ok(Node::UnaryOp {
712                id,
713                op,
714                node_type,
715                operand: Box::new(node_from_value(operand)?),
716            })
717        }
718        "Block" => {
719            let statements = obj
720                .get("statements")
721                .and_then(|v| v.as_array())
722                .map(|arr| {
723                    arr.iter()
724                        .map(node_from_value)
725                        .collect::<Result<Vec<_>, _>>()
726                })
727                .transpose()?
728                .unwrap_or_default();
729            let result = obj.get("result").ok_or("Block missing 'result'")?;
730            Ok(Node::Block {
731                id,
732                node_type,
733                statements,
734                result: Box::new(node_from_value(result)?),
735            })
736        }
737        "Loop" => {
738            let body = obj.get("body").ok_or("Loop missing 'body'")?;
739            Ok(Node::Loop {
740                id,
741                node_type,
742                body: Box::new(node_from_value(body)?),
743            })
744        }
745        "Match" => {
746            let scrutinee = obj.get("scrutinee").ok_or("Match missing 'scrutinee'")?;
747            let arms = obj
748                .get("arms")
749                .and_then(|v| v.as_array())
750                .map(|arr| {
751                    arr.iter()
752                        .map(|arm_val| {
753                            let arm_obj = arm_val.as_object().ok_or("Match arm must be object")?;
754                            let pattern: Pattern = arm_obj
755                                .get("pattern")
756                                .map(|v| {
757                                    serde_json::from_value(v.clone())
758                                        .map_err(|e| format!("Invalid pattern: {e}"))
759                                })
760                                .transpose()?
761                                .unwrap_or(Pattern::Wildcard);
762                            let body = arm_obj.get("body").ok_or("Match arm missing 'body'")?;
763                            Ok(MatchArm {
764                                pattern,
765                                body: node_from_value(body)?,
766                            })
767                        })
768                        .collect::<Result<Vec<_>, String>>()
769                })
770                .transpose()?
771                .unwrap_or_default();
772            Ok(Node::Match {
773                id,
774                node_type,
775                scrutinee: Box::new(node_from_value(scrutinee)?),
776                arms,
777            })
778        }
779        "StructLiteral" => {
780            let fields = obj
781                .get("fields")
782                .and_then(|v| v.as_array())
783                .map(|arr| {
784                    arr.iter()
785                        .map(|field_val| {
786                            let field_obj = field_val
787                                .as_object()
788                                .ok_or("StructLiteral field must be object")?;
789                            let name = field_obj
790                                .get("name")
791                                .and_then(|v| v.as_str())
792                                .unwrap_or("")
793                                .to_string();
794                            let value = field_obj
795                                .get("value")
796                                .ok_or("StructLiteral field missing 'value'")?;
797                            Ok((name, node_from_value(value)?))
798                        })
799                        .collect::<Result<Vec<_>, String>>()
800                })
801                .transpose()?
802                .unwrap_or_default();
803            Ok(Node::StructLiteral {
804                id,
805                node_type,
806                fields,
807            })
808        }
809        "FieldAccess" => {
810            let object = obj.get("object").ok_or("FieldAccess missing 'object'")?;
811            let field = obj
812                .get("field")
813                .and_then(|v| v.as_str())
814                .unwrap_or("")
815                .to_string();
816            Ok(Node::FieldAccess {
817                id,
818                node_type,
819                object: Box::new(node_from_value(object)?),
820                field,
821            })
822        }
823        "ArrayLiteral" => {
824            let elements = obj
825                .get("elements")
826                .and_then(|v| v.as_array())
827                .map(|arr| {
828                    arr.iter()
829                        .map(node_from_value)
830                        .collect::<Result<Vec<_>, _>>()
831                })
832                .transpose()?
833                .unwrap_or_default();
834            Ok(Node::ArrayLiteral {
835                id,
836                node_type,
837                elements,
838            })
839        }
840        "IndexAccess" => {
841            let array = obj.get("array").ok_or("IndexAccess missing 'array'")?;
842            let index = obj.get("index").ok_or("IndexAccess missing 'index'")?;
843            Ok(Node::IndexAccess {
844                id,
845                node_type,
846                array: Box::new(node_from_value(array)?),
847                index: Box::new(node_from_value(index)?),
848            })
849        }
850        "Error" => {
851            let message = obj
852                .get("message")
853                .and_then(|v| v.as_str())
854                .unwrap_or("")
855                .to_string();
856            Ok(Node::Error { id, message })
857        }
858        other => Err(format!("Unknown node kind: {other}")),
859    }
860}
861
862impl Serialize for MatchArm {
863    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
864    where
865        S: Serializer,
866    {
867        use serde::ser::SerializeMap;
868        let mut map = serializer.serialize_map(Some(2))?;
869        map.serialize_entry("pattern", &self.pattern)?;
870        map.serialize_entry("body", &self.body)?;
871        map.end()
872    }
873}
874
875impl<'de> Deserialize<'de> for MatchArm {
876    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
877    where
878        D: Deserializer<'de>,
879    {
880        let val = serde_json::Value::deserialize(deserializer)?;
881        let obj = val
882            .as_object()
883            .ok_or_else(|| serde::de::Error::custom("MatchArm must be a JSON object"))?;
884        let pattern: Pattern = obj
885            .get("pattern")
886            .map(|v| serde_json::from_value(v.clone()).map_err(serde::de::Error::custom))
887            .transpose()?
888            .unwrap_or(Pattern::Wildcard);
889        let body = obj
890            .get("body")
891            .ok_or_else(|| serde::de::Error::custom("MatchArm missing 'body'"))?;
892        let body = node_from_value(body).map_err(serde::de::Error::custom)?;
893        Ok(MatchArm { pattern, body })
894    }
895}
896
897impl fmt::Display for Node {
898    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899        match self {
900            Node::Literal { value, .. } => write!(f, "{value:?}"),
901            Node::Param { name, .. } => write!(f, "param:{name}"),
902            Node::Let { name, .. } => write!(f, "let {name}"),
903            Node::If { .. } => write!(f, "if"),
904            Node::Call { target, .. } => write!(f, "call {target}"),
905            Node::Return { .. } => write!(f, "return"),
906            Node::BinOp { op, .. } => write!(f, "binop {op:?}"),
907            Node::UnaryOp { op, .. } => write!(f, "unaryop {op:?}"),
908            Node::Block { .. } => write!(f, "block"),
909            Node::Loop { .. } => write!(f, "loop"),
910            Node::Match { .. } => write!(f, "match"),
911            Node::StructLiteral { .. } => write!(f, "struct literal"),
912            Node::FieldAccess { field, .. } => write!(f, ".{field}"),
913            Node::ArrayLiteral { .. } => write!(f, "array literal"),
914            Node::IndexAccess { .. } => write!(f, "index access"),
915            Node::Error { message, .. } => write!(f, "error: {message}"),
916        }
917    }
918}
919
920#[cfg(test)]
921mod tests {
922    use super::*;
923    use crate::ids::TypeId;
924
925    #[test]
926    fn test_literal_node_roundtrip() {
927        let node = Node::Literal {
928            id: NodeId::new("n_1"),
929            node_type: Type::I64,
930            value: LiteralValue::Integer(42),
931        };
932        let json = serde_json::to_string(&node).unwrap();
933        let parsed: Node = serde_json::from_str(&json).unwrap();
934        assert_eq!(node, parsed);
935    }
936
937    #[test]
938    fn test_call_node_roundtrip() {
939        let node = Node::Call {
940            id: NodeId::new("n_100"),
941            node_type: Type::Unit,
942            target: "std::io::println".to_string(),
943            args: vec![Node::Literal {
944                id: NodeId::new("n_101"),
945                node_type: Type::String,
946                value: LiteralValue::Str("hello world".to_string()),
947            }],
948        };
949        let json = serde_json::to_string_pretty(&node).unwrap();
950        let parsed: Node = serde_json::from_str(&json).unwrap();
951        assert_eq!(node, parsed);
952    }
953
954    #[test]
955    fn test_binop_node_roundtrip() {
956        let node = Node::BinOp {
957            id: NodeId::new("n_5"),
958            op: BinOpKind::Add,
959            node_type: Type::I64,
960            lhs: Box::new(Node::Literal {
961                id: NodeId::new("n_6"),
962                node_type: Type::I64,
963                value: LiteralValue::Integer(1),
964            }),
965            rhs: Box::new(Node::Literal {
966                id: NodeId::new("n_7"),
967                node_type: Type::I64,
968                value: LiteralValue::Integer(2),
969            }),
970        };
971        let json = serde_json::to_string(&node).unwrap();
972        let parsed: Node = serde_json::from_str(&json).unwrap();
973        assert_eq!(node, parsed);
974    }
975
976    #[test]
977    fn test_param_roundtrip() {
978        let node = Node::Param {
979            id: NodeId::new("n_1"),
980            name: "x".to_string(),
981            index: 0,
982            node_type: Type::I64,
983        };
984        let json = serde_json::to_string(&node).unwrap();
985        let parsed: Node = serde_json::from_str(&json).unwrap();
986        assert_eq!(node, parsed);
987    }
988
989    #[test]
990    fn test_let_roundtrip() {
991        let node = Node::Let {
992            id: NodeId::new("n_1"),
993            name: "x".to_string(),
994            node_type: Type::I64,
995            value: Box::new(Node::Literal {
996                id: NodeId::new("n_2"),
997                node_type: Type::I64,
998                value: LiteralValue::Integer(42),
999            }),
1000            body: Box::new(Node::Param {
1001                id: NodeId::new("n_3"),
1002                name: "x".to_string(),
1003                index: 0,
1004                node_type: Type::I64,
1005            }),
1006        };
1007        let json = serde_json::to_string(&node).unwrap();
1008        let parsed: Node = serde_json::from_str(&json).unwrap();
1009        assert_eq!(node, parsed);
1010    }
1011
1012    #[test]
1013    fn test_if_roundtrip() {
1014        let node = Node::If {
1015            id: NodeId::new("n_1"),
1016            node_type: Type::I64,
1017            cond: Box::new(Node::Literal {
1018                id: NodeId::new("n_2"),
1019                node_type: Type::Bool,
1020                value: LiteralValue::Boolean(true),
1021            }),
1022            then_branch: Box::new(Node::Literal {
1023                id: NodeId::new("n_3"),
1024                node_type: Type::I64,
1025                value: LiteralValue::Integer(1),
1026            }),
1027            else_branch: Box::new(Node::Literal {
1028                id: NodeId::new("n_4"),
1029                node_type: Type::I64,
1030                value: LiteralValue::Integer(0),
1031            }),
1032        };
1033        let json = serde_json::to_string(&node).unwrap();
1034        let parsed: Node = serde_json::from_str(&json).unwrap();
1035        assert_eq!(node, parsed);
1036    }
1037
1038    #[test]
1039    fn test_return_roundtrip() {
1040        let node = Node::Return {
1041            id: NodeId::new("n_1"),
1042            node_type: Type::I64,
1043            value: Box::new(Node::Literal {
1044                id: NodeId::new("n_2"),
1045                node_type: Type::I64,
1046                value: LiteralValue::Integer(42),
1047            }),
1048        };
1049        let json = serde_json::to_string(&node).unwrap();
1050        let parsed: Node = serde_json::from_str(&json).unwrap();
1051        assert_eq!(node, parsed);
1052    }
1053
1054    #[test]
1055    fn test_unaryop_roundtrip() {
1056        let node = Node::UnaryOp {
1057            id: NodeId::new("n_1"),
1058            op: UnaryOpKind::Neg,
1059            node_type: Type::I64,
1060            operand: Box::new(Node::Literal {
1061                id: NodeId::new("n_2"),
1062                node_type: Type::I64,
1063                value: LiteralValue::Integer(5),
1064            }),
1065        };
1066        let json = serde_json::to_string(&node).unwrap();
1067        let parsed: Node = serde_json::from_str(&json).unwrap();
1068        assert_eq!(node, parsed);
1069    }
1070
1071    #[test]
1072    fn test_block_roundtrip() {
1073        let node = Node::Block {
1074            id: NodeId::new("n_1"),
1075            node_type: Type::I64,
1076            statements: vec![Node::Literal {
1077                id: NodeId::new("n_2"),
1078                node_type: Type::Unit,
1079                value: LiteralValue::Unit,
1080            }],
1081            result: Box::new(Node::Literal {
1082                id: NodeId::new("n_3"),
1083                node_type: Type::I64,
1084                value: LiteralValue::Integer(42),
1085            }),
1086        };
1087        let json = serde_json::to_string(&node).unwrap();
1088        let parsed: Node = serde_json::from_str(&json).unwrap();
1089        assert_eq!(node, parsed);
1090    }
1091
1092    #[test]
1093    fn test_loop_roundtrip() {
1094        let node = Node::Loop {
1095            id: NodeId::new("n_1"),
1096            node_type: Type::Unit,
1097            body: Box::new(Node::Literal {
1098                id: NodeId::new("n_2"),
1099                node_type: Type::Unit,
1100                value: LiteralValue::Unit,
1101            }),
1102        };
1103        let json = serde_json::to_string(&node).unwrap();
1104        let parsed: Node = serde_json::from_str(&json).unwrap();
1105        assert_eq!(node, parsed);
1106    }
1107
1108    #[test]
1109    fn test_match_roundtrip() {
1110        let node = Node::Match {
1111            id: NodeId::new("n_1"),
1112            node_type: Type::String,
1113            scrutinee: Box::new(Node::Literal {
1114                id: NodeId::new("n_2"),
1115                node_type: Type::I64,
1116                value: LiteralValue::Integer(1),
1117            }),
1118            arms: vec![
1119                MatchArm {
1120                    pattern: Pattern::Literal {
1121                        value: LiteralValue::Integer(1),
1122                    },
1123                    body: Node::Literal {
1124                        id: NodeId::new("n_3"),
1125                        node_type: Type::String,
1126                        value: LiteralValue::Str("one".to_string()),
1127                    },
1128                },
1129                MatchArm {
1130                    pattern: Pattern::Wildcard,
1131                    body: Node::Literal {
1132                        id: NodeId::new("n_4"),
1133                        node_type: Type::String,
1134                        value: LiteralValue::Str("other".to_string()),
1135                    },
1136                },
1137            ],
1138        };
1139        let json = serde_json::to_string(&node).unwrap();
1140        let parsed: Node = serde_json::from_str(&json).unwrap();
1141        assert_eq!(node, parsed);
1142    }
1143
1144    #[test]
1145    fn test_struct_literal_roundtrip() {
1146        let node = Node::StructLiteral {
1147            id: NodeId::new("n_1"),
1148            node_type: Type::Named(TypeId::new("Point")),
1149            fields: vec![
1150                (
1151                    "x".to_string(),
1152                    Node::Literal {
1153                        id: NodeId::new("n_2"),
1154                        node_type: Type::F64,
1155                        value: LiteralValue::Float(1.0),
1156                    },
1157                ),
1158                (
1159                    "y".to_string(),
1160                    Node::Literal {
1161                        id: NodeId::new("n_3"),
1162                        node_type: Type::F64,
1163                        value: LiteralValue::Float(2.0),
1164                    },
1165                ),
1166            ],
1167        };
1168        let json = serde_json::to_string(&node).unwrap();
1169        let parsed: Node = serde_json::from_str(&json).unwrap();
1170        assert_eq!(node, parsed);
1171    }
1172
1173    #[test]
1174    fn test_field_access_roundtrip() {
1175        let node = Node::FieldAccess {
1176            id: NodeId::new("n_1"),
1177            node_type: Type::F64,
1178            object: Box::new(Node::Param {
1179                id: NodeId::new("n_2"),
1180                name: "point".to_string(),
1181                index: 0,
1182                node_type: Type::Named(TypeId::new("Point")),
1183            }),
1184            field: "x".to_string(),
1185        };
1186        let json = serde_json::to_string(&node).unwrap();
1187        let parsed: Node = serde_json::from_str(&json).unwrap();
1188        assert_eq!(node, parsed);
1189    }
1190
1191    #[test]
1192    fn test_array_literal_roundtrip() {
1193        let node = Node::ArrayLiteral {
1194            id: NodeId::new("n_1"),
1195            node_type: Type::Array {
1196                element: Box::new(Type::I64),
1197            },
1198            elements: vec![
1199                Node::Literal {
1200                    id: NodeId::new("n_2"),
1201                    node_type: Type::I64,
1202                    value: LiteralValue::Integer(1),
1203                },
1204                Node::Literal {
1205                    id: NodeId::new("n_3"),
1206                    node_type: Type::I64,
1207                    value: LiteralValue::Integer(2),
1208                },
1209            ],
1210        };
1211        let json = serde_json::to_string(&node).unwrap();
1212        let parsed: Node = serde_json::from_str(&json).unwrap();
1213        assert_eq!(node, parsed);
1214    }
1215
1216    #[test]
1217    fn test_index_access_roundtrip() {
1218        let node = Node::IndexAccess {
1219            id: NodeId::new("n_1"),
1220            node_type: Type::I64,
1221            array: Box::new(Node::Param {
1222                id: NodeId::new("n_2"),
1223                name: "arr".to_string(),
1224                index: 0,
1225                node_type: Type::Array {
1226                    element: Box::new(Type::I64),
1227                },
1228            }),
1229            index: Box::new(Node::Literal {
1230                id: NodeId::new("n_3"),
1231                node_type: Type::I64,
1232                value: LiteralValue::Integer(0),
1233            }),
1234        };
1235        let json = serde_json::to_string(&node).unwrap();
1236        let parsed: Node = serde_json::from_str(&json).unwrap();
1237        assert_eq!(node, parsed);
1238    }
1239
1240    #[test]
1241    fn test_nested_let_if_binop_roundtrip() {
1242        // Let x = 10 in (if x > 5 then x + 1 else x - 1)
1243        let node = Node::Let {
1244            id: NodeId::new("n_1"),
1245            name: "x".to_string(),
1246            node_type: Type::I64,
1247            value: Box::new(Node::Literal {
1248                id: NodeId::new("n_2"),
1249                node_type: Type::I64,
1250                value: LiteralValue::Integer(10),
1251            }),
1252            body: Box::new(Node::If {
1253                id: NodeId::new("n_3"),
1254                node_type: Type::I64,
1255                cond: Box::new(Node::BinOp {
1256                    id: NodeId::new("n_4"),
1257                    op: BinOpKind::Gt,
1258                    node_type: Type::Bool,
1259                    lhs: Box::new(Node::Param {
1260                        id: NodeId::new("n_5"),
1261                        name: "x".to_string(),
1262                        index: 0,
1263                        node_type: Type::I64,
1264                    }),
1265                    rhs: Box::new(Node::Literal {
1266                        id: NodeId::new("n_6"),
1267                        node_type: Type::I64,
1268                        value: LiteralValue::Integer(5),
1269                    }),
1270                }),
1271                then_branch: Box::new(Node::BinOp {
1272                    id: NodeId::new("n_7"),
1273                    op: BinOpKind::Add,
1274                    node_type: Type::I64,
1275                    lhs: Box::new(Node::Param {
1276                        id: NodeId::new("n_8"),
1277                        name: "x".to_string(),
1278                        index: 0,
1279                        node_type: Type::I64,
1280                    }),
1281                    rhs: Box::new(Node::Literal {
1282                        id: NodeId::new("n_9"),
1283                        node_type: Type::I64,
1284                        value: LiteralValue::Integer(1),
1285                    }),
1286                }),
1287                else_branch: Box::new(Node::BinOp {
1288                    id: NodeId::new("n_10"),
1289                    op: BinOpKind::Sub,
1290                    node_type: Type::I64,
1291                    lhs: Box::new(Node::Param {
1292                        id: NodeId::new("n_11"),
1293                        name: "x".to_string(),
1294                        index: 0,
1295                        node_type: Type::I64,
1296                    }),
1297                    rhs: Box::new(Node::Literal {
1298                        id: NodeId::new("n_12"),
1299                        node_type: Type::I64,
1300                        value: LiteralValue::Integer(1),
1301                    }),
1302                }),
1303            }),
1304        };
1305        let json = serde_json::to_string(&node).unwrap();
1306        let parsed: Node = serde_json::from_str(&json).unwrap();
1307        assert_eq!(node, parsed);
1308    }
1309
1310    #[test]
1311    fn test_error_node_roundtrip() {
1312        let node = Node::Error {
1313            id: NodeId::new("n_1"),
1314            message: "something went wrong".to_string(),
1315        };
1316        let json = serde_json::to_string(&node).unwrap();
1317        let parsed: Node = serde_json::from_str(&json).unwrap();
1318        assert_eq!(node, parsed);
1319    }
1320
1321    #[test]
1322    fn test_deserialize_from_spec_json() {
1323        let json = r#"{
1324            "id": "n_100",
1325            "kind": "Call",
1326            "type": "Unit",
1327            "target": "std::io::println",
1328            "args": [
1329                {
1330                    "id": "n_101",
1331                    "kind": "Literal",
1332                    "type": "String",
1333                    "value": "hello world"
1334                }
1335            ]
1336        }"#;
1337        let node: Node = serde_json::from_str(json).unwrap();
1338        match &node {
1339            Node::Call {
1340                id,
1341                target,
1342                args,
1343                node_type,
1344            } => {
1345                assert_eq!(id.as_str(), "n_100");
1346                assert_eq!(target, "std::io::println");
1347                assert_eq!(*node_type, Type::Unit);
1348                assert_eq!(args.len(), 1);
1349            }
1350            _ => panic!("Expected Call node"),
1351        }
1352    }
1353}