use std::collections::HashMap;
use std::io::{Cursor, Read};
use std::path::Path;
use quick_xml::events::Event;
use quick_xml::Reader;
use crate::error::{Error, Result};
pub struct OoxmlParser {
data: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct XmlElement {
pub name: String,
pub attributes: HashMap<String, String>,
pub text: String,
}
impl OoxmlParser {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let data = std::fs::read(path)?;
Self::from_bytes(&data)
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < 4 || data[0..4] != [0x50, 0x4B, 0x03, 0x04] {
return Err(Error::InvalidOoxml("Not a valid ZIP/OOXML file".into()));
}
Ok(Self {
data: data.to_vec(),
})
}
pub fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
Self::from_bytes(&data)
}
pub fn is_ooxml(data: &[u8]) -> bool {
data.len() >= 4 && data[0..4] == [0x50, 0x4B, 0x03, 0x04]
}
pub fn iter_files(&self) -> Result<Vec<String>> {
let cursor = Cursor::new(&self.data);
let archive = zip::ZipArchive::new(cursor)
.map_err(|e| Error::InvalidOoxml(format!("Invalid ZIP: {e}")))?;
let names: Vec<String> = (0..archive.len())
.filter_map(|i| {
archive
.clone()
.by_index(i)
.ok()
.map(|e| e.name().to_string())
})
.collect();
Ok(names)
}
pub fn read_file(&self, name: &str) -> Result<Vec<u8>> {
let cursor = Cursor::new(&self.data);
let mut archive = zip::ZipArchive::new(cursor)
.map_err(|e| Error::InvalidOoxml(format!("Invalid ZIP: {e}")))?;
let mut entry = archive
.by_name(name)
.map_err(|e| Error::InvalidOoxml(format!("File not found '{name}': {e}")))?;
let mut buf = Vec::new();
entry.read_to_end(&mut buf)?;
Ok(buf)
}
pub fn read_file_as_string(&self, name: &str) -> Result<String> {
let data = self.read_file(name)?;
String::from_utf8(data).map_err(|e| Error::InvalidOoxml(format!("Invalid UTF-8: {e}")))
}
pub fn iter_xml_elements(&self, subfile: &str, tags: &[&str]) -> Result<Vec<XmlElement>> {
let xml_data = self.read_file(subfile)?;
let mut reader = Reader::from_reader(Cursor::new(xml_data));
reader.config_mut().trim_text(true);
let mut elements = Vec::new();
let mut buf = Vec::new();
let mut current_element: Option<XmlElement> = None;
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) | Ok(Event::Empty(ref e)) => {
let local_name = String::from_utf8_lossy(e.local_name().as_ref()).to_string();
if tags.iter().any(|&t| t == local_name) {
let mut attrs = HashMap::new();
for attr in e.attributes().flatten() {
let key =
String::from_utf8_lossy(attr.key.local_name().as_ref()).to_string();
let value = String::from_utf8_lossy(&attr.value).to_string();
attrs.insert(key, value);
}
let elem = XmlElement {
name: local_name,
attributes: attrs,
text: String::new(),
};
if matches!(reader.read_event_into(&mut Vec::new()), Ok(Event::End(_))) {
elements.push(elem);
} else {
current_element = Some(elem);
}
}
}
Ok(Event::Text(ref e)) => {
if let Some(ref mut elem) = current_element {
elem.text = e.unescape().unwrap_or_default().to_string();
}
}
Ok(Event::End(_)) => {
if let Some(elem) = current_element.take() {
elements.push(elem);
}
}
Ok(Event::Eof) => break,
Err(e) => {
return Err(Error::XmlParsing(format!(
"Error parsing {subfile}: {e}"
)));
}
_ => {}
}
buf.clear();
}
Ok(elements)
}
pub fn is_single_xml(data: &[u8]) -> bool {
if let Ok(text) = std::str::from_utf8(&data[..std::cmp::min(data.len(), 500)]) {
text.contains("<?xml") && !Self::is_ooxml(data)
} else {
false
}
}
pub fn find_vba_projects(&self) -> Result<Vec<String>> {
let files = self.iter_files()?;
Ok(files
.into_iter()
.filter(|f| f.to_lowercase().ends_with("vbaproject.bin"))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_ooxml() {
assert!(OoxmlParser::is_ooxml(&[0x50, 0x4B, 0x03, 0x04, 0x00]));
assert!(!OoxmlParser::is_ooxml(&[0xD0, 0xCF, 0x11, 0xE0]));
assert!(!OoxmlParser::is_ooxml(&[0x00, 0x01]));
}
#[test]
fn test_is_single_xml() {
assert!(OoxmlParser::is_single_xml(b"<?xml version=\"1.0\"?><doc/>"));
assert!(!OoxmlParser::is_single_xml(&[0x50, 0x4B, 0x03, 0x04]));
}
#[test]
fn test_invalid_ooxml() {
let result = OoxmlParser::from_bytes(&[0x00, 0x01, 0x02]);
assert!(result.is_err());
}
}