use super::manifest::{BundleManifest, ModelEntry};
use super::BUNDLE_MAGIC;
use crate::error::{AprenderError, Result};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
fn io_error(msg: impl Into<String>) -> AprenderError {
AprenderError::Other(msg.into())
}
fn format_error(msg: impl Into<String>) -> AprenderError {
AprenderError::FormatError {
message: msg.into(),
}
}
#[derive(Debug, Clone, Copy)]
pub struct BundleFormat;
impl BundleFormat {
pub const HEADER_SIZE: usize = 20;
#[must_use]
pub fn validate_magic(bytes: &[u8]) -> bool {
contract_pre_magic_validation!();
let result = bytes.len() >= 8 && bytes.get(0..8) == Some(BUNDLE_MAGIC);
contract_post_magic_validation!(&result);
result
}
#[must_use]
pub fn read_version(header: &[u8]) -> Option<u32> {
if header.len() < 12 {
return None;
}
Some(u32::from_le_bytes(header[8..12].try_into().ok()?))
}
#[must_use]
pub fn read_manifest_length(header: &[u8]) -> Option<u64> {
if header.len() < 20 {
return None;
}
Some(u64::from_le_bytes(header[12..20].try_into().ok()?))
}
}
pub struct BundleReader {
reader: BufReader<File>,
header_version: u32,
manifest_offset: u64,
data_offset: u64,
}
impl std::fmt::Debug for BundleReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BundleReader")
.field("header_version", &self.header_version)
.field("manifest_offset", &self.manifest_offset)
.field("data_offset", &self.data_offset)
.finish_non_exhaustive()
}
}
impl BundleReader {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| {
io_error(format!(
"Failed to open bundle file '{}': {}",
path.as_ref().display(),
e
))
})?;
let mut reader = BufReader::new(file);
let mut header = [0u8; BundleFormat::HEADER_SIZE];
reader
.read_exact(&mut header)
.map_err(|e| io_error(format!("Failed to read bundle header: {e}")))?;
if !BundleFormat::validate_magic(&header) {
return Err(format_error("Invalid bundle file: magic bytes mismatch"));
}
let version = BundleFormat::read_version(&header)
.ok_or_else(|| format_error("Invalid bundle version"))?;
let manifest_len = BundleFormat::read_manifest_length(&header)
.ok_or_else(|| format_error("Invalid manifest length"))?;
Ok(Self {
reader,
header_version: version,
manifest_offset: BundleFormat::HEADER_SIZE as u64,
data_offset: BundleFormat::HEADER_SIZE as u64 + manifest_len,
})
}
#[must_use]
pub fn version(&self) -> u32 {
self.header_version
}
pub fn read_manifest(&mut self) -> Result<BundleManifest> {
self.reader
.seek(SeekFrom::Start(self.manifest_offset))
.map_err(|e| io_error(format!("Failed to seek to manifest: {e}")))?;
let manifest_len = (self.data_offset - self.manifest_offset) as usize;
let mut manifest_bytes = vec![0u8; manifest_len];
self.reader
.read_exact(&mut manifest_bytes)
.map_err(|e| io_error(format!("Failed to read manifest: {e}")))?;
BundleManifest::from_bytes(&manifest_bytes)
.ok_or_else(|| format_error("Failed to parse manifest"))
}
pub fn read_model(&mut self, entry: &ModelEntry) -> Result<Vec<u8>> {
let offset = self.data_offset + entry.offset;
self.reader
.seek(SeekFrom::Start(offset))
.map_err(|e| io_error(format!("Failed to seek to model data: {e}")))?;
let mut data = vec![0u8; entry.size];
self.reader
.read_exact(&mut data)
.map_err(|e| io_error(format!("Failed to read model data: {e}")))?;
Ok(data)
}
pub fn read_all_models(
&mut self,
manifest: &BundleManifest,
) -> Result<HashMap<String, Vec<u8>>> {
let mut models = HashMap::new();
for entry in manifest.iter() {
let data = self.read_model(entry)?;
models.insert(entry.name.clone(), data);
}
Ok(models)
}
#[must_use]
pub fn data_offset(&self) -> u64 {
self.data_offset
}
}
pub struct BundleWriter {
writer: BufWriter<File>,
}
impl std::fmt::Debug for BundleWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BundleWriter").finish_non_exhaustive()
}
}
impl BundleWriter {
pub fn create(path: impl AsRef<Path>) -> Result<Self> {
let file = File::create(path.as_ref()).map_err(|e| {
io_error(format!(
"Failed to create bundle file '{}': {}",
path.as_ref().display(),
e
))
})?;
Ok(Self {
writer: BufWriter::new(file),
})
}
pub fn write_bundle(
mut self,
manifest: &BundleManifest,
models: &HashMap<String, Vec<u8>>,
) -> Result<()> {
let mut updated_manifest = manifest.clone();
let mut current_offset = 0u64;
for name in manifest.model_names() {
if let Some(entry) = updated_manifest.get_model_mut(name) {
entry.offset = current_offset;
current_offset += entry.size as u64;
}
}
let manifest_bytes = updated_manifest.to_bytes();
self.write_header(manifest.version, manifest_bytes.len() as u64)?;
self.writer
.write_all(&manifest_bytes)
.map_err(|e| io_error(format!("Failed to write manifest: {e}")))?;
for name in manifest.model_names() {
if let Some(data) = models.get(name) {
self.writer
.write_all(data)
.map_err(|e| io_error(format!("Failed to write model '{name}': {e}")))?;
}
}
self.writer
.flush()
.map_err(|e| io_error(format!("Failed to flush bundle: {e}")))?;
Ok(())
}
fn write_header(&mut self, version: u32, manifest_len: u64) -> Result<()> {
self.writer
.write_all(BUNDLE_MAGIC)
.map_err(|e| io_error(format!("Failed to write magic bytes: {e}")))?;
self.writer
.write_all(&version.to_le_bytes())
.map_err(|e| io_error(format!("Failed to write version: {e}")))?;
self.writer
.write_all(&manifest_len.to_le_bytes())
.map_err(|e| io_error(format!("Failed to write manifest length: {e}")))?;
Ok(())
}
}
#[cfg(test)]
#[path = "format_tests.rs"]
mod tests;