Skip to main content

rajac_classpath/
classpath.rs

1use rajac_base::shared_string::SharedString;
2use rajac_symbols::{Symbol, SymbolKind, SymbolTable};
3use rayon::prelude::*;
4use ristretto_classfile::{BaseType, ClassFile, FieldAccessFlags, FieldType, MethodAccessFlags};
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{Cursor, Read};
8use std::path::Path;
9use std::time::Instant;
10use zip::ZipArchive;
11
12pub struct Classpath {
13    entries: Vec<ClasspathEntry>,
14}
15
16enum ClasspathEntry {
17    Directory(PathBuf),
18    Jar(PathBuf),
19}
20
21struct ParsedClass {
22    package: SharedString,
23    class_name: SharedString,
24    is_interface: bool,
25    super_class: Option<SharedString>,
26    interfaces: Vec<SharedString>,
27    methods: Vec<ParsedMethod>,
28    fields: Vec<ParsedField>,
29}
30
31#[derive(Clone, Debug)]
32struct ParsedMethod {
33    name: SharedString,
34    params: Vec<FieldType>,
35    return_type: Option<FieldType>,
36    modifiers: rajac_types::MethodModifiers,
37}
38
39#[derive(Clone, Debug)]
40struct ParsedField {
41    name: SharedString,
42    ty: FieldType,
43    modifiers: rajac_types::FieldModifiers,
44}
45
46impl Classpath {
47    pub fn new() -> Self {
48        Self {
49            entries: Vec::new(),
50        }
51    }
52
53    pub fn is_empty(&self) -> bool {
54        self.entries.is_empty()
55    }
56
57    pub fn add_directory(&mut self, path: impl Into<PathBuf>) {
58        self.entries.push(ClasspathEntry::Directory(path.into()));
59    }
60
61    pub fn add_jar(&mut self, path: impl Into<PathBuf>) {
62        self.entries.push(ClasspathEntry::Jar(path.into()));
63    }
64
65    pub fn add_to_symbol_table(&self, symbol_table: &mut SymbolTable) -> RajacResult<()> {
66        for entry in &self.entries {
67            match entry {
68                ClasspathEntry::Directory(dir) => {
69                    self.add_directory_to_symbol_table(dir, symbol_table)?;
70                }
71                ClasspathEntry::Jar(jar) => {
72                    self.add_jar_to_symbol_table(jar, symbol_table)?;
73                }
74            }
75        }
76        Ok(())
77    }
78
79    fn add_directory_to_symbol_table(
80        &self,
81        dir: &Path,
82        symbol_table: &mut SymbolTable,
83    ) -> RajacResult<()> {
84        if !dir.is_dir() {
85            return Ok(());
86        }
87
88        // First pass: Collect and parse all class files
89        let mut parsed_classes = Vec::new();
90        for entry in walkdir::WalkDir::new(dir)
91            .follow_links(true)
92            .into_iter()
93            .filter_map(|e| e.ok())
94        {
95            let path = entry.path();
96            if path.is_file() && path.extension().is_some_and(|ext| ext == "class") {
97                let bytes = std::fs::read(path).context("Failed to read class file")?;
98                if let Ok(class_file) = ClassFile::from_bytes(&mut Cursor::new(bytes))
99                    && let Some(parsed) = parse_class_file(&class_file)
100                {
101                    parsed_classes.push(parsed);
102                }
103            }
104        }
105
106        // Second pass: Collect raw data (without holding references to symbol_table)
107        let class_info: Vec<_> = parsed_classes
108            .iter()
109            .map(|parsed_class| {
110                let package_name = parsed_class.package.clone();
111                let class_name = parsed_class.class_name.clone();
112                let kind = if parsed_class.is_interface {
113                    SymbolKind::Interface
114                } else {
115                    SymbolKind::Class
116                };
117                (package_name, class_name, kind)
118            })
119            .collect();
120
121        // Third pass: Allocate types first
122        let type_arena = symbol_table.type_arena_mut();
123        let type_ids: Vec<_> = class_info
124            .iter()
125            .map(|(package_name, class_name, _)| {
126                let class_type = if !package_name.is_empty() {
127                    rajac_types::ClassType::new(class_name.clone())
128                        .with_package(package_name.clone())
129                } else {
130                    rajac_types::ClassType::new(class_name.clone())
131                };
132                type_arena.alloc(rajac_types::Type::class(class_type))
133            })
134            .collect();
135
136        // Fourth pass: Insert into symbol table
137        for (type_id, (package_name, class_name, kind)) in type_ids.into_iter().zip(class_info) {
138            let package = symbol_table.package(&package_name);
139            package.insert(class_name.clone(), Symbol::new(class_name, kind, type_id));
140        }
141
142        // Fifth pass: Resolve superclass and interface relationships
143        resolve_class_relationships(&parsed_classes, symbol_table)?;
144
145        Ok(())
146    }
147
148    fn add_jar_to_symbol_table(
149        &self,
150        jar: &Path,
151        symbol_table: &mut SymbolTable,
152    ) -> RajacResult<()> {
153        let file = File::open(jar).context("Failed to open JAR file")?;
154        let mut archive = ZipArchive::new(file).context("Failed to read JAR file")?;
155
156        let class_data: Vec<Vec<u8>> = (0..archive.len())
157            .filter_map(|i| {
158                let mut file = archive.by_index(i).ok()?;
159                let name = file.name();
160                if name.ends_with(".class") && !name.contains('$') {
161                    let mut bytes = Vec::new();
162                    file.read_to_end(&mut bytes).ok()?;
163                    Some(bytes)
164                } else {
165                    None
166                }
167            })
168            .collect();
169
170        drop(archive);
171
172        let start = Instant::now();
173        let parsed_classes: Vec<ParsedClass> = class_data
174            .into_par_iter()
175            .filter_map(|bytes| {
176                let class_file = ClassFile::from_bytes(&mut Cursor::new(bytes)).ok()?;
177                parse_class_file(&class_file)
178            })
179            .collect();
180        println!(
181            "Read {:?} in {}ms ({} classes)",
182            jar,
183            start.elapsed().as_millis(),
184            parsed_classes.len()
185        );
186
187        // First pass: Collect raw data
188        let class_info: Vec<_> = parsed_classes
189            .iter()
190            .map(|parsed_class| {
191                let package_name = parsed_class.package.clone();
192                let class_name = parsed_class.class_name.clone();
193                let kind = if parsed_class.is_interface {
194                    SymbolKind::Interface
195                } else {
196                    SymbolKind::Class
197                };
198                (package_name, class_name, kind)
199            })
200            .collect();
201
202        // Second pass: Allocate types first
203        let type_arena = symbol_table.type_arena_mut();
204        let type_ids: Vec<_> = class_info
205            .iter()
206            .map(|(package_name, class_name, _)| {
207                let class_type = if !package_name.is_empty() {
208                    rajac_types::ClassType::new(class_name.clone())
209                        .with_package(package_name.clone())
210                } else {
211                    rajac_types::ClassType::new(class_name.clone())
212                };
213                type_arena.alloc(rajac_types::Type::class(class_type))
214            })
215            .collect();
216
217        // Third pass: Insert into symbol table
218        for (type_id, (package_name, class_name, kind)) in type_ids.into_iter().zip(class_info) {
219            let package = symbol_table.package(&package_name);
220            package.insert(class_name.clone(), Symbol::new(class_name, kind, type_id));
221        }
222
223        // Fourth pass: Resolve superclass and interface relationships
224        resolve_class_relationships(&parsed_classes, symbol_table)?;
225
226        Ok(())
227    }
228}
229
230impl Default for Classpath {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236fn parse_class_file(class_file: &ClassFile) -> Option<ParsedClass> {
237    let internal_name = class_file.class_name().ok()?;
238
239    let (package, class_name) = if let Some(last_slash) = internal_name.rfind('/') {
240        (
241            SharedString::new(internal_name[..last_slash].replace('/', ".")),
242            SharedString::new(&internal_name[last_slash + 1..]),
243        )
244    } else {
245        (SharedString::new(""), SharedString::new(internal_name))
246    };
247
248    let is_interface = class_file
249        .access_flags
250        .contains(ristretto_classfile::ClassAccessFlags::INTERFACE);
251
252    // Extract superclass information
253    let super_class = if class_file.super_class != 0 {
254        class_file
255            .constant_pool
256            .try_get_class(class_file.super_class)
257            .ok()
258            .map(|name| SharedString::new(name.replace('/', ".")))
259    } else {
260        None
261    };
262
263    // Extract interface information
264    let interfaces: Vec<SharedString> = class_file
265        .interfaces
266        .iter()
267        .filter_map(|&interface_idx| {
268            class_file
269                .constant_pool
270                .try_get_class(interface_idx)
271                .ok()
272                .map(|name| SharedString::new(name.replace('/', ".")))
273        })
274        .collect();
275
276    let methods: Vec<ParsedMethod> = class_file
277        .methods
278        .iter()
279        .filter_map(|method| parse_class_method(class_file, method, &class_name))
280        .collect();
281
282    let fields: Vec<ParsedField> = class_file
283        .fields
284        .iter()
285        .filter_map(|field| parse_class_field(class_file, field))
286        .collect();
287
288    Some(ParsedClass {
289        package,
290        class_name,
291        is_interface,
292        super_class,
293        interfaces,
294        methods,
295        fields,
296    })
297}
298
299fn parse_class_method(
300    class_file: &ClassFile,
301    method: &ristretto_classfile::Method,
302    class_name: &SharedString,
303) -> Option<ParsedMethod> {
304    let raw_name = class_file
305        .constant_pool
306        .try_get_utf8(method.name_index)
307        .ok()?;
308    let name = method_name_for_class(raw_name, class_name)?;
309    let descriptor = class_file
310        .constant_pool
311        .try_get_utf8(method.descriptor_index)
312        .ok()?;
313    let (params, return_type) = FieldType::parse_method_descriptor(descriptor).ok()?;
314
315    Some(ParsedMethod {
316        name,
317        params,
318        return_type,
319        modifiers: method_modifiers_from_access_flags(method.access_flags),
320    })
321}
322
323fn parse_class_field(
324    class_file: &ClassFile,
325    field: &ristretto_classfile::Field,
326) -> Option<ParsedField> {
327    let raw_name = class_file
328        .constant_pool
329        .try_get_utf8(field.name_index)
330        .ok()?;
331    let name = SharedString::new(raw_name);
332
333    Some(ParsedField {
334        name,
335        ty: field.field_type.clone(),
336        modifiers: field_modifiers_from_access_flags(field.access_flags),
337    })
338}
339
340fn method_name_for_class(raw_name: &str, class_name: &SharedString) -> Option<SharedString> {
341    match raw_name {
342        "<clinit>" => None,
343        "<init>" => Some(class_name.clone()),
344        _ => Some(SharedString::new(raw_name)),
345    }
346}
347
348fn method_modifiers_from_access_flags(
349    access_flags: MethodAccessFlags,
350) -> rajac_types::MethodModifiers {
351    let mut bits = 0;
352    if access_flags.contains(MethodAccessFlags::PUBLIC) {
353        bits |= rajac_types::MethodModifiers::PUBLIC;
354    }
355    if access_flags.contains(MethodAccessFlags::PRIVATE) {
356        bits |= rajac_types::MethodModifiers::PRIVATE;
357    }
358    if access_flags.contains(MethodAccessFlags::PROTECTED) {
359        bits |= rajac_types::MethodModifiers::PROTECTED;
360    }
361    if access_flags.contains(MethodAccessFlags::STATIC) {
362        bits |= rajac_types::MethodModifiers::STATIC;
363    }
364    if access_flags.contains(MethodAccessFlags::FINAL) {
365        bits |= rajac_types::MethodModifiers::FINAL;
366    }
367    if access_flags.contains(MethodAccessFlags::ABSTRACT) {
368        bits |= rajac_types::MethodModifiers::ABSTRACT;
369    }
370    if access_flags.contains(MethodAccessFlags::NATIVE) {
371        bits |= rajac_types::MethodModifiers::NATIVE;
372    }
373    if access_flags.contains(MethodAccessFlags::SYNCHRONIZED) {
374        bits |= rajac_types::MethodModifiers::SYNCHRONIZED;
375    }
376    if access_flags.contains(MethodAccessFlags::STRICT) {
377        bits |= rajac_types::MethodModifiers::STRICTFP;
378    }
379
380    rajac_types::MethodModifiers(bits)
381}
382
383fn field_modifiers_from_access_flags(
384    access_flags: FieldAccessFlags,
385) -> rajac_types::FieldModifiers {
386    let mut bits = 0;
387    if access_flags.contains(FieldAccessFlags::PUBLIC) {
388        bits |= rajac_types::FieldModifiers::PUBLIC;
389    }
390    if access_flags.contains(FieldAccessFlags::PRIVATE) {
391        bits |= rajac_types::FieldModifiers::PRIVATE;
392    }
393    if access_flags.contains(FieldAccessFlags::PROTECTED) {
394        bits |= rajac_types::FieldModifiers::PROTECTED;
395    }
396    if access_flags.contains(FieldAccessFlags::STATIC) {
397        bits |= rajac_types::FieldModifiers::STATIC;
398    }
399    if access_flags.contains(FieldAccessFlags::FINAL) {
400        bits |= rajac_types::FieldModifiers::FINAL;
401    }
402    if access_flags.contains(FieldAccessFlags::VOLATILE) {
403        bits |= rajac_types::FieldModifiers::VOLATILE;
404    }
405    if access_flags.contains(FieldAccessFlags::TRANSIENT) {
406        bits |= rajac_types::FieldModifiers::TRANSIENT;
407    }
408
409    rajac_types::FieldModifiers(bits)
410}
411
412fn resolve_class_relationships(
413    parsed_classes: &[ParsedClass],
414    symbol_table: &mut SymbolTable,
415) -> RajacResult<()> {
416    // First pass: Collect all the relationships we need to resolve (only read from symbol_table)
417    let relationships: Vec<_> = parsed_classes
418        .iter()
419        .filter_map(|parsed_class| {
420            let package_table = symbol_table.get_package_shared(&parsed_class.package)?;
421            let symbol = package_table.get(&parsed_class.class_name)?;
422            let type_id = symbol.ty;
423
424            let super_type_id = parsed_class
425                .super_class
426                .as_ref()
427                .and_then(|super_class_name| {
428                    find_type_id_for_class_impl(super_class_name, symbol_table)
429                });
430
431            let interface_type_ids: Vec<rajac_types::TypeId> = parsed_class
432                .interfaces
433                .iter()
434                .filter_map(|interface_name| {
435                    find_type_id_for_class_impl(interface_name, symbol_table)
436                })
437                .collect();
438
439            Some((
440                type_id,
441                super_type_id,
442                interface_type_ids,
443                parsed_class.methods.clone(),
444                parsed_class.fields.clone(),
445            ))
446        })
447        .collect();
448
449    // Second pass: Apply the relationships (only write to type_arena)
450    let class_lookup = build_class_lookup(symbol_table);
451    let primitive_lookup = symbol_table.primitive_types().clone();
452    let (type_arena, method_arena, field_arena) = symbol_table.arenas_mut();
453    for (type_id, super_type_id, interface_type_ids, methods, fields) in relationships {
454        let mut resolved_methods = Vec::with_capacity(methods.len());
455        for method in methods {
456            let params = method
457                .params
458                .iter()
459                .map(|param| {
460                    resolve_field_type(param, &primitive_lookup, &class_lookup, type_arena)
461                })
462                .collect::<Vec<_>>();
463            let return_type = match &method.return_type {
464                Some(field_type) => {
465                    resolve_field_type(field_type, &primitive_lookup, &class_lookup, type_arena)
466                }
467                None => void_type_id(&primitive_lookup),
468            };
469            let signature = rajac_types::MethodSignature {
470                name: method.name.clone(),
471                params,
472                return_type,
473                throws: Vec::new(),
474                modifiers: method.modifiers,
475            };
476            let method_id = method_arena.alloc(signature);
477            resolved_methods.push((method.name, method_id));
478        }
479
480        let mut resolved_fields = Vec::with_capacity(fields.len());
481        for field in fields {
482            let field_type =
483                resolve_field_type(&field.ty, &primitive_lookup, &class_lookup, type_arena);
484            let signature =
485                rajac_types::FieldSignature::new(field.name.clone(), field_type, field.modifiers);
486            let field_id = field_arena.alloc(signature);
487            resolved_fields.push((field.name, field_id));
488        }
489
490        let class_type = type_arena.get_mut(type_id);
491        if let rajac_types::Type::Class(class_type_mut) = class_type {
492            class_type_mut.superclass = super_type_id;
493            class_type_mut.interfaces = interface_type_ids;
494            for (name, method_id) in resolved_methods {
495                class_type_mut.add_method(name, method_id);
496            }
497            for (name, field_id) in resolved_fields {
498                class_type_mut.add_field(name, field_id);
499            }
500        }
501    }
502
503    Ok(())
504}
505
506fn find_type_id_for_class_impl(
507    class_name: &str,
508    symbol_table: &SymbolTable,
509) -> Option<rajac_types::TypeId> {
510    let (package, simple_name) = if let Some(last_dot) = class_name.rfind('.') {
511        (
512            SharedString::new(&class_name[..last_dot]),
513            SharedString::new(&class_name[last_dot + 1..]),
514        )
515    } else {
516        (SharedString::new(""), SharedString::new(class_name))
517    };
518
519    symbol_table.lookup_type_id(package.as_str(), simple_name.as_str())
520}
521
522fn build_class_lookup(symbol_table: &SymbolTable) -> HashMap<String, rajac_types::TypeId> {
523    let mut lookup = HashMap::new();
524    for (package, table) in symbol_table.iter() {
525        for (name, symbol) in table.iter() {
526            let fqn = if package.is_empty() {
527                name.as_str().to_string()
528            } else {
529                format!("{}.{}", package.as_str(), name.as_str())
530            };
531            lookup.insert(fqn, symbol.ty);
532        }
533    }
534    lookup
535}
536
537fn resolve_field_type(
538    field_type: &FieldType,
539    primitive_lookup: &HashMap<SharedString, rajac_types::TypeId>,
540    class_lookup: &HashMap<String, rajac_types::TypeId>,
541    type_arena: &mut rajac_types::TypeArena,
542) -> rajac_types::TypeId {
543    match field_type {
544        FieldType::Base(base_type) => primitive_lookup
545            .get(&SharedString::new(primitive_name_from_base_type(base_type)))
546            .copied()
547            .unwrap_or(rajac_types::TypeId::INVALID),
548        FieldType::Object(class_name) => {
549            let fqn = class_name.replace('/', ".");
550            class_lookup
551                .get(&fqn)
552                .copied()
553                .unwrap_or(rajac_types::TypeId::INVALID)
554        }
555        FieldType::Array(component_type) => {
556            let element_type =
557                resolve_field_type(component_type, primitive_lookup, class_lookup, type_arena);
558            if element_type == rajac_types::TypeId::INVALID {
559                rajac_types::TypeId::INVALID
560            } else {
561                type_arena.alloc(rajac_types::Type::array(element_type))
562            }
563        }
564    }
565}
566
567fn primitive_name_from_base_type(base_type: &BaseType) -> &'static str {
568    match base_type {
569        BaseType::Boolean => "boolean",
570        BaseType::Byte => "byte",
571        BaseType::Char => "char",
572        BaseType::Short => "short",
573        BaseType::Int => "int",
574        BaseType::Long => "long",
575        BaseType::Float => "float",
576        BaseType::Double => "double",
577    }
578}
579
580fn void_type_id(
581    primitive_lookup: &HashMap<SharedString, rajac_types::TypeId>,
582) -> rajac_types::TypeId {
583    primitive_lookup
584        .get(&SharedString::new("void"))
585        .copied()
586        .unwrap_or(rajac_types::TypeId::INVALID)
587}
588
589use rajac_base::result::{RajacResult, ResultExt};
590use std::path::PathBuf;
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_parse_class_file_name() {
598        let classpath = Classpath::new();
599        assert!(classpath.entries.is_empty());
600    }
601
602    #[test]
603    fn method_name_handles_init_and_clinit() {
604        let class_name = SharedString::new("Widget");
605        assert_eq!(
606            method_name_for_class("<init>", &class_name),
607            Some(SharedString::new("Widget"))
608        );
609        assert_eq!(method_name_for_class("<clinit>", &class_name), None);
610        assert_eq!(
611            method_name_for_class("run", &class_name),
612            Some(SharedString::new("run"))
613        );
614    }
615
616    #[test]
617    fn resolves_field_types_with_lookup_and_arrays() {
618        let mut symbol_table = SymbolTable::new();
619        let primitive_lookup = symbol_table.primitive_types().clone();
620        let type_arena = symbol_table.type_arena_mut();
621        let string_id = type_arena.alloc(rajac_types::Type::class(
622            rajac_types::ClassType::new(SharedString::new("String"))
623                .with_package(SharedString::new("java.lang")),
624        ));
625        let mut class_lookup = HashMap::new();
626        class_lookup.insert("java.lang.String".to_string(), string_id);
627
628        let object_type = FieldType::Object("java/lang/String".to_string());
629        assert_eq!(
630            resolve_field_type(&object_type, &primitive_lookup, &class_lookup, type_arena),
631            string_id
632        );
633
634        let int_type = FieldType::Base(BaseType::Int);
635        let int_id = resolve_field_type(&int_type, &primitive_lookup, &class_lookup, type_arena);
636        assert_eq!(
637            type_arena.get(int_id),
638            &rajac_types::Type::primitive(rajac_types::PrimitiveType::Int)
639        );
640
641        let array_type = FieldType::Array(Box::new(FieldType::Base(BaseType::Boolean)));
642        let array_id =
643            resolve_field_type(&array_type, &primitive_lookup, &class_lookup, type_arena);
644        match type_arena.get(array_id) {
645            rajac_types::Type::Array(array) => {
646                let element_type = type_arena.get(array.element_type);
647                assert_eq!(
648                    element_type,
649                    &rajac_types::Type::primitive(rajac_types::PrimitiveType::Boolean)
650                );
651            }
652            other => panic!("expected array type, got {other:?}"),
653        }
654    }
655
656    #[test]
657    fn maps_access_flags_to_method_modifiers() {
658        let flags =
659            MethodAccessFlags::PUBLIC | MethodAccessFlags::STATIC | MethodAccessFlags::FINAL;
660        let modifiers = method_modifiers_from_access_flags(flags);
661        assert_eq!(
662            modifiers,
663            rajac_types::MethodModifiers(
664                rajac_types::MethodModifiers::PUBLIC
665                    | rajac_types::MethodModifiers::STATIC
666                    | rajac_types::MethodModifiers::FINAL
667            )
668        );
669    }
670
671    #[test]
672    fn maps_access_flags_to_field_modifiers() {
673        let flags = FieldAccessFlags::PUBLIC | FieldAccessFlags::STATIC | FieldAccessFlags::FINAL;
674        let modifiers = field_modifiers_from_access_flags(flags);
675        assert_eq!(
676            modifiers,
677            rajac_types::FieldModifiers(
678                rajac_types::FieldModifiers::PUBLIC
679                    | rajac_types::FieldModifiers::STATIC
680                    | rajac_types::FieldModifiers::FINAL
681            )
682        );
683    }
684}