use solana_rbpf::ebpf;
use solana_rbpf::elf::Executable;
use solana_rbpf::program::{BuiltinProgram, SBPFVersion, FunctionRegistry};
use solana_rbpf::vm::Config;
use std::sync::Arc;
use elf::ElfBytes;
use elf::endian::{AnyEndian, EndianParse};
use elf::file::Class;
use std::fs;
#[derive(Debug, serde::Serialize)]
pub struct Issue {
kind: String,
offset: usize,
desc: String,
}
pub struct Optimizer {
insns: Vec<solana_rbpf::ebpf::Insn>,
issues: Vec<Issue>,
elf_bytes: Vec<u8>,
text_section_idx: usize,
}
impl Optimizer {
pub fn new(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let elf_bytes = fs::read(path)?;
let elf = ElfBytes::<AnyEndian>::minimal_parse(&elf_bytes)?;
let (shdrs_opt, strtab_opt) = elf.section_headers_with_strtab()?;
let shdrs = shdrs_opt.ok_or("No section headers")?;
let strtab = strtab_opt.ok_or("No string table")?;
let text_section_idx = shdrs
.iter()
.position(|sh| strtab.get(sh.sh_name as usize).ok() == Some(".text"))
.ok_or("No .text section")?;
let text_section = shdrs.get(text_section_idx).map_err(|_| "Invalid text section index")?;
let text_bytes = elf.section_data(&text_section)?.0;
let insns = Self::disassemble_text_bytes(text_bytes)?;
Ok(Self { insns, issues: Vec::new(), elf_bytes, text_section_idx })
}
fn disassemble_text_bytes(bytes: &[u8]) -> Result<Vec<solana_rbpf::ebpf::Insn>, Box<dyn std::error::Error>> {
let mut insns = Vec::new();
let mut offset = 0;
while offset + 8 <= bytes.len() {
let chunk = &bytes[offset..offset + 8];
insns.push(ebpf::Insn {
ptr: 0,
opc: chunk[0],
dst: chunk[1] & 0x0F,
src: (chunk[1] >> 4) & 0x0F,
off: i16::from_le_bytes([chunk[2], chunk[3]]),
imm: i64::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7], 0, 0, 0, 0]),
});
offset += 8;
}
Ok(insns)
}
pub fn remove_logs(&mut self) {
let original_len = self.insns.len();
self.insns.retain(|insn| {
if insn.opc == 0x91 { self.issues.push(Issue {
kind: "LogRemoved".to_string(),
offset: insn.off as usize,
desc: "Removed redundant sol_log call".to_string(),
});
false
} else {
true
}
});
println!("Removed {} log instructions", original_len - self.insns.len());
}
#[allow(unused_mut)]
pub fn merge_loads(&mut self) {
let mut i = 0;
while i < self.insns.len() - 1 {
if self.insns[i].opc == ebpf::LD_DW_IMM {
let mut j = i + 1;
while j < self.insns.len() {
if self.insns[j].opc == ebpf::LD_DW_IMM &&
self.insns[i].dst == self.insns[j].dst &&
self.insns[i].imm == self.insns[j].imm {
self.insns.remove(j);
self.issues.push(Issue {
kind: "LoadMerged".to_string(),
offset: i,
desc: "Merged redundant load".to_string(),
});
} else {
break;
}
}
}
i += 1;
}
println!("Merged redundant load instructions");
}
pub fn merge_arithmetic(&mut self) {
let mut i = 0;
while i < self.insns.len() - 1 {
let curr = &self.insns[i];
if curr.opc == ebpf::ADD64_IMM || curr.opc == ebpf::SUB64_IMM {
let next = &self.insns[i + 1];
if next.opc == ebpf::ADD64_IMM && curr.dst == next.dst && curr.src == 0 && next.src == 0 {
self.insns[i].imm += next.imm;
self.insns.remove(i + 1);
self.issues.push(Issue {
kind: "AddMerged".to_string(),
offset: i,
desc: "Merged consecutive additions".to_string(),
});
continue;
} else if next.opc == ebpf::SUB64_IMM && curr.dst == next.dst && curr.src == 0 && next.src == 0 {
if curr.opc == ebpf::ADD64_IMM {
self.insns[i].imm -= next.imm;
} else {
self.insns[i].imm += next.imm;
self.insns[i].opc = ebpf::ADD64_IMM;
}
self.insns.remove(i + 1);
self.issues.push(Issue {
kind: "ArithmeticMerged".to_string(),
offset: i,
desc: "Merged addition and subtraction".to_string(),
});
continue;
} else if (curr.opc == ebpf::ADD64_IMM && next.opc == ebpf::SUB64_IMM) ||
(curr.opc == ebpf::SUB64_IMM && next.opc == ebpf::ADD64_IMM) {
if curr.dst == next.dst && curr.imm == next.imm && curr.src == 0 && next.src == 0 {
self.insns.remove(i + 1);
self.insns.remove(i);
self.issues.push(Issue {
kind: "ArithmeticEliminated".to_string(),
offset: i,
desc: "Eliminated canceling addition and subtraction".to_string(),
});
continue;
}
}
}
i += 1;
}
println!("Merged arithmetic instructions");
}
pub fn fold_constants(&mut self) {
let mut i = 0;
while i < self.insns.len() - 1 {
let curr = &self.insns[i];
if curr.opc == ebpf::LD_DW_IMM {
let next = &self.insns[i + 1];
if next.opc == ebpf::ADD64_IMM && next.dst == curr.dst && next.src == 0 {
self.insns[i].imm += next.imm;
self.insns.remove(i + 1);
self.issues.push(Issue {
kind: "ConstantFolded".to_string(),
offset: i,
desc: "Folded load and addition constants".to_string(),
});
continue;
} else if next.opc == ebpf::SUB64_IMM && next.dst == curr.dst && next.src == 0 {
self.insns[i].imm -= next.imm;
self.insns.remove(i + 1);
self.issues.push(Issue {
kind: "ConstantFolded".to_string(),
offset: i,
desc: "Folded load and subtraction constants".to_string(),
});
continue;
}
}
i += 1;
}
println!("Folded constant computations");
}
pub fn eliminate_dead_code(&mut self) {
let mut i = 0;
while i < self.insns.len() - 1 {
let curr = &self.insns[i];
let next = &self.insns[i + 1];
if curr.dst == next.dst && next.opc != ebpf::EXIT && !Self::reads_src(curr, next) {
self.insns.remove(i);
self.issues.push(Issue {
kind: "DeadCodeEliminated".to_string(),
offset: i,
desc: "Removed overwritten dead code".to_string(),
});
continue;
}
i += 1;
}
println!("Eliminated dead code");
}
fn reads_src(curr: &solana_rbpf::ebpf::Insn, next: &solana_rbpf::ebpf::Insn) -> bool {
next.src == curr.dst || (next.opc == ebpf::JA && curr.dst == 0)
}
pub fn reduce_strength(&mut self) {
for (i, insn) in self.insns.iter_mut().enumerate() {
if insn.opc == ebpf::MUL64_IMM && insn.src == 0 {
match insn.imm {
2 => {
insn.opc = ebpf::LSH64_IMM;
insn.imm = 1;
self.issues.push(Issue {
kind: "StrengthReduced".to_string(),
offset: i,
desc: "Replaced multiplication by 2 with left shift by 1".to_string(),
});
}
4 => {
insn.opc = ebpf::LSH64_IMM;
insn.imm = 2;
self.issues.push(Issue {
kind: "StrengthReduced".to_string(),
offset: i,
desc: "Replaced multiplication by 4 with left shift by 2".to_string(),
});
}
_ => {}
}
}
}
println!("Reduced instruction strength");
}
pub fn optimize_branches(&mut self) {
let mut i = 0;
while i < self.insns.len() - 1 {
let curr = &self.insns[i];
if curr.opc == ebpf::JEQ_IMM && curr.off >= 0 {
let next = &self.insns[i + 1];
if next.opc == ebpf::JA && (i as i16 + curr.off + 1) == (i as i16 + next.off + 1) {
self.insns.remove(i + 1);
self.insns.remove(i);
self.issues.push(Issue {
kind: "BranchEliminated".to_string(),
offset: i,
desc: "Eliminated redundant branch".to_string(),
});
continue;
}
}
i += 1;
}
println!("Optimized branch instructions");
}
pub fn check_size(&mut self) {
let size = self.insns.len() * 8;
if size > 128 * 1024 {
self.issues.push(Issue {
kind: "SizeExceeded".to_string(),
offset: 0,
desc: format!("Program size {} bytes exceeds 128KB", size),
});
}
}
pub fn generate(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
use solana_rbpf::vm::TestContextObject;
let loader = Arc::new(BuiltinProgram::<TestContextObject>::new_loader(
Config::default(),
FunctionRegistry::default(),
));
let executable = Executable::from_text_bytes(
&self.insns.iter().flat_map(|insn| {
let imm_bytes = insn.imm.to_le_bytes();
[
insn.opc,
(insn.dst & 0x0F) | ((insn.src & 0x0F) << 4),
insn.off.to_le_bytes()[0],
insn.off.to_le_bytes()[1],
imm_bytes[0],
imm_bytes[1],
imm_bytes[2],
imm_bytes[3],
]
}).collect::<Vec<u8>>(),
loader,
SBPFVersion::V2,
FunctionRegistry::default(),
)?;
let optimized_text = executable.get_text_bytes().1.to_vec();
let elf = ElfBytes::<AnyEndian>::minimal_parse(&self.elf_bytes)?;
let ehdr = elf.ehdr;
let (shdrs_opt, _) = elf.section_headers_with_strtab()?;
let shdrs = shdrs_opt.ok_or("No section headers")?;
let mut new_shdrs: Vec<_> = shdrs.iter().collect();
let mut elf_bytes = Vec::new();
let header_size = ehdr.e_ehsize as usize;
elf_bytes.resize(header_size, 0);
let mut offset = header_size as u64;
let mut section_data = Vec::new();
for (i, sh) in new_shdrs.iter_mut().enumerate() {
let data = if i == self.text_section_idx {
optimized_text.clone()
} else {
elf.section_data(sh)?.0.to_vec()
};
section_data.push((offset, data.clone()));
sh.sh_offset = offset;
sh.sh_size = data.len() as u64;
offset += data.len() as u64;
offset = (offset + 7) & !7; }
let shoff = offset;
println!("e_shoff: 0x{:x}, offset after sections: 0x{:x}", shoff, offset);
for (_, data) in section_data.iter() {
elf_bytes.extend_from_slice(data);
let padding = (8 - (data.len() % 8)) % 8;
elf_bytes.extend_from_slice(&vec![0; padding]);
}
let section_table_start = elf_bytes.len();
for sh in new_shdrs.iter() {
let sh_bytes = if ehdr.class == Class::ELF64 {
let mut bytes = Vec::new();
bytes.extend_from_slice(&sh.sh_name.to_le_bytes()); bytes.extend_from_slice(&sh.sh_type.to_le_bytes()); bytes.extend_from_slice(&sh.sh_flags.to_le_bytes()); bytes.extend_from_slice(&sh.sh_addr.to_le_bytes()); bytes.extend_from_slice(&sh.sh_offset.to_le_bytes()); bytes.extend_from_slice(&sh.sh_size.to_le_bytes()); bytes.extend_from_slice(&sh.sh_link.to_le_bytes()); bytes.extend_from_slice(&sh.sh_info.to_le_bytes()); bytes.extend_from_slice(&sh.sh_addralign.to_le_bytes()); bytes.extend_from_slice(&sh.sh_entsize.to_le_bytes()); bytes
} else {
let mut bytes = Vec::new();
bytes.extend_from_slice(&sh.sh_name.to_le_bytes()); bytes.extend_from_slice(&sh.sh_type.to_le_bytes()); bytes.extend_from_slice(&(sh.sh_flags as u32).to_le_bytes()); bytes.extend_from_slice(&(sh.sh_addr as u32).to_le_bytes()); bytes.extend_from_slice(&(sh.sh_offset as u32).to_le_bytes()); bytes.extend_from_slice(&(sh.sh_size as u32).to_le_bytes()); bytes.extend_from_slice(&sh.sh_link.to_le_bytes()); bytes.extend_from_slice(&sh.sh_info.to_le_bytes()); bytes.extend_from_slice(&(sh.sh_addralign as u32).to_le_bytes()); bytes.extend_from_slice(&(sh.sh_entsize as u32).to_le_bytes()); bytes
};
elf_bytes.extend_from_slice(&sh_bytes);
}
let class_val = match ehdr.class {
Class::ELF32 => elf::abi::ELFCLASS32,
Class::ELF64 => elf::abi::ELFCLASS64,
};
let endian_val = if ehdr.endianness.is_little() {
elf::abi::ELFDATA2LSB
} else {
elf::abi::ELFDATA2MSB
};
let mut ehdr_bytes = Vec::new();
ehdr_bytes.extend_from_slice(&[0x7f, b'E', b'L', b'F']); ehdr_bytes.extend_from_slice(&[class_val, endian_val, ehdr.version.try_into()?, ehdr.osabi]); ehdr_bytes.extend_from_slice(&[ehdr.abiversion, 0, 0, 0, 0, 0, 0, 0]); ehdr_bytes.extend_from_slice(&ehdr.e_type.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_machine.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.version.to_le_bytes()); if ehdr.class == Class::ELF64 {
ehdr_bytes.extend_from_slice(&ehdr.e_entry.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_phoff.to_le_bytes()); ehdr_bytes.extend_from_slice(&shoff.to_le_bytes()); } else {
ehdr_bytes.extend_from_slice(&(ehdr.e_entry as u32).to_le_bytes()); ehdr_bytes.extend_from_slice(&(ehdr.e_phoff as u32).to_le_bytes()); ehdr_bytes.extend_from_slice(&(shoff as u32).to_le_bytes()); }
ehdr_bytes.extend_from_slice(&ehdr.e_flags.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_ehsize.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_phentsize.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_phnum.to_le_bytes()); let shentsize: u16 = match ehdr.class {
Class::ELF32 => 40,
Class::ELF64 => 64,
};
ehdr_bytes.extend_from_slice(&shentsize.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_shnum.to_le_bytes()); ehdr_bytes.extend_from_slice(&ehdr.e_shstrndx.to_le_bytes());
if ehdr_bytes.len() < header_size {
ehdr_bytes.resize(header_size, 0);
}
elf_bytes[..header_size].copy_from_slice(&ehdr_bytes);
println!("Section table start: 0x{:x}, Final file length: 0x{:x} ({} bytes)", section_table_start, elf_bytes.len(), elf_bytes.len());
Ok(elf_bytes)
}
pub fn report(&self) -> String {
serde_json::to_string_pretty(&self.issues).unwrap_or("[]".to_string())
}
}
pub fn optimize_sbf(input_path: &str, output_path: &str) -> Result<String, Box<dyn std::error::Error>> {
let mut optimizer = Optimizer::new(input_path)?;
optimizer.remove_logs();
optimizer.merge_loads();
optimizer.merge_arithmetic();
optimizer.fold_constants();
optimizer.reduce_strength();
optimizer.optimize_branches();
optimizer.check_size();
let optimized_bytes = optimizer.generate()?;
fs::write(output_path, optimized_bytes)?;
Ok(optimizer.report())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args: Vec<String> = std::env::args().collect();
if args.len() != 5 || args[1] != "--input" || args[3] != "--output" {
eprintln!("Usage: {} --input <input_file> --output <output_file>", args[0]);
std::process::exit(1);
}
let input_path = &args[2];
let output_path = &args[4];
let report = optimize_sbf(input_path, output_path)?;
println!("Optimization report:\n{}", report);
Ok(())
}