use std::collections::HashMap;
use crate::ast::{
Annotation, AnnotationHandler, AnnotationHandlerType, Expr, Item, MethodDef, Program, Span,
Statement, TypeName,
};
pub fn augment_program_with_generated_extends(program: &Program) -> Program {
let mut augmented = program.clone();
let generated = collect_generated_annotation_extends(program);
for extend in generated {
augmented.items.push(Item::Extend(extend, Span::DUMMY));
}
augmented
}
pub fn collect_generated_annotation_extends(program: &Program) -> Vec<crate::ast::ExtendStatement> {
let mut comptime_handlers: HashMap<String, Vec<AnnotationHandler>> = HashMap::new();
for item in &program.items {
if let Item::AnnotationDef(ann_def, _) = item {
let handlers: Vec<_> = ann_def
.handlers
.iter()
.filter(|h| {
matches!(
h.handler_type,
AnnotationHandlerType::ComptimePre | AnnotationHandlerType::ComptimePost
)
})
.cloned()
.collect();
if !handlers.is_empty() {
comptime_handlers.insert(ann_def.name.clone(), handlers);
}
}
}
if comptime_handlers.is_empty() {
return Vec::new();
}
let mut methods_by_type: HashMap<String, Vec<MethodDef>> = HashMap::new();
collect_generated_annotation_extends_from_items(
&program.items,
&comptime_handlers,
&mut methods_by_type,
);
methods_by_type
.into_iter()
.map(|(type_name, methods)| crate::ast::ExtendStatement {
type_name: TypeName::Simple(type_name.into()),
methods,
})
.collect()
}
fn collect_generated_annotation_extends_from_items(
items: &[Item],
comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
for item in items {
match item {
Item::StructType(struct_def, _) => collect_annotation_methods_for_target(
&struct_def.annotations,
&struct_def.name,
comptime_handlers,
methods_by_type,
),
Item::Function(func_def, _) => collect_annotation_methods_for_target(
&func_def.annotations,
&func_def.name,
comptime_handlers,
methods_by_type,
),
Item::Module(module_def, _) => collect_generated_annotation_extends_from_items(
&module_def.items,
comptime_handlers,
methods_by_type,
),
_ => {}
}
}
}
fn collect_annotation_methods_for_target(
annotations: &[Annotation],
target_name: &str,
comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
for ann in annotations {
let Some(handlers) = comptime_handlers.get(&ann.name) else {
continue;
};
for handler in handlers {
collect_extend_methods_from_expr(&handler.body, target_name, methods_by_type);
}
}
}
fn collect_extend_methods_from_expr(
expr: &Expr,
target_name: &str,
methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
match expr {
Expr::Block(block, _) => {
for item in &block.items {
match item {
crate::ast::BlockItem::Statement(stmt) => {
collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
}
crate::ast::BlockItem::Expression(inner) => {
collect_extend_methods_from_expr(inner, target_name, methods_by_type);
}
_ => {}
}
}
}
Expr::Conditional {
then_expr,
else_expr,
..
} => {
collect_extend_methods_from_expr(then_expr, target_name, methods_by_type);
if let Some(else_expr) = else_expr {
collect_extend_methods_from_expr(else_expr, target_name, methods_by_type);
}
}
Expr::Match(match_expr, _) => {
for arm in &match_expr.arms {
collect_extend_methods_from_expr(&arm.body, target_name, methods_by_type);
}
}
Expr::Annotated { target, .. } => {
collect_extend_methods_from_expr(target, target_name, methods_by_type);
}
Expr::Comptime(stmts, _) => {
for stmt in stmts {
collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
}
}
_ => {}
}
}
fn collect_extend_methods_from_stmt(
stmt: &Statement,
target_name: &str,
methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
) {
match stmt {
Statement::Extend(extend, _) => {
let resolved_type = match &extend.type_name {
TypeName::Simple(name) if name.as_str() == "target" => target_name.to_string(),
TypeName::Generic { name, .. } if name.as_str() == "target" => target_name.to_string(),
TypeName::Simple(name) => name.to_string(),
TypeName::Generic { name, .. } => name.to_string(),
};
let entry = methods_by_type.entry(resolved_type).or_default();
for method in &extend.methods {
if !entry.iter().any(|existing| existing.name == method.name) {
entry.push(method.clone());
}
}
}
Statement::If(if_stmt, _) => {
for stmt in &if_stmt.then_body {
collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
}
if let Some(else_body) = &if_stmt.else_body {
for stmt in else_body {
collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
}
}
}
Statement::For(for_loop, _) => {
for stmt in &for_loop.body {
collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
}
}
Statement::While(while_loop, _) => {
for stmt in &while_loop.body {
collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
}
}
Statement::Expression(expr, _) => {
collect_extend_methods_from_expr(expr, target_name, methods_by_type);
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_program;
#[test]
fn collects_extend_target_for_annotated_type() {
let code = r#"
annotation add_sum() {
targets: [type]
comptime post(target, ctx) {
extend target {
method sum() { self.x + self.y }
}
}
}
@add_sum()
type Point { x: int, y: int }
"#;
let program = parse_program(code).expect("parse");
let generated = collect_generated_annotation_extends(&program);
assert_eq!(generated.len(), 1, "expected one generated extend");
let ext = &generated[0];
match &ext.type_name {
TypeName::Simple(name) => assert_eq!(name, "Point"),
other => panic!("expected simple type name, got {:?}", other),
}
assert!(
ext.methods.iter().any(|m| m.name == "sum"),
"expected generated sum method"
);
}
#[test]
fn does_not_generate_for_unused_annotation() {
let code = r#"
annotation add_sum() {
targets: [type]
comptime post(target, ctx) {
extend target {
method sum() { self.x + self.y }
}
}
}
type Point { x: int, y: int }
"#;
let program = parse_program(code).expect("parse");
let generated = collect_generated_annotation_extends(&program);
assert!(
generated.is_empty(),
"unused annotations must not generate extends"
);
}
}