1use crate::priority::call_graph::FunctionId;
9use im::{HashMap, HashSet, Vector};
10use std::path::PathBuf;
11use syn::visit::Visit;
12use syn::{
13 GenericParam, Generics, ImplItem, Item, ItemImpl, ItemTrait, Path as SynPath, TraitItem, Type,
14 TypeParam, TypePath, WhereClause, WherePredicate,
15};
16
17#[derive(Debug, Clone)]
19pub struct TraitDefinition {
20 pub name: String,
21 pub methods: Vector<TraitMethod>,
22 pub associated_types: Vector<AssociatedType>,
23 pub supertraits: Vector<String>,
24 pub generic_params: Vector<GenericParameter>,
25 pub module_path: Vector<String>,
26}
27
28#[derive(Debug, Clone)]
30pub struct TraitMethod {
31 pub name: String,
32 pub has_default: bool,
33 pub is_async: bool,
34 pub signature: String,
35}
36
37#[derive(Debug, Clone)]
39pub struct AssociatedType {
40 pub name: String,
41 pub bounds: Vector<String>,
42 pub default: Option<String>,
43}
44
45#[derive(Debug, Clone)]
47pub struct GenericParameter {
48 pub name: String,
49 pub bounds: Vector<String>,
50}
51
52#[derive(Debug, Clone)]
54pub struct Implementation {
55 pub trait_name: String,
56 pub implementing_type: String,
57 pub methods: HashMap<String, MethodImpl>,
58 pub generic_constraints: Vector<WhereClauseItem>,
59 pub is_blanket: bool,
60 pub is_negative: bool,
61 pub module_path: Vector<String>,
62}
63
64#[derive(Debug, Clone)]
66pub struct MethodImpl {
67 pub name: String,
68 pub function_id: FunctionId,
69 pub overrides_default: bool,
70}
71
72#[derive(Debug, Clone)]
74pub struct WhereClauseItem {
75 pub type_param: String,
76 pub bounds: Vector<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TraitObject {
82 pub trait_name: String,
83 pub additional_bounds: Vector<String>,
84 pub lifetime: Option<String>,
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct TraitImplementationTracker {
90 pub traits: HashMap<String, TraitDefinition>,
92 pub implementations: HashMap<String, Vector<Implementation>>,
94 pub trait_objects: HashMap<String, HashSet<String>>,
96 pub generic_bounds: HashMap<String, Vector<TraitBound>>,
98 pub type_to_traits: HashMap<String, HashSet<String>>,
100 pub blanket_impls: Vector<Implementation>,
102 pub associated_types: HashMap<(String, String), String>, }
105
106#[derive(Debug, Clone)]
108pub struct TraitBound {
109 pub trait_name: String,
110 pub type_params: Vector<String>,
111}
112
113impl TraitImplementationTracker {
114 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn register_trait(&mut self, trait_def: TraitDefinition) {
120 let name = trait_def.name.clone();
121 self.traits.insert(name, trait_def);
122 }
123
124 pub fn register_implementation(&mut self, implementation: Implementation) {
126 let trait_name = implementation.trait_name.clone();
127 let implementing_type = implementation.implementing_type.clone();
128
129 self.type_to_traits
131 .entry(implementing_type.clone())
132 .or_default()
133 .insert(trait_name.clone());
134
135 if implementation.is_blanket {
137 self.blanket_impls.push_back(implementation.clone());
138 }
139
140 self.implementations
142 .entry(trait_name.clone())
143 .or_default()
144 .push_back(implementation.clone());
145
146 if !implementation.is_negative {
148 self.trait_objects
149 .entry(trait_name)
150 .or_default()
151 .insert(implementing_type);
152 }
153 }
154
155 pub fn get_implementors(&self, trait_name: &str) -> Option<HashSet<String>> {
157 self.trait_objects.get(trait_name).cloned()
158 }
159
160 pub fn resolve_trait_object_call(
162 &self,
163 trait_name: &str,
164 method_name: &str,
165 ) -> Vector<FunctionId> {
166 let mut implementations = Vector::new();
167
168 if let Some(implementors) = self.get_implementors(trait_name) {
170 for impl_type in implementors {
171 if let Some(method_id) = self.resolve_method(&impl_type, trait_name, method_name) {
172 implementations.push_back(method_id);
173 }
174 }
175 }
176
177 implementations
178 }
179
180 pub fn resolve_method(
182 &self,
183 type_name: &str,
184 trait_name: &str,
185 method_name: &str,
186 ) -> Option<FunctionId> {
187 self.implementations
188 .get(trait_name)?
189 .iter()
190 .find(|impl_info| impl_info.implementing_type == type_name)
191 .and_then(|impl_info| impl_info.methods.get(method_name))
192 .map(|method| method.function_id.clone())
193 }
194
195 pub fn resolve_generic_bound(
197 &self,
198 bound: &TraitBound,
199 method_name: &str,
200 ) -> Vector<FunctionId> {
201 let mut implementations = Vector::new();
202
203 if let Some(impls) = self.implementations.get(&bound.trait_name) {
205 for impl_info in impls {
206 if let Some(method) = impl_info.methods.get(method_name) {
209 implementations.push_back(method.function_id.clone());
210 }
211 }
212 }
213
214 for blanket in &self.blanket_impls {
216 if blanket.trait_name == bound.trait_name {
217 if let Some(method) = blanket.methods.get(method_name) {
218 implementations.push_back(method.function_id.clone());
219 }
220 }
221 }
222
223 implementations
224 }
225
226 pub fn implements_trait(&self, type_name: &str, trait_name: &str) -> bool {
228 self.type_to_traits
229 .get(type_name)
230 .map(|traits| traits.contains(trait_name))
231 .unwrap_or(false)
232 }
233
234 pub fn get_traits_for_type(&self, type_name: &str) -> Option<&HashSet<String>> {
236 self.type_to_traits.get(type_name)
237 }
238
239 pub fn resolve_associated_type(&self, type_name: &str, assoc_type: &str) -> Option<String> {
241 self.associated_types
242 .get(&(type_name.to_string(), assoc_type.to_string()))
243 .cloned()
244 }
245
246 pub fn register_associated_type(
248 &mut self,
249 type_name: String,
250 assoc_type: String,
251 resolved_type: String,
252 ) {
253 self.associated_types
254 .insert((type_name, assoc_type), resolved_type);
255 }
256
257 pub fn is_blanket_impl(&self, implementation: &Implementation) -> bool {
259 implementation.implementing_type.contains('<')
261 || !implementation.generic_constraints.is_empty()
262 }
263
264 pub fn get_trait(&self, name: &str) -> Option<&TraitDefinition> {
266 self.traits.get(name)
267 }
268
269 pub fn get_blanket_impls(&self) -> &Vector<Implementation> {
271 &self.blanket_impls
272 }
273
274 pub fn trait_has_method(&self, trait_name: &str, method_name: &str) -> bool {
276 self.traits
277 .get(trait_name)
278 .map(|trait_def| {
279 trait_def
280 .methods
281 .iter()
282 .any(|method| method.name == method_name)
283 })
284 .unwrap_or(false)
285 }
286}
287
288pub struct TraitExtractor {
290 file_path: PathBuf,
291 module_path: Vec<String>,
292 tracker: TraitImplementationTracker,
293}
294
295impl TraitExtractor {
296 pub fn new(file_path: PathBuf) -> Self {
297 Self {
298 file_path,
299 module_path: Vec::new(),
300 tracker: TraitImplementationTracker::new(),
301 }
302 }
303
304 pub fn extract(&mut self, file: &syn::File) -> TraitImplementationTracker {
306 self.visit_file(file);
307 self.tracker.clone()
308 }
309
310 fn extract_trait_definition(&self, item_trait: &ItemTrait) -> TraitDefinition {
311 let mut methods = Vector::new();
312 let mut associated_types = Vector::new();
313
314 for trait_item in &item_trait.items {
315 match trait_item {
316 TraitItem::Fn(method) => {
317 methods.push_back(TraitMethod {
318 name: method.sig.ident.to_string(),
319 has_default: method.default.is_some(),
320 is_async: method.sig.asyncness.is_some(),
321 signature: format!("{}", quote::quote! { #method }),
322 });
323 }
324 TraitItem::Type(assoc_type) => {
325 let bounds = assoc_type
326 .bounds
327 .iter()
328 .map(|b| format!("{}", quote::quote! { #b }))
329 .collect();
330 let default = assoc_type
331 .default
332 .as_ref()
333 .map(|(_, ty)| format!("{}", quote::quote! { #ty }));
334
335 associated_types.push_back(AssociatedType {
336 name: assoc_type.ident.to_string(),
337 bounds,
338 default,
339 });
340 }
341 _ => {}
342 }
343 }
344
345 let generic_params = self.extract_generic_params(&item_trait.generics);
346 let supertraits = self.extract_supertraits(&item_trait.supertraits);
347
348 TraitDefinition {
349 name: item_trait.ident.to_string(),
350 methods,
351 associated_types,
352 supertraits,
353 generic_params,
354 module_path: self.module_path.clone().into(),
355 }
356 }
357
358 fn extract_generic_params(&self, generics: &Generics) -> Vector<GenericParameter> {
359 generics
360 .params
361 .iter()
362 .filter_map(|param| match param {
363 GenericParam::Type(type_param) => Some(self.extract_type_param(type_param)),
364 _ => None,
365 })
366 .collect()
367 }
368
369 fn extract_type_param(&self, type_param: &TypeParam) -> GenericParameter {
370 let bounds = type_param
371 .bounds
372 .iter()
373 .map(|b| format!("{}", quote::quote! { #b }))
374 .collect();
375
376 GenericParameter {
377 name: type_param.ident.to_string(),
378 bounds,
379 }
380 }
381
382 fn extract_supertraits(
383 &self,
384 supertraits: &syn::punctuated::Punctuated<syn::TypeParamBound, syn::token::Plus>,
385 ) -> Vector<String> {
386 supertraits
387 .iter()
388 .filter_map(|bound| match bound {
389 syn::TypeParamBound::Trait(trait_bound) => {
390 Some(self.path_to_string(&trait_bound.path))
391 }
392 _ => None,
393 })
394 .collect()
395 }
396
397 fn extract_implementation(&mut self, item_impl: &ItemImpl) -> Option<Implementation> {
398 let (_, trait_path, _) = item_impl.trait_.as_ref()?;
399 let trait_name = self.path_to_string(trait_path);
400 let implementing_type = self.type_to_string(&item_impl.self_ty);
401
402 let mut methods = HashMap::new();
403 for impl_item in &item_impl.items {
404 if let ImplItem::Fn(method) = impl_item {
405 let method_name = method.sig.ident.to_string();
406 let line = method.sig.ident.span().start().line;
407
408 let method_impl = MethodImpl {
409 name: method_name.clone(),
410 function_id: FunctionId::new(
411 self.file_path.clone(),
412 format!("{}::{}", implementing_type, method_name),
413 line,
414 ),
415 overrides_default: false, };
417
418 methods.insert(method_name, method_impl);
419 }
420 }
421
422 let generic_constraints =
423 self.extract_where_clause(item_impl.generics.where_clause.as_ref());
424 let is_blanket = self.is_blanket_implementation(item_impl);
425 let is_negative = false; Some(Implementation {
428 trait_name,
429 implementing_type,
430 methods,
431 generic_constraints,
432 is_blanket,
433 is_negative,
434 module_path: self.module_path.clone().into(),
435 })
436 }
437
438 fn extract_where_clause(&self, where_clause: Option<&WhereClause>) -> Vector<WhereClauseItem> {
439 where_clause
440 .map(|wc| {
441 wc.predicates
442 .iter()
443 .filter_map(|pred| match pred {
444 WherePredicate::Type(type_pred) => {
445 let type_param = self.type_to_string(&type_pred.bounded_ty);
446 let bounds = type_pred
447 .bounds
448 .iter()
449 .map(|b| format!("{}", quote::quote! { #b }))
450 .collect();
451 Some(WhereClauseItem { type_param, bounds })
452 }
453 _ => None,
454 })
455 .collect()
456 })
457 .unwrap_or_default()
458 }
459
460 fn is_blanket_implementation(&self, item_impl: &ItemImpl) -> bool {
461 matches!(&*item_impl.self_ty, Type::Path(TypePath { path, .. }) if path.segments.iter().any(|seg| !seg.arguments.is_empty()))
463 || !item_impl.generics.params.is_empty()
464 }
465
466 fn type_to_string(&self, ty: &Type) -> String {
467 format!("{}", quote::quote! { #ty })
468 .replace(" ", "")
469 .replace(",", ", ")
470 }
471
472 fn path_to_string(&self, path: &SynPath) -> String {
473 path.segments
474 .iter()
475 .map(|seg| seg.ident.to_string())
476 .collect::<Vec<_>>()
477 .join("::")
478 }
479}
480
481impl<'ast> Visit<'ast> for TraitExtractor {
482 fn visit_item(&mut self, item: &'ast Item) {
483 match item {
484 Item::Trait(item_trait) => {
485 let trait_def = self.extract_trait_definition(item_trait);
486 self.tracker.register_trait(trait_def);
487 }
488 Item::Impl(item_impl) => {
489 if let Some(implementation) = self.extract_implementation(item_impl) {
490 self.tracker.register_implementation(implementation);
491 }
492 }
493 Item::Mod(item_mod) => {
494 self.module_path.push(item_mod.ident.to_string());
495 }
496 _ => {}
497 }
498
499 syn::visit::visit_item(self, item);
500
501 if matches!(item, Item::Mod(_)) {
503 self.module_path.pop();
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_trait_implementation_tracker_new() {
514 let tracker = TraitImplementationTracker::new();
515 assert!(tracker.traits.is_empty());
516 assert!(tracker.implementations.is_empty());
517 }
518
519 #[test]
520 fn test_register_trait() {
521 let mut tracker = TraitImplementationTracker::new();
522 let trait_def = TraitDefinition {
523 name: "TestTrait".to_string(),
524 methods: Vector::new(),
525 associated_types: Vector::new(),
526 supertraits: Vector::new(),
527 generic_params: Vector::new(),
528 module_path: Vector::new(),
529 };
530
531 tracker.register_trait(trait_def);
532 assert!(tracker.get_trait("TestTrait").is_some());
533 }
534
535 #[test]
536 fn test_implements_trait() {
537 let mut tracker = TraitImplementationTracker::new();
538 let implementation = Implementation {
539 trait_name: "Display".to_string(),
540 implementing_type: "MyType".to_string(),
541 methods: HashMap::new(),
542 generic_constraints: Vector::new(),
543 is_blanket: false,
544 is_negative: false,
545 module_path: Vector::new(),
546 };
547
548 tracker.register_implementation(implementation);
549 assert!(tracker.implements_trait("MyType", "Display"));
550 assert!(!tracker.implements_trait("MyType", "Debug"));
551 }
552}