memscope_rs/estimation/
type_classifier.rs

1use std::collections::HashMap;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum TypeCategory {
5    Primitive,
6    Collection,
7    SmartPointer,
8    UserDefined,
9    System,
10    Async,
11}
12
13pub struct TypeClassifier {
14    category_map: HashMap<String, TypeCategory>,
15}
16
17impl TypeClassifier {
18    pub fn new() -> Self {
19        let mut classifier = Self {
20            category_map: HashMap::new(),
21        };
22        classifier.initialize_categories();
23        classifier
24    }
25
26    fn initialize_categories(&mut self) {
27        let primitives = [
28            "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", "f32", "f64",
29            "bool", "char", "usize", "isize",
30        ];
31
32        for &prim in &primitives {
33            self.category_map
34                .insert(prim.to_string(), TypeCategory::Primitive);
35        }
36
37        let collections = [
38            "Vec", "HashMap", "BTreeMap", "HashSet", "BTreeSet", "VecDeque",
39        ];
40        for &coll in &collections {
41            self.category_map
42                .insert(coll.to_string(), TypeCategory::Collection);
43        }
44
45        let smart_ptrs = ["Box", "Arc", "Rc", "Weak"];
46        for &ptr in &smart_ptrs {
47            self.category_map
48                .insert(ptr.to_string(), TypeCategory::SmartPointer);
49        }
50    }
51
52    pub fn classify(&self, type_name: &str) -> TypeCategory {
53        if let Some(category) = self.category_map.get(type_name) {
54            return category.clone();
55        }
56
57        if type_name.starts_with("std::") {
58            return TypeCategory::System;
59        }
60
61        for (prefix, category) in &self.category_map {
62            if type_name.starts_with(prefix) && type_name.contains('<') {
63                return category.clone();
64            }
65        }
66
67        if type_name.contains("Future") || type_name.contains("Stream") {
68            return TypeCategory::Async;
69        }
70
71        TypeCategory::UserDefined
72    }
73
74    pub fn is_container(&self, type_name: &str) -> bool {
75        matches!(self.classify(type_name), TypeCategory::Collection)
76    }
77
78    pub fn is_smart_pointer(&self, type_name: &str) -> bool {
79        matches!(self.classify(type_name), TypeCategory::SmartPointer)
80    }
81}
82
83impl Default for TypeClassifier {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_primitive_classification() {
95        let classifier = TypeClassifier::new();
96        assert_eq!(classifier.classify("i32"), TypeCategory::Primitive);
97        assert_eq!(classifier.classify("bool"), TypeCategory::Primitive);
98    }
99
100    #[test]
101    fn test_collection_classification() {
102        let classifier = TypeClassifier::new();
103        assert_eq!(classifier.classify("Vec<i32>"), TypeCategory::Collection);
104        assert_eq!(
105            classifier.classify("HashMap<String, i32>"),
106            TypeCategory::Collection
107        );
108    }
109
110    #[test]
111    fn test_smart_pointer_classification() {
112        let classifier = TypeClassifier::new();
113        assert_eq!(classifier.classify("Box<i32>"), TypeCategory::SmartPointer);
114        assert_eq!(
115            classifier.classify("Arc<String>"),
116            TypeCategory::SmartPointer
117        );
118    }
119
120    #[test]
121    fn test_system_classification() {
122        let classifier = TypeClassifier::new();
123        assert_eq!(
124            classifier.classify("std::string::String"),
125            TypeCategory::System
126        );
127        assert_eq!(classifier.classify("std::vec::Vec"), TypeCategory::System);
128    }
129
130    #[test]
131    fn test_helper_methods() {
132        let classifier = TypeClassifier::new();
133        assert!(classifier.is_container("Vec<i32>"));
134        assert!(classifier.is_smart_pointer("Box<String>"));
135        assert!(!classifier.is_container("i32"));
136    }
137}