Skip to main content

jvm_assembler/formats/class/writer/
mod.rs

1#![doc = include_str!("readme.md")]
2
3mod attributes;
4mod cp;
5mod entities;
6mod instructions;
7mod pool;
8mod utils;
9
10use crate::program::*;
11use byteorder::BigEndian;
12use cp::CpEntry;
13use gaia_types::{BinaryWriter, GaiaDiagnostics, Result};
14use std::{collections::HashMap, io::Write};
15
16/// JVM Class file writer
17pub struct ClassWriter<W> {
18    /// Binary writer
19    writer: BinaryWriter<W, BigEndian>,
20    /// Constant pool entries
21    cp_entries: Vec<CpEntry>,
22    /// Constant pool lookup table
23    cp_map: HashMap<CpEntry, u16>,
24}
25
26impl<W> ClassWriter<W> {
27    /// Create a new Class writer
28    pub fn new(writer: W) -> Self {
29        Self { writer: BinaryWriter::new(writer), cp_entries: Vec::new(), cp_map: HashMap::new() }
30    }
31
32    /// Finish writing and return the underlying writer
33    pub fn finish(self) -> W {
34        self.writer.finish()
35    }
36}
37
38impl<W: Write> ClassWriter<W> {
39    /// Write ClassView as binary Class format
40    pub fn write(mut self, program: &JvmProgram) -> GaiaDiagnostics<W> {
41        match self.write_class_file(program) {
42            Ok(_) => GaiaDiagnostics::success(self.finish()),
43            Err(error) => GaiaDiagnostics::failure(error),
44        }
45    }
46
47    /// Write Class file
48    fn write_class_file(&mut self, program: &JvmProgram) -> Result<()> {
49        // 1. Pre-collect all constants
50        let this_class_idx = self.add_class(program.name.clone());
51        let super_class_idx = if let Some(super_name) = &program.super_class {
52            self.add_class(super_name.clone())
53        }
54        else {
55            self.add_class("java/lang/Object".to_string())
56        };
57
58        // Collect class attribute constants
59        for attr in &program.attributes {
60            self.collect_attribute_constants(attr);
61        }
62
63        // Collect field constants
64        for field in &program.fields {
65            self.add_utf8(field.name.clone());
66            self.add_utf8(field.descriptor.clone());
67            for attr in &field.attributes {
68                self.collect_attribute_constants(attr);
69            }
70            if field.constant_value.is_some() {
71                self.add_utf8("ConstantValue".to_string());
72            }
73        }
74
75        // Collect method constants
76        let code_utf8_idx = self.add_utf8("Code".to_string());
77        for method in &program.methods {
78            self.add_utf8(method.name.clone());
79            self.add_utf8(method.descriptor.clone());
80
81            for handler in &method.exception_handlers {
82                if let Some(catch_type) = &handler.catch_type {
83                    self.add_class(catch_type.clone());
84                }
85            }
86
87            for attr in &method.attributes {
88                self.collect_attribute_constants(attr);
89            }
90            if !method.exceptions.is_empty() {
91                self.add_utf8("Exceptions".to_string());
92            }
93
94            // Pre-generate bytecode to get label positions for StackMapTable generation
95            let (_, label_positions) = self.generate_method_bytecode(method);
96
97            if program.version.major >= 50 {
98                self.add_utf8("StackMapTable".to_string());
99                // If automatic StackMapTable generation is needed, pre-collect its constants
100                let has_stack_map = method.attributes.iter().any(|a| matches!(a, JvmAttribute::StackMapTable { .. }));
101                if !has_stack_map {
102                    let analyzer = crate::analyzer::StackMapAnalyzer::new(program.name.clone(), method, &label_positions);
103                    let frames = analyzer.analyze();
104                    for frame in &frames {
105                        match frame {
106                            JvmStackMapFrame::SameLocals1StackItem { stack, .. }
107                            | JvmStackMapFrame::SameLocals1StackItemExtended { stack, .. } => {
108                                self.collect_verification_type_constants(stack);
109                            }
110                            JvmStackMapFrame::Append { locals, .. } => {
111                                for vt in locals {
112                                    self.collect_verification_type_constants(vt);
113                                }
114                            }
115                            JvmStackMapFrame::Full { locals, stack, .. } => {
116                                for vt in locals {
117                                    self.collect_verification_type_constants(vt);
118                                }
119                                for vt in stack {
120                                    self.collect_verification_type_constants(vt);
121                                }
122                            }
123                            _ => {}
124                        }
125                    }
126                }
127            }
128
129            // Collect constants in instructions
130            for inst in &method.instructions {
131                match inst {
132                    JvmInstruction::Ldc { symbol } | JvmInstruction::LdcW { symbol } | JvmInstruction::Ldc2W { symbol } => {
133                        self.add_string(symbol.clone());
134                    }
135                    JvmInstruction::Getstatic { class_name, field_name, descriptor }
136                    | JvmInstruction::Putstatic { class_name, field_name, descriptor }
137                    | JvmInstruction::Getfield { class_name, field_name, descriptor }
138                    | JvmInstruction::Putfield { class_name, field_name, descriptor } => {
139                        self.add_field_ref(class_name.clone(), field_name.clone(), descriptor.clone());
140                    }
141                    JvmInstruction::Invokevirtual { class_name, method_name, descriptor }
142                    | JvmInstruction::Invokespecial { class_name, method_name, descriptor }
143                    | JvmInstruction::Invokestatic { class_name, method_name, descriptor }
144                    | JvmInstruction::Invokedynamic { class_name, method_name, descriptor } => {
145                        self.add_method_ref(class_name.clone(), method_name.clone(), descriptor.clone());
146                    }
147                    JvmInstruction::Invokeinterface { class_name, method_name, descriptor } => {
148                        self.add_interface_method_ref(class_name.clone(), method_name.clone(), descriptor.clone());
149                    }
150                    JvmInstruction::New { class_name }
151                    | JvmInstruction::Anewarray { class_name }
152                    | JvmInstruction::Checkcast { class_name }
153                    | JvmInstruction::Instanceof { class_name }
154                    | JvmInstruction::Multianewarray { class_name, .. } => {
155                        self.add_class(class_name.clone());
156                    }
157                    _ => {}
158                }
159            }
160        }
161
162        // 2. Start writing
163        // Write magic number
164        self.writer.write_u32(0xCAFEBABE)?;
165
166        // Write version information
167        self.writer.write_u16(program.version.minor)?;
168        self.writer.write_u16(program.version.major)?;
169
170        // Write constant pool
171        self.write_constant_pool()?;
172
173        // Write access flags
174        self.writer.write_u16(program.access_flags.to_flags())?;
175
176        // Write this_class index
177        self.writer.write_u16(this_class_idx)?;
178
179        // Write super_class index
180        self.writer.write_u16(super_class_idx)?;
181
182        // Write interface count (currently 0)
183        self.writer.write_u16(0)?;
184
185        // Write fields
186        self.write_fields(program)?;
187
188        // Write methods
189        self.write_methods(program, code_utf8_idx)?;
190
191        // Write class attributes
192        self.writer.write_u16(program.attributes.len() as u16)?;
193        for attr in &program.attributes {
194            self.write_attribute(attr)?;
195        }
196
197        Ok(())
198    }
199}