Skip to main content

debtmap/analyzers/
trait_implementation_tracker.rs

1/// Comprehensive trait implementation tracking for dynamic dispatch resolution
2///
3/// This module extends the trait registry with full support for:
4/// - Generic trait implementations
5/// - Trait object resolution
6/// - Blanket implementations
7/// - Associated types and methods
8use crate::priority::call_graph::FunctionId;
9use im::{HashMap, HashSet, Vector};
10use std::path::PathBuf;
11use syn::visit::Visit;
12use syn::{
13    GenericParam, Generics, ImplItem, Item, ItemImpl, ItemTrait, Path as SynPath, TraitItem, Type,
14    TypeParam, TypePath, WhereClause, WherePredicate,
15};
16
17/// Represents a trait definition with all its details
18#[derive(Debug, Clone)]
19pub struct TraitDefinition {
20    pub name: String,
21    pub methods: Vector<TraitMethod>,
22    pub associated_types: Vector<AssociatedType>,
23    pub supertraits: Vector<String>,
24    pub generic_params: Vector<GenericParameter>,
25    pub module_path: Vector<String>,
26}
27
28/// Represents a trait method
29#[derive(Debug, Clone)]
30pub struct TraitMethod {
31    pub name: String,
32    pub has_default: bool,
33    pub is_async: bool,
34    pub signature: String,
35}
36
37/// Represents an associated type in a trait
38#[derive(Debug, Clone)]
39pub struct AssociatedType {
40    pub name: String,
41    pub bounds: Vector<String>,
42    pub default: Option<String>,
43}
44
45/// Represents a generic parameter on a trait
46#[derive(Debug, Clone)]
47pub struct GenericParameter {
48    pub name: String,
49    pub bounds: Vector<String>,
50}
51
52/// Represents a trait implementation
53#[derive(Debug, Clone)]
54pub struct Implementation {
55    pub trait_name: String,
56    pub implementing_type: String,
57    pub methods: HashMap<String, MethodImpl>,
58    pub generic_constraints: Vector<WhereClauseItem>,
59    pub is_blanket: bool,
60    pub is_negative: bool,
61    pub module_path: Vector<String>,
62}
63
64/// Represents a method implementation
65#[derive(Debug, Clone)]
66pub struct MethodImpl {
67    pub name: String,
68    pub function_id: FunctionId,
69    pub overrides_default: bool,
70}
71
72/// Represents a where clause constraint
73#[derive(Debug, Clone)]
74pub struct WhereClauseItem {
75    pub type_param: String,
76    pub bounds: Vector<String>,
77}
78
79/// Trait object information
80#[derive(Debug, Clone)]
81pub struct TraitObject {
82    pub trait_name: String,
83    pub additional_bounds: Vector<String>,
84    pub lifetime: Option<String>,
85}
86
87/// The main trait implementation tracker
88#[derive(Debug, Clone, Default)]
89pub struct TraitImplementationTracker {
90    /// All trait definitions indexed by name
91    pub traits: HashMap<String, TraitDefinition>,
92    /// All implementations indexed by trait name
93    pub implementations: HashMap<String, Vector<Implementation>>,
94    /// Trait object candidates (types that can be behind trait objects)
95    pub trait_objects: HashMap<String, HashSet<String>>,
96    /// Generic bounds registry
97    pub generic_bounds: HashMap<String, Vector<TraitBound>>,
98    /// Type to trait mapping for quick lookup
99    pub type_to_traits: HashMap<String, HashSet<String>>,
100    /// Blanket implementations
101    pub blanket_impls: Vector<Implementation>,
102    /// Associated type projections
103    pub associated_types: HashMap<(String, String), String>, // (Type, AssocType) -> ResolvedType
104}
105
106/// Represents a trait bound
107#[derive(Debug, Clone)]
108pub struct TraitBound {
109    pub trait_name: String,
110    pub type_params: Vector<String>,
111}
112
113impl TraitImplementationTracker {
114    pub fn new() -> Self {
115        Self::default()
116    }
117
118    /// Register a trait definition
119    pub fn register_trait(&mut self, trait_def: TraitDefinition) {
120        let name = trait_def.name.clone();
121        self.traits.insert(name, trait_def);
122    }
123
124    /// Register a trait implementation
125    pub fn register_implementation(&mut self, implementation: Implementation) {
126        let trait_name = implementation.trait_name.clone();
127        let implementing_type = implementation.implementing_type.clone();
128
129        // Update type to trait mapping
130        self.type_to_traits
131            .entry(implementing_type.clone())
132            .or_default()
133            .insert(trait_name.clone());
134
135        // Track blanket implementations separately
136        if implementation.is_blanket {
137            self.blanket_impls.push_back(implementation.clone());
138        }
139
140        // Add to regular implementations
141        self.implementations
142            .entry(trait_name.clone())
143            .or_default()
144            .push_back(implementation.clone());
145
146        // Track trait object candidates
147        if !implementation.is_negative {
148            self.trait_objects
149                .entry(trait_name)
150                .or_default()
151                .insert(implementing_type);
152        }
153    }
154
155    /// Get all types that implement a trait
156    pub fn get_implementors(&self, trait_name: &str) -> Option<HashSet<String>> {
157        self.trait_objects.get(trait_name).cloned()
158    }
159
160    /// Resolve a trait object method call to concrete implementations
161    pub fn resolve_trait_object_call(
162        &self,
163        trait_name: &str,
164        method_name: &str,
165    ) -> Vector<FunctionId> {
166        let mut implementations = Vector::new();
167
168        // Find all types that implement this trait
169        if let Some(implementors) = self.get_implementors(trait_name) {
170            for impl_type in implementors {
171                if let Some(method_id) = self.resolve_method(&impl_type, trait_name, method_name) {
172                    implementations.push_back(method_id);
173                }
174            }
175        }
176
177        implementations
178    }
179
180    /// Resolve a method call on a specific type for a specific trait
181    pub fn resolve_method(
182        &self,
183        type_name: &str,
184        trait_name: &str,
185        method_name: &str,
186    ) -> Option<FunctionId> {
187        self.implementations
188            .get(trait_name)?
189            .iter()
190            .find(|impl_info| impl_info.implementing_type == type_name)
191            .and_then(|impl_info| impl_info.methods.get(method_name))
192            .map(|method| method.function_id.clone())
193    }
194
195    /// Resolve generic constraint to possible implementations
196    pub fn resolve_generic_bound(
197        &self,
198        bound: &TraitBound,
199        method_name: &str,
200    ) -> Vector<FunctionId> {
201        let mut implementations = Vector::new();
202
203        // Find all types satisfying the bound
204        if let Some(impls) = self.implementations.get(&bound.trait_name) {
205            for impl_info in impls {
206                // Check if this implementation satisfies the bound
207                // This is simplified - real implementation would need constraint checking
208                if let Some(method) = impl_info.methods.get(method_name) {
209                    implementations.push_back(method.function_id.clone());
210                }
211            }
212        }
213
214        // Check blanket implementations
215        for blanket in &self.blanket_impls {
216            if blanket.trait_name == bound.trait_name {
217                if let Some(method) = blanket.methods.get(method_name) {
218                    implementations.push_back(method.function_id.clone());
219                }
220            }
221        }
222
223        implementations
224    }
225
226    /// Check if a type implements a trait
227    pub fn implements_trait(&self, type_name: &str, trait_name: &str) -> bool {
228        self.type_to_traits
229            .get(type_name)
230            .map(|traits| traits.contains(trait_name))
231            .unwrap_or(false)
232    }
233
234    /// Get all traits implemented by a type
235    pub fn get_traits_for_type(&self, type_name: &str) -> Option<&HashSet<String>> {
236        self.type_to_traits.get(type_name)
237    }
238
239    /// Resolve associated type projection
240    pub fn resolve_associated_type(&self, type_name: &str, assoc_type: &str) -> Option<String> {
241        self.associated_types
242            .get(&(type_name.to_string(), assoc_type.to_string()))
243            .cloned()
244    }
245
246    /// Register an associated type projection
247    pub fn register_associated_type(
248        &mut self,
249        type_name: String,
250        assoc_type: String,
251        resolved_type: String,
252    ) {
253        self.associated_types
254            .insert((type_name, assoc_type), resolved_type);
255    }
256
257    /// Check if an implementation is a blanket implementation
258    pub fn is_blanket_impl(&self, implementation: &Implementation) -> bool {
259        // Check if the implementing type contains generic parameters
260        implementation.implementing_type.contains('<')
261            || !implementation.generic_constraints.is_empty()
262    }
263
264    /// Get trait definition by name
265    pub fn get_trait(&self, name: &str) -> Option<&TraitDefinition> {
266        self.traits.get(name)
267    }
268
269    /// Get all blanket implementations
270    pub fn get_blanket_impls(&self) -> &Vector<Implementation> {
271        &self.blanket_impls
272    }
273
274    /// Check if a method exists in a trait
275    pub fn trait_has_method(&self, trait_name: &str, method_name: &str) -> bool {
276        self.traits
277            .get(trait_name)
278            .map(|trait_def| {
279                trait_def
280                    .methods
281                    .iter()
282                    .any(|method| method.name == method_name)
283            })
284            .unwrap_or(false)
285    }
286}
287
288/// AST visitor for extracting trait definitions and implementations
289pub struct TraitExtractor {
290    file_path: PathBuf,
291    module_path: Vec<String>,
292    tracker: TraitImplementationTracker,
293}
294
295impl TraitExtractor {
296    pub fn new(file_path: PathBuf) -> Self {
297        Self {
298            file_path,
299            module_path: Vec::new(),
300            tracker: TraitImplementationTracker::new(),
301        }
302    }
303
304    /// Extract trait information from a file
305    pub fn extract(&mut self, file: &syn::File) -> TraitImplementationTracker {
306        self.visit_file(file);
307        self.tracker.clone()
308    }
309
310    fn extract_trait_definition(&self, item_trait: &ItemTrait) -> TraitDefinition {
311        let mut methods = Vector::new();
312        let mut associated_types = Vector::new();
313
314        for trait_item in &item_trait.items {
315            match trait_item {
316                TraitItem::Fn(method) => {
317                    methods.push_back(TraitMethod {
318                        name: method.sig.ident.to_string(),
319                        has_default: method.default.is_some(),
320                        is_async: method.sig.asyncness.is_some(),
321                        signature: format!("{}", quote::quote! { #method }),
322                    });
323                }
324                TraitItem::Type(assoc_type) => {
325                    let bounds = assoc_type
326                        .bounds
327                        .iter()
328                        .map(|b| format!("{}", quote::quote! { #b }))
329                        .collect();
330                    let default = assoc_type
331                        .default
332                        .as_ref()
333                        .map(|(_, ty)| format!("{}", quote::quote! { #ty }));
334
335                    associated_types.push_back(AssociatedType {
336                        name: assoc_type.ident.to_string(),
337                        bounds,
338                        default,
339                    });
340                }
341                _ => {}
342            }
343        }
344
345        let generic_params = self.extract_generic_params(&item_trait.generics);
346        let supertraits = self.extract_supertraits(&item_trait.supertraits);
347
348        TraitDefinition {
349            name: item_trait.ident.to_string(),
350            methods,
351            associated_types,
352            supertraits,
353            generic_params,
354            module_path: self.module_path.clone().into(),
355        }
356    }
357
358    fn extract_generic_params(&self, generics: &Generics) -> Vector<GenericParameter> {
359        generics
360            .params
361            .iter()
362            .filter_map(|param| match param {
363                GenericParam::Type(type_param) => Some(self.extract_type_param(type_param)),
364                _ => None,
365            })
366            .collect()
367    }
368
369    fn extract_type_param(&self, type_param: &TypeParam) -> GenericParameter {
370        let bounds = type_param
371            .bounds
372            .iter()
373            .map(|b| format!("{}", quote::quote! { #b }))
374            .collect();
375
376        GenericParameter {
377            name: type_param.ident.to_string(),
378            bounds,
379        }
380    }
381
382    fn extract_supertraits(
383        &self,
384        supertraits: &syn::punctuated::Punctuated<syn::TypeParamBound, syn::token::Plus>,
385    ) -> Vector<String> {
386        supertraits
387            .iter()
388            .filter_map(|bound| match bound {
389                syn::TypeParamBound::Trait(trait_bound) => {
390                    Some(self.path_to_string(&trait_bound.path))
391                }
392                _ => None,
393            })
394            .collect()
395    }
396
397    fn extract_implementation(&mut self, item_impl: &ItemImpl) -> Option<Implementation> {
398        let (_, trait_path, _) = item_impl.trait_.as_ref()?;
399        let trait_name = self.path_to_string(trait_path);
400        let implementing_type = self.type_to_string(&item_impl.self_ty);
401
402        let mut methods = HashMap::new();
403        for impl_item in &item_impl.items {
404            if let ImplItem::Fn(method) = impl_item {
405                let method_name = method.sig.ident.to_string();
406                let line = method.sig.ident.span().start().line;
407
408                let method_impl = MethodImpl {
409                    name: method_name.clone(),
410                    function_id: FunctionId::new(
411                        self.file_path.clone(),
412                        format!("{}::{}", implementing_type, method_name),
413                        line,
414                    ),
415                    overrides_default: false, // Would need trait definition to determine
416                };
417
418                methods.insert(method_name, method_impl);
419            }
420        }
421
422        let generic_constraints =
423            self.extract_where_clause(item_impl.generics.where_clause.as_ref());
424        let is_blanket = self.is_blanket_implementation(item_impl);
425        let is_negative = false; // Negative implementations are not directly supported in stable Rust
426
427        Some(Implementation {
428            trait_name,
429            implementing_type,
430            methods,
431            generic_constraints,
432            is_blanket,
433            is_negative,
434            module_path: self.module_path.clone().into(),
435        })
436    }
437
438    fn extract_where_clause(&self, where_clause: Option<&WhereClause>) -> Vector<WhereClauseItem> {
439        where_clause
440            .map(|wc| {
441                wc.predicates
442                    .iter()
443                    .filter_map(|pred| match pred {
444                        WherePredicate::Type(type_pred) => {
445                            let type_param = self.type_to_string(&type_pred.bounded_ty);
446                            let bounds = type_pred
447                                .bounds
448                                .iter()
449                                .map(|b| format!("{}", quote::quote! { #b }))
450                                .collect();
451                            Some(WhereClauseItem { type_param, bounds })
452                        }
453                        _ => None,
454                    })
455                    .collect()
456            })
457            .unwrap_or_default()
458    }
459
460    fn is_blanket_implementation(&self, item_impl: &ItemImpl) -> bool {
461        // Check if implementing type is generic
462        matches!(&*item_impl.self_ty, Type::Path(TypePath { path, .. }) if path.segments.iter().any(|seg| !seg.arguments.is_empty()))
463            || !item_impl.generics.params.is_empty()
464    }
465
466    fn type_to_string(&self, ty: &Type) -> String {
467        format!("{}", quote::quote! { #ty })
468            .replace(" ", "")
469            .replace(",", ", ")
470    }
471
472    fn path_to_string(&self, path: &SynPath) -> String {
473        path.segments
474            .iter()
475            .map(|seg| seg.ident.to_string())
476            .collect::<Vec<_>>()
477            .join("::")
478    }
479}
480
481impl<'ast> Visit<'ast> for TraitExtractor {
482    fn visit_item(&mut self, item: &'ast Item) {
483        match item {
484            Item::Trait(item_trait) => {
485                let trait_def = self.extract_trait_definition(item_trait);
486                self.tracker.register_trait(trait_def);
487            }
488            Item::Impl(item_impl) => {
489                if let Some(implementation) = self.extract_implementation(item_impl) {
490                    self.tracker.register_implementation(implementation);
491                }
492            }
493            Item::Mod(item_mod) => {
494                self.module_path.push(item_mod.ident.to_string());
495            }
496            _ => {}
497        }
498
499        syn::visit::visit_item(self, item);
500
501        // Pop module path after visiting
502        if matches!(item, Item::Mod(_)) {
503            self.module_path.pop();
504        }
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_trait_implementation_tracker_new() {
514        let tracker = TraitImplementationTracker::new();
515        assert!(tracker.traits.is_empty());
516        assert!(tracker.implementations.is_empty());
517    }
518
519    #[test]
520    fn test_register_trait() {
521        let mut tracker = TraitImplementationTracker::new();
522        let trait_def = TraitDefinition {
523            name: "TestTrait".to_string(),
524            methods: Vector::new(),
525            associated_types: Vector::new(),
526            supertraits: Vector::new(),
527            generic_params: Vector::new(),
528            module_path: Vector::new(),
529        };
530
531        tracker.register_trait(trait_def);
532        assert!(tracker.get_trait("TestTrait").is_some());
533    }
534
535    #[test]
536    fn test_implements_trait() {
537        let mut tracker = TraitImplementationTracker::new();
538        let implementation = Implementation {
539            trait_name: "Display".to_string(),
540            implementing_type: "MyType".to_string(),
541            methods: HashMap::new(),
542            generic_constraints: Vector::new(),
543            is_blanket: false,
544            is_negative: false,
545            module_path: Vector::new(),
546        };
547
548        tracker.register_implementation(implementation);
549        assert!(tracker.implements_trait("MyType", "Display"));
550        assert!(!tracker.implements_trait("MyType", "Debug"));
551    }
552}