use std::collections::HashMap;
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::io::{Read, Seek, Write};
use std::path::Path;
use log::{debug, trace};
use zip::CompressionMethod;
use zip::read::ZipArchive;
use zip::write::{SimpleFileOptions, ZipWriter};
use super::content_types::{ContentTypes, ContentTypesBuilder};
use super::error::{Error, Result};
use super::properties::{AppProperties, CoreProperties};
use super::relationships::{Relationships, RelationshipsBuilder, TargetMode, rel_types};
#[derive(Debug, Clone)]
pub struct PartName(String);
fn decode_percent_encoding(input: &str) -> String {
let bytes = input.as_bytes();
let mut decoded = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2])) {
decoded.push(hi << 4 | lo);
i += 3;
continue;
}
}
decoded.push(bytes[i]);
i += 1;
}
String::from_utf8(decoded).unwrap_or_else(|_| input.to_string())
}
fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
impl PartName {
pub fn new(name: &str) -> Result<Self> {
let name = if name.contains('%') {
decode_percent_encoding(name)
} else {
name.to_string()
};
if !name.starts_with('/') {
return Err(Error::InvalidPartName(format!("must start with '/': {name}")));
}
if name.len() > 1 && name.ends_with('/') {
return Err(Error::InvalidPartName(format!("must not end with '/': {name}")));
}
if name.contains("//") {
return Err(Error::InvalidPartName(format!("empty segment (//): {name}")));
}
for segment in name.split('/').skip(1) {
if segment == "." || segment == ".." {
return Err(Error::InvalidPartName(format!("dot segment not allowed: {name}")));
}
if segment.ends_with('.') {
return Err(Error::InvalidPartName(format!("segment ends with dot: {name}")));
}
}
if name.contains('?') || name.contains('#') {
return Err(Error::InvalidPartName(format!("query/fragment not allowed: {name}")));
}
if name.contains('\\') {
return Err(Error::InvalidPartName(format!("backslash not allowed: {name}")));
}
Ok(Self(name))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn extension(&self) -> Option<&str> {
let filename = self.filename();
filename.rfind('.').map(|i| &filename[i + 1..])
}
pub fn directory(&self) -> &str {
match self.0.rfind('/') {
Some(i) => &self.0[..=i],
None => "/",
}
}
pub fn filename(&self) -> &str {
match self.0.rfind('/') {
Some(i) => &self.0[i + 1..],
None => &self.0,
}
}
pub fn rels_path(&self) -> String {
let dir = self.directory();
let file = self.filename();
format!("{dir}_rels/{file}.rels")
}
pub fn resolve_relative(&self, relative: &str) -> Result<PartName> {
if relative.starts_with('/') {
return PartName::new(relative);
}
let base = self.directory();
let combined = format!("{base}{relative}");
let normalized = normalize_path(&combined);
PartName::new(&normalized)
}
}
impl PartialEq for PartName {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(&other.0)
}
}
impl Eq for PartName {}
impl Hash for PartName {
fn hash<H: Hasher>(&self, state: &mut H) {
for byte in self.0.bytes() {
state.write_u8(byte.to_ascii_lowercase());
}
}
}
impl std::fmt::Display for PartName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
fn normalize_path(path: &str) -> String {
let mut segments: Vec<&str> = Vec::new();
for segment in path.split('/') {
match segment {
"." | "" => {
if segments.is_empty() {
segments.push(""); }
},
".." => {
if segments.len() > 1 {
segments.pop();
}
},
s => segments.push(s),
}
}
if segments.len() == 1 && segments[0].is_empty() {
return "/".to_string();
}
segments.join("/")
}
pub struct OpcReader<R: Read + Seek> {
archive: ZipArchive<R>,
content_types: ContentTypes,
package_rels: Relationships,
}
impl OpcReader<File> {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path)?;
Self::new(file)
}
}
#[cfg(feature = "mmap")]
impl OpcReader<std::io::Cursor<memmap2::Mmap>> {
pub fn open_mmap(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
debug!("OPC package opened via mmap ({} bytes)", mmap.len());
Self::new(std::io::Cursor::new(mmap))
}
}
impl<R: Read + Seek> OpcReader<R> {
pub fn new(reader: R) -> Result<Self> {
let mut archive = ZipArchive::new(reader)?;
debug!("OPC package opened, {} entries", archive.len());
let ct_data = read_zip_entry(&mut archive, "[Content_Types].xml")?;
let content_types = ContentTypes::parse(&ct_data)?;
let rels_data = read_zip_entry(&mut archive, "_rels/.rels").unwrap_or_default();
let package_rels = if rels_data.is_empty() {
Relationships::empty()
} else {
Relationships::parse(&rels_data)?
};
Ok(Self {
archive,
content_types,
package_rels,
})
}
pub fn content_types(&self) -> &ContentTypes {
&self.content_types
}
pub fn package_rels(&self) -> &Relationships {
&self.package_rels
}
pub fn read_part(&mut self, name: &PartName) -> Result<Vec<u8>> {
let zip_path = &name.as_str()[1..]; let data = read_zip_entry(&mut self.archive, zip_path)?;
trace!("read_part '{}' ({} bytes)", name, data.len());
if name.as_str().ends_with(".xml") || name.as_str().ends_with(".rels") {
if let Some(utf8_data) = super::xml::ensure_utf8(&data) {
trace!("read_part '{}': transcoded to UTF-8", name);
return Ok(utf8_data);
}
}
Ok(data)
}
pub fn read_rels_for(&mut self, part: &PartName) -> Result<Relationships> {
let rels_zip_path = part.rels_path();
let zip_path = &rels_zip_path[1..]; match read_zip_entry(&mut self.archive, zip_path) {
Ok(data) => {
trace!("read_rels_for '{}' ({} bytes)", part, data.len());
Relationships::parse(&data)
},
Err(Error::Zip(zip::result::ZipError::FileNotFound)) => {
trace!("read_rels_for '{}': no rels file", part);
Ok(Relationships::empty())
},
Err(Error::MissingPart(_)) => {
trace!("read_rels_for '{}': no rels file", part);
Ok(Relationships::empty())
},
Err(e) => Err(e),
}
}
pub fn main_document_part(&self) -> Result<PartName> {
if let Some(rel) = self.package_rels.first_by_type(rel_types::OFFICE_DOCUMENT) {
let target = if rel.target.starts_with('/') {
rel.target.clone()
} else {
format!("/{}", rel.target)
};
return PartName::new(&target);
}
const MAIN_CONTENT_TYPES: &[&str] = &[
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet.main+xml",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml",
"application/vnd.openxmlformats-officedocument.presentationml.presentation.main+xml",
];
for (part_name, ct) in self.content_types.overrides() {
if MAIN_CONTENT_TYPES.iter().any(|&expected| ct == expected) {
debug!(
"main_document_part: fallback to Content_Types override '{}' ({})",
part_name, ct
);
return Ok(part_name.clone());
}
}
Err(Error::RelationshipNotFound("officeDocument relationship not found".to_string()))
}
pub fn part_names(&self) -> Vec<PartName> {
self.archive
.file_names()
.filter_map(|name| {
let name = name.replace('\\', "/");
if name.ends_with('/') {
return None;
}
if name.eq_ignore_ascii_case("[Content_Types].xml") {
return None;
}
if name.contains("_rels/") {
return None;
}
let part_name = format!("/{name}");
PartName::new(&part_name).ok()
})
.collect()
}
pub fn has_part(&self, name: &PartName) -> bool {
let zip_path = &name.as_str()[1..];
self.archive.file_names().any(|n| {
let normalized = n.replace('\\', "/");
normalized.eq_ignore_ascii_case(zip_path)
})
}
}
pub(crate) fn read_zip_entry<R: Read + Seek>(
archive: &mut ZipArchive<R>,
name: &str,
) -> Result<Vec<u8>> {
let index = match archive.index_for_name(name) {
Some(i) => i,
None => {
let normalized = name.replace('\\', "/");
let mut found = None;
for i in 0..archive.len() {
if let Ok(entry) = archive.by_index_raw(i) {
let entry_name = entry.name().replace('\\', "/");
if entry_name.eq_ignore_ascii_case(&normalized) {
found = Some(i);
break;
}
}
}
found.ok_or_else(|| Error::MissingPart(name.to_string()))?
},
};
let mut file = archive
.by_index(index)
.map_err(|_| Error::MissingPart(name.to_string()))?;
let mut buf = Vec::with_capacity(file.size() as usize);
match file.read_to_end(&mut buf) {
Ok(_) => Ok(buf),
Err(e)
if e.kind() == std::io::ErrorKind::InvalidData
&& e.to_string().contains("checksum")
&& !buf.is_empty() =>
{
trace!("read_zip_entry '{}': ignoring CRC mismatch", name);
Ok(buf)
},
Err(e) => Err(e.into()),
}
}
pub struct OpcWriter<W: Write + Seek> {
writer: ZipWriter<W>,
content_types: ContentTypesBuilder,
package_rels: RelationshipsBuilder,
part_rels: HashMap<String, RelationshipsBuilder>,
}
impl OpcWriter<File> {
pub fn create(path: impl AsRef<Path>) -> Result<Self> {
let file = File::create(path)?;
Self::new(file)
}
}
impl<W: Write + Seek> OpcWriter<W> {
pub fn new(writer: W) -> Result<Self> {
Ok(Self {
writer: ZipWriter::new(writer),
content_types: ContentTypesBuilder::new(),
package_rels: RelationshipsBuilder::new(),
part_rels: HashMap::new(),
})
}
pub fn add_part(&mut self, name: &PartName, content_type: &str, data: &[u8]) -> Result<()> {
self.content_types.add_override(name.clone(), content_type);
let zip_path = &name.as_str()[1..]; let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated);
self.writer.start_file(zip_path, options)?;
self.writer.write_all(data)?;
Ok(())
}
pub fn add_package_rel(&mut self, rel_type: &str, target: &str) -> String {
self.package_rels.add(rel_type, target)
}
pub fn add_part_rel(&mut self, source: &PartName, rel_type: &str, target: &str) -> String {
self.part_rels
.entry(source.as_str().to_string())
.or_default()
.add(rel_type, target)
}
pub fn add_part_rel_with_mode(
&mut self,
source: &PartName,
rel_type: &str,
target: &str,
target_mode: TargetMode,
) -> String {
self.part_rels
.entry(source.as_str().to_string())
.or_default()
.add_with_mode(rel_type, target, target_mode)
}
pub fn set_core_properties(&mut self, props: &CoreProperties) -> Result<()> {
let data = props.serialize();
let name = PartName::new("/docProps/core.xml")?;
self.content_types.add_override(
name.clone(),
"application/vnd.openxmlformats-package.core-properties+xml",
);
let zip_path = &name.as_str()[1..];
let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated);
self.writer.start_file(zip_path, options)?;
self.writer.write_all(&data)?;
self.package_rels
.add(rel_types::CORE_PROPERTIES, "docProps/core.xml");
Ok(())
}
pub fn set_app_properties(&mut self, props: &AppProperties) -> Result<()> {
let data = props.serialize();
let name = PartName::new("/docProps/app.xml")?;
self.content_types.add_override(
name.clone(),
"application/vnd.openxmlformats-officedocument.extended-properties+xml",
);
let zip_path = &name.as_str()[1..];
let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated);
self.writer.start_file(zip_path, options)?;
self.writer.write_all(&data)?;
self.package_rels
.add(rel_types::EXTENDED_PROPERTIES, "docProps/app.xml");
Ok(())
}
pub fn finish(mut self) -> Result<W> {
let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated);
for (source_path, builder) in &self.part_rels {
if builder.is_empty() {
continue;
}
let source = PartName::new(source_path)?;
let rels_path = source.rels_path();
let zip_path = &rels_path[1..]; let data = builder.serialize();
self.writer.start_file(zip_path, options)?;
self.writer.write_all(&data)?;
}
let rels_data = self.package_rels.serialize();
self.writer.start_file("_rels/.rels", options)?;
self.writer.write_all(&rels_data)?;
let ct_data = self.content_types.serialize();
self.writer.start_file("[Content_Types].xml", options)?;
self.writer.write_all(&ct_data)?;
Ok(self.writer.finish()?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_part_names() {
assert!(PartName::new("/word/document.xml").is_ok());
assert!(PartName::new("/xl/worksheets/sheet1.xml").is_ok());
assert!(PartName::new("/docProps/core.xml").is_ok());
assert!(PartName::new("/word/media/image1.png").is_ok());
}
#[test]
fn invalid_part_names() {
assert!(PartName::new("word/document.xml").is_err()); assert!(PartName::new("/word/document.xml/").is_err()); assert!(PartName::new("/word//document.xml").is_err()); assert!(PartName::new("/word/./document.xml").is_err()); assert!(PartName::new("/word/../document.xml").is_err()); let pn = PartName::new("/word/my%20doc.xml").unwrap();
assert_eq!(pn.as_str(), "/word/my doc.xml");
assert!(PartName::new("/word/doc.xml?v=1").is_err()); }
#[test]
fn part_name_case_insensitive_eq() {
let a = PartName::new("/Word/Document.xml").unwrap();
let b = PartName::new("/word/document.xml").unwrap();
assert_eq!(a, b);
}
#[test]
fn part_name_case_insensitive_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(PartName::new("/Word/Document.xml").unwrap());
assert!(set.contains(&PartName::new("/word/document.xml").unwrap()));
}
#[test]
fn part_name_components() {
let pn = PartName::new("/word/document.xml").unwrap();
assert_eq!(pn.directory(), "/word/");
assert_eq!(pn.filename(), "document.xml");
assert_eq!(pn.extension(), Some("xml"));
assert_eq!(pn.rels_path(), "/word/_rels/document.xml.rels");
}
#[test]
fn resolve_relative_simple() {
let source = PartName::new("/word/document.xml").unwrap();
let resolved = source.resolve_relative("media/image1.png").unwrap();
assert_eq!(resolved.as_str(), "/word/media/image1.png");
}
#[test]
fn resolve_relative_parent() {
let source = PartName::new("/word/document.xml").unwrap();
let resolved = source.resolve_relative("../docProps/core.xml").unwrap();
assert_eq!(resolved.as_str(), "/docProps/core.xml");
}
#[test]
fn resolve_relative_absolute() {
let source = PartName::new("/word/document.xml").unwrap();
let resolved = source.resolve_relative("/xl/workbook.xml").unwrap();
assert_eq!(resolved.as_str(), "/xl/workbook.xml");
}
#[test]
fn opc_round_trip() {
use std::io::Cursor;
let buf = Vec::new();
let cursor = Cursor::new(buf);
let mut writer = OpcWriter::new(cursor).unwrap();
let doc_name = PartName::new("/word/document.xml").unwrap();
writer
.add_part(
&doc_name,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml",
b"<document/>",
)
.unwrap();
writer.add_package_rel(rel_types::OFFICE_DOCUMENT, "word/document.xml");
let result = writer.finish().unwrap();
let data = result.into_inner();
let cursor = Cursor::new(data);
let mut reader = OpcReader::new(cursor).unwrap();
let ct = reader.content_types().resolve(&doc_name);
assert_eq!(
ct,
Some(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml"
)
);
let main = reader.main_document_part().unwrap();
assert_eq!(main.as_str(), "/word/document.xml");
let content = reader.read_part(&doc_name).unwrap();
assert_eq!(content, b"<document/>");
let parts = reader.part_names();
assert!(parts.iter().any(|p| p == &doc_name));
}
}