use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct OleObject {
pub ole_version: u32,
pub format_id: u32,
pub class_name: String,
pub topic_name: String,
pub item_name: String,
pub data_size: u32,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OleFormatId {
Linked,
Embedded,
Static,
Unknown(u32),
}
impl From<u32> for OleFormatId {
fn from(val: u32) -> Self {
match val {
1 => OleFormatId::Linked,
2 => OleFormatId::Embedded,
3 => OleFormatId::Static,
other => OleFormatId::Unknown(other),
}
}
}
impl OleObject {
pub fn parse(data: &[u8]) -> Result<Self> {
if data.len() < 8 {
return Err(Error::OleObjectParsing("OLE object data too short".into()));
}
let mut pos = 0;
let ole_version = read_u32_le(data, &mut pos)?;
let format_id = read_u32_le(data, &mut pos)?;
let class_name = read_length_prefixed_string(data, &mut pos)?;
let topic_name = read_length_prefixed_string(data, &mut pos)?;
let item_name = read_length_prefixed_string(data, &mut pos)?;
let data_size = read_u32_le(data, &mut pos)?;
let end = pos + data_size as usize;
let obj_data = if end > data.len() {
data[pos..].to_vec()
} else {
data[pos..end].to_vec()
};
Ok(Self {
ole_version,
format_id,
class_name,
topic_name,
item_name,
data_size,
data: obj_data,
})
}
pub fn format(&self) -> OleFormatId {
OleFormatId::from(self.format_id)
}
pub fn is_package(&self) -> bool {
self.class_name.eq_ignore_ascii_case("Package")
}
pub fn is_ole2link(&self) -> bool {
self.class_name.eq_ignore_ascii_case("OLE2Link")
}
pub fn is_equation(&self) -> bool {
self.class_name
.to_ascii_lowercase()
.starts_with("equation.")
}
}
fn read_u32_le(data: &[u8], pos: &mut usize) -> Result<u32> {
if *pos + 4 > data.len() {
return Err(Error::OleObjectParsing(
"Unexpected end of data reading u32".into(),
));
}
let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
*pos += 4;
Ok(val)
}
fn read_length_prefixed_string(data: &[u8], pos: &mut usize) -> Result<String> {
let len = read_u32_le(data, pos)? as usize;
if len == 0 {
return Ok(String::new());
}
if *pos + len > data.len() {
return Err(Error::OleObjectParsing(
"Unexpected end of data reading length-prefixed string".into(),
));
}
let end = if len > 0 && data[*pos + len - 1] == 0 {
*pos + len - 1
} else {
*pos + len
};
let s = String::from_utf8_lossy(&data[*pos..end]).to_string();
*pos += len;
Ok(s)
}
#[cfg(test)]
mod tests {
use super::*;
fn build_ole_object(
class_name: &str,
topic_name: &str,
item_name: &str,
payload: &[u8],
) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&0x00000501u32.to_le_bytes());
buf.extend_from_slice(&2u32.to_le_bytes());
let cn_len = (class_name.len() + 1) as u32;
buf.extend_from_slice(&cn_len.to_le_bytes());
buf.extend_from_slice(class_name.as_bytes());
buf.push(0);
let tn_len = (topic_name.len() + 1) as u32;
buf.extend_from_slice(&tn_len.to_le_bytes());
buf.extend_from_slice(topic_name.as_bytes());
buf.push(0);
let in_len = (item_name.len() + 1) as u32;
buf.extend_from_slice(&in_len.to_le_bytes());
buf.extend_from_slice(item_name.as_bytes());
buf.push(0);
buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
buf.extend_from_slice(payload);
buf
}
#[test]
fn test_parse_valid_object() {
let payload = b"\x01\x02\x03\x04";
let data = build_ole_object("Package", "test.doc", "item1", payload);
let obj = OleObject::parse(&data).unwrap();
assert_eq!(obj.ole_version, 0x00000501);
assert_eq!(obj.format_id, 2);
assert_eq!(obj.class_name, "Package");
assert_eq!(obj.topic_name, "test.doc");
assert_eq!(obj.item_name, "item1");
assert_eq!(obj.data_size, 4);
assert_eq!(obj.data, payload);
}
#[test]
fn test_is_package() {
let data = build_ole_object("Package", "", "", &[]);
let obj = OleObject::parse(&data).unwrap();
assert!(obj.is_package());
}
#[test]
fn test_is_ole2link() {
let data = build_ole_object("OLE2Link", "", "", &[]);
let obj = OleObject::parse(&data).unwrap();
assert!(obj.is_ole2link());
}
#[test]
fn test_is_equation() {
let data = build_ole_object("Equation.3", "", "", &[]);
let obj = OleObject::parse(&data).unwrap();
assert!(obj.is_equation());
}
#[test]
fn test_format_id() {
let data = build_ole_object("Test", "", "", &[]);
let obj = OleObject::parse(&data).unwrap();
assert_eq!(obj.format(), OleFormatId::Embedded);
}
#[test]
fn test_parse_too_short() {
let result = OleObject::parse(&[0x01, 0x00]);
assert!(result.is_err());
}
#[test]
fn test_parse_empty() {
let result = OleObject::parse(&[]);
assert!(result.is_err());
}
#[test]
fn test_empty_names() {
let mut buf = Vec::new();
buf.extend_from_slice(&0x00000501u32.to_le_bytes());
buf.extend_from_slice(&2u32.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
let obj = OleObject::parse(&buf).unwrap();
assert_eq!(obj.class_name, "");
assert_eq!(obj.data_size, 0);
}
}