Skip to main content

infigraph_core/reflection/
mod.rs

1use anyhow::Result;
2use serde::Serialize;
3use std::collections::HashMap;
4use std::path::Path;
5
6use crate::graph::GraphStore;
7
8#[derive(Debug, Clone, Serialize)]
9pub struct ReflectionSite {
10    pub caller_symbol: String,
11    pub mechanism: &'static str,
12    pub raw_arg: String,
13    pub resolved_to: Option<String>,
14    pub config_source: Option<String>,
15    pub file: String,
16    pub line: u32,
17}
18
19struct ReflectionPattern {
20    mechanism: &'static str,
21    patterns: &'static [&'static str],
22    extensions: &'static [&'static str],
23}
24
25static REFLECTION_PATTERNS: &[ReflectionPattern] = &[
26    // Java reflection
27    ReflectionPattern {
28        mechanism: "ClassForName",
29        patterns: &["Class.forName(", "Class.forName ("],
30        extensions: &["java", "kt"],
31    },
32    ReflectionPattern {
33        mechanism: "ServiceLoader",
34        patterns: &["ServiceLoader.load(", "ServiceLoader.load ("],
35        extensions: &["java", "kt"],
36    },
37    ReflectionPattern {
38        mechanism: "JavaReflection",
39        patterns: &[".getMethod(", ".getDeclaredMethod(", ".invoke("],
40        extensions: &["java", "kt"],
41    },
42    // Python reflection
43    ReflectionPattern {
44        mechanism: "Getattr",
45        patterns: &["getattr(", "getattr ("],
46        extensions: &["py"],
47    },
48    ReflectionPattern {
49        mechanism: "ImportModule",
50        patterns: &[
51            "importlib.import_module(",
52            "importlib.import_module (",
53            "__import__(",
54        ],
55        extensions: &["py"],
56    },
57    // JavaScript/TypeScript dynamic require/import
58    ReflectionPattern {
59        mechanism: "DynamicRequire",
60        patterns: &["require(variable", "require(`"],
61        extensions: &["js", "ts", "jsx", "tsx"],
62    },
63    ReflectionPattern {
64        mechanism: "DynamicImport",
65        patterns: &["import(", "import ("],
66        extensions: &["js", "ts", "jsx", "tsx"],
67    },
68    // C# reflection
69    ReflectionPattern {
70        mechanism: "CSharpReflection",
71        patterns: &[
72            "Activator.CreateInstance(",
73            "Type.GetType(",
74            "Assembly.Load(",
75        ],
76        extensions: &["cs"],
77    },
78    // Ruby dynamic dispatch
79    ReflectionPattern {
80        mechanism: "RubySend",
81        patterns: &[".send(", ".public_send(", "const_get("],
82        extensions: &["rb"],
83    },
84    // Go plugin
85    ReflectionPattern {
86        mechanism: "GoPlugin",
87        patterns: &["plugin.Open(", "reflect.ValueOf(", "reflect.TypeOf("],
88        extensions: &["go"],
89    },
90];
91
92pub fn detect_reflection_sites(store: &GraphStore, root: &Path) -> Result<Vec<ReflectionSite>> {
93    let _lock = store.write_lock()?;
94    let conn = store.connection()?;
95
96    let result = conn
97        .query("MATCH (s:Symbol) WHERE s.docstring IS NOT NULL AND s.docstring <> '' RETURN s.id, s.docstring, s.file")
98        .map_err(|e| anyhow::anyhow!("query failed: {e}"))?;
99
100    let all_symbols = load_symbol_names(store)?;
101    let config_values = scan_config_files(root);
102
103    let mut sites = Vec::new();
104
105    for row in result {
106        if row.len() < 3 {
107            continue;
108        }
109        let symbol_id = row[0].to_string();
110        let docstring = row[1].to_string();
111        let file = row[2].to_string();
112
113        let ext = file.rsplit('.').next().unwrap_or("");
114
115        for rp in REFLECTION_PATTERNS {
116            if !rp.extensions.contains(&ext) {
117                continue;
118            }
119            for &pattern in rp.patterns {
120                if let Some(pos) = docstring.find(pattern) {
121                    let raw_arg = extract_string_arg(&docstring[pos + pattern.len()..]);
122                    if raw_arg.is_empty() {
123                        continue;
124                    }
125
126                    let (resolved, config_src) =
127                        try_resolve(&raw_arg, rp.mechanism, &all_symbols, &config_values, root);
128
129                    let line = docstring[..pos].lines().count() as u32 + 1;
130
131                    sites.push(ReflectionSite {
132                        caller_symbol: symbol_id.clone(),
133                        mechanism: rp.mechanism,
134                        raw_arg: raw_arg.clone(),
135                        resolved_to: resolved,
136                        config_source: config_src,
137                        file: file.clone(),
138                        line,
139                    });
140                    break;
141                }
142            }
143        }
144    }
145
146    if !sites.is_empty() {
147        write_resolves_to(store, &sites)?;
148    }
149
150    Ok(sites)
151}
152
153fn extract_string_arg(after_pattern: &str) -> String {
154    let trimmed = after_pattern.trim();
155    if trimmed.starts_with('"') {
156        if let Some(end) = trimmed[1..].find('"') {
157            return trimmed[1..end + 1].to_string();
158        }
159    }
160    if trimmed.starts_with('\'') {
161        if let Some(end) = trimmed[1..].find('\'') {
162            return trimmed[1..end + 1].to_string();
163        }
164    }
165    if trimmed.starts_with('`') {
166        if let Some(end) = trimmed[1..].find('`') {
167            return trimmed[1..end + 1].to_string();
168        }
169    }
170    let end = trimmed
171        .find(|c: char| c == ')' || c == ',' || c.is_whitespace())
172        .unwrap_or(trimmed.len().min(80));
173    let candidate = &trimmed[..end];
174    if candidate.contains('.')
175        || candidate.contains("::")
176        || candidate
177            .chars()
178            .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '/')
179    {
180        return candidate.to_string();
181    }
182    String::new()
183}
184
185fn load_symbol_names(store: &GraphStore) -> Result<HashMap<String, String>> {
186    let conn = store.connection()?;
187    let result = conn
188        .query("MATCH (s:Symbol) RETURN s.id, s.name")
189        .map_err(|e| anyhow::anyhow!("load symbols: {e}"))?;
190
191    let mut map = HashMap::new();
192    for row in result {
193        if row.len() >= 2 {
194            let id = row[0].to_string();
195            let name = row[1].to_string();
196            map.insert(name, id);
197        }
198    }
199    Ok(map)
200}
201
202fn scan_config_files(root: &Path) -> HashMap<String, String> {
203    let mut values = HashMap::new();
204
205    let config_files = [
206        "application.properties",
207        "application.yml",
208        "application.yaml",
209        "config.properties",
210        "config.yml",
211        "config.yaml",
212    ];
213
214    for cf in &config_files {
215        let path = root.join(cf);
216        if let Ok(content) = std::fs::read_to_string(&path) {
217            parse_properties_into(&content, &mut values);
218        }
219        let src_resources = root.join("src/main/resources").join(cf);
220        if let Ok(content) = std::fs::read_to_string(&src_resources) {
221            parse_properties_into(&content, &mut values);
222        }
223    }
224
225    // META-INF/services for ServiceLoader
226    let services_dir = root.join("src/main/resources/META-INF/services");
227    if services_dir.is_dir() {
228        if let Ok(entries) = std::fs::read_dir(&services_dir) {
229            for entry in entries.flatten() {
230                let iface = entry.file_name().to_string_lossy().to_string();
231                if let Ok(content) = std::fs::read_to_string(entry.path()) {
232                    for line in content.lines() {
233                        let line = line.trim();
234                        if !line.is_empty() && !line.starts_with('#') {
235                            values.insert(format!("service:{}", iface), line.to_string());
236                        }
237                    }
238                }
239            }
240        }
241    }
242
243    // Python settings
244    for settings_file in &["settings.py", "config/settings.py", "config.py"] {
245        let path = root.join(settings_file);
246        if let Ok(content) = std::fs::read_to_string(&path) {
247            parse_python_settings(&content, &mut values);
248        }
249    }
250
251    values
252}
253
254fn parse_properties_into(content: &str, values: &mut HashMap<String, String>) {
255    for line in content.lines() {
256        let line = line.trim();
257        if line.is_empty() || line.starts_with('#') || line.starts_with('!') {
258            continue;
259        }
260        if let Some(eq_pos) = line.find('=') {
261            let key = line[..eq_pos].trim();
262            let val = line[eq_pos + 1..].trim();
263            values.insert(key.to_string(), val.to_string());
264        } else if let Some(colon_pos) = line.find(':') {
265            let key = line[..colon_pos].trim();
266            let val = line[colon_pos + 1..].trim();
267            if !val.is_empty() {
268                values.insert(key.to_string(), val.to_string());
269            }
270        }
271    }
272}
273
274fn parse_python_settings(content: &str, values: &mut HashMap<String, String>) {
275    for line in content.lines() {
276        let line = line.trim();
277        if line.starts_with('#') || line.is_empty() {
278            continue;
279        }
280        if let Some(eq_pos) = line.find('=') {
281            let key = line[..eq_pos].trim();
282            let val = line[eq_pos + 1..]
283                .trim()
284                .trim_matches(|c: char| c == '\'' || c == '"');
285            if key.chars().all(|c| c.is_alphanumeric() || c == '_') {
286                values.insert(key.to_string(), val.to_string());
287            }
288        }
289    }
290}
291
292fn try_resolve(
293    raw_arg: &str,
294    mechanism: &str,
295    all_symbols: &HashMap<String, String>,
296    config_values: &HashMap<String, String>,
297    _root: &Path,
298) -> (Option<String>, Option<String>) {
299    // Direct match: raw_arg is a FQCN or symbol name
300    if let Some(symbol_id) = all_symbols.get(raw_arg) {
301        return (Some(symbol_id.clone()), None);
302    }
303
304    // Try short name match (last segment)
305    let short_name = raw_arg.rsplit('.').next().unwrap_or(raw_arg);
306    let short_name2 = raw_arg.rsplit("::").next().unwrap_or(raw_arg);
307    for name in [short_name, short_name2] {
308        if let Some(symbol_id) = all_symbols.get(name) {
309            return (Some(symbol_id.clone()), None);
310        }
311    }
312
313    // ServiceLoader: check META-INF/services
314    if mechanism == "ServiceLoader" {
315        let service_key = format!("service:{}", raw_arg);
316        if let Some(impl_fqcn) = config_values.get(&service_key) {
317            let impl_short = impl_fqcn.rsplit('.').next().unwrap_or(impl_fqcn);
318            if let Some(symbol_id) = all_symbols.get(impl_short) {
319                return (
320                    Some(symbol_id.clone()),
321                    Some(format!("META-INF/services/{}", raw_arg)),
322                );
323            }
324            return (
325                Some(impl_fqcn.clone()),
326                Some(format!("META-INF/services/{}", raw_arg)),
327            );
328        }
329    }
330
331    // Config-driven: check if raw_arg is a config key that maps to a class name
332    for (key, val) in config_values {
333        if key.contains(raw_arg) || raw_arg.contains(key.as_str()) {
334            let val_short = val.rsplit('.').next().unwrap_or(val);
335            if let Some(symbol_id) = all_symbols.get(val_short) {
336                return (Some(symbol_id.clone()), Some(key.clone()));
337            }
338            if val.contains('.') || val.contains("::") {
339                return (Some(val.clone()), Some(key.clone()));
340            }
341        }
342    }
343
344    (None, None)
345}
346
347fn write_resolves_to(store: &GraphStore, sites: &[ReflectionSite]) -> Result<()> {
348    let conn = store.connection()?;
349
350    conn.query("BEGIN TRANSACTION")
351        .map_err(|e| anyhow::anyhow!("begin txn: {e}"))?;
352
353    let _ = conn.query("MATCH ()-[r:RESOLVES_TO]->() DELETE r");
354
355    for site in sites {
356        if let Some(ref target) = site.resolved_to {
357            let src_esc = crate::escape_str(&site.caller_symbol);
358            let tgt_esc = crate::escape_str(target);
359            let mech_esc = crate::escape_str(site.mechanism);
360            let cfg_esc = crate::escape_str(site.config_source.as_deref().unwrap_or(""));
361
362            let _ = conn.query(&format!(
363                "MATCH (s:Symbol), (t:Symbol) WHERE s.id = '{src_esc}' AND t.id = '{tgt_esc}' \
364                 CREATE (s)-[:RESOLVES_TO {{mechanism: '{mech_esc}', config_source: '{cfg_esc}'}}]->(t)"
365            ));
366        }
367    }
368
369    conn.query("COMMIT")
370        .map_err(|e| anyhow::anyhow!("commit txn: {e}"))?;
371
372    Ok(())
373}
374
375pub fn format_reflection_sites(sites: &[ReflectionSite]) -> String {
376    if sites.is_empty() {
377        return "No reflection/dynamic invocation sites detected.".to_string();
378    }
379
380    let resolved_count = sites.iter().filter(|s| s.resolved_to.is_some()).count();
381    let unresolved_count = sites.len() - resolved_count;
382
383    let mut out = format!(
384        "Reflection sites: {} total ({} resolved, {} unresolved)\n\n",
385        sites.len(),
386        resolved_count,
387        unresolved_count
388    );
389
390    let mut by_mechanism: std::collections::BTreeMap<&str, Vec<&ReflectionSite>> =
391        std::collections::BTreeMap::new();
392    for s in sites {
393        by_mechanism.entry(s.mechanism).or_default().push(s);
394    }
395
396    for (mech, items) in &by_mechanism {
397        out.push_str(&format!("## {} ({} sites)\n", mech, items.len()));
398        for item in items {
399            let status = match &item.resolved_to {
400                Some(target) => format!("-> {}", target),
401                None => "UNRESOLVED".to_string(),
402            };
403            out.push_str(&format!(
404                "  {}:{} — {}({}) {}\n",
405                item.file, item.line, mech, item.raw_arg, status
406            ));
407            if let Some(ref cfg) = item.config_source {
408                out.push_str(&format!("    via config: {}\n", cfg));
409            }
410        }
411        out.push('\n');
412    }
413
414    out
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_extract_string_arg_double_quotes() {
423        assert_eq!(
424            extract_string_arg("\"com.example.MyClass\")"),
425            "com.example.MyClass"
426        );
427    }
428
429    #[test]
430    fn test_extract_string_arg_single_quotes() {
431        assert_eq!(extract_string_arg("'my_module')"), "my_module");
432    }
433
434    #[test]
435    fn test_extract_string_arg_backtick() {
436        assert_eq!(
437            extract_string_arg("`./modules/${name}`)"),
438            "./modules/${name}"
439        );
440    }
441
442    #[test]
443    fn test_extract_string_arg_bare_identifier() {
444        assert_eq!(extract_string_arg("MyClass.class)"), "MyClass.class");
445    }
446
447    #[test]
448    fn test_extract_string_arg_empty_for_variable() {
449        assert_eq!(extract_string_arg("someVariable)"), "someVariable");
450    }
451
452    #[test]
453    fn test_detect_java_class_forname() {
454        let docstring = "handler = Class.forName(\"com.example.Handler\").newInstance();";
455        let mut found = Vec::new();
456        for rp in REFLECTION_PATTERNS {
457            if !rp.extensions.contains(&"java") {
458                continue;
459            }
460            for &pattern in rp.patterns {
461                if docstring.contains(pattern) {
462                    found.push(rp.mechanism);
463                    break;
464                }
465            }
466        }
467        assert!(
468            found.contains(&"ClassForName"),
469            "should detect Class.forName"
470        );
471    }
472
473    #[test]
474    fn test_detect_python_importlib() {
475        let docstring = "mod = importlib.import_module(\"handlers.email\")";
476        let mut found = Vec::new();
477        for rp in REFLECTION_PATTERNS {
478            if !rp.extensions.contains(&"py") {
479                continue;
480            }
481            for &pattern in rp.patterns {
482                if docstring.contains(pattern) {
483                    found.push(rp.mechanism);
484                    break;
485                }
486            }
487        }
488        assert!(
489            found.contains(&"ImportModule"),
490            "should detect importlib.import_module"
491        );
492    }
493
494    #[test]
495    fn test_detect_python_getattr() {
496        let docstring = "fn = getattr(obj, method_name)";
497        let mut found = Vec::new();
498        for rp in REFLECTION_PATTERNS {
499            if !rp.extensions.contains(&"py") {
500                continue;
501            }
502            for &pattern in rp.patterns {
503                if docstring.contains(pattern) {
504                    found.push(rp.mechanism);
505                    break;
506                }
507            }
508        }
509        assert!(found.contains(&"Getattr"), "should detect getattr");
510    }
511
512    #[test]
513    fn test_detect_csharp_activator() {
514        let docstring = "var obj = Activator.CreateInstance(\"MyApp.Handlers.EmailHandler\");";
515        let mut found = Vec::new();
516        for rp in REFLECTION_PATTERNS {
517            if !rp.extensions.contains(&"cs") {
518                continue;
519            }
520            for &pattern in rp.patterns {
521                if docstring.contains(pattern) {
522                    found.push(rp.mechanism);
523                    break;
524                }
525            }
526        }
527        assert!(
528            found.contains(&"CSharpReflection"),
529            "should detect Activator.CreateInstance"
530        );
531    }
532
533    #[test]
534    fn test_detect_ruby_send() {
535        let docstring = "result = obj.send(method_name, *args)";
536        let mut found = Vec::new();
537        for rp in REFLECTION_PATTERNS {
538            if !rp.extensions.contains(&"rb") {
539                continue;
540            }
541            for &pattern in rp.patterns {
542                if docstring.contains(pattern) {
543                    found.push(rp.mechanism);
544                    break;
545                }
546            }
547        }
548        assert!(found.contains(&"RubySend"), "should detect .send(");
549    }
550
551    #[test]
552    fn test_detect_go_reflect() {
553        let docstring = "v := reflect.ValueOf(handler)";
554        let mut found = Vec::new();
555        for rp in REFLECTION_PATTERNS {
556            if !rp.extensions.contains(&"go") {
557                continue;
558            }
559            for &pattern in rp.patterns {
560                if docstring.contains(pattern) {
561                    found.push(rp.mechanism);
562                    break;
563                }
564            }
565        }
566        assert!(found.contains(&"GoPlugin"), "should detect reflect.ValueOf");
567    }
568
569    #[test]
570    fn test_detect_java_service_loader() {
571        let docstring = "ServiceLoader.load(PaymentProcessor.class)";
572        let mut found = Vec::new();
573        for rp in REFLECTION_PATTERNS {
574            if !rp.extensions.contains(&"java") {
575                continue;
576            }
577            for &pattern in rp.patterns {
578                if docstring.contains(pattern) {
579                    found.push(rp.mechanism);
580                    break;
581                }
582            }
583        }
584        assert!(
585            found.contains(&"ServiceLoader"),
586            "should detect ServiceLoader.load"
587        );
588    }
589
590    #[test]
591    fn test_parse_properties() {
592        let content = "handler.class=com.example.MyHandler\ndb.url=jdbc:mysql://localhost/test";
593        let mut values = HashMap::new();
594        parse_properties_into(content, &mut values);
595        assert_eq!(
596            values.get("handler.class").unwrap(),
597            "com.example.MyHandler"
598        );
599        assert_eq!(values.get("db.url").unwrap(), "jdbc:mysql://localhost/test");
600    }
601
602    #[test]
603    fn test_parse_yaml_style_properties() {
604        let content = "handler: com.example.MyHandler\nport: 8080";
605        let mut values = HashMap::new();
606        parse_properties_into(content, &mut values);
607        assert_eq!(values.get("handler").unwrap(), "com.example.MyHandler");
608    }
609
610    #[test]
611    fn test_try_resolve_direct_match() {
612        let mut symbols = HashMap::new();
613        symbols.insert(
614            "MyHandler".to_string(),
615            "handler.java::MyHandler".to_string(),
616        );
617        let configs = HashMap::new();
618        let (resolved, _) = try_resolve(
619            "MyHandler",
620            "ClassForName",
621            &symbols,
622            &configs,
623            Path::new("."),
624        );
625        assert_eq!(resolved.unwrap(), "handler.java::MyHandler");
626    }
627
628    #[test]
629    fn test_try_resolve_fqcn_short_name() {
630        let mut symbols = HashMap::new();
631        symbols.insert(
632            "MyHandler".to_string(),
633            "handler.java::MyHandler".to_string(),
634        );
635        let configs = HashMap::new();
636        let (resolved, _) = try_resolve(
637            "com.example.MyHandler",
638            "ClassForName",
639            &symbols,
640            &configs,
641            Path::new("."),
642        );
643        assert_eq!(resolved.unwrap(), "handler.java::MyHandler");
644    }
645
646    #[test]
647    fn test_try_resolve_unresolved() {
648        let symbols = HashMap::new();
649        let configs = HashMap::new();
650        let (resolved, _) = try_resolve(
651            "com.unknown.Mystery",
652            "ClassForName",
653            &symbols,
654            &configs,
655            Path::new("."),
656        );
657        assert!(resolved.is_none());
658    }
659
660    #[test]
661    fn test_no_false_positive_plain_text() {
662        let docstring = "This class forwards messages to the service loader pattern.";
663        let mut found = Vec::new();
664        for rp in REFLECTION_PATTERNS {
665            for &pattern in rp.patterns {
666                if docstring.contains(pattern) {
667                    found.push(rp.mechanism);
668                    break;
669                }
670            }
671        }
672        assert!(found.is_empty(), "plain text should not match: {:?}", found);
673    }
674
675    #[test]
676    fn test_parse_python_settings() {
677        let content = "HANDLER_CLASS = 'myapp.handlers.EmailHandler'\nDEBUG = True";
678        let mut values = HashMap::new();
679        super::parse_python_settings(content, &mut values);
680        assert_eq!(
681            values.get("HANDLER_CLASS").unwrap(),
682            "myapp.handlers.EmailHandler"
683        );
684        assert_eq!(values.get("DEBUG").unwrap(), "True");
685    }
686}