memscope_rs/classification/
type_classifier.rs

1use regex::Regex;
2use std::collections::HashMap;
3use std::sync::OnceLock;
4
5/// Type categories for memory analysis
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum TypeCategory {
8    Primitive,
9    Collection,
10    SmartPointer,
11    UserDefined,
12    System,
13    Async,
14    String,
15    Option,
16    Result,
17    Tuple,
18    Array,
19    Custom(String),
20}
21
22impl TypeCategory {
23    /// Get display name for the category
24    pub fn display_name(&self) -> &str {
25        match self {
26            TypeCategory::Primitive => "Primitive Types",
27            TypeCategory::Collection => "Collections",
28            TypeCategory::SmartPointer => "Smart Pointers",
29            TypeCategory::UserDefined => "User Defined",
30            TypeCategory::System => "System Types",
31            TypeCategory::Async => "Async Types",
32            TypeCategory::String => "String Types",
33            TypeCategory::Option => "Option Types",
34            TypeCategory::Result => "Result Types",
35            TypeCategory::Tuple => "Tuples",
36            TypeCategory::Array => "Arrays",
37            TypeCategory::Custom(name) => name,
38        }
39    }
40
41    /// Get CSS class for visualization
42    pub fn css_class(&self) -> &str {
43        match self {
44            TypeCategory::Primitive => "type-primitive",
45            TypeCategory::Collection => "type-collection",
46            TypeCategory::SmartPointer => "type-smart-pointer",
47            TypeCategory::UserDefined => "type-user-defined",
48            TypeCategory::System => "type-system",
49            TypeCategory::Async => "type-async",
50            TypeCategory::String => "type-string",
51            TypeCategory::Option => "type-option",
52            TypeCategory::Result => "type-result",
53            TypeCategory::Tuple => "type-tuple",
54            TypeCategory::Array => "type-array",
55            TypeCategory::Custom(_) => "type-custom",
56        }
57    }
58
59    /// Get color for visualization
60    pub fn color(&self) -> &str {
61        match self {
62            TypeCategory::Primitive => "#4CAF50",    // Green
63            TypeCategory::Collection => "#2196F3",   // Blue
64            TypeCategory::SmartPointer => "#FF9800", // Orange
65            TypeCategory::UserDefined => "#9C27B0",  // Purple
66            TypeCategory::System => "#607D8B",       // Blue Grey
67            TypeCategory::Async => "#E91E63",        // Pink
68            TypeCategory::String => "#795548",       // Brown
69            TypeCategory::Option => "#00BCD4",       // Cyan
70            TypeCategory::Result => "#CDDC39",       // Lime
71            TypeCategory::Tuple => "#FFC107",        // Amber
72            TypeCategory::Array => "#3F51B5",        // Indigo
73            TypeCategory::Custom(_) => "#9E9E9E",    // Grey
74        }
75    }
76}
77
78/// Classification rule with pattern and priority
79pub struct ClassificationRule {
80    pattern: Regex,
81    category: TypeCategory,
82    priority: u8, // Lower number = higher priority
83    description: String,
84}
85
86impl ClassificationRule {
87    pub fn new(
88        pattern: &str,
89        category: TypeCategory,
90        priority: u8,
91        description: &str,
92    ) -> Result<Self, regex::Error> {
93        Ok(Self {
94            pattern: Regex::new(pattern)?,
95            category,
96            priority,
97            description: description.to_string(),
98        })
99    }
100
101    pub fn matches(&self, type_name: &str) -> bool {
102        self.pattern.is_match(type_name)
103    }
104
105    pub fn category(&self) -> &TypeCategory {
106        &self.category
107    }
108
109    pub fn priority(&self) -> u8 {
110        self.priority
111    }
112
113    pub fn description(&self) -> &str {
114        &self.description
115    }
116}
117
118/// Centralized type classifier
119pub struct TypeClassifier {
120    rules: Vec<ClassificationRule>,
121    cache: std::sync::Mutex<HashMap<String, TypeCategory>>,
122}
123
124impl std::fmt::Debug for TypeClassifier {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("TypeClassifier")
127            .field("rules_count", &self.rules.len())
128            .finish()
129    }
130}
131
132static GLOBAL_CLASSIFIER: OnceLock<TypeClassifier> = OnceLock::new();
133
134impl TypeClassifier {
135    /// Get the global type classifier instance
136    pub fn global() -> &'static TypeClassifier {
137        GLOBAL_CLASSIFIER.get_or_init(|| Self::new().expect("Failed to initialize type classifier"))
138    }
139
140    /// Create a new type classifier with default rules
141    pub fn new() -> Result<Self, regex::Error> {
142        let mut classifier = Self {
143            rules: Vec::new(),
144            cache: std::sync::Mutex::new(HashMap::new()),
145        };
146        classifier.initialize_default_rules()?;
147        Ok(classifier)
148    }
149
150    /// Initialize the default classification rules
151    fn initialize_default_rules(&mut self) -> Result<(), regex::Error> {
152        // Primitive types (highest priority)
153        self.add_rule(
154            r"^(i8|i16|i32|i64|i128|isize|u8|u16|u32|u64|u128|usize|f32|f64|bool|char)$",
155            TypeCategory::Primitive,
156            1,
157            "Basic primitive types",
158        )?;
159
160        // String types
161        self.add_rule(
162            r"^(String|&str|str|std::string::String)$",
163            TypeCategory::String,
164            1,
165            "String types",
166        )?;
167
168        // Option and Result types
169        self.add_rule(
170            r"^(Option|std::option::Option)<",
171            TypeCategory::Option,
172            2,
173            "Option types",
174        )?;
175        self.add_rule(
176            r"^(Result|std::result::Result)<",
177            TypeCategory::Result,
178            2,
179            "Result types",
180        )?;
181
182        // Smart pointers
183        self.add_rule(
184            r"^(Box|std::boxed::Box)<",
185            TypeCategory::SmartPointer,
186            2,
187            "Box smart pointer",
188        )?;
189        self.add_rule(
190            r"^(Arc|std::sync::Arc)<",
191            TypeCategory::SmartPointer,
192            2,
193            "Arc smart pointer",
194        )?;
195        self.add_rule(
196            r"^(Rc|std::rc::Rc)<",
197            TypeCategory::SmartPointer,
198            2,
199            "Rc smart pointer",
200        )?;
201        self.add_rule(
202            r"^(Weak|std::sync::Weak|std::rc::Weak)<",
203            TypeCategory::SmartPointer,
204            2,
205            "Weak smart pointer",
206        )?;
207
208        // Collections
209        self.add_rule(
210            r"^(Vec|std::vec::Vec)<",
211            TypeCategory::Collection,
212            2,
213            "Vector collection",
214        )?;
215        self.add_rule(
216            r"^(HashMap|std::collections::HashMap)<",
217            TypeCategory::Collection,
218            2,
219            "HashMap collection",
220        )?;
221        self.add_rule(
222            r"^(BTreeMap|std::collections::BTreeMap)<",
223            TypeCategory::Collection,
224            2,
225            "BTreeMap collection",
226        )?;
227        self.add_rule(
228            r"^(HashSet|std::collections::HashSet)<",
229            TypeCategory::Collection,
230            2,
231            "HashSet collection",
232        )?;
233        self.add_rule(
234            r"^(BTreeSet|std::collections::BTreeSet)<",
235            TypeCategory::Collection,
236            2,
237            "BTreeSet collection",
238        )?;
239        self.add_rule(
240            r"^(VecDeque|std::collections::VecDeque)<",
241            TypeCategory::Collection,
242            2,
243            "VecDeque collection",
244        )?;
245        self.add_rule(
246            r"^(LinkedList|std::collections::LinkedList)<",
247            TypeCategory::Collection,
248            2,
249            "LinkedList collection",
250        )?;
251
252        // Arrays and tuples
253        self.add_rule(
254            r"^\[.*;\s*\d+\]$",
255            TypeCategory::Array,
256            2,
257            "Fixed-size arrays",
258        )?;
259        self.add_rule(r"^\(.*,.*\)$", TypeCategory::Tuple, 2, "Tuple types")?;
260
261        // Async types
262        self.add_rule(
263            r"^(Future|std::future::Future)<",
264            TypeCategory::Async,
265            2,
266            "Future types",
267        )?;
268        self.add_rule(
269            r"^(Stream|futures::stream::Stream)<",
270            TypeCategory::Async,
271            2,
272            "Stream types",
273        )?;
274        self.add_rule(
275            r"^(Sink|futures::sink::Sink)<",
276            TypeCategory::Async,
277            2,
278            "Sink types",
279        )?;
280
281        // System types (lower priority)
282        self.add_rule(r"^std::", TypeCategory::System, 3, "Standard library types")?;
283        self.add_rule(r"^core::", TypeCategory::System, 3, "Core library types")?;
284
285        Ok(())
286    }
287
288    /// Add a new classification rule
289    pub fn add_rule(
290        &mut self,
291        pattern: &str,
292        category: TypeCategory,
293        priority: u8,
294        description: &str,
295    ) -> Result<(), regex::Error> {
296        let rule = ClassificationRule::new(pattern, category, priority, description)?;
297        self.rules.push(rule);
298
299        // Sort rules by priority after adding
300        self.rules.sort_by_key(|rule| rule.priority);
301
302        // Clear cache since rules changed
303        if let Ok(mut cache) = self.cache.lock() {
304            cache.clear();
305        }
306
307        Ok(())
308    }
309
310    /// Classify a type name
311    pub fn classify(&self, type_name: &str) -> TypeCategory {
312        // Check cache first
313        if let Ok(cache) = self.cache.lock() {
314            if let Some(category) = cache.get(type_name) {
315                return category.clone();
316            }
317        }
318
319        // Find matching rules and get the highest priority one
320        let mut matched_rules: Vec<_> = self
321            .rules
322            .iter()
323            .filter(|rule| rule.matches(type_name))
324            .collect();
325
326        matched_rules.sort_by_key(|rule| rule.priority());
327
328        let category = matched_rules
329            .first()
330            .map(|rule| rule.category().clone())
331            .unwrap_or(TypeCategory::UserDefined);
332
333        // Cache the result
334        if let Ok(mut cache) = self.cache.lock() {
335            cache.insert(type_name.to_string(), category.clone());
336        }
337
338        category
339    }
340
341    /// Get statistics about type classification
342    pub fn get_stats(&self) -> ClassificationStats {
343        let cache = self.cache.lock().unwrap();
344        let mut category_counts: HashMap<TypeCategory, usize> = HashMap::new();
345
346        for category in cache.values() {
347            *category_counts.entry(category.clone()).or_insert(0) += 1;
348        }
349
350        ClassificationStats {
351            total_types_classified: cache.len(),
352            category_counts,
353            cache_hit_ratio: 1.0, // Since we always cache results
354        }
355    }
356
357    /// Clear the classification cache
358    pub fn clear_cache(&self) {
359        if let Ok(mut cache) = self.cache.lock() {
360            cache.clear();
361        }
362    }
363
364    /// Get all available rules
365    pub fn get_rules(&self) -> &[ClassificationRule] {
366        &self.rules
367    }
368}
369
370impl Default for TypeClassifier {
371    fn default() -> Self {
372        Self::new().expect("Failed to create default type classifier")
373    }
374}
375
376/// Statistics about type classification
377#[derive(Debug, Clone)]
378pub struct ClassificationStats {
379    pub total_types_classified: usize,
380    pub category_counts: HashMap<TypeCategory, usize>,
381    pub cache_hit_ratio: f64,
382}
383
384impl ClassificationStats {
385    /// Generate a human-readable report
386    pub fn report(&self) -> String {
387        let mut report = "Type Classification Statistics:\n".to_string();
388        report.push_str(&format!(
389            "  Total types classified: {}\n",
390            self.total_types_classified
391        ));
392        report.push_str(&format!(
393            "  Cache hit ratio: {:.1}%\n",
394            self.cache_hit_ratio * 100.0
395        ));
396        report.push_str("  Category breakdown:\n");
397
398        let mut sorted_categories: Vec<_> = self.category_counts.iter().collect();
399        sorted_categories.sort_by(|a, b| b.1.cmp(a.1)); // Sort by count descending
400
401        for (category, count) in sorted_categories {
402            let percentage = (*count as f64 / self.total_types_classified as f64) * 100.0;
403            report.push_str(&format!(
404                "    {}: {} ({:.1}%)\n",
405                category.display_name(),
406                count,
407                percentage
408            ));
409        }
410
411        report
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_primitive_types() {
421        let classifier = TypeClassifier::new().unwrap();
422
423        assert_eq!(classifier.classify("i32"), TypeCategory::Primitive);
424        assert_eq!(classifier.classify("u64"), TypeCategory::Primitive);
425        assert_eq!(classifier.classify("f64"), TypeCategory::Primitive);
426        assert_eq!(classifier.classify("bool"), TypeCategory::Primitive);
427        assert_eq!(classifier.classify("char"), TypeCategory::Primitive);
428    }
429
430    #[test]
431    fn test_string_types() {
432        let classifier = TypeClassifier::new().unwrap();
433
434        assert_eq!(classifier.classify("String"), TypeCategory::String);
435        assert_eq!(classifier.classify("&str"), TypeCategory::String);
436        assert_eq!(classifier.classify("str"), TypeCategory::String);
437    }
438
439    #[test]
440    fn test_smart_pointers() {
441        let classifier = TypeClassifier::new().unwrap();
442
443        assert_eq!(classifier.classify("Box<i32>"), TypeCategory::SmartPointer);
444        assert_eq!(
445            classifier.classify("Arc<String>"),
446            TypeCategory::SmartPointer
447        );
448        assert_eq!(
449            classifier.classify("Rc<Vec<u8>>"),
450            TypeCategory::SmartPointer
451        );
452    }
453
454    #[test]
455    fn test_collections() {
456        let classifier = TypeClassifier::new().unwrap();
457
458        assert_eq!(classifier.classify("Vec<i32>"), TypeCategory::Collection);
459        assert_eq!(
460            classifier.classify("HashMap<String, i32>"),
461            TypeCategory::Collection
462        );
463        assert_eq!(
464            classifier.classify("BTreeSet<u64>"),
465            TypeCategory::Collection
466        );
467    }
468
469    #[test]
470    fn test_option_result() {
471        let classifier = TypeClassifier::new().unwrap();
472
473        assert_eq!(classifier.classify("Option<i32>"), TypeCategory::Option);
474        assert_eq!(
475            classifier.classify("Result<String, Error>"),
476            TypeCategory::Result
477        );
478    }
479
480    #[test]
481    fn test_arrays_tuples() {
482        let classifier = TypeClassifier::new().unwrap();
483
484        assert_eq!(classifier.classify("[i32; 10]"), TypeCategory::Array);
485        assert_eq!(classifier.classify("(i32, String)"), TypeCategory::Tuple);
486        assert_eq!(
487            classifier.classify("(i32, String, bool)"),
488            TypeCategory::Tuple
489        );
490    }
491
492    #[test]
493    fn test_user_defined() {
494        let classifier = TypeClassifier::new().unwrap();
495
496        assert_eq!(classifier.classify("MyStruct"), TypeCategory::UserDefined);
497        assert_eq!(
498            classifier.classify("custom::MyType"),
499            TypeCategory::UserDefined
500        );
501    }
502
503    #[test]
504    fn test_system_types() {
505        let classifier = TypeClassifier::new().unwrap();
506
507        assert_eq!(
508            classifier.classify("std::thread::Thread"),
509            TypeCategory::System
510        );
511        assert_eq!(
512            classifier.classify("core::ptr::NonNull<u8>"),
513            TypeCategory::System
514        );
515    }
516
517    #[test]
518    fn test_async_types() {
519        let classifier = TypeClassifier::new().unwrap();
520
521        assert_eq!(
522            classifier.classify("Future<Output = i32>"),
523            TypeCategory::Async
524        );
525        assert_eq!(
526            classifier.classify("Stream<Item = String>"),
527            TypeCategory::Async
528        );
529    }
530
531    #[test]
532    fn test_priority_system() {
533        let classifier = TypeClassifier::new().unwrap();
534
535        // std::string::String should be classified as String, not System
536        // even though it matches both patterns
537        assert_eq!(
538            classifier.classify("std::string::String"),
539            TypeCategory::String
540        );
541    }
542
543    #[test]
544    fn test_caching() {
545        let classifier = TypeClassifier::new().unwrap();
546
547        // First classification
548        let category1 = classifier.classify("Vec<i32>");
549
550        // Second classification should use cache
551        let category2 = classifier.classify("Vec<i32>");
552
553        assert_eq!(category1, category2);
554        assert_eq!(category1, TypeCategory::Collection);
555    }
556
557    #[test]
558    fn test_global_classifier() {
559        let classifier1 = TypeClassifier::global();
560        let classifier2 = TypeClassifier::global();
561
562        // Should be the same instance
563        assert!(std::ptr::eq(classifier1, classifier2));
564    }
565
566    #[test]
567    fn test_stats() {
568        let classifier = TypeClassifier::new().unwrap();
569
570        // Classify some types
571        classifier.classify("i32");
572        classifier.classify("String");
573        classifier.classify("Vec<u8>");
574        classifier.classify("Option<i32>");
575
576        let stats = classifier.get_stats();
577        assert_eq!(stats.total_types_classified, 4);
578        assert!(stats.category_counts.contains_key(&TypeCategory::Primitive));
579        assert!(stats.category_counts.contains_key(&TypeCategory::String));
580        assert!(stats
581            .category_counts
582            .contains_key(&TypeCategory::Collection));
583        assert!(stats.category_counts.contains_key(&TypeCategory::Option));
584    }
585}