use super::{Config, Loader, Mapper};
use std::convert::TryInto;
use anyhow::{anyhow, Result};
use goblin::elf::{header::*, note::NoteIterator, program_header::*, Elf};
use mmarinus::{perms, Kind, Map};
use primordial::Page;
use sallyport::elf;
use std::ops::Range;
#[derive(Clone, Debug)]
struct Segment<'a> {
bytes: &'a [u8],
range: Range<usize>,
skipb: usize,
flags: u32,
}
pub struct Binary<'a>(&'a [u8], Elf<'a>);
impl<'a> Binary<'a> {
fn new(bytes: &'a [u8]) -> Result<Self> {
let elf = Elf::parse(bytes)?;
if elf.header.e_ident[EI_CLASS] != ELFCLASS64 {
return Err(anyhow!("unsupported ELF header: e_ident[EI_CLASS]"));
}
if elf.header.e_ident[EI_DATA] != ELFDATA2LSB {
return Err(anyhow!("unsupported ELF header: e_ident[EI_DATA]",));
}
if elf.header.e_ident[EI_VERSION] != EV_CURRENT {
return Err(anyhow!("unsupported ELF header: e_ident[EI_VERSION]",));
}
if elf.header.e_machine != EM_X86_64 {
return Err(anyhow!("unsupported ELF header: e_machine"));
}
if elf.header.e_version != EV_CURRENT as u32 {
return Err(anyhow!("unsupported ELF header: e_version"));
}
if elf.program_headers.iter().any(|ph| ph.p_type == PT_INTERP) {
return Err(anyhow!("unsupported ELF header: p_type == PT_INTERP",));
}
if !elf
.program_headers
.iter()
.filter(|ph| ph.p_type == PT_LOAD)
.filter(|ph| elf.header.e_entry >= ph.p_vaddr)
.any(|ph| elf.header.e_entry < ph.p_vaddr + ph.p_memsz)
{
return Err(anyhow!("unsupported ELF header: e_entry"));
}
Ok(Self(bytes, elf))
}
fn segments(&self, relocate: usize) -> impl Iterator<Item = Segment<'_>> {
assert_eq!(relocate % Page::SIZE, 0);
self.headers(PT_LOAD).map(move |phdr| {
let range = phdr.vm_range();
let range = range.start + relocate..range.end + relocate + Page::SIZE - 1;
Segment {
bytes: &self.0[phdr.file_range()],
skipb: phdr.p_vaddr as usize % Page::SIZE,
flags: phdr.p_flags,
range: Range {
start: range.start / Page::SIZE * Page::SIZE,
end: range.end / Page::SIZE * Page::SIZE,
},
}
})
}
fn range(&self) -> Range<usize> {
let lo = self
.headers(PT_LOAD)
.map(|phdr| phdr.vm_range().start)
.min();
let hi = self.headers(PT_LOAD).map(|phdr| phdr.vm_range().end).max();
lo.unwrap_or_default()..hi.unwrap_or_default()
}
pub fn headers(&self, kind: u32) -> impl Iterator<Item = &ProgramHeader> {
self.1
.program_headers
.iter()
.filter(move |ph| ph.p_type == kind)
}
pub fn notes(&self, name: &'a str, kind: u32) -> impl Iterator<Item = &[u8]> {
let empty = NoteIterator {
iters: vec![],
index: 0,
};
self.1
.iter_note_headers(self.0)
.unwrap_or(empty)
.filter_map(Result::ok)
.filter(move |n| n.n_type == kind)
.filter(move |n| n.name == name)
.map(|n| n.desc)
}
#[allow(dead_code)]
pub unsafe fn note<T: Copy>(&self, name: &str, kind: u32) -> Option<T> {
use core::mem::size_of;
for note in self.notes(name, kind) {
if note.len() == size_of::<T>() {
return Some(note.as_ptr().cast::<T>().read_unaligned());
}
}
None
}
}
impl<T: Mapper> Loader for T {
fn load(shim: impl AsRef<[u8]>, exec: impl AsRef<[u8]>) -> Result<Self::Output> {
let sbin = Binary::new(shim.as_ref())?;
let ebin = Binary::new(exec.as_ref())?;
let slot = sbin
.headers(sallyport::elf::pt::EXEC)
.next()
.ok_or_else(|| anyhow!("Shim is missing the executable slot!"))?
.vm_range();
let range = ebin.range();
if range.start != 0 || range.end > slot.end - slot.start {
return Err(anyhow!("The executable doesn't fit in the slot!"));
}
let version = semver::Version::parse(sallyport::VERSION).unwrap();
let supported = sbin
.notes(elf::note::NAME, elf::note::REQUIRES)
.filter_map(|n| std::str::from_utf8(n).ok())
.filter_map(|n| semver::VersionReq::parse(n).ok())
.any(|req| req.matches(&version));
if !supported {
return Err(anyhow!("Unable to satisfy sallyport version requirement!"));
}
let mut loader: Self = Self::Config::new(&sbin, &ebin)?.try_into()?;
let ssegs: Vec<Segment<'_>> = sbin.segments(0).collect();
let esegs: Vec<Segment<'_>> = ebin.segments(slot.start).collect();
let mut sorted: Vec<_> = ssegs.iter().chain(esegs.iter()).collect();
sorted.sort_unstable_by_key(|seg| seg.range.start);
for pair in sorted.windows(2) {
if pair[0].range.end > pair[1].range.start {
return Err(anyhow!("Segments overlap!"));
}
}
for seg in ssegs.iter().chain(esegs.iter()) {
let mut map = Map::map(seg.range.end - seg.range.start)
.anywhere()
.anonymously()
.known::<perms::ReadWrite>(Kind::Private)?;
map[seg.skipb..][..seg.bytes.len()].copy_from_slice(seg.bytes);
let flags = Self::Config::flags(seg.flags);
loader.map(map, seg.range.start, flags)?;
}
loader.try_into()
}
}