use std::fmt::Debug;
use cairo_lang_diagnostics::DiagnosticAdded;
use cairo_lang_semantic::{self as semantic, ConcreteVariant, PatternVariable};
use cairo_lang_syntax::node::ast::ExprPtr;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::Itertools;
use num_bigint::BigInt;
use num_integer::Integer;
use crate::diagnostic::{
LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder, MatchKind,
};
use crate::ids::LocationId;
use crate::lower::context::LoweringContext;
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct FlowControlVar(usize);
impl<'db> FlowControlVar {
pub fn ty(&self, graph: &FlowControlGraph<'db>) -> semantic::TypeId<'db> {
graph.var_types[self.0]
}
pub fn location(&self, graph: &FlowControlGraph<'db>) -> LocationId<'db> {
graph.var_locations[self.0]
}
pub fn times_used(&self, graph: &FlowControlGraph<'db>) -> usize {
graph.times_used[self]
}
}
impl Debug for FlowControlVar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "v{}", self.0)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct PatternVarId(usize);
impl PatternVarId {
pub fn get<'db, 'a>(&self, graph: &'a FlowControlGraph<'db>) -> &'a PatternVariable<'db> {
&graph.pattern_vars[self.0]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct NodeId(pub usize);
#[derive(Debug)]
pub struct EvaluateExpr {
pub expr: semantic::ExprId,
pub var_id: FlowControlVar,
pub next: NodeId,
}
#[derive(Debug)]
pub struct BooleanIf {
pub condition_var: FlowControlVar,
pub true_branch: NodeId,
pub false_branch: NodeId,
}
pub struct EnumMatch<'db> {
pub matched_var: FlowControlVar,
pub concrete_enum_id: semantic::ConcreteEnumId<'db>,
pub variants: Vec<(ConcreteVariant<'db>, NodeId, FlowControlVar)>,
}
impl<'db> std::fmt::Debug for EnumMatch<'db> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"EnumMatch {{ matched_var: {:?}, variants: {}}}",
self.matched_var,
self.variants.iter().map(|(_, node, var)| format!("({node:?}, {var:?})")).join(", ")
)
}
}
#[derive(Debug)]
pub struct ValueMatch {
pub matched_var: FlowControlVar,
pub nodes: Vec<NodeId>,
}
pub struct EqualsLiteral<'db> {
pub input: FlowControlVar,
pub literal: BigInt,
pub stable_ptr: ExprPtr<'db>,
pub true_branch: NodeId,
pub false_branch: NodeId,
}
impl<'db> std::fmt::Debug for EqualsLiteral<'db> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"EqualsLiteral {{ input: {:?}, literal: {}, true_branch: {:?}, false_branch: {:?} }}",
self.input, self.literal, self.true_branch, self.false_branch,
)
}
}
#[derive(Debug)]
pub struct ArmExpr {
pub expr: semantic::ExprId,
}
#[derive(Debug)]
pub struct WhileBody<'db> {
pub body: semantic::ExprId,
pub loop_expr_id: semantic::ExprId,
pub loop_stable_ptr: SyntaxStablePtrId<'db>,
}
#[derive(Debug)]
pub struct Deconstruct {
pub input: FlowControlVar,
pub outputs: Vec<FlowControlVar>,
pub next: NodeId,
}
#[derive(Debug)]
pub struct BindVar {
pub input: FlowControlVar,
pub output: PatternVarId,
pub next: NodeId,
}
#[derive(Debug)]
pub struct Upcast {
pub input: FlowControlVar,
pub output: FlowControlVar,
pub next: NodeId,
}
#[derive(Debug)]
pub struct Downcast {
pub input: FlowControlVar,
pub output: FlowControlVar,
pub in_range: NodeId,
pub out_of_range: NodeId,
}
#[derive(Debug)]
pub struct LetElseSuccess<'db> {
pub var_ids_and_stable_ptrs: Vec<(semantic::VarId<'db>, SyntaxStablePtrId<'db>)>,
}
pub enum FlowControlNode<'db> {
EvaluateExpr(EvaluateExpr),
BooleanIf(BooleanIf),
EnumMatch(EnumMatch<'db>),
ValueMatch(ValueMatch),
EqualsLiteral(EqualsLiteral<'db>),
ArmExpr(ArmExpr),
WhileBody(WhileBody<'db>),
Deconstruct(Deconstruct),
BindVar(BindVar),
Upcast(Upcast),
Downcast(Downcast),
LetElseSuccess(LetElseSuccess<'db>),
UnitResult,
Missing(DiagnosticAdded),
}
impl<'db> FlowControlNode<'db> {
pub fn input_var(&self) -> Option<FlowControlVar> {
match self {
FlowControlNode::EvaluateExpr(..) => None,
FlowControlNode::BooleanIf(node) => Some(node.condition_var),
FlowControlNode::EnumMatch(node) => Some(node.matched_var),
FlowControlNode::ValueMatch(node) => Some(node.matched_var),
FlowControlNode::EqualsLiteral(node) => Some(node.input),
FlowControlNode::ArmExpr(..) => None,
FlowControlNode::WhileBody(..) => None,
FlowControlNode::Deconstruct(node) => Some(node.input),
FlowControlNode::BindVar(node) => Some(node.input),
FlowControlNode::Upcast(node) => Some(node.input),
FlowControlNode::Downcast(node) => Some(node.input),
FlowControlNode::LetElseSuccess(..) => None,
FlowControlNode::UnitResult => None,
FlowControlNode::Missing(_) => None,
}
}
}
impl<'db> Debug for FlowControlNode<'db> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FlowControlNode::EvaluateExpr(node) => node.fmt(f),
FlowControlNode::BooleanIf(node) => node.fmt(f),
FlowControlNode::EnumMatch(node) => node.fmt(f),
FlowControlNode::ValueMatch(node) => node.fmt(f),
FlowControlNode::EqualsLiteral(node) => node.fmt(f),
FlowControlNode::ArmExpr(node) => node.fmt(f),
FlowControlNode::WhileBody(node) => node.fmt(f),
FlowControlNode::Deconstruct(node) => node.fmt(f),
FlowControlNode::BindVar(node) => node.fmt(f),
FlowControlNode::Upcast(node) => node.fmt(f),
FlowControlNode::Downcast(node) => node.fmt(f),
FlowControlNode::LetElseSuccess(node) => node.fmt(f),
FlowControlNode::UnitResult => write!(f, "UnitResult"),
FlowControlNode::Missing(_) => write!(f, "Missing"),
}
}
}
pub struct FlowControlGraph<'db> {
nodes: Vec<FlowControlNode<'db>>,
var_types: Vec<semantic::TypeId<'db>>,
var_locations: Vec<LocationId<'db>>,
pattern_vars: Vec<PatternVariable<'db>>,
kind: MatchKind<'db>,
times_used: UnorderedHashMap<FlowControlVar, usize>,
}
impl<'db> FlowControlGraph<'db> {
pub fn root(&self) -> NodeId {
NodeId(self.nodes.len() - 1)
}
pub fn size(&self) -> usize {
self.nodes.len()
}
pub fn node(&self, id: NodeId) -> &FlowControlNode<'db> {
&self.nodes[id.0]
}
pub fn kind(&self) -> MatchKind<'db> {
self.kind
}
}
impl<'db> Debug for FlowControlGraph<'db> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Root: {}", self.root().0)?;
for (i, node) in self.nodes.iter().enumerate().rev() {
writeln!(f, "{i} {node:?}")?;
}
Ok(())
}
}
pub struct FlowControlGraphBuilder<'db> {
graph: FlowControlGraph<'db>,
diagnostics: LoweringDiagnostics<'db>,
}
impl<'db> FlowControlGraphBuilder<'db> {
pub fn new(kind: MatchKind<'db>) -> Self {
let graph = FlowControlGraph {
nodes: Vec::new(),
var_types: Vec::new(),
var_locations: Vec::new(),
pattern_vars: Vec::new(),
kind,
times_used: UnorderedHashMap::default(),
};
Self { graph, diagnostics: LoweringDiagnostics::default() }
}
pub fn add_node(&mut self, node: FlowControlNode<'db>) -> NodeId {
if let Some(input_var) = node.input_var() {
self.graph.times_used.entry(input_var).or_insert(0).inc();
}
let id = NodeId(self.graph.size());
self.graph.nodes.push(node);
id
}
pub fn is_var_used(&self, var: FlowControlVar) -> bool {
self.graph.times_used.contains_key(&var)
}
pub fn finalize(
self,
root: NodeId,
ctx: &mut LoweringContext<'db, '_>,
) -> FlowControlGraph<'db> {
assert_eq!(root.0, self.graph.size() - 1, "The root must be the last node.");
ctx.diagnostics.extend(self.diagnostics.build());
self.graph
}
pub fn new_var(
&mut self,
ty: semantic::TypeId<'db>,
location: LocationId<'db>,
) -> FlowControlVar {
let var = FlowControlVar(self.graph.var_types.len());
self.graph.var_types.push(ty);
self.graph.var_locations.push(location);
var
}
pub fn register_pattern_var(&mut self, var: PatternVariable<'db>) -> PatternVarId {
let idx = self.graph.pattern_vars.len();
self.graph.pattern_vars.push(var);
PatternVarId(idx)
}
pub fn var_ty(&self, input_var: FlowControlVar) -> semantic::TypeId<'db> {
self.graph.var_types[input_var.0]
}
pub fn var_location(&self, input_var: FlowControlVar) -> LocationId<'db> {
self.graph.var_locations[input_var.0]
}
pub fn report(
&mut self,
stable_ptr: impl Into<SyntaxStablePtrId<'db>>,
kind: LoweringDiagnosticKind<'db>,
) -> DiagnosticAdded {
self.diagnostics.report(stable_ptr, kind)
}
pub fn report_with_missing_node(
&mut self,
stable_ptr: impl Into<SyntaxStablePtrId<'db>>,
kind: LoweringDiagnosticKind<'db>,
) -> NodeId {
let diag_added = self.diagnostics.report(stable_ptr, kind);
self.add_node(FlowControlNode::Missing(diag_added))
}
pub fn kind(&self) -> MatchKind<'db> {
self.graph.kind
}
}