Skip to main content

debtmap/complexity/
token_classifier.rs

1use std::collections::HashMap;
2
3/// Represents different types of variables for classification
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5pub enum VarType {
6    Iterator,      // Loop iterators (i, j, k, iter)
7    Counter,       // Counting variables
8    Temporary,     // temp, tmp, result
9    Configuration, // config, settings, options
10    Resource,      // file, conn, db, client
11    Data,          // data, value, item
12    Other,         // Everything else
13}
14
15/// Represents different types of field access patterns
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum AccessType {
18    Getter,     // Simple field access
19    Setter,     // Field assignment
20    Chained,    // a.b.c pattern
21    Collection, // Array/map access
22}
23
24/// Represents different types of method calls
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub enum CallType {
27    Getter,      // get_*, *_ref, as_*
28    Setter,      // set_*, with_*
29    Validator,   // is_*, has_*, can_*, should_*
30    Converter,   // to_*, into_*, from_*
31    IO,          // read, write, send, receive
32    ErrorHandle, // unwrap, expect, map_err
33    Collection,  // push, pop, insert, remove
34    External,    // Calls to external crates
35    Other,       // Everything else
36}
37
38/// Represents different control flow constructs
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum FlowType {
41    If,
42    Match,
43    Loop,
44    While,
45    For,
46    Return,
47    Break,
48    Continue,
49}
50
51/// Represents error handling patterns
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub enum ErrorType {
54    Result,
55    Option,
56    Unwrap,
57    Expect,
58    QuestionMark,
59    MapErr,
60    AndThen,
61    OrElse,
62}
63
64/// Represents collection operations
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub enum CollectionOp {
67    Iteration,   // iter, into_iter
68    Mapping,     // map, filter_map
69    Filtering,   // filter, take_while
70    Aggregation, // fold, reduce, collect
71    Access,      // get, contains
72    Mutation,    // push, insert, remove
73}
74
75/// Represents different types of literals
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub enum LiteralCategory {
78    Numeric,
79    String,
80    Boolean,
81    Char,
82    Null,
83}
84
85/// Main token classification enum
86#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87pub enum TokenClass {
88    LocalVar(VarType),
89    FieldAccess(AccessType),
90    MethodCall(CallType),
91    ExternalAPI(String), // Module/crate name
92    ControlFlow(FlowType),
93    ErrorHandling(ErrorType),
94    Collection(CollectionOp),
95    Literal(LiteralCategory),
96    Keyword(String),
97    Operator(String),
98    Unknown(String),
99}
100
101/// Context information for token classification
102#[derive(Debug, Clone)]
103pub struct TokenContext {
104    pub is_method_call: bool,
105    pub is_field_access: bool,
106    pub is_external: bool,
107    pub scope_depth: usize,
108    pub parent_node_type: NodeType,
109}
110
111/// Represents the type of parent AST node
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub enum NodeType {
114    Function,
115    Method,
116    Closure,
117    Block,
118    Expression,
119    Statement,
120    Pattern,
121    Type,
122}
123
124/// Configuration for token classification
125#[derive(Debug, Clone)]
126pub struct ClassificationConfig {
127    pub enabled: bool,
128    pub weights: HashMap<TokenClass, f64>,
129    pub cache_size: usize,
130}
131
132impl Default for ClassificationConfig {
133    fn default() -> Self {
134        let mut weights = HashMap::new();
135
136        // Local variables - lower weight for common patterns
137        weights.insert(TokenClass::LocalVar(VarType::Iterator), 0.1);
138        weights.insert(TokenClass::LocalVar(VarType::Counter), 0.2);
139        weights.insert(TokenClass::LocalVar(VarType::Temporary), 0.3);
140        weights.insert(TokenClass::LocalVar(VarType::Configuration), 0.5);
141        weights.insert(TokenClass::LocalVar(VarType::Resource), 0.7);
142        weights.insert(TokenClass::LocalVar(VarType::Data), 0.5);
143        weights.insert(TokenClass::LocalVar(VarType::Other), 0.4);
144
145        // Field access - moderate weight
146        weights.insert(TokenClass::FieldAccess(AccessType::Getter), 0.3);
147        weights.insert(TokenClass::FieldAccess(AccessType::Setter), 0.4);
148        weights.insert(TokenClass::FieldAccess(AccessType::Chained), 0.6);
149        weights.insert(TokenClass::FieldAccess(AccessType::Collection), 0.5);
150
151        // Method calls - varied by type
152        weights.insert(TokenClass::MethodCall(CallType::Getter), 0.2);
153        weights.insert(TokenClass::MethodCall(CallType::Setter), 0.3);
154        weights.insert(TokenClass::MethodCall(CallType::Validator), 0.4);
155        weights.insert(TokenClass::MethodCall(CallType::Converter), 0.5);
156        weights.insert(TokenClass::MethodCall(CallType::IO), 0.9);
157        weights.insert(TokenClass::MethodCall(CallType::ErrorHandle), 0.7);
158        weights.insert(TokenClass::MethodCall(CallType::Collection), 0.4);
159        weights.insert(TokenClass::MethodCall(CallType::External), 1.0);
160        weights.insert(TokenClass::MethodCall(CallType::Other), 0.6);
161
162        // Control flow - standard weight
163        weights.insert(TokenClass::ControlFlow(FlowType::If), 0.5);
164        weights.insert(TokenClass::ControlFlow(FlowType::Match), 0.6);
165        weights.insert(TokenClass::ControlFlow(FlowType::Loop), 0.7);
166        weights.insert(TokenClass::ControlFlow(FlowType::While), 0.7);
167        weights.insert(TokenClass::ControlFlow(FlowType::For), 0.6);
168        weights.insert(TokenClass::ControlFlow(FlowType::Return), 0.3);
169        weights.insert(TokenClass::ControlFlow(FlowType::Break), 0.4);
170        weights.insert(TokenClass::ControlFlow(FlowType::Continue), 0.4);
171
172        // Error handling - higher weight
173        weights.insert(TokenClass::ErrorHandling(ErrorType::Result), 0.6);
174        weights.insert(TokenClass::ErrorHandling(ErrorType::Option), 0.5);
175        weights.insert(TokenClass::ErrorHandling(ErrorType::Unwrap), 0.8);
176        weights.insert(TokenClass::ErrorHandling(ErrorType::Expect), 0.8);
177        weights.insert(TokenClass::ErrorHandling(ErrorType::QuestionMark), 0.4);
178        weights.insert(TokenClass::ErrorHandling(ErrorType::MapErr), 0.6);
179        weights.insert(TokenClass::ErrorHandling(ErrorType::AndThen), 0.5);
180        weights.insert(TokenClass::ErrorHandling(ErrorType::OrElse), 0.5);
181
182        // Collection operations - moderate weight
183        weights.insert(TokenClass::Collection(CollectionOp::Iteration), 0.3);
184        weights.insert(TokenClass::Collection(CollectionOp::Mapping), 0.5);
185        weights.insert(TokenClass::Collection(CollectionOp::Filtering), 0.5);
186        weights.insert(TokenClass::Collection(CollectionOp::Aggregation), 0.7);
187        weights.insert(TokenClass::Collection(CollectionOp::Access), 0.4);
188        weights.insert(TokenClass::Collection(CollectionOp::Mutation), 0.6);
189
190        // Literals - very low weight
191        weights.insert(TokenClass::Literal(LiteralCategory::Numeric), 0.1);
192        weights.insert(TokenClass::Literal(LiteralCategory::String), 0.2);
193        weights.insert(TokenClass::Literal(LiteralCategory::Boolean), 0.1);
194        weights.insert(TokenClass::Literal(LiteralCategory::Char), 0.1);
195        weights.insert(TokenClass::Literal(LiteralCategory::Null), 0.1);
196
197        Self {
198            enabled: false, // Disabled by default for backward compatibility
199            weights,
200            cache_size: 10000,
201        }
202    }
203}
204
205/// Main token classifier
206#[derive(Debug)]
207pub struct TokenClassifier {
208    config: ClassificationConfig,
209    cache: HashMap<(String, bool, bool), TokenClass>,
210}
211
212impl TokenClassifier {
213    pub fn new(config: ClassificationConfig) -> Self {
214        Self {
215            config,
216            cache: HashMap::new(),
217        }
218    }
219
220    pub fn classify(&mut self, token: &str, context: &TokenContext) -> TokenClass {
221        if !self.config.enabled {
222            return TokenClass::Unknown(token.to_string());
223        }
224
225        // Check cache first
226        let cache_key = (
227            token.to_string(),
228            context.is_method_call,
229            context.is_field_access,
230        );
231        if let Some(cached) = self.cache.get(&cache_key) {
232            return cached.clone();
233        }
234
235        // Perform classification
236        let class = self.classify_internal(token, context);
237
238        // Update cache if not at capacity
239        if self.cache.len() < self.config.cache_size {
240            self.cache.insert(cache_key, class.clone());
241        }
242
243        class
244    }
245
246    fn classify_internal(&self, token: &str, context: &TokenContext) -> TokenClass {
247        // Control flow keywords
248        if matches!(token, "if" | "else" | "elif") {
249            return TokenClass::ControlFlow(FlowType::If);
250        }
251        if token == "match" {
252            return TokenClass::ControlFlow(FlowType::Match);
253        }
254        if token == "loop" {
255            return TokenClass::ControlFlow(FlowType::Loop);
256        }
257        if token == "while" {
258            return TokenClass::ControlFlow(FlowType::While);
259        }
260        if token == "for" {
261            return TokenClass::ControlFlow(FlowType::For);
262        }
263        if token == "return" {
264            return TokenClass::ControlFlow(FlowType::Return);
265        }
266        if token == "break" {
267            return TokenClass::ControlFlow(FlowType::Break);
268        }
269        if token == "continue" {
270            return TokenClass::ControlFlow(FlowType::Continue);
271        }
272
273        // Method calls
274        if context.is_method_call {
275            return self.classify_method_call(token, context);
276        }
277
278        // Field access
279        if context.is_field_access {
280            return self.classify_field_access(token, context);
281        }
282
283        // Local variables
284        if !context.is_external && token.chars().all(|c| c.is_alphanumeric() || c == '_') {
285            return self.classify_local_var(token);
286        }
287
288        // Literals
289        if token.parse::<f64>().is_ok() {
290            return TokenClass::Literal(LiteralCategory::Numeric);
291        }
292        if token == "true" || token == "false" {
293            return TokenClass::Literal(LiteralCategory::Boolean);
294        }
295        if token.starts_with('"') && token.ends_with('"') {
296            return TokenClass::Literal(LiteralCategory::String);
297        }
298        if token.starts_with('\'') && token.ends_with('\'') && token.len() == 3 {
299            return TokenClass::Literal(LiteralCategory::Char);
300        }
301        if token == "null" || token == "None" || token == "nil" {
302            return TokenClass::Literal(LiteralCategory::Null);
303        }
304
305        // Keywords
306        if matches!(
307            token,
308            "fn" | "let"
309                | "const"
310                | "mut"
311                | "pub"
312                | "struct"
313                | "enum"
314                | "trait"
315                | "impl"
316                | "mod"
317                | "use"
318                | "async"
319                | "await"
320                | "self"
321                | "Self"
322        ) {
323            return TokenClass::Keyword(token.to_string());
324        }
325
326        // Operators
327        if token.chars().all(|c| "+-*/%=<>!&|^~?.".contains(c)) {
328            return TokenClass::Operator(token.to_string());
329        }
330
331        TokenClass::Unknown(token.to_string())
332    }
333
334    fn classify_method_call(&self, token: &str, context: &TokenContext) -> TokenClass {
335        let lower = token.to_lowercase();
336
337        // Getters
338        if lower.starts_with("get_") || lower.ends_with("_ref") || lower.starts_with("as_") {
339            return TokenClass::MethodCall(CallType::Getter);
340        }
341
342        // Setters
343        if lower.starts_with("set_") || lower.starts_with("with_") {
344            return TokenClass::MethodCall(CallType::Setter);
345        }
346
347        // Validators
348        if lower.starts_with("is_")
349            || lower.starts_with("has_")
350            || lower.starts_with("can_")
351            || lower.starts_with("should_")
352        {
353            return TokenClass::MethodCall(CallType::Validator);
354        }
355
356        // Converters
357        if lower.starts_with("to_")
358            || lower.starts_with("into_")
359            || lower.starts_with("from_")
360            || lower == "parse"
361        {
362            return TokenClass::MethodCall(CallType::Converter);
363        }
364
365        // I/O operations
366        if matches!(
367            lower.as_str(),
368            "read"
369                | "write"
370                | "send"
371                | "receive"
372                | "recv"
373                | "flush"
374                | "sync"
375                | "open"
376                | "close"
377                | "connect"
378        ) {
379            return TokenClass::MethodCall(CallType::IO);
380        }
381
382        // Error handling
383        if matches!(
384            lower.as_str(),
385            "unwrap" | "expect" | "map_err" | "ok" | "err" | "and_then" | "or_else" | "unwrap_or"
386        ) {
387            return TokenClass::MethodCall(CallType::ErrorHandle);
388        }
389
390        // Collection operations
391        if matches!(
392            lower.as_str(),
393            "push"
394                | "pop"
395                | "insert"
396                | "remove"
397                | "clear"
398                | "len"
399                | "is_empty"
400                | "contains"
401                | "get"
402                | "iter"
403                | "map"
404                | "filter"
405                | "fold"
406                | "collect"
407                | "sort"
408        ) {
409            return TokenClass::MethodCall(CallType::Collection);
410        }
411
412        // External API if marked as external
413        if context.is_external {
414            return TokenClass::MethodCall(CallType::External);
415        }
416
417        TokenClass::MethodCall(CallType::Other)
418    }
419
420    fn classify_field_access(&self, _token: &str, _context: &TokenContext) -> TokenClass {
421        // Simple classification based on context
422        // Could be enhanced with more sophisticated analysis
423        TokenClass::FieldAccess(AccessType::Getter)
424    }
425
426    fn classify_local_var(&self, token: &str) -> TokenClass {
427        let lower = token.to_lowercase();
428
429        // Iterators
430        if matches!(
431            lower.as_str(),
432            "i" | "j" | "k" | "n" | "idx" | "index" | "iter" | "it" | "cursor"
433        ) {
434            return TokenClass::LocalVar(VarType::Iterator);
435        }
436
437        // Counters
438        if lower.contains("count") || lower.contains("num") || lower.contains("total") {
439            return TokenClass::LocalVar(VarType::Counter);
440        }
441
442        // Temporary
443        if matches!(
444            lower.as_str(),
445            "temp" | "tmp" | "result" | "res" | "ret" | "val"
446        ) {
447            return TokenClass::LocalVar(VarType::Temporary);
448        }
449
450        // Configuration
451        if lower.contains("config")
452            || lower.contains("setting")
453            || lower.contains("option")
454            || lower.contains("param")
455        {
456            return TokenClass::LocalVar(VarType::Configuration);
457        }
458
459        // Resources
460        if lower.contains("file")
461            || lower.contains("conn")
462            || lower.contains("client")
463            || lower.contains("socket")
464            || lower.contains("stream")
465            || lower.contains("handle")
466        {
467            return TokenClass::LocalVar(VarType::Resource);
468        }
469
470        // Data
471        if lower.contains("data")
472            || lower.contains("value")
473            || lower.contains("item")
474            || lower.contains("element")
475            || lower.contains("node")
476            || lower.contains("entry")
477        {
478            return TokenClass::LocalVar(VarType::Data);
479        }
480
481        TokenClass::LocalVar(VarType::Other)
482    }
483
484    pub fn get_weight(&self, class: &TokenClass) -> f64 {
485        self.config.weights.get(class).copied().unwrap_or(0.5)
486    }
487
488    pub fn update_weights(&mut self, weights: HashMap<TokenClass, f64>) {
489        self.config.weights = weights;
490    }
491
492    pub fn clear_cache(&mut self) {
493        self.cache.clear();
494    }
495}
496
497/// Result of token classification with metadata
498#[derive(Debug, Clone)]
499pub struct ClassifiedToken {
500    pub class: TokenClass,
501    pub raw_token: String,
502    pub context: TokenContext,
503    pub weight: f64,
504}
505
506impl ClassifiedToken {
507    pub fn new(class: TokenClass, raw_token: String, context: TokenContext, weight: f64) -> Self {
508        Self {
509            class,
510            raw_token,
511            context,
512            weight,
513        }
514    }
515}