use anyhow::{Context, Result};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use crate::graph::cfg_edges_extract::{CfgEdge, CfgEdgeType};
use crate::graph::cfg_extractor::BlockKind;
use crate::graph::schema::CfgBlock;
#[derive(Debug, thiserror::Error)]
pub enum ClassParseError {
#[error("Failed to read .class file: {0}")]
IoError(#[from] std::io::Error),
#[error("Invalid .class file format: {0}")]
InvalidFormat(String),
#[error("No methods found in .class file")]
NoMethodsFound,
#[error("Method not found: {0}")]
MethodNotFound(String),
}
pub type CfgWithEdges = crate::graph::cfg_edges_extract::CfgWithEdges;
#[derive(Debug, Clone)]
struct BytecodeInstruction {
opcode: u8,
operands: Vec<u8>,
offset: usize,
}
#[derive(Debug, Clone)]
struct BytecodeBlock {
start_offset: usize,
end_offset: usize,
instructions: Vec<BytecodeInstruction>,
terminator: BlockTerminator,
}
#[derive(Debug, Clone, PartialEq)]
enum BlockTerminator {
Return,
Unconditional { target: usize },
Conditional { target: usize },
Switch { default: usize, cases: Vec<usize> },
Throw,
Fallthrough,
Unknown,
}
pub fn extract_cfg_from_class(class_bytes: &[u8]) -> Result<HashMap<String, CfgWithEdges>> {
if class_bytes.len() < 4 {
return Err(ClassParseError::InvalidFormat("File too short".to_string()).into());
}
let magic = &class_bytes[0..4];
if magic != &[0xCA, 0xFE, 0xBA, 0xBE] {
return Err(ClassParseError::InvalidFormat("Invalid magic number".to_string()).into());
}
let mut result = HashMap::new();
let methods = find_method_bytecode(class_bytes)?;
if methods.is_empty() {
return Err(ClassParseError::NoMethodsFound.into());
}
for (method_name, bytecode) in methods {
let cfg = build_cfg_from_bytecode(&method_name, &bytecode)?;
result.insert(method_name, cfg);
}
Ok(result)
}
pub fn extract_cfg_for_method(class_bytes: &[u8], method_name: &str) -> Result<CfgWithEdges> {
let methods = extract_cfg_from_class(class_bytes)?;
let cfg = methods
.get(method_name)
.ok_or_else(|| ClassParseError::MethodNotFound(method_name.to_string()))?;
Ok(cfg.clone())
}
fn find_method_bytecode(class_bytes: &[u8]) -> Result<HashMap<String, Vec<u8>>> {
let mut methods = HashMap::new();
if class_bytes.len() < 10 {
return Ok(methods);
}
let mut pos = 8;
let constant_pool_count = u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
for _ in 0..constant_pool_count {
if pos >= class_bytes.len() {
return Ok(methods);
}
let tag = class_bytes[pos];
pos += 1;
match tag {
1 => {
if pos + 1 >= class_bytes.len() {
return Ok(methods);
}
let length = u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2 + length;
}
3 | 4 | 9 | 10 | 11 | 12 => {
pos += 4;
}
5 | 6 => {
pos += 8;
}
7 | 8 => {
pos += 2;
}
_ => {
return Ok(methods);
}
}
}
pos += 6;
if pos + 2 > class_bytes.len() {
return Ok(methods);
}
let interfaces_count = u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2 + (interfaces_count * 2);
if pos + 2 > class_bytes.len() {
return Ok(methods);
}
let fields_count = u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
for _ in 0..fields_count {
if pos + 6 > class_bytes.len() {
return Ok(methods);
}
pos += 6;
let attributes_count =
u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
for _ in 0..attributes_count {
if pos + 2 > class_bytes.len() {
return Ok(methods);
}
pos += 2;
let attribute_length = u32::from_be_bytes([
class_bytes[pos],
class_bytes[pos + 1],
class_bytes[pos + 2],
class_bytes[pos + 3],
]) as usize;
pos += 4 + attribute_length;
}
}
if pos + 2 > class_bytes.len() {
return Ok(methods);
}
let methods_count = u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
for _ in 0..methods_count {
if pos + 8 > class_bytes.len() {
return Ok(methods);
}
pos += 6;
let attributes_count =
u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
let method_name = format!("method_{}", methods.len());
for _ in 0..attributes_count {
if pos + 6 > class_bytes.len() {
return Ok(methods);
}
pos += 2;
let attribute_length = u32::from_be_bytes([
class_bytes[pos],
class_bytes[pos + 1],
class_bytes[pos + 2],
class_bytes[pos + 3],
]) as usize;
pos += 4;
let attribute_start = pos;
if pos + 6 <= class_bytes.len() {
let max_stack =
u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
let max_locals =
u16::from_be_bytes([class_bytes[pos + 2], class_bytes[pos + 3]]) as usize;
let code_length = u32::from_be_bytes([
class_bytes[pos + 4],
class_bytes[pos + 5],
class_bytes[pos + 6],
class_bytes[pos + 7],
]) as usize;
pos += 8;
if max_stack > 0
&& max_locals > 0
&& code_length > 0
&& pos + code_length <= class_bytes.len()
{
let bytecode = class_bytes[pos..pos + code_length].to_vec();
if !bytecode.is_empty() {
methods.insert(method_name.clone(), bytecode);
}
pos += code_length;
if pos + 4 > class_bytes.len() {
break;
}
let exception_table_length =
u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
pos += exception_table_length * 8;
if pos + 2 > class_bytes.len() {
break;
}
let code_attributes_count =
u16::from_be_bytes([class_bytes[pos], class_bytes[pos + 1]]) as usize;
pos += 2;
for _ in 0..code_attributes_count {
if pos + 2 > class_bytes.len() {
break;
}
pos += 2;
let code_attr_length = u32::from_be_bytes([
class_bytes[pos],
class_bytes[pos + 1],
class_bytes[pos + 2],
class_bytes[pos + 3],
]) as usize;
pos += 4 + code_attr_length;
}
} else {
pos = attribute_start + attribute_length;
}
} else {
pos = attribute_start + attribute_length;
}
}
}
Ok(methods)
}
fn build_cfg_from_bytecode(_method_name: &str, bytecode: &[u8]) -> Result<CfgWithEdges> {
let instructions = parse_bytecode(bytecode)?;
if instructions.is_empty() {
return Ok(CfgWithEdges {
blocks: vec![],
edges: vec![],
function_id: 0,
});
}
let blocks = identify_basic_blocks(&instructions);
let mut block_map: HashMap<usize, usize> = HashMap::new();
for (idx, block) in blocks.iter().enumerate() {
block_map.insert(block.start_offset, idx);
}
let mut cfg_blocks: Vec<CfgBlock> = Vec::new();
for (idx, block) in blocks.iter().enumerate() {
let kind = if idx == 0 {
BlockKind::Entry
} else if block.terminator == BlockTerminator::Return {
BlockKind::Return
} else {
BlockKind::For };
cfg_blocks.push(CfgBlock {
cfg_hash: None,
statements: Some(
block
.instructions
.iter()
.map(|instr| format!("opcode: {:#02x}", instr.opcode))
.collect(),
),
function_id: 0,
kind: format!("{:?}", kind),
terminator: format!("{:?}", block.terminator),
byte_start: block.start_offset as u64,
byte_end: block.end_offset as u64,
start_line: 0,
start_col: 0,
end_line: 0,
end_col: 0,
coord_x: 0,
coord_y: 0,
coord_z: 0,
coord_t: None,
});
}
let mut cfg_edges: Vec<CfgEdge> = Vec::new();
for (idx, block) in blocks.iter().enumerate() {
match &block.terminator {
BlockTerminator::Fallthrough => {
if idx + 1 < blocks.len() {
cfg_edges.push(CfgEdge {
source_idx: idx,
target_idx: idx + 1,
edge_type: CfgEdgeType::Fallthrough,
});
}
}
BlockTerminator::Unconditional { target } => {
if let Some(&target_idx) = block_map.get(target) {
cfg_edges.push(CfgEdge {
source_idx: idx,
target_idx: target_idx,
edge_type: CfgEdgeType::Jump,
});
}
}
BlockTerminator::Conditional { target } => {
if let Some(&target_idx) = block_map.get(target) {
cfg_edges.push(CfgEdge {
source_idx: idx,
target_idx: target_idx,
edge_type: CfgEdgeType::ConditionalTrue,
});
}
if idx + 1 < blocks.len() {
cfg_edges.push(CfgEdge {
source_idx: idx,
target_idx: idx + 1,
edge_type: CfgEdgeType::ConditionalFalse,
});
}
}
BlockTerminator::Switch { default, cases } => {
if let Some(&default_idx) = block_map.get(default) {
cfg_edges.push(CfgEdge {
source_idx: idx,
target_idx: default_idx,
edge_type: CfgEdgeType::Jump,
});
}
for case_target in cases {
if let Some(&target_idx) = block_map.get(case_target) {
cfg_edges.push(CfgEdge {
source_idx: idx,
target_idx: target_idx,
edge_type: CfgEdgeType::Jump,
});
}
}
}
BlockTerminator::Return | BlockTerminator::Throw => {
}
BlockTerminator::Unknown => {
}
}
}
Ok(CfgWithEdges {
blocks: cfg_blocks,
edges: cfg_edges,
function_id: 0,
})
}
fn parse_bytecode(bytecode: &[u8]) -> Result<Vec<BytecodeInstruction>> {
let mut instructions = Vec::new();
let mut offset = 0;
while offset < bytecode.len() {
let opcode = bytecode[offset];
let mut operands = Vec::new();
let size = 1;
for i in 1..size {
if offset + i < bytecode.len() {
operands.push(bytecode[offset + i]);
}
}
instructions.push(BytecodeInstruction {
opcode,
operands,
offset,
});
offset += size;
}
Ok(instructions)
}
fn identify_basic_blocks(instructions: &[BytecodeInstruction]) -> Vec<BytecodeBlock> {
if instructions.is_empty() {
return vec![];
}
let mut blocks: Vec<BytecodeBlock> = vec![];
let mut block_starts: Vec<usize> = vec![0];
for instr in instructions {
match instr.opcode {
0x99 | 0x9A | 0x9B | 0x9C | 0x9D | 0x9E | 0x9F | 0xA0 | 0xA1 | 0xA2 | 0xA3 | 0xA4
| 0xA5 | 0xA6 | 0xA7 => {
if instr.operands.len() >= 2 {
let offset_bytes = [instr.operands[0], instr.operands[1]];
let target_offset = i16::from_be_bytes(offset_bytes) as i32 as usize;
block_starts.push(target_offset);
}
}
0xC8 => {
if instr.operands.len() >= 2 {
let offset_bytes = [instr.operands[0], instr.operands[1]];
let target_offset = i16::from_be_bytes(offset_bytes) as i32 as usize;
block_starts.push(target_offset);
}
}
0xAA | 0xAB => {
}
0xAC | 0xAD | 0xAE | 0xAF | 0xB0 | 0xB1 | 0xB2 | 0xB3 | 0xB4 | 0xB5 | 0xB6 | 0xB7
| 0xB8 | 0xB9 | 0xBA | 0xBB | 0xBC | 0xBD | 0xBE | 0xBF => {
}
_ => {
}
}
}
block_starts.sort();
block_starts.dedup();
for (i, &start) in block_starts.iter().enumerate() {
let end = if i + 1 < block_starts.len() {
block_starts[i + 1]
} else {
instructions.len()
};
if start < instructions.len() {
let block_instructions = instructions[start..end].to_vec();
let terminator = classify_terminator(&block_instructions);
blocks.push(BytecodeBlock {
start_offset: start,
end_offset: end,
instructions: block_instructions,
terminator,
});
}
}
blocks
}
fn classify_terminator(instructions: &[BytecodeInstruction]) -> BlockTerminator {
if instructions.is_empty() {
return BlockTerminator::Unknown;
}
let last_instr = &instructions[instructions.len() - 1];
match last_instr.opcode {
0xAC | 0xAD | 0xAE | 0xAF | 0xB0 | 0xB1 => BlockTerminator::Return,
0xA7 | 0xC8 => {
if last_instr.operands.len() >= 2 {
let offset_bytes = [last_instr.operands[0], last_instr.operands[1]];
let target_offset = i16::from_be_bytes(offset_bytes) as i32 as usize;
BlockTerminator::Unconditional {
target: target_offset,
}
} else {
BlockTerminator::Unknown
}
}
0x99 | 0x9A | 0x9B | 0x9C | 0x9D | 0x9E | 0x9F | 0xA0 | 0xA1 | 0xA2 | 0xA3 | 0xA4
| 0xA5 | 0xA6 => {
if last_instr.operands.len() >= 2 {
let offset_bytes = [last_instr.operands[0], last_instr.operands[1]];
let target_offset = i16::from_be_bytes(offset_bytes) as i32 as usize;
BlockTerminator::Conditional {
target: target_offset,
}
} else {
BlockTerminator::Unknown
}
}
0xAA | 0xAB => BlockTerminator::Switch {
default: 0,
cases: vec![],
},
0xBF => BlockTerminator::Throw,
_ => BlockTerminator::Fallthrough,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_bytecode_empty() {
let bytecode = vec![];
let result = parse_bytecode(&bytecode);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_parse_bytecode_simple() {
let bytecode = vec![0x03, 0x3C, 0x1C, 0xAC]; let result = parse_bytecode(&bytecode);
assert!(result.is_ok());
let instructions = result.unwrap();
assert_eq!(instructions.len(), 4);
assert_eq!(instructions[0].opcode, 0x03);
assert_eq!(instructions[3].opcode, 0xAC);
}
#[test]
fn test_identify_basic_blocks_empty() {
let instructions = vec![];
let blocks = identify_basic_blocks(&instructions);
assert!(blocks.is_empty());
}
#[test]
fn test_identify_basic_blocks_simple() {
let instructions = vec![
BytecodeInstruction {
opcode: 0x03,
operands: vec![],
offset: 0,
},
BytecodeInstruction {
opcode: 0x3C,
operands: vec![],
offset: 1,
},
BytecodeInstruction {
opcode: 0xAC,
operands: vec![],
offset: 2,
},
];
let blocks = identify_basic_blocks(&instructions);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].terminator, BlockTerminator::Return);
}
#[test]
fn test_extract_cfg_from_class_invalid_magic() {
let invalid_bytes = vec![0x00, 0x00, 0x00, 0x00]; let result = extract_cfg_from_class(&invalid_bytes);
assert!(result.is_err());
}
#[test]
fn test_extract_cfg_from_class_too_short() {
let short_bytes = vec![0xCA, 0xFE]; let result = extract_cfg_from_class(&short_bytes);
assert!(result.is_err());
}
}