use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::NamedLanguageElementId;
use cairo_lang_diagnostics::{DiagnosticNote, Maybe};
use cairo_lang_filesystem::flag::FlagsGroup;
use cairo_lang_semantic::corelib::{CorelibSemantic, validate_literal};
use cairo_lang_semantic::expr::compute::unwrap_pattern_type;
use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::structure::StructSemantic;
use cairo_lang_semantic::{
self as semantic, ConcreteEnumId, ConcreteStructId, ConcreteTypeId, ExprNumericLiteral,
PatternEnumVariant, PatternLiteral, PatternStruct, PatternTuple, PatternWrappingInfo, TypeId,
TypeLongId, corelib,
};
use cairo_lang_syntax::node::TypedStablePtr;
use cairo_lang_syntax::node::ast::ExprPtr;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{Itertools, zip_eq};
use num_bigint::BigInt;
use salsa::Database;
use super::super::graph::{
Deconstruct, EnumMatch, FlowControlGraphBuilder, FlowControlNode, FlowControlVar, NodeId,
};
use super::cache::Cache;
use super::filtered_patterns::{Bindings, FilteredPatterns};
use crate::diagnostic::{LoweringDiagnosticKind, MatchDiagnostic, MatchError};
use crate::ids::LocationId;
use crate::lower::context::LoweringContext;
use crate::lower::flow_control::graph::{Downcast, EqualsLiteral, Upcast, ValueMatch};
type BuildNodeCallback<'db, 'a> =
&'a mut dyn FnMut(&mut FlowControlGraphBuilder<'db>, FilteredPatterns, String) -> NodeId;
type PatternOption<'a, 'db> = Option<&'a semantic::Pattern<'db>>;
pub struct CreateNodeParams<'db, 'mt, 'a> {
pub ctx: &'a LoweringContext<'db, 'mt>,
pub graph: &'a mut FlowControlGraphBuilder<'db>,
pub patterns: &'a [PatternOption<'a, 'db>],
pub build_node_callback: BuildNodeCallback<'db, 'a>,
pub location: LocationId<'db>,
}
pub fn create_node_for_patterns<'db>(
params: CreateNodeParams<'db, '_, '_>,
input_var: FlowControlVar,
) -> NodeId {
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location } = params;
let mut bindings: Vec<Bindings> = vec![];
let patterns: Vec<PatternOption<'_, 'db>> = patterns
.iter()
.map(|pattern| {
if let Some(semantic::Pattern::Variable(pattern_variable)) = pattern {
let pattern_var = graph.register_pattern_var(pattern_variable.clone());
bindings.push(Bindings::single(input_var, pattern_var));
None
} else {
bindings.push(Bindings::default());
*pattern
}
})
.collect_vec();
let mut cache = Cache::default();
let mut build_node_callback = |graph: &mut FlowControlGraphBuilder<'db>,
pattern_indices: FilteredPatterns,
path: String| {
cache.get_or_compute(
build_node_callback,
graph,
pattern_indices.add_bindings(&bindings),
path,
)
};
let var_ty = graph.var_ty(input_var);
let (long_ty, wrapping_info) = unwrap_pattern_type(ctx.db, var_ty);
let params = CreateNodeParams {
ctx,
graph,
patterns: &patterns,
build_node_callback: &mut build_node_callback,
location,
};
if patterns.is_empty()
&& let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) = long_ty
&& ctx.db.concrete_enum_variants(concrete_enum_id).unwrap().is_empty()
{
return create_node_for_enum(params, input_var, concrete_enum_id, wrapping_info);
}
let Some(first_non_any_pattern) =
patterns.iter().flatten().find(|pattern| !pattern_is_any(pattern))
else {
return build_node_callback(graph, FilteredPatterns::all(patterns.len()), "_".into());
};
if corelib::numeric_upcastable_to_felt252(ctx.db, var_ty) {
return create_node_for_value(params, input_var);
}
match long_ty {
TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) => {
create_node_for_enum(params, input_var, concrete_enum_id, wrapping_info)
}
TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) => {
create_node_for_struct(params, input_var, concrete_struct_id, wrapping_info)
}
TypeLongId::Tuple(types) => create_node_for_tuple(params, input_var, &types, wrapping_info),
_ => graph.report_with_missing_node(
first_non_any_pattern.stable_ptr(),
LoweringDiagnosticKind::MatchError(MatchError {
kind: graph.kind(),
error: MatchDiagnostic::UnsupportedMatchedType(
wrapping_info.wrap(ctx.db, TypeId::new(ctx.db, long_ty)).format(ctx.db),
),
}),
),
}
}
#[derive(Clone, Default)]
struct VariantInfo<'a, 'db> {
filter: FilteredPatterns,
inner_patterns: Vec<PatternOption<'a, 'db>>,
inner_var_location: Option<LocationId<'db>>,
}
fn create_node_for_enum<'db>(
params: CreateNodeParams<'db, '_, '_>,
input_var: FlowControlVar,
concrete_enum_id: ConcreteEnumId<'db>,
wrapping_info: PatternWrappingInfo,
) -> NodeId {
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location } = params;
let concrete_variants = ctx.db.concrete_enum_variants(concrete_enum_id).unwrap();
let mut variants = vec![VariantInfo::default(); concrete_variants.len()];
for (idx, pattern) in patterns.iter().enumerate() {
match pattern {
Some(semantic::Pattern::EnumVariant(PatternEnumVariant {
variant,
inner_pattern: inner_pattern_id,
..
})) => {
let inner_pattern =
inner_pattern_id.map(|inner_pattern| get_pattern(ctx, inner_pattern));
variants[variant.idx].filter.add(idx);
variants[variant.idx].inner_patterns.push(inner_pattern);
if let Some(inner_pattern) = inner_pattern {
variants[variant.idx]
.inner_var_location
.get_or_insert(ctx.get_location(inner_pattern.stable_ptr().untyped()));
}
}
Some(semantic::Pattern::Otherwise(..)) | None => {
for variant_info in variants.iter_mut() {
variant_info.filter.add(idx);
variant_info.inner_patterns.push(None);
}
}
Some(semantic::Pattern::Variable(..)) => unreachable!(),
Some(
pattern @ (semantic::Pattern::StringLiteral(..)
| semantic::Pattern::Literal(..)
| semantic::Pattern::Struct(..)
| semantic::Pattern::Tuple(..)
| semantic::Pattern::FixedSizeArray(..)
| semantic::Pattern::Missing(..)),
) => {
graph.report_with_missing_node(
pattern.stable_ptr().untyped(),
LoweringDiagnosticKind::UnexpectedError,
);
}
}
}
let variants = zip_eq(concrete_variants, variants)
.map(|(concrete_variant, variant_info)| {
let inner_var_location = variant_info.inner_var_location.unwrap_or_else(|| {
graph.var_location(input_var).with_note(
ctx.db,
DiagnosticNote::text_only(format!(
"In variant {:?}.",
concrete_variant.into_debug(ctx.db)
)),
)
});
let inner_var =
graph.new_var(wrapping_info.wrap(ctx.db, concrete_variant.ty), inner_var_location);
let node = create_node_for_patterns(
CreateNodeParams {
ctx,
graph,
patterns: &variant_info.inner_patterns,
build_node_callback: &mut |graph, pattern_indices_inner, path| {
build_node_callback(
graph,
pattern_indices_inner.lift(&variant_info.filter),
format!("{}({path})", concrete_variant.id.name(ctx.db).long(ctx.db)),
)
},
location,
},
inner_var,
);
(concrete_variant, node, inner_var)
})
.collect_vec();
if let Some(first_variant) = variants.first() {
let first_variant_node = first_variant.1;
if variants.iter().all(|(_, node_id, inner_var)| {
*node_id == first_variant_node && !graph.is_var_used(*inner_var)
}) {
return first_variant_node;
}
}
graph.add_node(FlowControlNode::EnumMatch(EnumMatch {
matched_var: input_var,
concrete_enum_id,
variants,
}))
}
fn create_node_for_tuple<'db>(
params: CreateNodeParams<'db, '_, '_>,
input_var: FlowControlVar,
types: &[TypeId<'db>],
wrapping_info: PatternWrappingInfo,
) -> NodeId {
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location } = params;
let inner_vars = types
.iter()
.map(|ty| graph.new_var(wrapping_info.wrap(ctx.db, *ty), location))
.collect_vec();
let node = create_node_for_tuple_inner(
CreateNodeParams {
ctx,
graph,
patterns,
build_node_callback: &mut |graph, pattern_indices, path| {
build_node_callback(graph, pattern_indices, format!("({path})"))
},
location,
},
&inner_vars,
types,
0,
None,
);
graph.add_node(FlowControlNode::Deconstruct(Deconstruct {
input: input_var,
outputs: inner_vars,
next: node,
}))
}
fn create_node_for_struct<'db>(
params: CreateNodeParams<'db, '_, '_>,
input_var: FlowControlVar,
concrete_struct_id: ConcreteStructId<'db>,
wrapping_info: PatternWrappingInfo,
) -> NodeId {
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location } = params;
let members = match ctx.db.concrete_struct_members(concrete_struct_id) {
Ok(members) => members,
Err(diag_added) => return graph.add_node(FlowControlNode::Missing(diag_added)),
};
let types = members.iter().map(|(_, member)| member.ty).collect_vec();
let inner_vars = types
.iter()
.map(|ty| graph.new_var(wrapping_info.wrap(ctx.db, *ty), location))
.collect_vec();
let node = create_node_for_tuple_inner(
CreateNodeParams {
ctx,
graph,
patterns,
build_node_callback: &mut |graph, pattern_indices, path| {
let struct_name = concrete_struct_id.struct_id(ctx.db).name(ctx.db).long(ctx.db);
build_node_callback(graph, pattern_indices, format!("{struct_name}{{{path}}}"))
},
location,
},
&inner_vars,
&types,
0,
Some(&members.iter().map(|(_, member)| member).collect_vec()),
);
graph.add_node(FlowControlNode::Deconstruct(Deconstruct {
input: input_var,
outputs: inner_vars,
next: node,
}))
}
fn create_node_for_tuple_inner<'db>(
params: CreateNodeParams<'db, '_, '_>,
inner_vars: &[FlowControlVar],
types: &[TypeId<'db>],
item_idx: usize,
struct_members: Option<&[&semantic::Member<'db>]>,
) -> NodeId {
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location } = params;
if item_idx == types.len() {
return build_node_callback(graph, FilteredPatterns::all(patterns.len()), "".into());
}
let current_member = struct_members.map(|members| members[item_idx]);
let mut patterns_on_current_item = Vec::<Option<semantic::Pattern<'db>>>::default();
for pattern in patterns {
match pattern {
Some(semantic::Pattern::Tuple(PatternTuple { field_patterns, .. }))
if current_member.is_none() =>
{
patterns_on_current_item
.push(Some(get_pattern(ctx, field_patterns[item_idx]).clone()));
}
Some(semantic::Pattern::Struct(PatternStruct { field_patterns, .. }))
if current_member.is_some() =>
{
let item_pattern = field_patterns
.iter()
.find(|(_, member)| member.id == current_member.unwrap().id)
.map(|(pattern, _)| get_pattern(ctx, *pattern));
patterns_on_current_item.push(item_pattern.cloned());
}
Some(semantic::Pattern::Otherwise(..)) | None => {
patterns_on_current_item.push(None);
}
Some(semantic::Pattern::Variable(..)) => unreachable!(),
Some(semantic::Pattern::Literal(pattern_literal))
if pattern_literal.literal.ty == ctx.db.core_info().u256 =>
{
if let Ok(inner_pattern) =
handle_u256_literal(ctx, graph, pattern_literal, item_idx)
{
patterns_on_current_item.push(Some(inner_pattern))
}
}
Some(
pattern @ (semantic::Pattern::StringLiteral(..)
| semantic::Pattern::EnumVariant(..)
| semantic::Pattern::Literal(..)
| semantic::Pattern::Tuple(..)
| semantic::Pattern::Struct(..)
| semantic::Pattern::FixedSizeArray(..)
| semantic::Pattern::Missing(..)),
) => {
return graph.report_with_missing_node(
pattern.stable_ptr().untyped(),
LoweringDiagnosticKind::UnexpectedError,
);
}
}
}
let patterns_ref: Vec<_> = patterns_on_current_item.iter().map(|x| x.as_ref()).collect();
create_node_for_patterns(
CreateNodeParams {
ctx,
graph,
patterns: &patterns_ref,
build_node_callback: &mut |graph, pattern_indices, path_head| {
create_node_for_tuple_inner(
CreateNodeParams {
ctx,
graph,
patterns: &pattern_indices.indices().map(|idx| patterns[idx]).collect_vec(),
build_node_callback: &mut |graph, pattern_indices_inner, path_tail| {
build_node_callback(
graph,
pattern_indices_inner.lift(&pattern_indices),
add_item_to_path(ctx.db, &path_head, &path_tail, current_member),
)
},
location,
},
inner_vars,
types,
item_idx + 1,
struct_members,
)
},
location,
},
inner_vars[item_idx],
)
}
fn add_item_to_path<'db>(
db: &dyn Database,
item: &String,
path_tail: &String,
current_member: Option<&semantic::Member<'db>>,
) -> String {
let item_str = if let Some(current_member) = current_member {
format!("{}: {}", current_member.id.name(db).long(db), item)
} else {
item.clone()
};
if path_tail.is_empty() { item_str } else { format!("{item_str}, {path_tail}") }
}
fn handle_u256_literal<'db>(
ctx: &LoweringContext<'db, '_>,
graph: &mut FlowControlGraphBuilder<'db>,
pattern_literal: &semantic::PatternLiteral<'db>,
item_idx: usize,
) -> Maybe<semantic::Pattern<'db>> {
let PatternLiteral {
literal: ExprNumericLiteral { value, ty, stable_ptr: expr_stable_ptr },
stable_ptr,
} = pattern_literal;
if let Err(err) = validate_literal(ctx.db, *ty, value) {
return Err(graph.report(*stable_ptr, LoweringDiagnosticKind::LiteralError(err)));
}
let inner_value = if item_idx == 0 {
value.clone() & ((BigInt::from(1) << 128) - 1)
} else if item_idx == 1 {
value.clone() >> 128
} else {
unreachable!("Unexpected number of members for u256.")
};
Ok(semantic::Pattern::Literal(semantic::PatternLiteral {
literal: semantic::ExprNumericLiteral {
value: inner_value,
ty: ctx.db.core_info().u128,
stable_ptr: *expr_stable_ptr,
},
stable_ptr: *stable_ptr,
}))
}
fn create_node_for_value<'db>(
params: CreateNodeParams<'db, '_, '_>,
input_var: FlowControlVar,
) -> NodeId {
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location: _ } = params;
let var_ty = graph.var_ty(input_var);
let mut literals_map = OrderedHashMap::<BigInt, (FilteredPatterns, ExprPtr<'db>)>::default();
let mut otherwise_filter = FilteredPatterns::default();
for (pattern_index, pattern) in patterns.iter().enumerate() {
match pattern {
Some(semantic::Pattern::Literal(semantic::PatternLiteral { literal, .. })) => {
if let Err(err) = validate_literal(ctx.db, var_ty, &literal.value) {
graph.report(literal.stable_ptr, LoweringDiagnosticKind::LiteralError(err));
continue;
}
literals_map
.entry(literal.value.clone())
.or_insert((otherwise_filter.clone(), literal.stable_ptr))
.0
.add(pattern_index);
}
Some(semantic::Pattern::Otherwise(_)) | None => {
otherwise_filter.add(pattern_index);
for (_, (filter, _)) in literals_map.iter_mut() {
filter.add(pattern_index);
}
}
Some(semantic::Pattern::Variable(..)) => unreachable!(),
Some(
pattern @ (semantic::Pattern::StringLiteral(..)
| semantic::Pattern::EnumVariant(..)
| semantic::Pattern::Struct(..)
| semantic::Pattern::Tuple(..)
| semantic::Pattern::FixedSizeArray(..)
| semantic::Pattern::Missing(..)),
) => {
return graph.report_with_missing_node(
pattern.stable_ptr().untyped(),
LoweringDiagnosticKind::UnexpectedError,
);
}
}
}
let info = ctx.db.core_info();
let felt252_ty = info.felt252;
let value_match_size = optimized_value_match_size(ctx, &literals_map, var_ty != felt252_ty);
let convert_to_felt252 = var_ty != felt252_ty && literals_map.len() > value_match_size;
let input_var_felt252 = if convert_to_felt252 {
graph.new_var(felt252_ty, graph.var_location(input_var))
} else {
input_var
};
let mut current_node = build_node_callback(graph, otherwise_filter.clone(), "_".into());
let value_match_size_bigint = BigInt::from(value_match_size);
for (literal, (filter, stable_ptr)) in literals_map.iter().rev() {
if *literal < value_match_size_bigint {
continue;
}
let node_if_literal = build_node_callback(graph, filter.clone(), literal.to_string());
if node_if_literal == current_node {
continue;
}
current_node = graph.add_node(FlowControlNode::EqualsLiteral(EqualsLiteral {
input: input_var_felt252,
literal: literal.clone(),
stable_ptr: *stable_ptr,
true_branch: node_if_literal,
false_branch: current_node,
}));
}
if convert_to_felt252 {
current_node = graph.add_node(FlowControlNode::Upcast(Upcast {
input: input_var,
output: input_var_felt252,
next: current_node,
}));
}
if value_match_size > 0 {
let bounded_int_ty =
corelib::bounded_int_ty(ctx.db, 0.into(), (value_match_size - 1).into());
let in_range_var = graph.new_var(bounded_int_ty, graph.var_location(input_var));
let nodes = (0..value_match_size)
.map(|i| {
build_node_callback(graph, literals_map[&BigInt::from(i)].0.clone(), i.to_string())
})
.collect();
let value_match_node = graph
.add_node(FlowControlNode::ValueMatch(ValueMatch { matched_var: in_range_var, nodes }));
current_node = graph.add_node(FlowControlNode::Downcast(Downcast {
input: input_var,
output: in_range_var,
in_range: value_match_node,
out_of_range: current_node,
}));
}
current_node
}
fn optimized_value_match_size<'db>(
ctx: &LoweringContext<'_, '_>,
values: &OrderedHashMap<BigInt, (FilteredPatterns, ExprPtr<'db>)>,
is_small_type: bool,
) -> usize {
let mut i: usize = 0;
while values.contains_key(&BigInt::from(i)) {
i += 1;
}
let n_arms = i + 1;
if n_arms >= numeric_match_optimization_threshold(ctx, is_small_type) { i } else { 0 }
}
pub fn numeric_match_optimization_threshold<'db>(
ctx: &LoweringContext<'db, '_>,
is_small_type: bool,
) -> usize {
let default_threshold = if is_small_type { 8 } else { 10 };
ctx.db.flag_numeric_match_optimization_min_arms_threshold().unwrap_or(default_threshold)
}
fn pattern_is_any<'a, 'db>(pattern: &'a semantic::Pattern<'db>) -> bool {
match pattern {
semantic::Pattern::Otherwise(..) | semantic::Pattern::Variable(..) => true,
semantic::Pattern::Literal(..)
| semantic::Pattern::StringLiteral(..)
| semantic::Pattern::Struct(..)
| semantic::Pattern::Tuple(..)
| semantic::Pattern::FixedSizeArray(..)
| semantic::Pattern::EnumVariant(..)
| semantic::Pattern::Missing(..) => false,
}
}
pub fn get_pattern<'db, 'a>(
ctx: &'a LoweringContext<'db, '_>,
semantic_pattern: semantic::PatternId,
) -> &'a semantic::Pattern<'db> {
&ctx.function_body.arenas.patterns[semantic_pattern]
}