use std::sync::Arc;
use mago_span::HasSpan;
use mago_syntax::ast::*;
use crate::parser::with_parsed_program;
use crate::types::ClassInfo;
use crate::util::{resolve_class_keyword, short_name};
use crate::completion::conditional_resolution::extract_class_string_from_expr;
use crate::completion::resolver::{Loaders, VarResolutionCtx};
pub(in crate::completion) fn resolve_class_string_targets(
var_name: &str,
current_class: &ClassInfo,
all_classes: &[Arc<ClassInfo>],
content: &str,
cursor_offset: u32,
class_loader: &dyn Fn(&str) -> Option<Arc<ClassInfo>>,
) -> Vec<ClassInfo> {
with_parsed_program(
content,
"resolve_class_string_targets",
|program, _content| {
let ctx = VarResolutionCtx {
var_name,
current_class,
all_classes,
content,
cursor_offset,
class_loader,
loaders: Loaders::default(),
resolved_class_cache: None,
enclosing_return_type: None,
branch_aware: false,
};
resolve_class_string_in_statements(program.statements.iter(), &ctx)
},
)
}
fn resolve_class_string_in_statements<'b>(
statements: impl Iterator<Item = &'b Statement<'b>>,
ctx: &VarResolutionCtx<'_>,
) -> Vec<ClassInfo> {
let stmts: Vec<&Statement> = statements.collect();
for &stmt in &stmts {
match stmt {
Statement::Class(class) => {
let start = class.left_brace.start.offset;
let end = class.right_brace.end.offset;
if ctx.cursor_offset >= start && ctx.cursor_offset <= end {
return resolve_class_string_in_members(class.members.iter(), ctx);
}
}
Statement::Interface(iface) => {
let start = iface.left_brace.start.offset;
let end = iface.right_brace.end.offset;
if ctx.cursor_offset >= start && ctx.cursor_offset <= end {
return resolve_class_string_in_members(iface.members.iter(), ctx);
}
}
Statement::Enum(enum_def) => {
let start = enum_def.left_brace.start.offset;
let end = enum_def.right_brace.end.offset;
if ctx.cursor_offset >= start && ctx.cursor_offset <= end {
return resolve_class_string_in_members(enum_def.members.iter(), ctx);
}
}
Statement::Trait(trait_def) => {
let start = trait_def.left_brace.start.offset;
let end = trait_def.right_brace.end.offset;
if ctx.cursor_offset >= start && ctx.cursor_offset <= end {
return resolve_class_string_in_members(trait_def.members.iter(), ctx);
}
}
Statement::Namespace(ns) => {
let results = resolve_class_string_in_statements(ns.statements().iter(), ctx);
if !results.is_empty() {
return results;
}
}
Statement::Function(func) => {
let body_start = func.body.left_brace.start.offset;
let body_end = func.body.right_brace.end.offset;
if ctx.cursor_offset >= body_start && ctx.cursor_offset <= body_end {
let mut results = Vec::new();
walk_class_string_assignments(func.body.statements.iter(), ctx, &mut results);
return results;
}
}
_ => {}
}
}
let mut results = Vec::new();
walk_class_string_assignments(stmts.into_iter(), ctx, &mut results);
results
}
fn resolve_class_string_in_members<'b>(
members: impl Iterator<Item = &'b ClassLikeMember<'b>>,
ctx: &VarResolutionCtx<'_>,
) -> Vec<ClassInfo> {
for member in members {
if let ClassLikeMember::Method(method) = member {
let body = match &method.body {
MethodBody::Concrete(body) => body,
_ => continue,
};
let start = body.left_brace.start.offset;
let end = body.right_brace.end.offset;
if ctx.cursor_offset >= start && ctx.cursor_offset <= end {
let mut results = Vec::new();
walk_class_string_assignments(body.statements.iter(), ctx, &mut results);
return results;
}
}
}
vec![]
}
fn walk_class_string_assignments<'b>(
statements: impl Iterator<Item = &'b Statement<'b>>,
ctx: &VarResolutionCtx<'_>,
results: &mut Vec<ClassInfo>,
) {
for stmt in statements {
if stmt.span().start.offset >= ctx.cursor_offset {
continue;
}
if let Statement::Expression(expr_stmt) = stmt {
check_class_string_assignment(expr_stmt.expression, ctx, results);
}
}
}
fn check_class_string_assignment(
expr: &Expression<'_>,
ctx: &VarResolutionCtx<'_>,
results: &mut Vec<ClassInfo>,
) {
let Expression::Assignment(assignment) = expr else {
return;
};
if !assignment.operator.is_assign() {
return;
}
let lhs_name = match assignment.lhs {
Expression::Variable(Variable::Direct(dv)) => dv.name.to_string(),
_ => return,
};
if lhs_name != ctx.var_name {
return;
}
let class_names = extract_class_string_names(assignment.rhs);
results.clear();
for name in class_names {
let resolved_name =
if let Some(resolved) = resolve_class_keyword(&name, Some(ctx.current_class)) {
resolved
} else {
name
};
let lookup = short_name(&resolved_name);
if let Some(cls) = ctx.all_classes.iter().find(|c| c.name == lookup) {
ClassInfo::push_unique(results, ClassInfo::clone(cls));
} else if let Some(cls) = (ctx.class_loader)(&resolved_name) {
ClassInfo::push_unique(results, Arc::unwrap_or_clone(cls));
}
}
}
fn extract_class_string_names(expr: &Expression<'_>) -> Vec<String> {
if let Some(name) = extract_class_string_from_expr(expr) {
return vec![name];
}
match expr {
Expression::Parenthesized(p) => extract_class_string_names(p.expression),
Expression::Match(match_expr) => {
let mut names = Vec::new();
for arm in match_expr.arms.iter() {
names.extend(extract_class_string_names(arm.expression()));
}
names
}
Expression::Conditional(cond) => {
let mut names = Vec::new();
let then_expr = cond.then.unwrap_or(cond.condition);
names.extend(extract_class_string_names(then_expr));
names.extend(extract_class_string_names(cond.r#else));
names
}
Expression::Binary(binary) if binary.operator.is_null_coalesce() => {
let mut names = Vec::new();
names.extend(extract_class_string_names(binary.lhs));
names.extend(extract_class_string_names(binary.rhs));
names
}
_ => vec![],
}
}