c_ast/
ast.rs

1use crate::indented_text::IndentedText as I;
2use std::rc::Rc;
3
4/// Trait for Ast that is (probably) a single line of code.
5trait SingleLineCode<Param = ()> {
6    fn to_code_line(&self, param: Param) -> String;
7}
8
9/// Trait for turning Ast into code.
10trait MultiLineCode<Param = ()> {
11    fn to_code_i(&self, param: Param) -> I;
12
13    fn to_code(&self, param: Param) -> String {
14        self.to_code_i(param).to_string()
15    }
16}
17
18/// Auto implement
19/// - SingleLineCode --> MultilineCode
20/// - MultilineCode --> Display*
21/// * Must be implemented manually because of rust's orphan rules. See the
22///   call to `impl_display_for_ast!`.
23mod auto_implement;
24auto_implement::impl_display_for_ast![
25    Program,
26    Include,
27    TopLevelDeclaration,
28    Function,
29    Statement,
30    Expr,
31    TypeExpr,
32];
33
34/// Represents a c program.
35///
36/// # Example
37/// ```
38/// use c_ast::*;
39///
40/// let p = Program {
41///     includes: vec![Include::Arrow("stdio.h".to_string())],
42///     declarations: vec![
43///         TopLevelDeclaration::Function(Function {
44///             return_type: tvar("int"),
45///             name: "main".to_string(),
46///             parameters: vec![],
47///             body: Some(vec![
48///                 return_(0.literal()),
49///             ]),
50///         }),
51///     ],
52/// };
53/// let code = r#"
54/// #include <stdio.h>
55///
56/// int main();
57///
58/// int main() {
59///     return 0;
60/// }
61/// "#;
62///
63/// assert_eq!(p.to_string().trim(), code.trim());
64/// ```
65#[derive(Debug, Clone, PartialEq, Hash)]
66pub struct Program {
67    pub includes: Vec<Include>,
68    // TODO: Seperate into functions and structs - and put structs first.
69    pub declarations: Vec<TopLevelDeclaration>,
70}
71
72/// A c include.
73///
74/// # Example
75/// ```
76/// use c_ast::Include;
77///
78/// assert_eq!(Include::Arrow("stdio.h".into()).to_string().trim(), "#include <stdio.h>");
79/// assert_eq!(Include::Quote("stdio.h".into()).to_string().trim(), "#include \"stdio.h\"");
80/// ```
81#[derive(Debug, Clone, PartialEq, Eq, Hash)]
82pub enum Include {
83    /// #include <...>
84    Arrow(String),
85    /// #include "..."
86    Quote(String),
87}
88
89impl SingleLineCode for Include {
90    fn to_code_line(&self, _: ()) -> String {
91        match self {
92            Include::Arrow(s) => format!("#include <{}>", s),
93            Include::Quote(s) => format!("#include \"{}\"", s),
94        }
95    }
96}
97
98/// A top level declaration. Can only be directly inside a program's
99/// declarations field.
100///
101/// # Example
102/// ```
103/// use c_ast::TopLevelDeclaration;
104///
105///
106#[derive(Debug, Clone, PartialEq, Hash)]
107pub enum TopLevelDeclaration {
108    Function(Function),
109    Struct(String, Option<Vec<(PTypeExpr, String)>>),
110    Var(PTypeExpr, String, Option<Expr>),
111}
112
113/// A c function (potentially with a body).
114///
115/// # Example
116/// ```
117/// use c_ast::*;
118///
119/// // float f(int x) {
120/// //     return x;
121/// // }
122/// let f = Function {
123///     return_type: tvar("float"),
124///     name: "main".to_string(),
125///     parameters: vec![],
126///     body: Some(vec![
127///         return_(0_i32.literal())
128///     ]),
129/// };
130#[derive(Debug, Clone, PartialEq, Hash)]
131pub struct Function {
132    pub return_type: PTypeExpr,
133    pub name: String,
134    pub parameters: Vec<(PTypeExpr, String)>,
135    pub body: Option<Block>,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Hash)]
139pub enum TypeExpr {
140    Var(String),
141    StructVar(String),
142    Ptr(PTypeExpr),
143    Struct(Vec<(PTypeExpr, String)>),
144    FunctionPtr(PTypeExpr, Vec<PTypeExpr>),
145    Array(PTypeExpr),
146}
147
148pub type PTypeExpr = Rc<TypeExpr>;
149
150pub type Block = Vec<Statement>;
151
152#[derive(Debug, Clone, PartialEq, Hash)]
153pub enum Statement {
154    Expr(Expr),
155    Assign(Expr, Expr),
156    Return(Expr),
157    If(Expr, Block, Option<Block>),
158    While(Expr, Block),
159    For(Expr, Expr, Expr, Block),
160    Declaration {
161        type_expression: PTypeExpr,
162        name: String,
163        initializer: Option<Expr>,
164    },
165}
166
167#[derive(Debug, Clone, PartialEq, Hash)]
168pub enum Expr {
169    Var(String),
170    Call(Box<Expr>, Vec<Expr>),
171    Index(Box<Expr>, Box<Expr>),
172    Int(i32),
173    Str(String),
174    Char(char),
175    Unary(UnaryOp, Box<Expr>),
176    Binary(BinaryOp, Box<Expr>, Box<Expr>),
177    Cast(PTypeExpr, Box<Expr>),
178    SizeOf(PTypeExpr),
179    Arrow(Box<Expr>, String),
180    Dot(Box<Expr>, String),
181    Inc(Box<Expr>),
182    Dec(Box<Expr>),
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
186pub enum UnaryOp {
187    Neg,
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum BinaryOp {
192    Add,
193    Sub,
194    Mul,
195    Div,
196    Eq,
197    Neq,
198    Lt,
199    Le,
200    Gt,
201    Ge,
202    And,
203    Or,
204}
205
206fn seperate_with_newlines(i: impl IntoIterator<Item = I>) -> impl Iterator<Item = I> {
207    i.into_iter().flat_map(|i| [i, I::line("")])
208}
209
210impl MultiLineCode for Program {
211    fn to_code_i(&self, (): ()) -> I {
212        let mut structs = (vec![], vec![]);
213        let mut functions = (vec![], vec![]);
214
215        for decl in &self.declarations {
216            let (start_vec, end_vec) = match decl {
217                TopLevelDeclaration::Struct(..) => &mut structs,
218                _ => &mut functions,
219            };
220            let (start, end) = (
221                decl.to_code_i(PrototypeOrImplementation::Prototype),
222                decl.to_code_i(PrototypeOrImplementation::Implementation),
223            );
224            start_vec.push(start);
225            end_vec.push(end);
226        }
227
228        I::many([
229            I::lines(self.includes.iter().map(|include| match include {
230                Include::Arrow(path) => format!("#include <{path}>"),
231                Include::Quote(path) => format!("#include \"{path}\""),
232            })),
233            I::line(""),
234            I::many(seperate_with_newlines(structs.0)),
235            I::many(seperate_with_newlines(functions.0)),
236            I::many(seperate_with_newlines(structs.1)),
237            I::many(seperate_with_newlines(functions.1)),
238        ])
239    }
240}
241
242#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
243pub enum PrototypeOrImplementation {
244    Prototype,
245    #[default]
246    Implementation,
247}
248
249impl MultiLineCode<PrototypeOrImplementation> for TopLevelDeclaration {
250    fn to_code_i(&self, p_or_i: PrototypeOrImplementation) -> I {
251        use PrototypeOrImplementation::*;
252        use TopLevelDeclaration::*;
253        match self {
254            Function(function) => function.to_code_i(p_or_i),
255            Struct(name, fields) => match p_or_i {
256                Prototype => I::line(format!("typedef struct {} {};", name, name)),
257                Implementation => {
258                    if let Some(fields) = fields {
259                        I::line(struct_code(Some(name), fields) + ";")
260                    } else {
261                        I::empty()
262                    }
263                }
264            },
265            Var(typ, name, expr) => match (p_or_i, expr) {
266                (Implementation, None) => I::empty(),
267                (Implementation, Some(expr)) => {
268                    I::line(typ.to_code_line(Some(name)) + " = " + &expr.to_code_line(()) + ";")
269                }
270                (Prototype, _) => I::line(typ.to_code_line(Some(name)) + ";"),
271            },
272        }
273    }
274}
275
276impl MultiLineCode<PrototypeOrImplementation> for Function {
277    fn to_code_i(&self, p_or_i: PrototypeOrImplementation) -> I {
278        let body = self
279            .body
280            .as_ref()
281            .filter(|_| p_or_i == PrototypeOrImplementation::Implementation);
282
283        I::line(format!(
284            "{} {}({})",
285            self.return_type.to_code_line(None),
286            &self.name,
287            &self
288                .parameters
289                .iter()
290                .map(|(typ, name)| typ.to_code_line(Some(name)))
291                .collect::<Vec<_>>()
292                .join(", "),
293        ))
294        .with_last_line(if body.is_some() { " {" } else { ";" })
295        .then(if let &Some(body) = &body {
296            I::many([
297                I::indent_many(body.iter().map(|stmt| stmt.to_code_i(()))),
298                I::line("}"),
299            ])
300        } else {
301            I::empty()
302        })
303    }
304}
305
306fn block_code(x: &Block) -> I {
307    I::many_vec(x.iter().map(|s| s.to_code_i(())).collect())
308}
309
310impl MultiLineCode for Statement {
311    fn to_code_i(&self, (): ()) -> I {
312        match self {
313            Statement::Expr(expr) => I::line(expr.to_code_line(()) + ";"),
314            Statement::Return(expr) => I::line(format!("return {};", expr.to_code_line(()))),
315            Statement::If(cond, then, else_) => I::many([
316                I::line(format!("if ({}) {{", cond.to_code_line(()))),
317                I::indent(block_code(then)),
318                if let Some(else_) = else_ {
319                    I::many([I::line("} else {"), I::indent(block_code(else_))])
320                } else {
321                    I::Empty
322                },
323                I::line("}"),
324            ]),
325            Statement::While(cond, body) => I::many([
326                I::line(format!("while ({}) {{", cond.to_code_line(()))),
327                I::indent(block_code(body)),
328                I::line("}"),
329            ]),
330            Statement::For(start, cond, inc, body) => I::many([
331                I::line(format!(
332                    "for ({}; {}; {}) {{",
333                    start.to_code_line(()),
334                    cond.to_code_line(()),
335                    inc.to_code_line(()),
336                )),
337                I::indent(block_code(body)),
338                I::line("}"),
339            ]),
340            Statement::Declaration {
341                type_expression,
342                name,
343                initializer,
344            } => I::line(assignment_declaration_code(type_expression, name, initializer) + ";"),
345            Statement::Assign(left, right) => I::line(format!(
346                "{} = {};",
347                left.to_code_line(()),
348                right.to_code_line(())
349            )),
350        }
351    }
352}
353
354impl SingleLineCode for Expr {
355    fn to_code_line(&self, (): ()) -> String {
356        match self {
357            Expr::Var(name) => name.into(),
358            Expr::Int(value) => value.to_string(),
359            Expr::Str(string) => format!("\"{string}\""),
360            Expr::Char(ch) => format!("'{}'", ch),
361            Expr::Call(function, arguments) => {
362                let mut code = String::new();
363                code.push_str(function.to_code_line(()).as_str());
364                code.push('(');
365                code.push_str(
366                    arguments
367                        .iter()
368                        .map(|argument| argument.to_code_line(()))
369                        .collect::<Vec<String>>()
370                        .join(", ")
371                        .as_str(),
372                );
373                code.push(')');
374                code
375            }
376            Expr::Unary(op, expr) => {
377                let mut code = String::new();
378                code.push('(');
379                code.push(match op {
380                    UnaryOp::Neg => '-',
381                });
382                code.push_str(expr.to_code_line(()).as_str());
383                code.push(')');
384                code
385            }
386            Expr::Binary(op, lhs, rhs) => {
387                let mut code = String::new();
388                code.push('(');
389                code.push_str(lhs.to_code_line(()).as_str());
390                code.push_str(match op {
391                    BinaryOp::Add => " + ",
392                    BinaryOp::Sub => " - ",
393                    BinaryOp::Mul => " * ",
394                    BinaryOp::Div => " / ",
395                    BinaryOp::Eq => " == ",
396                    BinaryOp::Neq => " != ",
397                    BinaryOp::Lt => " < ",
398                    BinaryOp::Le => " <= ",
399                    BinaryOp::Gt => " > ",
400                    BinaryOp::Ge => " >= ",
401                    BinaryOp::And => " && ",
402                    BinaryOp::Or => " || ",
403                });
404                code.push_str(rhs.to_code_line(()).as_str());
405                code.push(')');
406                code
407            }
408            Expr::Cast(type_expr, expr) => {
409                let mut code = String::new();
410                code.push_str("((");
411                code.push_str(type_expr.to_code_line(None).as_str());
412                code.push(')');
413                code.push_str(expr.to_code_line(()).as_str());
414                code.push(')');
415                code
416            }
417            Expr::Arrow(e, name) => e.to_code_line(()) + "->" + name,
418            Expr::Dot(e, name) => e.to_code_line(()) + "." + name,
419            Expr::SizeOf(type_expr) => {
420                format!("sizeof({})", type_expr.to_code_line(None))
421            }
422            Expr::Inc(e) => e.to_code_line(()) + "++",
423            Expr::Dec(e) => e.to_code_line(()) + "--",
424            Expr::Index(e, index) => {
425                format!("{}[{}]", e.to_code_line(()), index.to_code_line(()))
426            }
427        }
428    }
429}
430
431fn struct_code(struct_type_name: Option<&str>, fields: &Vec<(PTypeExpr, String)>) -> String {
432    // Write to a buffer piece by piece.
433    let mut buf = String::new();
434
435    buf += "struct ";
436    if let Some(name) = struct_type_name {
437        buf += name;
438        buf += " ";
439    }
440    buf += "{ ";
441
442    for (field_typ, field_name) in fields {
443        buf += &field_typ.to_code(Some(field_name));
444        buf += "; ";
445    }
446
447    buf += "}";
448    buf
449}
450
451impl MultiLineCode<()> for Block {
452    fn to_code_i(&self, (): ()) -> I {
453        I::many(self.iter().map(|s| s.to_code_i(())))
454    }
455}
456
457impl SingleLineCode<Option<&str>> for TypeExpr {
458    /// Convert a type expression to a string. Does not take names into account.
459    fn to_code_line(&self, name: Option<&str>) -> String {
460        use TypeExpr::*;
461        let add_name = |s: String| {
462            if let Some(name) = name {
463                s + " " + name
464            } else {
465                s
466            }
467        };
468
469        match self {
470            Var(type_name) => add_name(type_name.into()),
471            StructVar(name) => add_name("struct ".to_string() + name.as_ref()),
472            Struct(fields) => add_name(struct_code(None, fields)),
473            FunctionPtr(ret, params) => {
474                let mut buf = String::new();
475                buf += &ret.to_code_line(None);
476                buf += " (*";
477                buf += name.unwrap_or("");
478                buf += ")(";
479                for (i, param) in params.iter().enumerate() {
480                    if i != 0 {
481                        buf += ", ";
482                    }
483                    buf += &param.to_code_line(None);
484                }
485                buf += ")";
486                buf
487            }
488            Ptr(typ) => {
489                if let Some(name) = name {
490                    typ.to_code_line(None) + " " + name + "*"
491                } else {
492                    typ.to_code_line(None) + "*"
493                }
494            }
495            Array(typ) => {
496                if let Some(name) = name {
497                    typ.to_code_line(None) + " " + name + "[]"
498                } else {
499                    typ.to_code_line(None) + "[]"
500                }
501            }
502        }
503    }
504
505    /*
506    /// Convert a type expression to a string, of a value with a name for a
507    /// typedef.
508    /// The difference is that in a typedef, a struct needs to get the name
509    /// while in a variable declaration, it should not.
510    ///
511    /// Example:
512    /// ```c
513    /// typedef struct foo { int bar; } foo;
514    /// struct { int bar; } my_foo; // No name after struct keyword!
515    /// ```
516    pub fn to_code_with_name_typedef(&self, name: &str) -> String {
517        use TypeExpr::*;
518        match self {
519            Struct(fields) => struct_code(Some(name), fields) + " " + name,
520            _ => self.to_code_with_name(name),
521        }
522    }
523    */
524}
525
526fn assignment_declaration_code(
527    type_expression: &TypeExpr,
528    name: &str,
529    rhs: &Option<Expr>,
530) -> String {
531    if let Some(rhs) = rhs {
532        format!(
533            "{} = {}",
534            type_expression.to_code_line(Some(name)),
535            rhs.to_code_line(())
536        )
537    } else {
538        type_expression.to_code(Some(name))
539    }
540}
541
542/*
543macro_rules! some_try {
544    ($x: expr) => {{
545        match $x {
546            Ok(res) => res,
547            Err(err) => return Some(Err(err)),
548        }
549    }};
550}
551*/