use super::consts::*;
use std::io::{self, Read, Seek, SeekFrom};
use zerocopy::{FromBytes, LE, U16, U32, U64};
use zerocopy_derive::FromBytes as DeriveFromBytes;
#[derive(Debug, Clone, DeriveFromBytes)]
#[repr(C)]
struct RawDirectoryEntry {
name: [u8; 64],
name_len: U16<LE>,
entry_type: u8,
node_color: u8,
sid_left: U32<LE>,
sid_right: U32<LE>,
sid_child: U32<LE>,
clsid: [u8; 16],
state_bits: U32<LE>,
creation_time: U64<LE>,
modified_time: U64<LE>,
start_sector: U32<LE>,
stream_size: U64<LE>,
}
#[derive(Debug)]
pub struct OleFile<R: Read + Seek> {
reader: R,
file_size: u64,
sector_size: usize,
mini_sector_size: usize,
mini_stream_cutoff: u32,
fat: Vec<u32>,
minifat: Vec<u32>,
first_dir_sector: u32,
root: Option<DirectoryEntry>,
dir_entries: Vec<Option<DirectoryEntry>>,
ministream: Option<Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct DirectoryEntry {
pub sid: u32,
pub name: String,
pub entry_type: u8,
pub sid_left: u32,
pub sid_right: u32,
pub sid_child: u32,
pub clsid: String,
pub start_sector: u32,
pub size: u64,
pub is_minifat: bool,
pub children: Vec<DirectoryEntry>,
}
#[derive(Debug)]
pub enum OleError {
Io(io::Error),
InvalidFormat(String),
InvalidData(String),
NotOleFile,
CorruptedFile(String),
StreamNotFound,
}
impl From<io::Error> for OleError {
fn from(err: io::Error) -> Self {
OleError::Io(err)
}
}
impl std::fmt::Display for OleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OleError::Io(e) => write!(f, "IO error: {}", e),
OleError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
OleError::InvalidData(s) => write!(f, "Invalid data: {}", s),
OleError::NotOleFile => write!(f, "Not an OLE file"),
OleError::CorruptedFile(s) => write!(f, "Corrupted file: {}", s),
OleError::StreamNotFound => write!(f, "Stream not found"),
}
}
}
impl std::error::Error for OleError {}
impl<R: Read + Seek> OleFile<R> {
pub fn open(mut reader: R) -> Result<Self, OleError> {
let file_size = reader.seek(SeekFrom::End(0))?;
reader.seek(SeekFrom::Start(0))?;
if file_size < MINIMAL_OLEFILE_SIZE as u64 {
return Err(OleError::NotOleFile);
}
let mut header = [0u8; 512];
reader.read_exact(&mut header)?;
if &header[0..8] != MAGIC {
return Err(OleError::NotOleFile);
}
let dll_version = U16::<LE>::read_from_bytes(&header[0x1A..0x1C])
.map(|v| v.get())
.unwrap_or(0);
let byte_order = U16::<LE>::read_from_bytes(&header[0x1C..0x1E])
.map(|v| v.get())
.unwrap_or(0);
let sector_shift = U16::<LE>::read_from_bytes(&header[0x1E..0x20])
.map(|v| v.get())
.unwrap_or(0);
let mini_sector_shift = U16::<LE>::read_from_bytes(&header[0x20..0x22])
.map(|v| v.get())
.unwrap_or(0);
let first_dir_sector = U32::<LE>::read_from_bytes(&header[0x30..0x34])
.map(|v| v.get())
.unwrap_or(0);
let mini_stream_cutoff = U32::<LE>::read_from_bytes(&header[0x38..0x3C])
.map(|v| v.get())
.unwrap_or(0);
let first_minifat_sector = U32::<LE>::read_from_bytes(&header[0x3C..0x40])
.map(|v| v.get())
.unwrap_or(0);
let num_minifat_sectors = U32::<LE>::read_from_bytes(&header[0x40..0x44])
.map(|v| v.get())
.unwrap_or(0);
let first_difat_sector = U32::<LE>::read_from_bytes(&header[0x44..0x48])
.map(|v| v.get())
.unwrap_or(0);
let num_difat_sectors = U32::<LE>::read_from_bytes(&header[0x48..0x4C])
.map(|v| v.get())
.unwrap_or(0);
if byte_order != 0xFFFE {
return Err(OleError::InvalidFormat("Invalid byte order".to_string()));
}
let sector_size = 1usize << sector_shift;
let mini_sector_size = 1usize << mini_sector_shift;
if (dll_version == 3 && sector_size != 512) || (dll_version == 4 && sector_size != 4096) {
return Err(OleError::InvalidFormat("Sector size mismatch".to_string()));
}
let mut ole = OleFile {
reader,
file_size,
sector_size,
mini_sector_size,
mini_stream_cutoff,
fat: Vec::new(),
minifat: Vec::new(),
first_dir_sector,
root: None,
dir_entries: Vec::new(),
ministream: None,
};
ole.load_fat(&header, first_difat_sector, num_difat_sectors)?;
ole.load_directory()?;
if num_minifat_sectors > 0 {
ole.load_minifat(first_minifat_sector)?;
}
Ok(ole)
}
pub fn file_size(&self) -> u64 {
self.file_size
}
fn load_fat(
&mut self,
header: &[u8; 512],
first_difat_sector: u32,
num_difat_sectors: u32,
) -> Result<(), OleError> {
let mut fat_sectors = Vec::new();
for i in 0..109 {
let offset = 0x4C + i * 4;
if offset + 4 > 512 {
break;
}
let sector = U32::<LE>::read_from_bytes(&header[offset..offset + 4])
.map(|v| v.get())
.unwrap_or(0);
if sector == FREESECT || sector == ENDOFCHAIN {
break;
}
fat_sectors.push(sector);
}
if num_difat_sectors > 0 {
let mut difat_sector = first_difat_sector;
let entries_per_sector = (self.sector_size / 4) - 1;
for _ in 0..num_difat_sectors {
let sector_data = self.read_sector(difat_sector)?;
for i in 0..entries_per_sector {
let offset = i * 4;
let sector = U32::<LE>::read_from_bytes(§or_data[offset..offset + 4])
.map(|v| v.get())
.unwrap_or(0);
if sector == FREESECT || sector == ENDOFCHAIN {
break;
}
fat_sectors.push(sector);
}
let next_offset = entries_per_sector * 4;
difat_sector = U32::<LE>::read_from_bytes(§or_data[next_offset..next_offset + 4])
.map(|v| v.get())
.unwrap_or(0);
if difat_sector == ENDOFCHAIN || difat_sector == FREESECT {
break;
}
}
}
let entries_per_sector = self.sector_size / 4;
self.fat.reserve(fat_sectors.len() * entries_per_sector);
for §or_id in &fat_sectors {
let sector_data = self.read_sector(sector_id)?;
for i in 0..entries_per_sector {
let offset = i * 4;
let entry = U32::<LE>::read_from_bytes(§or_data[offset..offset + 4])
.map(|v| v.get())
.unwrap_or(0);
self.fat.push(entry);
}
}
Ok(())
}
fn load_minifat(&mut self, first_minifat_sector: u32) -> Result<(), OleError> {
let minifat_data = self.read_stream_from_fat(first_minifat_sector)?;
let entries_count = minifat_data.len() / 4;
self.minifat.reserve(entries_count);
for i in 0..entries_count {
let offset = i * 4;
let entry = U32::<LE>::read_from_bytes(&minifat_data[offset..offset + 4])
.map_err(|_| OleError::InvalidFormat("Failed to read u32".to_string()))?;
self.minifat.push(entry.get());
}
Ok(())
}
fn load_directory(&mut self) -> Result<(), OleError> {
let dir_data = self.read_stream_from_fat(self.first_dir_sector)?;
let num_entries = dir_data.len() / DIRENTRY_SIZE;
self.dir_entries = vec![None; num_entries];
if num_entries > 0 {
let root = self.parse_directory_entry(&dir_data[0..DIRENTRY_SIZE], 0)?;
let root_child_sid = root.sid_child;
self.root = Some(root);
self.build_storage_tree(root_child_sid, &dir_data)?;
}
Ok(())
}
fn parse_directory_entry(&self, data: &[u8], sid: u32) -> Result<DirectoryEntry, OleError> {
let raw = RawDirectoryEntry::read_from_bytes(data)
.map_err(|_| OleError::InvalidFormat("Failed to parse directory entry".to_string()))?;
let name_len = raw.name_len.get() as usize;
let name_bytes = &raw.name[0..name_len.saturating_sub(2).min(64)];
let name = decode_utf16le(name_bytes);
let clsid = format_clsid(&raw.clsid);
let size = if self.sector_size == 512 {
raw.stream_size.get() & 0xFFFFFFFF
} else {
raw.stream_size.get()
};
let is_minifat = size < self.mini_stream_cutoff as u64 && raw.entry_type == STGTY_STREAM;
Ok(DirectoryEntry {
sid,
name,
entry_type: raw.entry_type,
sid_left: raw.sid_left.get(),
sid_right: raw.sid_right.get(),
sid_child: raw.sid_child.get(),
clsid,
start_sector: raw.start_sector.get(),
size,
is_minifat,
children: Vec::new(),
})
}
fn build_storage_tree(&mut self, child_sid: u32, dir_data: &[u8]) -> Result<(), OleError> {
if child_sid == NOSTREAM {
return Ok(());
}
let sid = child_sid as usize;
if sid >= dir_data.len() / DIRENTRY_SIZE {
return Err(OleError::CorruptedFile(
"Invalid directory entry index".to_string(),
));
}
if self.dir_entries[sid].is_none() {
let offset = sid * DIRENTRY_SIZE;
let entry =
self.parse_directory_entry(&dir_data[offset..offset + DIRENTRY_SIZE], sid as u32)?;
self.dir_entries[sid] = Some(entry);
}
let entry = self.dir_entries[sid].as_ref().unwrap();
let left_sid = entry.sid_left;
let right_sid = entry.sid_right;
let child_sid = entry.sid_child;
self.build_storage_tree(left_sid, dir_data)?;
self.build_storage_tree(right_sid, dir_data)?;
self.build_storage_tree(child_sid, dir_data)?;
Ok(())
}
fn read_sector(&mut self, sector_id: u32) -> Result<Vec<u8>, OleError> {
let position = ((sector_id as u64) + 1) * (self.sector_size as u64);
self.reader.seek(SeekFrom::Start(position))?;
let mut buffer = vec![0u8; self.sector_size];
self.reader.read_exact(&mut buffer)?;
Ok(buffer)
}
fn read_stream_from_fat(&mut self, start_sector: u32) -> Result<Vec<u8>, OleError> {
let mut data = Vec::new();
let mut sector = start_sector;
loop {
if sector == ENDOFCHAIN {
break;
}
if sector >= self.fat.len() as u32 {
return Err(OleError::CorruptedFile(
"Invalid sector index in FAT".to_string(),
));
}
let sector_data = self.read_sector(sector)?;
data.extend_from_slice(§or_data);
sector = self.fat[sector as usize];
}
Ok(data)
}
fn read_stream_from_minifat(
&mut self,
start_sector: u32,
size: u64,
) -> Result<Vec<u8>, OleError> {
if self.ministream.is_none() {
if let Some(ref root) = self.root {
let ministream_data = self.read_stream_from_fat(root.start_sector)?;
self.ministream = Some(ministream_data);
} else {
return Err(OleError::CorruptedFile("No root entry".to_string()));
}
}
let ministream = self.ministream.as_ref().unwrap();
let mut data = Vec::new();
let mut sector = start_sector;
loop {
if sector == ENDOFCHAIN {
break;
}
if sector >= self.minifat.len() as u32 {
return Err(OleError::CorruptedFile(
"Invalid sector index in MiniFAT".to_string(),
));
}
let position = (sector as usize) * self.mini_sector_size;
if position + self.mini_sector_size > ministream.len() {
return Err(OleError::CorruptedFile(
"Mini sector out of bounds".to_string(),
));
}
data.extend_from_slice(&ministream[position..position + self.mini_sector_size]);
sector = self.minifat[sector as usize];
}
data.truncate(size as usize);
Ok(data)
}
pub fn list_streams(&self) -> Vec<Vec<String>> {
let mut streams = Vec::new();
if let Some(ref root) = self.root {
self.collect_streams(root, &mut Vec::new(), &mut streams);
}
streams
}
pub fn list_directory_entries(&self, path: &[&str]) -> Result<Vec<DirectoryEntry>, OleError> {
let mut entries = Vec::new();
let dir_entry = if path.is_empty() {
self.root.as_ref().ok_or(OleError::StreamNotFound)?
} else {
&self.find_entry(path)?
};
if dir_entry.entry_type != STGTY_STORAGE && dir_entry.entry_type != STGTY_ROOT {
return Err(OleError::InvalidFormat("Not a directory".to_string()));
}
if dir_entry.sid_child != NOSTREAM {
self.collect_directory_children(dir_entry.sid_child, &mut entries);
}
Ok(entries)
}
fn collect_directory_children(&self, sid: u32, entries: &mut Vec<DirectoryEntry>) {
if sid == NOSTREAM || sid as usize >= self.dir_entries.len() {
return;
}
if let Some(ref entry) = self.dir_entries[sid as usize] {
if entry.sid_left != NOSTREAM {
self.collect_directory_children(entry.sid_left, entries);
}
entries.push(entry.clone());
if entry.sid_right != NOSTREAM {
self.collect_directory_children(entry.sid_right, entries);
}
}
}
pub fn directory_exists(&self, path: &[&str]) -> bool {
match self.find_entry(path) {
Ok(entry) => entry.entry_type == STGTY_STORAGE || entry.entry_type == STGTY_ROOT,
Err(_) => false,
}
}
fn collect_streams(
&self,
entry: &DirectoryEntry,
path: &mut [String],
streams: &mut Vec<Vec<String>>,
) {
let mut current_path = path.to_owned();
if !entry.name.is_empty() && entry.entry_type != STGTY_ROOT {
current_path.push(entry.name.clone());
}
if entry.entry_type == STGTY_STREAM {
streams.push(current_path);
return;
}
if entry.entry_type == STGTY_STORAGE || entry.entry_type == STGTY_ROOT {
if entry.sid_child != NOSTREAM {
self.traverse_children(entry.sid_child, ¤t_path, streams);
}
}
}
fn traverse_children(&self, sid: u32, path: &Vec<String>, streams: &mut Vec<Vec<String>>) {
if sid == NOSTREAM || sid as usize >= self.dir_entries.len() {
return;
}
if let Some(ref entry) = self.dir_entries[sid as usize] {
if entry.sid_left != NOSTREAM {
self.traverse_children(entry.sid_left, path, streams);
}
let mut current_path = path.clone();
self.collect_streams(entry, &mut current_path, streams);
if entry.sid_right != NOSTREAM {
self.traverse_children(entry.sid_right, path, streams);
}
}
}
pub fn open_stream(&mut self, path: &[&str]) -> Result<Vec<u8>, OleError> {
let entry = self.find_entry(path)?;
if entry.entry_type != STGTY_STREAM {
return Err(OleError::InvalidFormat("Not a stream".to_string()));
}
if entry.is_minifat {
self.read_stream_from_minifat(entry.start_sector, entry.size)
} else {
let mut data = self.read_stream_from_fat(entry.start_sector)?;
data.truncate(entry.size as usize);
Ok(data)
}
}
fn find_entry(&self, path: &[&str]) -> Result<DirectoryEntry, OleError> {
if path.is_empty() {
return self.root.clone().ok_or(OleError::StreamNotFound);
}
let root = self.root.as_ref().ok_or(OleError::StreamNotFound)?;
let mut current_sid = root.sid_child;
for (i, &name) in path.iter().enumerate() {
let entry = self.find_child_by_name(current_sid, name)?;
if i == path.len() - 1 {
return Ok(entry);
}
current_sid = entry.sid_child;
}
Err(OleError::StreamNotFound)
}
fn find_child_by_name(&self, sid: u32, name: &str) -> Result<DirectoryEntry, OleError> {
if sid == NOSTREAM || sid as usize >= self.dir_entries.len() {
return Err(OleError::StreamNotFound);
}
let entry = self.dir_entries[sid as usize]
.as_ref()
.ok_or(OleError::StreamNotFound)?;
if entry.name.to_lowercase() == name.to_lowercase() {
return Ok(entry.clone());
}
if entry.sid_left != NOSTREAM && let Ok(found) = self.find_child_by_name(entry.sid_left, name) {
return Ok(found);
}
if entry.sid_right != NOSTREAM && let Ok(found) = self.find_child_by_name(entry.sid_right, name) {
return Ok(found);
}
Err(OleError::StreamNotFound)
}
pub fn get_root_name(&self) -> Option<&str> {
self.root.as_ref().map(|r| r.name.as_str())
}
pub fn exists(&self, path: &[&str]) -> bool {
self.find_entry(path).is_ok()
}
}
fn decode_utf16le(bytes: &[u8]) -> String {
let mut utf16_chars = Vec::new();
for chunk in bytes.chunks_exact(2) {
let code_unit = U16::<LE>::read_from_bytes(chunk)
.map(|v| v.get())
.unwrap_or(0);
utf16_chars.push(code_unit);
}
String::from_utf16_lossy(&utf16_chars)
.trim_end_matches('\0')
.to_string()
}
fn format_clsid(bytes: &[u8]) -> String {
if bytes.len() != 16 {
return String::new();
}
if bytes.iter().all(|&b| b == 0) {
return String::new();
}
format!(
"{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}",
U32::<LE>::read_from_bytes(&bytes[0..4])
.map(|v| v.get())
.unwrap_or(0),
U16::<LE>::read_from_bytes(&bytes[4..6])
.map(|v| v.get())
.unwrap_or(0),
U16::<LE>::read_from_bytes(&bytes[6..8])
.map(|v| v.get())
.unwrap_or(0),
bytes[8],
bytes[9],
bytes[10],
bytes[11],
bytes[12],
bytes[13],
bytes[14],
bytes[15],
)
}
pub fn is_ole_file(data: &[u8]) -> bool {
data.len() >= MINIMAL_OLEFILE_SIZE && &data[0..8] == MAGIC
}