use crate::error::{Error, Result};
const OLE_SIGNATURE: [u8; 8] = [0xD0, 0xCF, 0x11, 0xE0, 0xA1, 0xB1, 0x1A, 0xE1];
const HEADER_SIZE: usize = 512;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SectorType {
Free,
EndOfChain,
Fat,
Difat,
Data(u32),
}
impl std::fmt::Display for SectorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SectorType::Free => write!(f, "FREE"),
SectorType::EndOfChain => write!(f, "ENDOFCHAIN"),
SectorType::Fat => write!(f, "FAT"),
SectorType::Difat => write!(f, "DIFAT"),
SectorType::Data(next) => write!(f, "-> {next}"),
}
}
}
#[derive(Debug, Clone)]
pub struct OleHeader {
pub minor_version: u16,
pub major_version: u16,
pub byte_order: u16,
pub sector_size: u32,
pub mini_sector_size: u32,
pub total_fat_sectors: u32,
pub first_dir_sector: u32,
pub mini_stream_cutoff: u32,
pub first_mini_fat_sector: u32,
pub total_mini_fat_sectors: u32,
pub first_difat_sector: u32,
pub total_difat_sectors: u32,
}
pub struct SectorMap {
pub header: OleHeader,
pub fat: Vec<SectorType>,
pub sector_size: usize,
pub total_sectors: usize,
}
impl SectorMap {
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < HEADER_SIZE {
return Err(Error::InvalidOle("File too small for OLE header".into()));
}
if data[0..8] != OLE_SIGNATURE {
return Err(Error::InvalidOle("Invalid OLE signature".into()));
}
let header = Self::parse_header(data)?;
let sector_size = header.sector_size as usize;
let total_sectors = if data.len() > HEADER_SIZE {
(data.len() - HEADER_SIZE) / sector_size
} else {
0
};
let fat = Self::build_fat(data, &header, sector_size, total_sectors)?;
Ok(Self {
header,
fat,
sector_size,
total_sectors,
})
}
fn parse_header(data: &[u8]) -> Result<OleHeader> {
let minor_version = u16::from_le_bytes([data[24], data[25]]);
let major_version = u16::from_le_bytes([data[26], data[27]]);
let byte_order = u16::from_le_bytes([data[28], data[29]]);
let sector_shift = u16::from_le_bytes([data[30], data[31]]);
let mini_sector_shift = u16::from_le_bytes([data[32], data[33]]);
let sector_size = 1u32 << sector_shift;
let mini_sector_size = 1u32 << mini_sector_shift;
let total_fat_sectors = u32::from_le_bytes([data[44], data[45], data[46], data[47]]);
let first_dir_sector = u32::from_le_bytes([data[48], data[49], data[50], data[51]]);
let mini_stream_cutoff = u32::from_le_bytes([data[56], data[57], data[58], data[59]]);
let first_mini_fat_sector = u32::from_le_bytes([data[60], data[61], data[62], data[63]]);
let total_mini_fat_sectors = u32::from_le_bytes([data[64], data[65], data[66], data[67]]);
let first_difat_sector = u32::from_le_bytes([data[68], data[69], data[70], data[71]]);
let total_difat_sectors = u32::from_le_bytes([data[72], data[73], data[74], data[75]]);
Ok(OleHeader {
minor_version,
major_version,
byte_order,
sector_size,
mini_sector_size,
total_fat_sectors,
first_dir_sector,
mini_stream_cutoff,
first_mini_fat_sector,
total_mini_fat_sectors,
first_difat_sector,
total_difat_sectors,
})
}
fn build_fat(
data: &[u8],
header: &OleHeader,
sector_size: usize,
total_sectors: usize,
) -> Result<Vec<SectorType>> {
let mut fat_sector_ids: Vec<u32> = Vec::new();
for i in 0..109u32 {
let off = 76 + (i as usize) * 4;
if off + 4 > HEADER_SIZE {
break;
}
let sid = u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]);
if sid == 0xFFFFFFFE || sid == 0xFFFFFFFF {
break;
}
fat_sector_ids.push(sid);
}
let mut difat_sid = header.first_difat_sector;
while difat_sid != 0xFFFFFFFE && difat_sid != 0xFFFFFFFF {
let offset = HEADER_SIZE + (difat_sid as usize) * sector_size;
if offset + sector_size > data.len() {
break;
}
let entries_per_sector = (sector_size / 4) - 1;
for i in 0..entries_per_sector {
let off = offset + i * 4;
let sid = u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]);
if sid == 0xFFFFFFFE || sid == 0xFFFFFFFF {
break;
}
fat_sector_ids.push(sid);
}
let next_off = offset + sector_size - 4;
difat_sid = u32::from_le_bytes([data[next_off], data[next_off + 1], data[next_off + 2], data[next_off + 3]]);
}
let mut fat = vec![SectorType::Free; total_sectors];
let mut fat_idx = 0;
for &sid in &fat_sector_ids {
let offset = HEADER_SIZE + (sid as usize) * sector_size;
if offset + sector_size > data.len() {
break;
}
for i in (0..sector_size).step_by(4) {
if fat_idx >= total_sectors {
break;
}
let val = u32::from_le_bytes([
data[offset + i],
data[offset + i + 1],
data[offset + i + 2],
data[offset + i + 3],
]);
fat[fat_idx] = match val {
0xFFFFFFFF => SectorType::Free,
0xFFFFFFFE => SectorType::EndOfChain,
0xFFFFFFFD => SectorType::Fat,
0xFFFFFFFC => SectorType::Difat,
next => SectorType::Data(next),
};
fat_idx += 1;
}
}
Ok(fat)
}
pub fn sector_counts(&self) -> SectorCounts {
let mut counts = SectorCounts::default();
for s in &self.fat {
match s {
SectorType::Free => counts.free += 1,
SectorType::EndOfChain => counts.end_of_chain += 1,
SectorType::Fat => counts.fat += 1,
SectorType::Difat => counts.difat += 1,
SectorType::Data(_) => counts.data += 1,
}
}
counts
}
}
#[derive(Debug, Default)]
pub struct SectorCounts {
pub free: usize,
pub end_of_chain: usize,
pub fat: usize,
pub difat: usize,
pub data: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_invalid_signature() {
let data = vec![0u8; 512];
let result = SectorMap::from_bytes(&data);
assert!(result.is_err());
}
#[test]
fn test_too_small() {
let data = vec![0u8; 10];
let result = SectorMap::from_bytes(&data);
assert!(result.is_err());
}
#[test]
fn test_sector_type_display() {
assert_eq!(SectorType::Free.to_string(), "FREE");
assert_eq!(SectorType::EndOfChain.to_string(), "ENDOFCHAIN");
assert_eq!(SectorType::Data(5).to_string(), "-> 5");
}
}