Skip to main content

jvm_assembler/formats/class/writer/
mod.rs

1#![doc = include_str!("readme.md")]
2
3mod attributes;
4mod cp;
5mod entities;
6mod instructions;
7mod pool;
8mod utils;
9
10use crate::program::*;
11use cp::CpEntry;
12use gaia_binary::{BigEndian, BinaryWriter, Fixed};
13use gaia_types::{GaiaDiagnostics, Result};
14use std::{collections::HashMap, io::Write};
15
16/// JVM Class file writer
17pub struct ClassWriter<W: Write> {
18    /// Binary writer
19    writer: BinaryWriter<W, Fixed<BigEndian>>,
20    /// Constant pool entries
21    cp_entries: Vec<CpEntry>,
22    /// Constant pool lookup table
23    cp_map: HashMap<CpEntry, u16>,
24}
25
26impl<W: Write> ClassWriter<W> {
27    /// Create a new Class writer
28    pub fn new(writer: W) -> Self {
29        Self { writer: BinaryWriter::new(writer), cp_entries: Vec::new(), cp_map: HashMap::new() }
30    }
31
32    /// Finish writing and return the underlying writer
33    pub fn finish(self) -> W {
34        self.writer.into_inner()
35    }
36
37    /// Write ClassView as binary Class format
38    pub fn write(mut self, program: &JvmProgram) -> GaiaDiagnostics<W> {
39        match self.write_class_file(program) {
40            Ok(_) => GaiaDiagnostics::success(self.finish()),
41            Err(error) => GaiaDiagnostics::failure(error),
42        }
43    }
44
45    /// Write Class file
46    fn write_class_file(&mut self, program: &JvmProgram) -> Result<()> {
47        // 1. Pre-collect all constants
48        let this_class_idx = self.add_class(program.name.clone());
49        let super_class_idx = if let Some(super_name) = &program.super_class {
50            self.add_class(super_name.clone())
51        }
52        else {
53            self.add_class("java/lang/Object".to_string())
54        };
55
56        // Collect class attribute constants
57        for attr in &program.attributes {
58            self.collect_attribute_constants(attr);
59        }
60
61        // Collect inner classes constants
62        if !program.inner_classes.is_empty() {
63            let inner_classes_attr = JvmAttribute::InnerClasses { classes: program.inner_classes.clone() };
64            self.collect_attribute_constants(&inner_classes_attr);
65        }
66
67        // Collect enclosing method constants
68        if let Some((class_name, method_name, method_descriptor)) = &program.enclosing_method {
69            let enclosing_method_attr = JvmAttribute::EnclosingMethod {
70                class_name: class_name.clone(),
71                method_name: method_name.clone(),
72                method_descriptor: method_descriptor.clone(),
73            };
74            self.collect_attribute_constants(&enclosing_method_attr);
75        }
76
77        // Collect field constants
78        for field in &program.fields {
79            self.add_utf8(field.name.clone());
80            self.add_utf8(field.descriptor.clone());
81            for attr in &field.attributes {
82                self.collect_attribute_constants(attr);
83            }
84            if field.constant_value.is_some() {
85                self.add_utf8("ConstantValue".to_string());
86            }
87        }
88
89        // Collect method constants
90        let code_utf8_idx = self.add_utf8("Code".to_string());
91        for method in &program.methods {
92            self.add_utf8(method.name.clone());
93            self.add_utf8(method.descriptor.clone());
94
95            for handler in &method.exception_handlers {
96                if let Some(catch_type) = &handler.catch_type {
97                    self.add_class(catch_type.clone());
98                }
99            }
100
101            for attr in &method.attributes {
102                self.collect_attribute_constants(attr);
103            }
104            if !method.exceptions.is_empty() {
105                self.add_utf8("Exceptions".to_string());
106            }
107
108            // Pre-generate bytecode to get label positions for StackMapTable generation
109            let (_, label_positions) = self.generate_method_bytecode(method);
110
111            if program.version.major >= 50 {
112                self.add_utf8("StackMapTable".to_string());
113                // If automatic StackMapTable generation is needed, pre-collect its constants
114                let has_stack_map = method.attributes.iter().any(|a| matches!(a, JvmAttribute::StackMapTable { .. }));
115                if !has_stack_map {
116                    let analyzer = crate::analyzer::StackMapAnalyzer::new(program.name.clone(), method, &label_positions);
117                    let frames = analyzer.analyze();
118                    for frame in &frames {
119                        match frame {
120                            JvmStackMapFrame::SameLocals1StackItem { stack, .. }
121                            | JvmStackMapFrame::SameLocals1StackItemExtended { stack, .. } => {
122                                self.collect_verification_type_constants(stack);
123                            }
124                            JvmStackMapFrame::Append { locals, .. } => {
125                                for vt in locals {
126                                    self.collect_verification_type_constants(vt);
127                                }
128                            }
129                            JvmStackMapFrame::Full { locals, stack, .. } => {
130                                for vt in locals {
131                                    self.collect_verification_type_constants(vt);
132                                }
133                                for vt in stack {
134                                    self.collect_verification_type_constants(vt);
135                                }
136                            }
137                            _ => {}
138                        }
139                    }
140                }
141            }
142
143            // Collect constants in instructions
144            for inst in &method.instructions {
145                match inst {
146                    JvmInstruction::Ldc { symbol } | JvmInstruction::LdcW { symbol } | JvmInstruction::Ldc2W { symbol } => {
147                        self.add_string(symbol.clone());
148                    }
149                    JvmInstruction::Getstatic { class_name, field_name, descriptor }
150                    | JvmInstruction::Putstatic { class_name, field_name, descriptor }
151                    | JvmInstruction::Getfield { class_name, field_name, descriptor }
152                    | JvmInstruction::Putfield { class_name, field_name, descriptor } => {
153                        self.add_field_ref(class_name.clone(), field_name.clone(), descriptor.clone());
154                    }
155                    JvmInstruction::Invokevirtual { class_name, method_name, descriptor }
156                    | JvmInstruction::Invokespecial { class_name, method_name, descriptor }
157                    | JvmInstruction::Invokestatic { class_name, method_name, descriptor }
158                    | JvmInstruction::Invokedynamic { class_name, method_name, descriptor } => {
159                        self.add_method_ref(class_name.clone(), method_name.clone(), descriptor.clone());
160                    }
161                    JvmInstruction::Invokeinterface { class_name, method_name, descriptor } => {
162                        self.add_interface_method_ref(class_name.clone(), method_name.clone(), descriptor.clone());
163                    }
164                    JvmInstruction::New { class_name }
165                    | JvmInstruction::Anewarray { class_name }
166                    | JvmInstruction::Checkcast { class_name }
167                    | JvmInstruction::Instanceof { class_name }
168                    | JvmInstruction::Multianewarray { class_name, .. } => {
169                        self.add_class(class_name.clone());
170                    }
171                    _ => {}
172                }
173            }
174        }
175
176        // 2. Start writing
177        // Write magic number
178        self.writer.write_u32(0xCAFEBABE)?;
179
180        // Write version information
181        self.writer.write_u16(program.version.minor)?;
182        self.writer.write_u16(program.version.major)?;
183
184        // Write constant pool
185        self.write_constant_pool()?;
186
187        // Write access flags
188        self.writer.write_u16(program.access_flags.to_flags())?;
189
190        // Write this_class index
191        self.writer.write_u16(this_class_idx)?;
192
193        // Write super_class index
194        self.writer.write_u16(super_class_idx)?;
195
196        // Write interface count (currently 0)
197        self.writer.write_u16(0)?;
198
199        // Write fields
200        self.write_fields(program)?;
201
202        // Write methods
203        self.write_methods(program, code_utf8_idx)?;
204
205        // Write class attributes
206        let mut attributes = program.attributes.clone();
207        if !program.inner_classes.is_empty() {
208            attributes.push(JvmAttribute::InnerClasses { classes: program.inner_classes.clone() });
209        }
210        if let Some((class_name, method_name, method_descriptor)) = &program.enclosing_method {
211            attributes.push(JvmAttribute::EnclosingMethod {
212                class_name: class_name.clone(),
213                method_name: method_name.clone(),
214                method_descriptor: method_descriptor.clone(),
215            });
216        }
217        self.writer.write_u16(attributes.len() as u16)?;
218        for attr in &attributes {
219            self.write_attribute(attr)?;
220        }
221
222        Ok(())
223    }
224}