memscope_rs/estimation/
type_classifier.rs1use std::collections::HashMap;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum TypeCategory {
5 Primitive,
6 Collection,
7 SmartPointer,
8 UserDefined,
9 System,
10 Async,
11}
12
13pub struct TypeClassifier {
14 category_map: HashMap<String, TypeCategory>,
15}
16
17impl TypeClassifier {
18 pub fn new() -> Self {
19 let mut classifier = Self {
20 category_map: HashMap::new(),
21 };
22 classifier.initialize_categories();
23 classifier
24 }
25
26 fn initialize_categories(&mut self) {
27 let primitives = [
28 "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", "f32", "f64",
29 "bool", "char", "usize", "isize",
30 ];
31
32 for &prim in &primitives {
33 self.category_map
34 .insert(prim.to_string(), TypeCategory::Primitive);
35 }
36
37 let collections = [
38 "Vec", "HashMap", "BTreeMap", "HashSet", "BTreeSet", "VecDeque",
39 ];
40 for &coll in &collections {
41 self.category_map
42 .insert(coll.to_string(), TypeCategory::Collection);
43 }
44
45 let smart_ptrs = ["Box", "Arc", "Rc", "Weak"];
46 for &ptr in &smart_ptrs {
47 self.category_map
48 .insert(ptr.to_string(), TypeCategory::SmartPointer);
49 }
50 }
51
52 pub fn classify(&self, type_name: &str) -> TypeCategory {
53 if let Some(category) = self.category_map.get(type_name) {
54 return category.clone();
55 }
56
57 if type_name.starts_with("std::") {
58 return TypeCategory::System;
59 }
60
61 for (prefix, category) in &self.category_map {
62 if type_name.starts_with(prefix) && type_name.contains('<') {
63 return category.clone();
64 }
65 }
66
67 if type_name.contains("Future") || type_name.contains("Stream") {
68 return TypeCategory::Async;
69 }
70
71 TypeCategory::UserDefined
72 }
73
74 pub fn is_container(&self, type_name: &str) -> bool {
75 matches!(self.classify(type_name), TypeCategory::Collection)
76 }
77
78 pub fn is_smart_pointer(&self, type_name: &str) -> bool {
79 matches!(self.classify(type_name), TypeCategory::SmartPointer)
80 }
81}
82
83impl Default for TypeClassifier {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_primitive_classification() {
95 let classifier = TypeClassifier::new();
96 assert_eq!(classifier.classify("i32"), TypeCategory::Primitive);
97 assert_eq!(classifier.classify("bool"), TypeCategory::Primitive);
98 }
99
100 #[test]
101 fn test_collection_classification() {
102 let classifier = TypeClassifier::new();
103 assert_eq!(classifier.classify("Vec<i32>"), TypeCategory::Collection);
104 assert_eq!(
105 classifier.classify("HashMap<String, i32>"),
106 TypeCategory::Collection
107 );
108 }
109
110 #[test]
111 fn test_smart_pointer_classification() {
112 let classifier = TypeClassifier::new();
113 assert_eq!(classifier.classify("Box<i32>"), TypeCategory::SmartPointer);
114 assert_eq!(
115 classifier.classify("Arc<String>"),
116 TypeCategory::SmartPointer
117 );
118 }
119
120 #[test]
121 fn test_system_classification() {
122 let classifier = TypeClassifier::new();
123 assert_eq!(
124 classifier.classify("std::string::String"),
125 TypeCategory::System
126 );
127 assert_eq!(classifier.classify("std::vec::Vec"), TypeCategory::System);
128 }
129
130 #[test]
131 fn test_helper_methods() {
132 let classifier = TypeClassifier::new();
133 assert!(classifier.is_container("Vec<i32>"));
134 assert!(classifier.is_smart_pointer("Box<String>"));
135 assert!(!classifier.is_container("i32"));
136 }
137}