use std::{
io::{BufReader, Read, Write},
sync::Arc,
};
use anyhow::{bail, Result};
use sha2::{Digest, Sha256};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
use zstd::stream::{read::Decoder, write::Encoder};
use crate::{
fsverity::FsVerityHashValue,
repository::Repository,
util::{read_exactish, Sha256Digest},
};
#[derive(Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
#[repr(C)]
pub struct DigestMapEntry<ObjectID: FsVerityHashValue> {
pub body: Sha256Digest,
pub verity: ObjectID,
}
#[derive(Debug)]
pub struct DigestMap<ObjectID: FsVerityHashValue> {
pub map: Vec<DigestMapEntry<ObjectID>>,
}
impl<ObjectID: FsVerityHashValue> Default for DigestMap<ObjectID> {
fn default() -> Self {
Self::new()
}
}
impl<ObjectID: FsVerityHashValue> DigestMap<ObjectID> {
pub fn new() -> Self {
DigestMap { map: vec![] }
}
pub fn lookup(&self, body: &Sha256Digest) -> Option<&ObjectID> {
match self.map.binary_search_by_key(body, |e| e.body) {
Ok(idx) => Some(&self.map[idx].verity),
Err(..) => None,
}
}
pub fn insert(&mut self, body: &Sha256Digest, verity: &ObjectID) {
match self.map.binary_search_by_key(body, |e| e.body) {
Ok(idx) => assert_eq!(self.map[idx].verity, *verity), Err(idx) => self.map.insert(
idx,
DigestMapEntry {
body: *body,
verity: verity.clone(),
},
),
}
}
}
pub struct SplitStreamWriter<ObjectID: FsVerityHashValue> {
repo: Arc<Repository<ObjectID>>,
inline_content: Vec<u8>,
writer: Encoder<'static, Vec<u8>>,
pub sha256: Option<(Sha256, Sha256Digest)>,
}
impl<ObjectID: FsVerityHashValue> std::fmt::Debug for SplitStreamWriter<ObjectID> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SplitStreamWriter")
.field("repo", &self.repo)
.field("inline_content", &self.inline_content)
.field("sha256", &self.sha256)
.finish()
}
}
impl<ObjectID: FsVerityHashValue> SplitStreamWriter<ObjectID> {
pub fn new(
repo: &Arc<Repository<ObjectID>>,
refs: Option<DigestMap<ObjectID>>,
sha256: Option<Sha256Digest>,
) -> Self {
let mut writer = Encoder::new(vec![], 0).unwrap();
match refs {
Some(DigestMap { map }) => {
writer.write_all(&(map.len() as u64).to_le_bytes()).unwrap();
writer.write_all(map.as_bytes()).unwrap();
}
None => {
writer.write_all(&0u64.to_le_bytes()).unwrap();
}
}
Self {
repo: Arc::clone(repo),
inline_content: vec![],
writer,
sha256: sha256.map(|x| (Sha256::new(), x)),
}
}
fn write_fragment(writer: &mut impl Write, size: usize, data: &[u8]) -> Result<()> {
writer.write_all(&(size as u64).to_le_bytes())?;
Ok(writer.write_all(data)?)
}
fn flush_inline(&mut self, new_value: Vec<u8>) -> Result<()> {
if !self.inline_content.is_empty() {
Self::write_fragment(
&mut self.writer,
self.inline_content.len(),
&self.inline_content,
)?;
self.inline_content = new_value;
}
Ok(())
}
pub fn write_inline(&mut self, data: &[u8]) {
if let Some((ref mut sha256, ..)) = self.sha256 {
sha256.update(data);
}
self.inline_content.extend(data);
}
fn write_reference(&mut self, reference: &ObjectID, padding: Vec<u8>) -> Result<()> {
self.flush_inline(padding)?;
Self::write_fragment(&mut self.writer, 0, reference.as_bytes())
}
pub fn write_external(&mut self, data: &[u8], padding: Vec<u8>) -> Result<()> {
if let Some((ref mut sha256, ..)) = self.sha256 {
sha256.update(data);
sha256.update(&padding);
}
let id = self.repo.ensure_object(data)?;
self.write_reference(&id, padding)
}
pub async fn write_external_async(&mut self, data: Vec<u8>, padding: Vec<u8>) -> Result<()> {
if let Some((ref mut sha256, ..)) = self.sha256 {
sha256.update(&data);
sha256.update(&padding);
}
let id = self.repo.ensure_object_async(data).await?;
self.write_reference(&id, padding)
}
pub fn done(mut self) -> Result<ObjectID> {
self.flush_inline(vec![])?;
if let Some((context, expected)) = self.sha256 {
if Into::<Sha256Digest>::into(context.finalize()) != expected {
bail!("Content doesn't have expected SHA256 hash value!");
}
}
self.repo.ensure_object(&self.writer.finish()?)
}
}
#[derive(Debug)]
pub enum SplitStreamData<ObjectID: FsVerityHashValue> {
Inline(Box<[u8]>),
External(ObjectID),
}
pub struct SplitStreamReader<R: Read, ObjectID: FsVerityHashValue> {
decoder: Decoder<'static, BufReader<R>>,
pub refs: DigestMap<ObjectID>,
inline_bytes: usize,
}
impl<R: Read, ObjectID: FsVerityHashValue> std::fmt::Debug for SplitStreamReader<R, ObjectID> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SplitStreamReader")
.field("refs", &self.refs)
.field("inline_bytes", &self.inline_bytes)
.finish()
}
}
fn read_u64_le<R: Read>(reader: &mut R) -> Result<Option<usize>> {
let mut buf = [0u8; 8];
if read_exactish(reader, &mut buf)? {
Ok(Some(u64::from_le_bytes(buf) as usize))
} else {
Ok(None)
}
}
fn read_into_vec(reader: &mut impl Read, vec: &mut Vec<u8>, size: usize) -> Result<()> {
vec.resize(size, 0u8);
reader.read_exact(vec.as_mut_slice())?;
Ok(())
}
enum ChunkType<ObjectID: FsVerityHashValue> {
Eof,
Inline,
External(ObjectID),
}
impl<R: Read, ObjectID: FsVerityHashValue> SplitStreamReader<R, ObjectID> {
pub fn new(reader: R) -> Result<Self> {
let mut decoder = Decoder::new(reader)?;
let n_map_entries = {
let mut buf = [0u8; 8];
decoder.read_exact(&mut buf)?;
u64::from_le_bytes(buf)
} as usize;
let mut refs = DigestMap::<ObjectID> {
map: Vec::with_capacity(n_map_entries),
};
for _ in 0..n_map_entries {
refs.map.push(DigestMapEntry::read_from_io(&mut decoder)?);
}
Ok(Self {
decoder,
refs,
inline_bytes: 0,
})
}
fn ensure_chunk(
&mut self,
eof_ok: bool,
ext_ok: bool,
expected_bytes: usize,
) -> Result<ChunkType<ObjectID>> {
if self.inline_bytes == 0 {
match read_u64_le(&mut self.decoder)? {
None => {
if !eof_ok {
bail!("Unexpected EOF when parsing splitstream");
}
return Ok(ChunkType::Eof);
}
Some(0) => {
if !ext_ok {
bail!("Unexpected external reference when parsing splitstream");
}
let id = ObjectID::read_from_io(&mut self.decoder)?;
return Ok(ChunkType::External(id));
}
Some(size) => {
self.inline_bytes = size;
}
}
}
if self.inline_bytes < expected_bytes {
bail!("Unexpectedly small inline content when parsing splitstream");
}
Ok(ChunkType::Inline)
}
pub fn read_inline_exact(&mut self, buffer: &mut [u8]) -> Result<bool> {
if let ChunkType::Inline = self.ensure_chunk(true, false, buffer.len())? {
self.decoder.read_exact(buffer)?;
self.inline_bytes -= buffer.len();
Ok(true)
} else {
Ok(false)
}
}
fn discard_padding(&mut self, size: usize) -> Result<()> {
let mut buf = [0u8; 512];
assert!(size <= 512);
self.ensure_chunk(false, false, size)?;
self.decoder.read_exact(&mut buf[0..size])?;
self.inline_bytes -= size;
Ok(())
}
pub fn read_exact(
&mut self,
actual_size: usize,
stored_size: usize,
) -> Result<SplitStreamData<ObjectID>> {
if let ChunkType::External(id) = self.ensure_chunk(false, true, stored_size)? {
if actual_size < stored_size {
self.discard_padding(stored_size - actual_size)?;
}
Ok(SplitStreamData::External(id))
} else {
let mut content = vec![];
read_into_vec(&mut self.decoder, &mut content, stored_size)?;
content.truncate(actual_size);
self.inline_bytes -= stored_size;
Ok(SplitStreamData::Inline(content.into()))
}
}
pub fn cat(
&mut self,
output: &mut impl Write,
mut load_data: impl FnMut(&ObjectID) -> Result<Vec<u8>>,
) -> Result<()> {
let mut buffer = vec![];
loop {
match self.ensure_chunk(true, true, 0)? {
ChunkType::Eof => break Ok(()),
ChunkType::Inline => {
read_into_vec(&mut self.decoder, &mut buffer, self.inline_bytes)?;
self.inline_bytes = 0;
output.write_all(&buffer)?;
}
ChunkType::External(ref id) => {
output.write_all(&load_data(id)?)?;
}
}
}
}
pub fn get_object_refs(&mut self, mut callback: impl FnMut(&ObjectID)) -> Result<()> {
let mut buffer = vec![];
for entry in &self.refs.map {
callback(&entry.verity);
}
loop {
match self.ensure_chunk(true, true, 0)? {
ChunkType::Eof => break Ok(()),
ChunkType::Inline => {
read_into_vec(&mut self.decoder, &mut buffer, self.inline_bytes)?;
self.inline_bytes = 0;
}
ChunkType::External(ref id) => {
callback(id);
}
}
}
}
pub fn get_stream_refs(&mut self, mut callback: impl FnMut(&Sha256Digest)) {
for entry in &self.refs.map {
callback(&entry.body);
}
}
pub fn lookup(&self, body: &Sha256Digest) -> Result<&ObjectID> {
match self.refs.lookup(body) {
Some(id) => Ok(id),
None => bail!("Reference is not found in splitstream"),
}
}
}
impl<F: Read, ObjectID: FsVerityHashValue> Read for SplitStreamReader<F, ObjectID> {
fn read(&mut self, data: &mut [u8]) -> std::io::Result<usize> {
match self.ensure_chunk(true, false, 1) {
Ok(ChunkType::Eof) => Ok(0),
Ok(ChunkType::Inline) => {
let n_bytes = std::cmp::min(data.len(), self.inline_bytes);
self.decoder.read_exact(&mut data[0..n_bytes])?;
self.inline_bytes -= n_bytes;
Ok(n_bytes)
}
Ok(ChunkType::External(..)) => unreachable!(),
Err(e) => Err(std::io::Error::other(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_read_into_vec() -> Result<()> {
let mut reader = Cursor::new(vec![]);
let mut vec = Vec::new();
let result = read_into_vec(&mut reader, &mut vec, 0);
assert!(result.is_ok());
assert_eq!(vec.len(), 0);
let mut reader = Cursor::new(vec![1, 2, 3, 4, 5]);
let mut vec = Vec::new();
let result = read_into_vec(&mut reader, &mut vec, 3);
assert!(result.is_ok());
assert_eq!(vec.len(), 3);
assert_eq!(vec, vec![1, 2, 3]);
let mut reader = Cursor::new(vec![1, 2, 3]);
let mut vec = Vec::new();
let result = read_into_vec(&mut reader, &mut vec, 5);
assert!(result.is_err());
let mut reader = Cursor::new(vec![1, 2, 3]);
let mut vec = Vec::new();
let result = read_into_vec(&mut reader, &mut vec, 3);
assert!(result.is_ok());
assert_eq!(vec.len(), 3);
assert_eq!(vec, vec![1, 2, 3]);
let mut reader = Cursor::new(vec![1, 2, 3, 4, 5]);
let mut vec = Vec::with_capacity(10);
let result = read_into_vec(&mut reader, &mut vec, 4);
assert!(result.is_ok());
assert_eq!(vec.len(), 4);
assert_eq!(vec, vec![1, 2, 3, 4]);
assert_eq!(vec.capacity(), 10);
let mut reader = Cursor::new(vec![1, 2, 3]);
let mut vec = vec![9, 9, 9];
let result = read_into_vec(&mut reader, &mut vec, 2);
assert!(result.is_ok());
assert_eq!(vec.len(), 2);
assert_eq!(vec, vec![1, 2]);
Ok(())
}
}