horkos 0.2.0

Cloud infrastructure language where insecure code won't compile
Documentation
//! Abstract Syntax Tree definitions for Horkos.
//!
//! The AST uses a tree structure where parents own their children.
//! Every node is wrapped in `Spanned<T>` to track source locations.

mod span;

pub use span::{Span, Spanned};

use serde::{Deserialize, Serialize};

/// A complete Horkos program (one source file).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Program {
    pub statements: Vec<Spanned<Statement>>,
}

/// Top-level statements.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Statement {
    /// `import "path" as alias`
    Import(ImportStmt),
    /// `val name: Type = expr`
    ValDecl(ValDecl),
    /// `module name { ... }`
    Module(ModuleDecl),
    /// `assert(condition, "message")`
    Assert(AssertStmt),
    /// `hcl { ... }`
    HclBlock(HclBlock),
    /// `unsafe("reason") { ... }`
    Unsafe(UnsafeStmt),
}

/// Import statement: `import "legacy.tf" as legacy`
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImportStmt {
    /// The path being imported
    pub path: Spanned<String>,
    /// Optional alias: `as name`
    pub alias: Option<Spanned<String>>,
}

/// Value declaration: `val bucket: Bucket = S3.createBucket("data")`
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValDecl {
    /// The name being bound
    pub name: Spanned<String>,
    /// Optional type annotation
    pub type_ann: Option<Spanned<TypeExpr>>,
    /// The value expression
    pub value: Spanned<Expr>,
}

/// Module declaration for organizing code.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModuleDecl {
    /// Module name
    pub name: Spanned<String>,
    /// Module body
    pub body: Vec<Spanned<Statement>>,
}

/// Assert statement: `assert(condition, "error message")`
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssertStmt {
    /// The condition that must be true
    pub condition: Spanned<Expr>,
    /// The error message if assertion fails
    pub message: Spanned<String>,
}

/// Inline HCL block: `hcl("reason") { resource "aws_s3_bucket" "b" { ... } }`
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HclBlock {
    /// The reason for using raw HCL
    pub reason: Spanned<String>,
    /// The raw HCL content
    pub content: Spanned<String>,
}

/// Unsafe statement: `unsafe("reason") { ... }`
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UnsafeStmt {
    /// The reason for the unsafe block
    pub reason: Spanned<String>,
    /// The body of the unsafe block
    pub body: Vec<Spanned<Statement>>,
}

/// Type expressions.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TypeExpr {
    /// Simple named type: `Bucket`, `String`
    Named(String),
    /// Generic type: `Unverified<Bucket>`, `Map<String, Subnet>`
    Generic {
        name: String,
        args: Vec<Spanned<TypeExpr>>,
    },
}

/// Expressions.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Expr {
    /// Literal values
    Literal(Literal),

    /// Variable reference: `bucket`, `vpc`
    Identifier(String),

    /// Member access: `S3.createBucket`, `legacy.bucket`
    MemberAccess {
        object: Box<Spanned<Expr>>,
        field: Spanned<String>,
    },

    /// Function call: `S3.createBucket("name", publicAccess: true)`
    FuncCall {
        callee: Box<Spanned<Expr>>,
        args: Vec<Arg>,
    },

    /// Lambda expression: `x => x + 1` or `(a, b) => a + b`
    Lambda {
        params: Vec<Spanned<String>>,
        body: Box<Spanned<Expr>>,
    },

    /// Binary operation: `a + b`, `x == y`, `p && q`
    Binary {
        left: Box<Spanned<Expr>>,
        op: BinaryOp,
        right: Box<Spanned<Expr>>,
    },

    /// Unary operation: `!x`, `-n`
    Unary {
        op: UnaryOp,
        operand: Box<Spanned<Expr>>,
    },

    /// List literal: `["us-east-1a", "us-east-1b"]`
    List(Vec<Spanned<Expr>>),

    /// Record literal: `{ name: "web", port: 80 }`
    Record(Vec<RecordField>),

    /// Unsafe expression: `unsafe("reason") { expr }`
    /// Unwraps Unverified<T> → T with mandatory justification
    Unsafe {
        reason: Spanned<String>,
        body: Box<Spanned<Expr>>,
    },

    /// If expression: `if condition then valueA else valueB`
    /// or block form: `if condition { ... } else { ... }`
    If {
        condition: Box<Spanned<Expr>>,
        then_branch: Box<Spanned<Expr>>,
        else_branch: Box<Spanned<Expr>>,
    },
}

/// Binary operators.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BinaryOp {
    // Arithmetic
    Add, // +
    Sub, // -
    Mul, // *
    Div, // /

    // Comparison
    Eq,    // ==
    NotEq, // !=
    Lt,    // <
    LtEq,  // <=
    Gt,    // >
    GtEq,  // >=

    // Logical
    And, // &&
    Or,  // ||
}

impl BinaryOp {
    /// Precedence level (higher binds tighter)
    pub fn precedence(self) -> u8 {
        match self {
            BinaryOp::Or => 1,
            BinaryOp::And => 2,
            BinaryOp::Eq | BinaryOp::NotEq => 3,
            BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::Gt | BinaryOp::GtEq => 4,
            BinaryOp::Add | BinaryOp::Sub => 5,
            BinaryOp::Mul | BinaryOp::Div => 6,
        }
    }

    /// Display name for error messages
    pub fn name(self) -> &'static str {
        match self {
            BinaryOp::Add => "+",
            BinaryOp::Sub => "-",
            BinaryOp::Mul => "*",
            BinaryOp::Div => "/",
            BinaryOp::Eq => "==",
            BinaryOp::NotEq => "!=",
            BinaryOp::Lt => "<",
            BinaryOp::LtEq => "<=",
            BinaryOp::Gt => ">",
            BinaryOp::GtEq => ">=",
            BinaryOp::And => "&&",
            BinaryOp::Or => "||",
        }
    }
}

/// Unary operators.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum UnaryOp {
    Not, // !
    Neg, // -
}

/// Function argument (positional or named).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Arg {
    /// If Some, this is a named argument: `name: value`
    pub name: Option<Spanned<String>>,
    /// The argument value
    pub value: Spanned<Expr>,
}

/// Record field: `name: value`
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecordField {
    pub name: Spanned<String>,
    pub value: Spanned<Expr>,
}

/// Literal values.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Literal {
    String(String),
    Number(f64),
    Bool(bool),
}

// === Helper functions ===

/// Information about an unsafe block found during analysis.
#[derive(Debug, Clone)]
pub struct UnsafeBlockLocation {
    pub line: usize,
    pub column: usize,
    pub reason: String,
}

/// Find all unsafe blocks in a program.
pub fn find_unsafe_blocks(program: &Program) -> Vec<UnsafeBlockLocation> {
    let mut results = Vec::new();
    find_unsafe_in_statements(&program.statements, &mut results);
    results
}

fn find_unsafe_in_statements(
    statements: &[Spanned<Statement>],
    results: &mut Vec<UnsafeBlockLocation>,
) {
    for stmt in statements {
        match &stmt.node {
            Statement::ValDecl(decl) => {
                find_unsafe_in_expr(&decl.value, results);
            }
            Statement::Module(m) => {
                find_unsafe_in_statements(&m.body, results);
            }
            Statement::Assert(a) => {
                find_unsafe_in_expr(&a.condition, results);
            }
            Statement::Import(_) | Statement::HclBlock(_) => {}
            Statement::Unsafe(u) => {
                results.push(UnsafeBlockLocation {
                    line: stmt.span.start_line,
                    column: stmt.span.start_col,
                    reason: u.reason.node.clone(),
                });
                find_unsafe_in_statements(&u.body, results);
            }
        }
    }
}

fn find_unsafe_in_expr(expr: &Spanned<Expr>, results: &mut Vec<UnsafeBlockLocation>) {
    match &expr.node {
        Expr::Unsafe { reason, body } => {
            results.push(UnsafeBlockLocation {
                line: expr.span.start_line,
                column: expr.span.start_col,
                reason: reason.node.clone(),
            });
            find_unsafe_in_expr(body, results);
        }
        Expr::MemberAccess { object, .. } => {
            find_unsafe_in_expr(object, results);
        }
        Expr::FuncCall { callee, args } => {
            find_unsafe_in_expr(callee, results);
            for arg in args {
                find_unsafe_in_expr(&arg.value, results);
            }
        }
        Expr::Lambda { body, .. } => {
            find_unsafe_in_expr(body, results);
        }
        Expr::Binary { left, right, .. } => {
            find_unsafe_in_expr(left, results);
            find_unsafe_in_expr(right, results);
        }
        Expr::Unary { operand, .. } => {
            find_unsafe_in_expr(operand, results);
        }
        Expr::List(elements) => {
            for elem in elements {
                find_unsafe_in_expr(elem, results);
            }
        }
        Expr::Record(fields) => {
            for field in fields {
                find_unsafe_in_expr(&field.value, results);
            }
        }
        Expr::If {
            condition,
            then_branch,
            else_branch,
        } => {
            find_unsafe_in_expr(condition, results);
            find_unsafe_in_expr(then_branch, results);
            find_unsafe_in_expr(else_branch, results);
        }
        Expr::Literal(_) | Expr::Identifier(_) => {}
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_find_unsafe_blocks_empty() {
        let program = Program { statements: vec![] };
        assert!(find_unsafe_blocks(&program).is_empty());
    }
}