shape-ast 0.1.8

AST types and Pest grammar for the Shape programming language
Documentation
//! Comptime annotation expansion utilities.
//!
//! This transform extracts direct `extend ... { ... }` directives from
//! `annotation ... comptime pre/post(...) { ... }` handler bodies and materializes
//! them as synthetic top-level `Item::Extend` entries. It is intentionally
//! static/AST-driven and does not execute comptime code.

use std::collections::HashMap;

use crate::ast::{
    Annotation, AnnotationHandler, AnnotationHandlerType, Expr, Item, MethodDef, Program, Span,
    Statement, TypeName,
};

/// Return a cloned program augmented with synthetic `Item::Extend` items
/// derived from direct comptime handler directives on annotated targets.
pub fn augment_program_with_generated_extends(program: &Program) -> Program {
    let mut augmented = program.clone();
    let generated = collect_generated_annotation_extends(program);
    for extend in generated {
        augmented.items.push(Item::Extend(extend, Span::DUMMY));
    }
    augmented
}

/// Collect synthetic extends generated by directly declared annotation comptime handlers.
pub fn collect_generated_annotation_extends(program: &Program) -> Vec<crate::ast::ExtendStatement> {
    let mut comptime_handlers: HashMap<String, Vec<AnnotationHandler>> = HashMap::new();
    for item in &program.items {
        if let Item::AnnotationDef(ann_def, _) = item {
            let handlers: Vec<_> = ann_def
                .handlers
                .iter()
                .filter(|h| {
                    matches!(
                        h.handler_type,
                        AnnotationHandlerType::ComptimePre | AnnotationHandlerType::ComptimePost
                    )
                })
                .cloned()
                .collect();
            if !handlers.is_empty() {
                comptime_handlers.insert(ann_def.name.clone(), handlers);
            }
        }
    }

    if comptime_handlers.is_empty() {
        return Vec::new();
    }

    let mut methods_by_type: HashMap<String, Vec<MethodDef>> = HashMap::new();
    collect_generated_annotation_extends_from_items(
        &program.items,
        &comptime_handlers,
        &mut methods_by_type,
    );

    methods_by_type
        .into_iter()
        .map(|(type_name, methods)| crate::ast::ExtendStatement {
            type_name: TypeName::Simple(type_name.into()),
            methods,
        })
        .collect()
}

fn collect_generated_annotation_extends_from_items(
    items: &[Item],
    comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
    for item in items {
        match item {
            Item::StructType(struct_def, _) => collect_annotation_methods_for_target(
                &struct_def.annotations,
                &struct_def.name,
                comptime_handlers,
                methods_by_type,
            ),
            Item::Function(func_def, _) => collect_annotation_methods_for_target(
                &func_def.annotations,
                &func_def.name,
                comptime_handlers,
                methods_by_type,
            ),
            Item::Module(module_def, _) => collect_generated_annotation_extends_from_items(
                &module_def.items,
                comptime_handlers,
                methods_by_type,
            ),
            _ => {}
        }
    }
}

fn collect_annotation_methods_for_target(
    annotations: &[Annotation],
    target_name: &str,
    comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
    for ann in annotations {
        let Some(handlers) = comptime_handlers.get(&ann.name) else {
            continue;
        };
        for handler in handlers {
            collect_extend_methods_from_expr(&handler.body, target_name, methods_by_type);
        }
    }
}

fn collect_extend_methods_from_expr(
    expr: &Expr,
    target_name: &str,
    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
    match expr {
        Expr::Block(block, _) => {
            for item in &block.items {
                match item {
                    crate::ast::BlockItem::Statement(stmt) => {
                        collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
                    }
                    crate::ast::BlockItem::Expression(inner) => {
                        collect_extend_methods_from_expr(inner, target_name, methods_by_type);
                    }
                    _ => {}
                }
            }
        }
        Expr::Conditional {
            then_expr,
            else_expr,
            ..
        } => {
            collect_extend_methods_from_expr(then_expr, target_name, methods_by_type);
            if let Some(else_expr) = else_expr {
                collect_extend_methods_from_expr(else_expr, target_name, methods_by_type);
            }
        }
        Expr::Match(match_expr, _) => {
            for arm in &match_expr.arms {
                collect_extend_methods_from_expr(&arm.body, target_name, methods_by_type);
            }
        }
        Expr::Annotated { target, .. } => {
            collect_extend_methods_from_expr(target, target_name, methods_by_type);
        }
        Expr::Comptime(stmts, _) => {
            for stmt in stmts {
                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
            }
        }
        _ => {}
    }
}

fn collect_extend_methods_from_stmt(
    stmt: &Statement,
    target_name: &str,
    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
    match stmt {
        Statement::Extend(extend, _) => {
            let resolved_type = match &extend.type_name {
                TypeName::Simple(name) if name.as_str() == "target" => target_name.to_string(),
                TypeName::Generic { name, .. } if name.as_str() == "target" => target_name.to_string(),
                TypeName::Simple(name) => name.to_string(),
                TypeName::Generic { name, .. } => name.to_string(),
            };
            let entry = methods_by_type.entry(resolved_type).or_default();
            for method in &extend.methods {
                if !entry.iter().any(|existing| existing.name == method.name) {
                    entry.push(method.clone());
                }
            }
        }
        Statement::If(if_stmt, _) => {
            for stmt in &if_stmt.then_body {
                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
            }
            if let Some(else_body) = &if_stmt.else_body {
                for stmt in else_body {
                    collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
                }
            }
        }
        Statement::For(for_loop, _) => {
            for stmt in &for_loop.body {
                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
            }
        }
        Statement::While(while_loop, _) => {
            for stmt in &while_loop.body {
                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
            }
        }
        Statement::Expression(expr, _) => {
            collect_extend_methods_from_expr(expr, target_name, methods_by_type);
        }
        _ => {}
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::parser::parse_program;

    #[test]
    fn collects_extend_target_for_annotated_type() {
        let code = r#"
annotation add_sum() {
    targets: [type]
    comptime post(target, ctx) {
        extend target {
            method sum() { self.x + self.y }
        }
    }
}
@add_sum()
type Point { x: int, y: int }
"#;
        let program = parse_program(code).expect("parse");
        let generated = collect_generated_annotation_extends(&program);
        assert_eq!(generated.len(), 1, "expected one generated extend");
        let ext = &generated[0];
        match &ext.type_name {
            TypeName::Simple(name) => assert_eq!(name, "Point"),
            other => panic!("expected simple type name, got {:?}", other),
        }
        assert!(
            ext.methods.iter().any(|m| m.name == "sum"),
            "expected generated sum method"
        );
    }

    #[test]
    fn does_not_generate_for_unused_annotation() {
        let code = r#"
annotation add_sum() {
    targets: [type]
    comptime post(target, ctx) {
        extend target {
            method sum() { self.x + self.y }
        }
    }
}
type Point { x: int, y: int }
"#;
        let program = parse_program(code).expect("parse");
        let generated = collect_generated_annotation_extends(&program);
        assert!(
            generated.is_empty(),
            "unused annotations must not generate extends"
        );
    }
}