use tree_sitter::{Node, Point};
use crate::cst::node_to_span;
use crate::file_analysis::{InferredType, ScopeId, Span};
use super::{connector_keyword_between, point_lt, raw_leading_op, raw_mid_op, Builder};
pub(super) enum NarrowSubject {
Variable(String),
Place { key: String, root: String },
}
fn narrow_subject_of(node: Node, src: &[u8]) -> Option<NarrowSubject> {
if let Some(v) = crate::cst::canonical_var_name(node, src) {
return Some(NarrowSubject::Variable(v));
}
let (key, root) = crate::cst::canonical_place_path(node, src)?;
Some(NarrowSubject::Place { key, root })
}
#[derive(Clone)]
pub(super) enum NarrowOp {
To(InferredType),
StripOptional { query_point: Point },
}
impl NarrowOp {
fn negated(&self) -> Option<NarrowOp> {
match self {
NarrowOp::To(_) => None,
NarrowOp::StripOptional { .. } => Some(NarrowOp::To(InferredType::Undef)),
}
}
}
pub(super) struct GuardFact {
subject: NarrowSubject,
op: NarrowOp,
asserts_when_true: bool,
}
impl GuardFact {
fn op_for_region(&self, holds: bool) -> Option<NarrowOp> {
if self.asserts_when_true == holds {
Some(self.op.clone())
} else {
self.op.negated()
}
}
}
#[derive(Clone)]
pub(super) struct DefinedNarrowing {
name: String,
scope: ScopeId,
region: Span,
query_point: Point,
}
fn ref_string_to_type(s: &str) -> Option<InferredType> {
Some(match s {
"HASH" => InferredType::HashRef,
"ARRAY" => InferredType::ArrayRef,
"CODE" => InferredType::CodeRef { return_edge: None },
"Regexp" => InferredType::Regexp,
"SCALAR" | "REF" | "GLOB" | "LVALUE" | "FORMAT" | "IO" | "VSTRING" => return None,
other => InferredType::ClassName(other.to_string()),
})
}
fn is_subject_node_kind(kind: &str) -> bool {
matches!(
kind,
"scalar" | "hash_element_expression" | "array_element_expression"
)
}
fn func1op_subject_arg<'a>(node: Node<'a>) -> Option<Node<'a>> {
crate::cst::first_named_child_where(node, is_subject_node_kind)
}
fn recognize_guards(cond: Node, source: &[u8]) -> Vec<GuardFact> {
let cond = crate::cst::peel_groups(cond);
match cond.kind() {
"unary_expression" if raw_leading_op(cond, source) == "!" => cond
.child_by_field_name("operand")
.map(|op| {
recognize_guards(op, source)
.into_iter()
.map(|mut f| {
f.asserts_when_true = !f.asserts_when_true;
f
})
.collect()
})
.unwrap_or_default(),
"binary_expression" if matches!(raw_mid_op(cond, source).as_str(), "&&" | "and") => {
let mut out = Vec::new();
if let Some(l) = cond.child_by_field_name("left") {
out.extend(recognize_guards(l, source));
}
if let Some(r) = cond.child_by_field_name("right") {
out.extend(recognize_guards(r, source));
}
out
}
"method_call_expression" => recognize_isa_guard(cond, source).into_iter().collect(),
"equality_expression" => recognize_ref_eq_guard(cond, source).into_iter().collect(),
"func1op_call_expression" => recognize_defined_guard(cond, source).into_iter().collect(),
"ambiguous_function_call_expression" | "function_call_expression" => {
recognize_blessed_guard(cond, source).into_iter().collect()
}
_ => Vec::new(),
}
}
fn recognize_defined_guard(call: Node, source: &[u8]) -> Option<GuardFact> {
if call.child(0)?.utf8_text(source).ok()? != "defined" {
return None;
}
let arg = func1op_subject_arg(call)?;
Some(GuardFact {
subject: narrow_subject_of(arg, source)?,
op: NarrowOp::StripOptional { query_point: arg.start_position() },
asserts_when_true: true,
})
}
fn recognize_blessed_guard(call: Node, source: &[u8]) -> Option<GuardFact> {
let name = call.child_by_field_name("function")?.utf8_text(source).ok()?;
if name != "blessed" {
return None;
}
let arg = crate::cst::peel_groups(call.child_by_field_name("arguments")?);
if !is_subject_node_kind(arg.kind()) {
return None;
}
Some(GuardFact {
subject: narrow_subject_of(arg, source)?,
op: NarrowOp::StripOptional { query_point: arg.start_position() },
asserts_when_true: true,
})
}
fn recognize_isa_guard(call: Node, source: &[u8]) -> Option<GuardFact> {
let mc = crate::cst::MethodCall::cast(call)?;
let method = mc.method()?.utf8_text(source).ok()?;
if method != "isa" && method != "DOES" {
return None;
}
let subject = narrow_subject_of(mc.invocant()?, source)?;
let class = crate::cst::plain_string_literal_text(call.child_by_field_name("arguments")?, source)?;
Some(GuardFact {
subject,
op: NarrowOp::To(InferredType::ClassName(class)),
asserts_when_true: true,
})
}
fn recognize_ref_eq_guard(eq: Node, source: &[u8]) -> Option<GuardFact> {
if raw_mid_op(eq, source) != "eq" {
return None;
}
let left = eq.child_by_field_name("left")?;
let right = eq.child_by_field_name("right")?;
let (ref_call, lit) = if left.kind() == "func1op_call_expression" {
(left, right)
} else if right.kind() == "func1op_call_expression" {
(right, left)
} else {
return None;
};
let fname = ref_call.child(0)?.utf8_text(source).ok()?;
if fname != "ref" && fname != "reftype" {
return None;
}
let subject = narrow_subject_of(func1op_subject_arg(ref_call)?, source)?;
let ty = ref_string_to_type(&crate::cst::plain_string_literal_text(lit, source)?)?;
Some(GuardFact {
subject,
op: NarrowOp::To(ty),
asserts_when_true: true,
})
}
fn is_exit_expression(node: Node, source: &[u8]) -> bool {
const EXITS: [&str; 7] = ["die", "croak", "confess", "last", "next", "redo", "goto"];
match node.kind() {
"return_expression" | "last_expression" | "next_expression" | "redo_expression" => true,
"function" | "bareword" => node
.utf8_text(source)
.map(|s| EXITS.contains(&s.trim()))
.unwrap_or(false),
"func1op_call_expression"
| "function_call_expression"
| "ambiguous_function_call_expression" => node
.child_by_field_name("function")
.or_else(|| node.child(0))
.and_then(|n| n.utf8_text(source).ok())
.map(|s| EXITS.contains(&s.trim()))
.unwrap_or(false),
_ => false,
}
}
fn trailing_else<'a>(cond_stmt: Node<'a>) -> Option<Node<'a>> {
for i in 0..cond_stmt.named_child_count() {
let c = cond_stmt.named_child(i)?;
if c.kind() == "else" {
return c.child_by_field_name("block");
}
}
None
}
impl<'a> Builder<'a> {
pub(super) fn narrow_block_guard(&mut self, cond_stmt: Node<'a>) {
let Some(condition) = cond_stmt.child_by_field_name("condition") else { return };
let Some(block) = cond_stmt.child_by_field_name("block") else { return };
let Some(holds_when_true) = self.block_guard_polarity(cond_stmt, condition) else { return };
let then_region = node_to_span(block);
let else_block = trailing_else(cond_stmt);
let facts = recognize_guards(condition, self.source);
for fact in &facts {
if let Some(op) = fact.op_for_region(holds_when_true) {
self.emit_narrowing_fact(&fact.subject, op, then_region, block);
}
if let Some(else_block) = else_block {
if let Some(op) = fact.op_for_region(!holds_when_true) {
self.emit_narrowing_fact(
&fact.subject,
op,
node_to_span(else_block),
else_block,
);
}
}
}
}
fn block_guard_polarity(&self, cond_stmt: Node<'a>, condition: Node<'a>) -> Option<bool> {
let between =
std::str::from_utf8(&self.source[cond_stmt.start_byte()..condition.start_byte()])
.ok()?;
match between.split_whitespace().next()? {
"if" => Some(true),
"unless" => Some(false),
_ => None,
}
}
pub(super) fn narrow_postfix_exit(&mut self, postfix: Node<'a>) {
let Some(condition) = postfix.child_by_field_name("condition") else { return };
let Some(modified) = postfix.named_child(0) else { return };
if !is_exit_expression(modified, self.source) {
return;
}
let Some(kw) = connector_keyword_between(modified, condition, self.source) else { return };
let holds_when_true = match kw.as_str() {
"unless" => true,
"if" => false,
_ => return,
};
self.narrow_block_remainder(postfix, condition, holds_when_true);
}
pub(super) fn narrow_logical_exit(&mut self, expr: Node<'a>) {
let Some(left) = expr.child_by_field_name("left") else { return };
let Some(right) = expr.child_by_field_name("right") else { return };
if !is_exit_expression(right, self.source) {
return;
}
let holds_when_true = match raw_mid_op(expr, self.source).as_str() {
"or" | "||" => true,
"and" | "&&" => false,
_ => return,
};
self.narrow_block_remainder(expr, left, holds_when_true);
}
fn narrow_block_remainder(
&mut self,
stmt_expr: Node<'a>,
condition: Node<'a>,
holds_when_true: bool,
) {
let mut stmt = stmt_expr;
while stmt.kind() != "expression_statement" {
let Some(p) = stmt.parent() else { return };
stmt = p;
}
let Some(block) = stmt.parent() else { return };
if block.kind() != "block" {
return;
}
let region = Span { start: stmt.end_position(), end: block.end_position() };
for fact in recognize_guards(condition, self.source) {
if let Some(op) = fact.op_for_region(holds_when_true) {
self.emit_narrowing_fact(&fact.subject, op, region, block);
}
}
}
fn emit_narrowing_fact(
&mut self,
subject: &NarrowSubject,
op: NarrowOp,
region: Span,
container: Node<'a>,
) {
use crate::witnesses::{Witness, WitnessAttachment, WitnessPayload, WitnessSource};
let (name, end) = match subject {
NarrowSubject::Variable(var) => {
let end = self.first_subject_write(var, region, container);
(var.clone(), end)
}
NarrowSubject::Place { key, root } => {
let end = self.first_place_invalidation(key, root, region, container);
(key.clone(), end)
}
};
let end = end.unwrap_or(region.end);
if !point_lt(region.start, end) {
return; }
let region = Span { start: region.start, end };
let scope = self.current_scope();
match op {
NarrowOp::To(ty) => {
self.bag.push(Witness {
attachment: WitnessAttachment::Variable { name, scope },
source: WitnessSource::Builder("narrowing".into()),
payload: WitnessPayload::InferredType(ty),
span: region,
});
}
NarrowOp::StripOptional { query_point } => {
self.defined_narrowings.push(DefinedNarrowing {
name,
scope,
region,
query_point,
});
}
}
}
pub(super) fn emit_defined_narrowing_witnesses(&mut self) {
use crate::witnesses::{Witness, WitnessAttachment, WitnessPayload, WitnessSource};
self.bag.remove_by_source_tag("defined_narrowing");
let guards = self.defined_narrowings.clone();
let mut emits: Vec<(String, ScopeId, InferredType, Span)> = Vec::new();
for g in &guards {
if let Some(inner) = self
.bag_query_variable(&g.name, g.scope, g.query_point)
.as_ref()
.and_then(InferredType::optional_inner)
{
emits.push((g.name.clone(), g.scope, inner.clone(), g.region));
}
}
for (name, scope, inner, region) in emits {
self.bag.push(Witness {
attachment: WitnessAttachment::Variable { name, scope },
source: WitnessSource::Builder("defined_narrowing".into()),
payload: WitnessPayload::InferredType(inner),
span: region,
});
}
}
fn first_place_invalidation(
&self,
key: &str,
root: &str,
region: Span,
container: Node<'a>,
) -> Option<Point> {
fn consider(node: Node, after: Point, best: &mut Option<Point>) {
let p = node.start_position();
if !point_lt(p, after) && best.map_or(true, |b| point_lt(p, b)) {
*best = Some(p);
}
}
fn is_proper_prefix(prefix: &str, key: &str) -> bool {
key.len() > prefix.len()
&& key.starts_with(prefix)
&& matches!(key.as_bytes()[prefix.len()], b'-' | b'{' | b'[')
}
fn scan(node: Node, key: &str, root: &str, after: Point, src: &[u8], best: &mut Option<Point>) {
if node.kind() == "assignment_expression" {
if let Some(left) = node.child_by_field_name("left") {
if crate::cst::canonical_place_path(left, src).map(|(k, _)| k).as_deref()
== Some(key)
{
consider(node, after, best);
}
}
}
let is_prefix = match node.kind() {
"scalar" => crate::cst::canonical_var_name(node, src).as_deref() == Some(root),
"hash_element_expression" | "array_element_expression" => {
crate::cst::canonical_place_path(node, src)
.is_some_and(|(k, _)| is_proper_prefix(&k, key))
}
_ => false,
};
if is_prefix {
let guarded = node.parent().map_or(false, |p| {
matches!(p.kind(), "hash_element_expression" | "array_element_expression")
&& p.named_child(0) == Some(node)
});
if !guarded {
consider(node, after, best);
}
}
for i in 0..node.named_child_count() {
if let Some(c) = node.named_child(i) {
scan(c, key, root, after, src, best);
}
}
}
let mut best = None;
scan(container, key, root, region.start, self.source, &mut best);
best
}
fn first_subject_write(&self, var: &str, region: Span, container: Node<'a>) -> Option<Point> {
fn writes_var(left: Node, var: &str, src: &[u8]) -> bool {
match left.kind() {
"scalar" => crate::cst::canonical_var_name(left, src).as_deref() == Some(var),
"variable_declaration" => {
let mut stack = vec![left];
while let Some(n) = stack.pop() {
if n.kind() == "scalar"
&& crate::cst::canonical_var_name(n, src).as_deref() == Some(var)
{
return true;
}
for i in 0..n.named_child_count() {
if let Some(c) = n.named_child(i) {
stack.push(c);
}
}
}
false
}
_ => false,
}
}
fn scan(node: Node, var: &str, after: Point, src: &[u8], best: &mut Option<Point>) {
if node.kind() == "assignment_expression" {
if let Some(left) = node.child_by_field_name("left") {
if writes_var(left, var, src) {
let p = node.start_position();
if !point_lt(p, after) && best.map_or(true, |b| point_lt(p, b)) {
*best = Some(p);
}
}
}
}
for i in 0..node.named_child_count() {
if let Some(c) = node.named_child(i) {
scan(c, var, after, src, best);
}
}
}
let mut best = None;
scan(container, var, region.start, self.source, &mut best);
best
}
}