Skip to main content

kaish_kernel/ast/
sexpr.rs

1//! S-expression formatter for kaish AST.
2//!
3//! Converts AST nodes to the S-expression format used in test snapshots.
4//! S-expressions provide a stable, readable format that's easier to diff
5//! than Debug output.
6
7use super::*;
8
9/// Format a Program as an S-expression.
10/// For single-statement programs, formats just the statement.
11/// For multi-statement programs, formats as a sequence.
12pub fn format_program(program: &Program) -> String {
13    let stmts: Vec<_> = program
14        .statements
15        .iter()
16        .filter(|s| !matches!(s, Stmt::Empty))
17        .collect();
18
19    match stmts.len() {
20        0 => "(program)".to_string(),
21        1 => format_stmt(stmts[0]),
22        _ => {
23            let parts: Vec<String> = stmts.iter().map(|s| format_stmt(s)).collect();
24            format!("(program {})", parts.join(" "))
25        }
26    }
27}
28
29/// Format a statement as an S-expression.
30pub fn format_stmt(stmt: &Stmt) -> String {
31    match stmt {
32        Stmt::Assignment(a) => format_assignment(a),
33        Stmt::Command(cmd) => format_command(cmd),
34        Stmt::Pipeline(p) => format_pipeline(p),
35        Stmt::If(if_stmt) => format_if(if_stmt),
36        Stmt::For(for_loop) => format_for(for_loop),
37        Stmt::While(while_loop) => format_while(while_loop),
38        Stmt::Case(case_stmt) => format_case(case_stmt),
39        Stmt::Break(n) => match n {
40            Some(level) => format!("(break {})", level),
41            None => "(break)".to_string(),
42        },
43        Stmt::Continue(n) => match n {
44            Some(level) => format!("(continue {})", level),
45            None => "(continue)".to_string(),
46        },
47        Stmt::Return(expr) => match expr {
48            Some(e) => format!("(return {})", format_expr(e)),
49            None => "(return)".to_string(),
50        },
51        Stmt::Exit(expr) => match expr {
52            Some(e) => format!("(exit {})", format_expr(e)),
53            None => "(exit)".to_string(),
54        },
55        Stmt::ToolDef(tool) => format_tooldef(tool),
56        Stmt::Test(test_expr) => format!("(test {})", format_test_expr(test_expr)),
57        Stmt::AndChain { left, right } => {
58            format!("(and-chain {} {})", format_stmt(left), format_stmt(right))
59        }
60        Stmt::OrChain { left, right } => {
61            format!("(or-chain {} {})", format_stmt(left), format_stmt(right))
62        }
63        Stmt::Empty => "(empty)".to_string(),
64    }
65}
66
67/// Format an assignment as an S-expression.
68fn format_assignment(a: &Assignment) -> String {
69    let value = format_expr(&a.value);
70    format!("(assign {} {} local={})", a.name, value, a.local)
71}
72
73/// Format a command as an S-expression.
74pub fn format_command(cmd: &Command) -> String {
75    let mut parts = vec![format!("(cmd {}", cmd.name)];
76
77    for arg in &cmd.args {
78        parts.push(format_arg(arg));
79    }
80
81    for redir in &cmd.redirects {
82        parts.push(format_redirect(redir));
83    }
84
85    format!("{})", parts.join(" "))
86}
87
88/// Format an argument as an S-expression.
89fn format_arg(arg: &Arg) -> String {
90    match arg {
91        Arg::Positional(expr) => format!("(pos {})", format_expr(expr)),
92        Arg::Named { key, value } => format!("(named {} {})", key, format_expr(value)),
93        Arg::ShortFlag(f) => format!("(shortflag {})", f),
94        Arg::LongFlag(f) => format!("(longflag {})", f),
95        Arg::DoubleDash => "(doubledash)".to_string(),
96    }
97}
98
99/// Format a redirect as an S-expression.
100fn format_redirect(redir: &Redirect) -> String {
101    let kind = match redir.kind {
102        RedirectKind::StdoutOverwrite => ">",
103        RedirectKind::StdoutAppend => ">>",
104        RedirectKind::Stdin => "<",
105        RedirectKind::HereDoc => "<<",
106        RedirectKind::HereString => "<<<",
107        RedirectKind::Stderr => "2>",
108        RedirectKind::Both => "&>",
109        RedirectKind::MergeStderr => "2>&1",
110        RedirectKind::MergeStdout => "1>&2",
111    };
112    format!("(redir {} {})", kind, format_expr(&redir.target))
113}
114
115/// Format a pipeline as an S-expression.
116pub fn format_pipeline(p: &Pipeline) -> String {
117    let cmds: Vec<String> = p.commands.iter().map(format_command).collect();
118
119    if p.background {
120        if cmds.len() == 1 {
121            format!("(background {})", cmds[0])
122        } else {
123            format!("(background (pipeline {}))", cmds.join(" "))
124        }
125    } else {
126        format!("(pipeline {})", cmds.join(" "))
127    }
128}
129
130/// Format an if statement as an S-expression.
131fn format_if(if_stmt: &IfStmt) -> String {
132    let cond = format_expr(&if_stmt.condition);
133    let then_stmts: Vec<String> = if_stmt
134        .then_branch
135        .iter()
136        .filter(|s| !matches!(s, Stmt::Empty))
137        .map(format_stmt)
138        .collect();
139    let then_part = format!("(then {})", then_stmts.join(" "));
140
141    match &if_stmt.else_branch {
142        Some(else_stmts) => {
143            let else_inner: Vec<String> = else_stmts
144                .iter()
145                .filter(|s| !matches!(s, Stmt::Empty))
146                .map(format_stmt)
147                .collect();
148            if else_inner.is_empty() {
149                format!("(if {} {} (else))", cond, then_part)
150            } else {
151                format!("(if {} {} (else {}))", cond, then_part, else_inner.join(" "))
152            }
153        }
154        None => format!("(if {} {} (else))", cond, then_part),
155    }
156}
157
158/// Format a for loop as an S-expression.
159fn format_for(for_loop: &ForLoop) -> String {
160    let items: Vec<String> = for_loop.items.iter().map(format_expr).collect();
161    let body_stmts: Vec<String> = for_loop
162        .body
163        .iter()
164        .filter(|s| !matches!(s, Stmt::Empty))
165        .map(format_stmt)
166        .collect();
167    format!(
168        "(for {} (in {}) (do {}))",
169        for_loop.variable,
170        items.join(" "),
171        body_stmts.join(" ")
172    )
173}
174
175/// Format a while loop as an S-expression.
176fn format_while(while_loop: &WhileLoop) -> String {
177    let cond = format_expr(&while_loop.condition);
178    let body_stmts: Vec<String> = while_loop
179        .body
180        .iter()
181        .filter(|s| !matches!(s, Stmt::Empty))
182        .map(format_stmt)
183        .collect();
184    format!("(while {} (do {}))", cond, body_stmts.join(" "))
185}
186
187/// Format a case statement as an S-expression.
188fn format_case(case_stmt: &CaseStmt) -> String {
189    let expr = format_expr(&case_stmt.expr);
190    let branches: Vec<String> = case_stmt
191        .branches
192        .iter()
193        .map(format_case_branch)
194        .collect();
195    format!("(case {} ({}))", expr, branches.join(" "))
196}
197
198/// Format a case branch as an S-expression.
199fn format_case_branch(branch: &CaseBranch) -> String {
200    let patterns = branch.patterns.join("|");
201    let body_stmts: Vec<String> = branch
202        .body
203        .iter()
204        .filter(|s| !matches!(s, Stmt::Empty))
205        .map(format_stmt)
206        .collect();
207    format!("(branch \"{}\" ({}))", patterns, body_stmts.join(" "))
208}
209
210/// Format a tool definition as an S-expression.
211fn format_tooldef(tool: &ToolDef) -> String {
212    let params: Vec<String> = tool.params.iter().map(format_param).collect();
213    let body_stmts: Vec<String> = tool
214        .body
215        .iter()
216        .filter(|s| !matches!(s, Stmt::Empty))
217        .map(format_stmt)
218        .collect();
219    format!(
220        "(tooldef {} ({}) ({}))",
221        tool.name,
222        params.join(" "),
223        body_stmts.join(" ")
224    )
225}
226
227/// Format a parameter definition as an S-expression.
228fn format_param(param: &ParamDef) -> String {
229    let type_str = param
230        .param_type
231        .as_ref()
232        .map(|t| match t {
233            ParamType::String => "string",
234            ParamType::Int => "int",
235            ParamType::Float => "float",
236            ParamType::Bool => "bool",
237        })
238        .unwrap_or("any");
239
240    match &param.default {
241        Some(default) => format!("(param {} {} {})", param.name, type_str, format_expr(default)),
242        None => format!("(param {} {})", param.name, type_str),
243    }
244}
245
246/// Format an expression as an S-expression.
247pub fn format_expr(expr: &Expr) -> String {
248    match expr {
249        Expr::Literal(value) => format_value(value),
250        Expr::VarRef(path) => format!("(varref {})", format_varpath(path)),
251        Expr::Interpolated(parts) => {
252            let parts_str: Vec<String> = parts
253                .iter()
254                .map(format_string_part)
255                .collect();
256            format!("(interpolated {})", parts_str.join(" "))
257        }
258        Expr::HereDocBody { parts, strip_tabs } => {
259            let parts_str: Vec<String> = parts
260                .iter()
261                .map(|sp| format_string_part(&sp.part))
262                .collect();
263            format!(
264                "(heredoc-body strip-tabs={} {})",
265                strip_tabs,
266                parts_str.join(" ")
267            )
268        }
269        Expr::BinaryOp { left, op, right } => {
270            let op_str = match op {
271                BinaryOp::And => "and",
272                BinaryOp::Or => "or",
273            };
274            format!("({} {} {})", op_str, format_expr(left), format_expr(right))
275        }
276        Expr::CommandSubst(pipeline) => {
277            format!("(cmdsubst {})", format_pipeline(pipeline))
278        }
279        Expr::Test(test_expr) => format!("(test {})", format_test_expr(test_expr)),
280        Expr::Positional(n) => format!("(positional {})", n),
281        Expr::AllArgs => "(all-args)".to_string(),
282        Expr::ArgCount => "(arg-count)".to_string(),
283        Expr::VarLength(name) => format!("(var-length {})", name),
284        Expr::VarWithDefault { name, default } => {
285            let default_parts: Vec<String> = default.iter().map(format_string_part).collect();
286            format!("(var-default {} ({}))", name, default_parts.join(" "))
287        }
288        Expr::Arithmetic(expr_str) => format!("(arithmetic \"{}\")", expr_str),
289        Expr::Command(cmd) => format_command(cmd),
290        Expr::LastExitCode => "(last-exit-code)".to_string(),
291        Expr::CurrentPid => "(current-pid)".to_string(),
292        Expr::GlobPattern(s) => format!("(glob \"{}\")", s),
293    }
294}
295
296/// Format a test expression as an S-expression.
297pub fn format_test_expr(test: &TestExpr) -> String {
298    match test {
299        TestExpr::FileTest { op, path } => {
300            let op_str = match op {
301                FileTestOp::Exists => "-e",
302                FileTestOp::IsFile => "-f",
303                FileTestOp::IsDir => "-d",
304                FileTestOp::Readable => "-r",
305                FileTestOp::Writable => "-w",
306                FileTestOp::Executable => "-x",
307            };
308            format!("(file {} {})", op_str, format_expr(path))
309        }
310        TestExpr::StringTest { op, value } => {
311            let op_str = match op {
312                StringTestOp::IsEmpty => "-z",
313                StringTestOp::IsNonEmpty => "-n",
314            };
315            format!("(string {} {})", op_str, format_expr(value))
316        }
317        TestExpr::Comparison { left, op, right } => {
318            let op_str = match op {
319                TestCmpOp::Eq => "==",
320                TestCmpOp::NotEq => "!=",
321                TestCmpOp::Match => "=~",
322                TestCmpOp::NotMatch => "!~",
323                TestCmpOp::Gt => ">",
324                TestCmpOp::Lt => "<",
325                TestCmpOp::GtEq => ">=",
326                TestCmpOp::LtEq => "<=",
327                TestCmpOp::NumEq => "-eq",
328                TestCmpOp::NumNotEq => "-ne",
329                TestCmpOp::NumGt => "-gt",
330                TestCmpOp::NumLt => "-lt",
331                TestCmpOp::NumGtEq => "-ge",
332                TestCmpOp::NumLtEq => "-le",
333            };
334            format!(
335                "(cmp {} {} {})",
336                op_str,
337                format_expr(left),
338                format_expr(right)
339            )
340        }
341        TestExpr::And { left, right } => {
342            format!("(and {} {})", format_test_expr(left), format_test_expr(right))
343        }
344        TestExpr::Or { left, right } => {
345            format!("(or {} {})", format_test_expr(left), format_test_expr(right))
346        }
347        TestExpr::Not { expr } => {
348            format!("(not {})", format_test_expr(expr))
349        }
350    }
351}
352
353/// Format a StringPart as an S-expression.
354fn format_string_part(part: &StringPart) -> String {
355    match part {
356        StringPart::Literal(s) => format!("\"{}\"", escape_for_display(s)),
357        StringPart::Var(path) => format!("(varref {})", format_varpath(path)),
358        StringPart::VarWithDefault { name, default } => {
359            let default_parts: Vec<String> = default.iter().map(format_string_part).collect();
360            format!("(vardefault {} ({}))", name, default_parts.join(" "))
361        }
362        StringPart::VarLength(name) => format!("(varlength {})", name),
363        StringPart::Positional(n) => format!("(positional {})", n),
364        StringPart::AllArgs => "(allargs)".to_string(),
365        StringPart::ArgCount => "(argcount)".to_string(),
366        StringPart::Arithmetic(expr) => format!("(arith \"{}\")", expr),
367        StringPart::CommandSubst(pipeline) => format!("(cmdsubst {})", format_pipeline(pipeline)),
368        StringPart::LastExitCode => "(last-exit-code)".to_string(),
369        StringPart::CurrentPid => "(current-pid)".to_string(),
370    }
371}
372
373/// Escape control characters for display in test output.
374fn escape_for_display(s: &str) -> String {
375    s.replace('\n', "\\n")
376        .replace('\t', "\\t")
377        .replace('\r', "\\r")
378}
379
380/// Format a value as an S-expression.
381pub fn format_value(value: &Value) -> String {
382    match value {
383        Value::Null => "(null)".to_string(),
384        Value::Bool(b) => format!("(bool {})", b),
385        Value::Int(n) => format!("(int {})", n),
386        Value::Float(f) => format!("(float {})", f),
387        Value::String(s) => format!("(string \"{}\")", escape_for_display(s)),
388        Value::Json(json) => format!("(json {})", json),
389        Value::Blob(blob) => format!("(blob id={} size={} type={})", blob.id, blob.size, blob.content_type),
390    }
391}
392
393/// Format a variable path as an S-expression.
394pub fn format_varpath(path: &VarPath) -> String {
395    path.segments
396        .iter()
397        .map(|seg| match seg {
398            VarSegment::Field(name) => name.clone(),
399        })
400        .collect::<Vec<_>>()
401        .join(".")
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn format_simple_int() {
410        assert_eq!(format_value(&Value::Int(42)), "(int 42)");
411    }
412
413    #[test]
414    fn format_simple_string() {
415        assert_eq!(format_value(&Value::String("hello".to_string())), "(string \"hello\")");
416    }
417
418    #[test]
419    fn format_varpath_simple() {
420        let path = VarPath::simple("X");
421        assert_eq!(format_varpath(&path), "X");
422    }
423
424    #[test]
425    fn format_varpath_nested() {
426        let path = VarPath {
427            segments: vec![
428                VarSegment::Field("VAR".to_string()),
429                VarSegment::Field("field".to_string()),
430            ],
431        };
432        assert_eq!(format_varpath(&path), "VAR.field");
433    }
434}