Skip to main content

lisette_semantics/
call_classification.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2
3use syntax::ast::Expression;
4use syntax::program::{DefinitionBody, Module};
5use syntax::types::{Symbol, Type, type_args_match_params};
6
7pub fn is_ufcs_method_type(method_ty: &Type, base_generics_count: usize) -> bool {
8    let Type::Forall { vars, body } = method_ty else {
9        return base_generics_count > 0;
10    };
11
12    if vars.len() > base_generics_count {
13        return true;
14    }
15
16    if let Type::Function(f) = body.as_ref()
17        && let Some(receiver_param) = f.params.first()
18        && let Type::Nominal {
19            params: receiver_params,
20            ..
21        } = receiver_param.strip_refs()
22        && !type_args_match_params(&receiver_params, vars.iter())
23    {
24        return true;
25    }
26
27    false
28}
29
30/// Compute UFCS methods for a single module's types.
31///
32/// Three conditions (any one suffices):
33/// 1. Extra type params: method's Forall vars exceed base type's generics count
34/// 2. Partial receiver: receiver is not the impl's own type parameters in order
35/// 3. Mixed impl blocks: type has both bounded and unbounded impl blocks
36pub fn compute_module_ufcs(module: &Module, module_id: &str) -> Vec<(String, String)> {
37    let mut ufcs = Vec::new();
38
39    // Conditions 1+2: check each method's type signature
40    for (key, definition) in &module.definitions {
41        let (methods, base_generics_count) = match &definition.body {
42            DefinitionBody::Struct {
43                methods, generics, ..
44            } => (methods, generics.len()),
45            DefinitionBody::Enum {
46                methods, generics, ..
47            } => (methods, generics.len()),
48            DefinitionBody::TypeAlias {
49                methods, generics, ..
50            } => (methods, generics.len()),
51            _ => continue,
52        };
53
54        for (method_name, method_ty) in methods {
55            if is_ufcs_method_type(method_ty, base_generics_count) {
56                ufcs.push((key.to_string(), method_name.to_string()));
57            }
58        }
59    }
60
61    // Condition 3: mixed constrained/unconstrained impl blocks
62    let mut constrained_methods: HashMap<String, Vec<String>> = HashMap::default();
63    let mut unconstrained_types: HashSet<String> = HashSet::default();
64
65    for file in module.files.values() {
66        for item in &file.items {
67            if let Expression::ImplBlock {
68                receiver_name,
69                generics,
70                methods,
71                ..
72            } = item
73            {
74                let qualified_type = Symbol::from_parts(module_id, receiver_name).to_string();
75                if generics.iter().any(|g| !g.bounds.is_empty()) {
76                    let method_names: Vec<String> = methods
77                        .iter()
78                        .filter_map(|m| {
79                            if let Expression::Function { name, .. } = m {
80                                Some(name.to_string())
81                            } else {
82                                None
83                            }
84                        })
85                        .collect();
86                    constrained_methods
87                        .entry(qualified_type)
88                        .or_default()
89                        .extend(method_names);
90                } else {
91                    unconstrained_types.insert(qualified_type);
92                }
93            }
94        }
95    }
96
97    for (type_name, methods) in constrained_methods {
98        if unconstrained_types.contains(&type_name) {
99            for method_name in methods {
100                ufcs.push((type_name.clone(), method_name));
101            }
102        }
103    }
104
105    ufcs
106}