use std::collections::BTreeMap;
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
use crate::error::Error;
use crate::models::{Checksum, Component, DType, Encoding, Format, Manifest, Object, MAGIC};
use crate::reader::{Tensor, TensorElement};
use crate::utils::{align_offset, is_little_endian, swap_endianness_in_place, DigestWriter};
const ZERO_PAD: [u8; 64] = [0u8; 64];
#[derive(Debug, Clone, Copy)]
pub enum Compression {
Raw,
Zstd(i32),
}
pub struct Writer<W: Write + Seek> {
writer: W,
manifest: Manifest,
current_offset: u64,
}
impl Writer<BufWriter<File>> {
pub fn create(path: impl AsRef<Path>) -> Result<Self, Error> {
let file = File::create(path)?;
Self::new(BufWriter::with_capacity(256 * 1024, file))
}
pub fn append(path: impl AsRef<Path>) -> Result<Self, Error> {
let mut file = OpenOptions::new().read(true).write(true).open(&path)?;
file.seek(SeekFrom::End(-16))?;
let mut size_buf = [0u8; 8];
file.read_exact(&mut size_buf)?;
let manifest_size = u64::from_le_bytes(size_buf);
let mut footer_magic = [0u8; 8];
file.read_exact(&mut footer_magic)?;
if footer_magic != *MAGIC {
return Err(Error::InvalidMagicNumber {
found: footer_magic.to_vec(),
});
}
let file_size = file.seek(SeekFrom::End(0))?;
let manifest_start = file_size - 16 - manifest_size;
file.seek(SeekFrom::Start(manifest_start))?;
let mut cbor_buf = vec![0u8; manifest_size as usize];
file.read_exact(&mut cbor_buf)?;
let manifest: Manifest = ciborium::from_reader(std::io::Cursor::new(&cbor_buf))
.map_err(Error::CborDeserialize)?;
file.set_len(manifest_start)?;
file.seek(SeekFrom::Start(manifest_start))?;
Ok(Self {
writer: BufWriter::with_capacity(256 * 1024, file),
manifest,
current_offset: manifest_start,
})
}
}
impl<W: Write + Seek> Writer<W> {
pub fn new(mut writer: W) -> Result<Self, Error> {
writer.write_all(MAGIC)?;
Ok(Self {
writer,
manifest: Manifest::default(),
current_offset: MAGIC.len() as u64,
})
}
pub fn set_attributes(&mut self, attrs: BTreeMap<String, ciborium::Value>) {
self.manifest.attributes = Some(attrs);
}
pub fn add_bytes(
&mut self,
name: &str,
shape: Vec<u64>,
dtype: DType,
compression: Compression,
data: &[u8],
checksum: Checksum,
) -> Result<(), Error> {
let num_elements: u64 = if shape.is_empty() {
1
} else {
shape.iter().try_fold(1u64, |acc, &d| {
acc.checked_mul(d).ok_or_else(|| {
Error::InvalidFileStructure("Shape product overflows u64".into())
})
})?
};
let expected_size = num_elements
.checked_mul(dtype.byte_size() as u64)
.ok_or_else(|| Error::InvalidFileStructure("Tensor byte size overflows u64".into()))?;
if data.len() as u64 != expected_size {
return Err(Error::InconsistentDataSize {
expected: expected_size,
found: data.len() as u64,
});
}
if self.manifest.objects.contains_key(name) {
return Err(Error::Other(format!("Duplicate tensor name: '{}'", name)));
}
let component = self.write_component(data, dtype, compression, checksum)?;
let mut components = BTreeMap::new();
components.insert("data".to_string(), component);
let obj = Object {
shape,
format: Format::Dense,
attributes: None,
components,
};
self.manifest.objects.insert(name.to_string(), obj);
Ok(())
}
pub fn add<T: TensorElement + bytemuck::Pod>(
&mut self,
name: &str,
shape: &[u64],
data: &[T],
) -> Result<(), Error> {
let bytes = bytemuck::cast_slice(data);
self.add_bytes(
name,
shape.to_vec(),
T::DTYPE,
Compression::Raw,
bytes,
Checksum::None,
)
}
pub fn add_with<'a, T: TensorElement + bytemuck::Pod>(
&'a mut self,
name: &str,
shape: &[u64],
data: &'a [T],
) -> AddBuilder<'a, W, T> {
AddBuilder {
writer: self,
name: name.to_string(),
shape: shape.to_vec(),
data,
compression: Compression::Raw,
checksum: Checksum::None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn add_csr_bytes(
&mut self,
name: &str,
shape: Vec<u64>,
dtype: DType,
values: &[u8],
indices: &[u64],
indptr: &[u64],
compression: Compression,
checksum: Checksum,
) -> Result<(), Error> {
if self.manifest.objects.contains_key(name) {
return Err(Error::Other(format!("Duplicate tensor name: '{}'", name)));
}
let indices_bytes = bytemuck::cast_slice(indices);
let indptr_bytes = bytemuck::cast_slice(indptr);
let values_comp = self.write_component(values, dtype, compression, checksum)?;
let indices_comp =
self.write_component(indices_bytes, DType::U64, compression, checksum)?;
let indptr_comp = self.write_component(indptr_bytes, DType::U64, compression, checksum)?;
let mut components = BTreeMap::new();
components.insert("values".to_string(), values_comp);
components.insert("indices".to_string(), indices_comp);
components.insert("indptr".to_string(), indptr_comp);
let obj = Object {
shape,
format: Format::SparseCsr,
attributes: None,
components,
};
self.manifest.objects.insert(name.to_string(), obj);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn add_csr<T: TensorElement + bytemuck::Pod>(
&mut self,
name: &str,
shape: Vec<u64>,
dtype: DType,
values: &[T],
indices: &[u64],
indptr: &[u64],
compression: Compression,
checksum: Checksum,
) -> Result<(), Error> {
let values_bytes = bytemuck::cast_slice(values);
self.add_csr_bytes(
name,
shape,
dtype,
values_bytes,
indices,
indptr,
compression,
checksum,
)
}
#[allow(clippy::too_many_arguments)]
pub fn add_coo_bytes(
&mut self,
name: &str,
shape: Vec<u64>,
dtype: DType,
values: &[u8],
coords: &[u64],
compression: Compression,
checksum: Checksum,
) -> Result<(), Error> {
if self.manifest.objects.contains_key(name) {
return Err(Error::Other(format!("Duplicate tensor name: '{}'", name)));
}
let coords_bytes = bytemuck::cast_slice(coords);
let values_comp = self.write_component(values, dtype, compression, checksum)?;
let coords_comp = self.write_component(coords_bytes, DType::U64, compression, checksum)?;
let mut components = BTreeMap::new();
components.insert("values".to_string(), values_comp);
components.insert("coords".to_string(), coords_comp);
let obj = Object {
shape,
format: Format::SparseCoo,
attributes: None,
components,
};
self.manifest.objects.insert(name.to_string(), obj);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn add_coo<T: TensorElement + bytemuck::Pod>(
&mut self,
name: &str,
shape: Vec<u64>,
dtype: DType,
values: &[T],
coords: &[u64],
compression: Compression,
checksum: Checksum,
) -> Result<(), Error> {
let values_bytes = bytemuck::cast_slice(values);
self.add_coo_bytes(
name,
shape,
dtype,
values_bytes,
coords,
compression,
checksum,
)
}
pub fn add_object(
&mut self,
name: &str,
shape: Vec<u64>,
format: Format,
component_data: &[(&str, DType, Option<&str>, &[u8])],
attributes: Option<BTreeMap<String, ciborium::Value>>,
compression: Compression,
checksum: Checksum,
) -> Result<(), Error> {
if self.manifest.objects.contains_key(name) {
return Err(Error::Other(format!("Duplicate tensor name: '{}'", name)));
}
let mut components = BTreeMap::new();
for (comp_name, dtype, logical_type, bytes) in component_data {
let mut component = self.write_component(bytes, *dtype, compression, checksum)?;
if let Some(lt) = logical_type {
component.r#type = Some(lt.to_string());
}
components.insert(comp_name.to_string(), component);
}
let obj = Object {
shape,
format,
attributes,
components,
};
self.manifest.objects.insert(name.to_string(), obj);
Ok(())
}
pub fn write_object(
&mut self,
name: &str,
tensor: &Tensor<'_>,
compression: Compression,
checksum: Checksum,
) -> Result<(), Error> {
let comp_data: Vec<(&str, DType, Option<&str>, &[u8])> = tensor
.components
.iter()
.map(|(k, v)| {
(
k.as_str(),
v.dtype,
v.logical_type.as_deref(),
v.data.as_slice(),
)
})
.collect();
self.add_object(
name,
tensor.shape.clone(),
tensor.format.clone(),
&comp_data,
tensor.attributes.clone(),
compression,
checksum,
)
}
fn write_component(
&mut self,
data: &[u8],
dtype: DType,
compression: Compression,
checksum: Checksum,
) -> Result<Component, Error> {
let (aligned_offset, padding) = align_offset(self.current_offset);
if padding > 0 {
self.writer.write_all(&ZERO_PAD[..padding as usize])?;
}
self.current_offset = aligned_offset;
let mut digest_writer = DigestWriter::new(&mut self.writer, checksum);
let stored_encoding = match compression {
Compression::Raw => {
Self::write_data(&mut digest_writer, data, dtype)?;
Encoding::Raw
}
Compression::Zstd(level) => {
{
let mut encoder = zstd::stream::write::Encoder::new(&mut digest_writer, level)
.map_err(Error::ZstdCompression)?;
Self::write_data(&mut encoder, data, dtype)?;
encoder.finish().map_err(Error::ZstdCompression)?;
}
Encoding::Zstd
}
};
let length = digest_writer.bytes_written;
let digest = digest_writer.finalize();
self.current_offset += length;
Ok(Component {
dtype,
r#type: None,
offset: aligned_offset,
length,
uncompressed_length: match stored_encoding {
Encoding::Zstd => Some(data.len() as u64),
Encoding::Raw => None,
},
encoding: stored_encoding,
digest,
})
}
fn write_data<Output: Write>(
writer: &mut Output,
data: &[u8],
dtype: DType,
) -> Result<(), Error> {
let is_native_safe = is_little_endian() || !dtype.is_multi_byte();
if is_native_safe {
writer.write_all(data)?;
} else {
const CHUNK_SIZE: usize = 4096;
let mut buffer = Vec::with_capacity(CHUNK_SIZE);
for chunk in data.chunks(CHUNK_SIZE) {
buffer.clear();
buffer.extend_from_slice(chunk);
swap_endianness_in_place(&mut buffer, dtype.byte_size());
writer.write_all(&buffer)?;
}
}
Ok(())
}
pub fn finish(mut self) -> Result<u64, Error> {
let mut cbor = Vec::new();
ciborium::into_writer(&self.manifest, &mut cbor).map_err(Error::CborSerialize)?;
self.writer.write_all(&cbor)?;
let cbor_size = cbor.len() as u64;
self.writer.write_all(&cbor_size.to_le_bytes())?;
self.writer.write_all(MAGIC)?;
self.writer.flush()?;
Ok(self.current_offset + cbor_size + 8 + 8)
}
}
pub struct AddBuilder<'a, W: Write + Seek, T: TensorElement + bytemuck::Pod> {
writer: &'a mut Writer<W>,
name: String,
shape: Vec<u64>,
data: &'a [T],
compression: Compression,
checksum: Checksum,
}
impl<'a, W: Write + Seek, T: TensorElement + bytemuck::Pod> AddBuilder<'a, W, T> {
pub fn compress(mut self, compression: Compression) -> Self {
self.compression = compression;
self
}
pub fn checksum(mut self, checksum: Checksum) -> Self {
self.checksum = checksum;
self
}
pub fn write(self) -> Result<(), Error> {
let bytes = bytemuck::cast_slice(self.data);
self.writer.add_bytes(
&self.name,
self.shape,
T::DTYPE,
self.compression,
bytes,
self.checksum,
)
}
}