use crate::internal::{consts, MiniAllocator, ObjType, SectorInit};
use std::io::{self, BufRead, Read, Seek, SeekFrom, Write};
use std::sync::{Arc, RwLock, Weak};
use crate::internal::stream_buffer::StreamBuffer;
pub struct Stream<F> {
minialloc: Weak<RwLock<MiniAllocator<F>>>,
stream_id: u32,
total_len: u64,
buffer: StreamBuffer,
buf_offset_from_start: u64,
flusher: Option<Box<dyn Flusher<F>>>,
}
impl<F> Stream<F> {
pub(crate) fn new(
minialloc: &Arc<RwLock<MiniAllocator<F>>>,
stream_id: u32,
max_buffer_size: usize,
) -> Stream<F> {
let total_len =
minialloc.read().unwrap().dir_entry(stream_id).stream_len;
Stream {
minialloc: Arc::downgrade(minialloc),
stream_id,
total_len,
buffer: StreamBuffer::new(max_buffer_size),
buf_offset_from_start: 0,
flusher: None,
}
}
fn minialloc(&self) -> io::Result<Arc<RwLock<MiniAllocator<F>>>> {
self.minialloc
.upgrade()
.ok_or_else(|| io::Error::other("CompoundFile was dropped"))
}
pub fn len(&self) -> u64 {
self.total_len
}
pub fn is_empty(&self) -> bool {
self.total_len == 0
}
fn current_position(&self) -> u64 {
self.buf_offset_from_start + (self.buffer.cursor() as u64)
}
fn flush_changes(&mut self) -> io::Result<()> {
if let Some(flusher) = self.flusher.take() {
flusher.flush_changes(self)?;
}
Ok(())
}
}
impl<F: Read + Write + Seek> Stream<F> {
pub fn set_len(&mut self, size: u64) -> io::Result<()> {
if size != self.total_len {
let new_position = self.current_position().min(size);
self.flush_changes()?;
let minialloc = self.minialloc()?;
resize_stream(
&mut minialloc.write().unwrap(),
self.stream_id,
size,
)?;
self.total_len = size;
self.buf_offset_from_start = new_position;
self.buffer.clear();
}
Ok(())
}
fn mark_modified(&mut self) {
if self.flusher.is_none() {
let flusher: Box<dyn Flusher<F>> = Box::new(FlushBuffer);
self.flusher = Some(flusher);
}
}
}
impl<F: Read + Seek> BufRead for Stream<F> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if !self.buffer.has_remaining()
&& self.current_position() < self.total_len
{
self.flush_changes()?;
self.buf_offset_from_start += self.buffer.cursor() as u64;
let remaining = self.total_len - self.buf_offset_from_start;
let stream_id = self.stream_id;
let offset = self.buf_offset_from_start;
let minialloc = self.minialloc()?;
self.buffer.refill_with(remaining, |buf| {
read_data_from_stream(
&mut minialloc.write().unwrap(),
stream_id,
offset,
buf,
)
})?;
}
Ok(self.buffer.remaining_slice())
}
fn consume(&mut self, amt: usize) {
self.buffer.consume(amt);
}
}
impl<F: Read + Seek> Read for Stream<F> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut buffered_data = self.fill_buf()?;
let num_bytes = buffered_data.read(buf)?;
self.consume(num_bytes);
Ok(num_bytes)
}
}
impl<F: Read + Seek> Seek for Stream<F> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let new_pos: u64 =
match pos {
SeekFrom::Start(delta) => {
if delta > self.total_len {
invalid_input!(
"Cannot seek to {} bytes from start, because stream \
length is only {} bytes",
delta, self.total_len,
);
}
delta
}
SeekFrom::End(delta) => {
if delta > 0 {
invalid_input!(
"Cannot seek to {} bytes past the end of the stream",
delta,
);
} else {
let delta = (-delta) as u64;
if delta > self.total_len {
invalid_input!(
"Cannot seek to {} bytes before end, because \
stream length is only {} bytes",
delta,
self.total_len,
);
}
self.total_len - delta
}
}
SeekFrom::Current(delta) => {
let old_pos = self.current_position();
debug_assert!(old_pos <= self.total_len);
if delta < 0 {
let delta = (-delta) as u64;
if delta > old_pos {
invalid_input!(
"Cannot seek to {} bytes before current position, \
which is only {}",
delta, old_pos,
);
}
old_pos - delta
} else {
let delta = delta as u64;
let remaining = self.total_len - old_pos;
if delta > remaining {
invalid_input!(
"Cannot seek to {} bytes after current position, \
because there are only {} bytes remaining in the \
stream",
delta, remaining,
);
}
old_pos + delta
}
}
};
if new_pos < self.buf_offset_from_start
|| new_pos
> self.buf_offset_from_start + self.buffer.filled_len() as u64
{
self.flush_changes()?;
self.buf_offset_from_start = new_pos;
self.buffer.clear();
} else {
self.buffer.seek((new_pos - self.buf_offset_from_start) as usize);
}
Ok(new_pos)
}
}
impl<F: Read + Write + Seek> Write for Stream<F> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let num_bytes_written = match self.buffer.write_bytes(buf) {
Some(count) => count,
None => {
self.flush_changes()?;
self.buf_offset_from_start += self.buffer.cursor() as u64;
self.buffer.clear();
self.buffer.write_bytes(buf).unwrap_or(0)
}
};
if num_bytes_written > 0 {
self.mark_modified();
self.total_len = self.total_len.max(
self.buf_offset_from_start + self.buffer.filled_len() as u64,
);
}
Ok(num_bytes_written)
}
fn flush(&mut self) -> io::Result<()> {
self.flush_changes()?;
let minialloc = self.minialloc()?;
minialloc.write().unwrap().flush()?;
Ok(())
}
}
impl<F> Drop for Stream<F> {
fn drop(&mut self) {
let _ = self.flush_changes();
}
}
trait Flusher<F> {
fn flush_changes(&self, stream: &mut Stream<F>) -> io::Result<()>;
}
struct FlushBuffer;
impl<F: Read + Write + Seek> Flusher<F> for FlushBuffer {
fn flush_changes(&self, stream: &mut Stream<F>) -> io::Result<()> {
let minialloc = stream.minialloc()?;
write_data_to_stream(
&mut minialloc.write().unwrap(),
stream.stream_id,
stream.buf_offset_from_start,
stream.buffer.filled_slice(),
)?;
debug_assert_eq!(
minialloc.read().unwrap().dir_entry(stream.stream_id).stream_len,
stream.total_len
);
Ok(())
}
}
fn read_data_from_stream<F: Read + Seek>(
minialloc: &mut MiniAllocator<F>,
stream_id: u32,
buf_offset_from_start: u64,
buf: &mut [u8],
) -> io::Result<usize> {
let (start_sector, stream_len) = {
let dir_entry = minialloc.dir_entry(stream_id);
debug_assert_eq!(dir_entry.obj_type, ObjType::Stream);
(dir_entry.start_sector, dir_entry.stream_len)
};
let num_bytes = if buf_offset_from_start >= stream_len {
0
} else {
let remaining = stream_len - buf_offset_from_start;
if remaining < buf.len() as u64 {
remaining as usize
} else {
buf.len()
}
};
if num_bytes > 0 {
if stream_len < consts::MINI_STREAM_CUTOFF as u64 {
let mut chain = minialloc.open_mini_chain(start_sector)?;
chain.seek(SeekFrom::Start(buf_offset_from_start))?;
chain.read_exact(&mut buf[..num_bytes])?;
} else {
let mut chain =
minialloc.open_chain(start_sector, SectorInit::Zero)?;
chain.seek(SeekFrom::Start(buf_offset_from_start))?;
chain.read_exact(&mut buf[..num_bytes])?;
}
}
Ok(num_bytes)
}
fn write_data_to_stream<F: Read + Write + Seek>(
minialloc: &mut MiniAllocator<F>,
stream_id: u32,
buf_offset_from_start: u64,
buf: &[u8],
) -> io::Result<()> {
let (old_start_sector, old_stream_len) = {
let dir_entry = minialloc.dir_entry(stream_id);
debug_assert_eq!(dir_entry.obj_type, ObjType::Stream);
(dir_entry.start_sector, dir_entry.stream_len)
};
debug_assert!(buf_offset_from_start <= old_stream_len);
let new_stream_len =
old_stream_len.max(buf_offset_from_start + buf.len() as u64);
let new_start_sector = if old_start_sector == consts::END_OF_CHAIN {
debug_assert_eq!(old_stream_len, 0);
debug_assert_eq!(buf_offset_from_start, 0);
if new_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
let mut chain = minialloc.open_mini_chain(consts::END_OF_CHAIN)?;
chain.write_all(buf)?;
chain.start_sector_id()
} else {
let mut chain = minialloc
.open_chain(consts::END_OF_CHAIN, SectorInit::Zero)?;
chain.write_all(buf)?;
chain.start_sector_id()
}
} else if old_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
if new_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
let mut chain = minialloc.open_mini_chain(old_start_sector)?;
chain.seek(SeekFrom::Start(buf_offset_from_start))?;
chain.write_all(buf)?;
debug_assert_eq!(chain.start_sector_id(), old_start_sector);
old_start_sector
} else {
debug_assert!(
buf_offset_from_start < consts::MINI_STREAM_CUTOFF as u64
);
let mut tmp = vec![0u8; buf_offset_from_start as usize];
let mut chain = minialloc.open_mini_chain(old_start_sector)?;
chain.read_exact(&mut tmp)?;
chain.free()?;
let mut chain = minialloc
.open_chain(consts::END_OF_CHAIN, SectorInit::Zero)?;
chain.write_all(&tmp)?;
chain.write_all(buf)?;
chain.start_sector_id()
}
} else {
debug_assert!(new_stream_len >= consts::MINI_STREAM_CUTOFF as u64);
let mut chain =
minialloc.open_chain(old_start_sector, SectorInit::Zero)?;
chain.seek(SeekFrom::Start(buf_offset_from_start))?;
chain.write_all(buf)?;
debug_assert_eq!(chain.start_sector_id(), old_start_sector);
old_start_sector
};
minialloc.with_dir_entry_mut(stream_id, |dir_entry| {
dir_entry.start_sector = new_start_sector;
dir_entry.stream_len = new_stream_len;
})
}
fn resize_stream<F: Read + Write + Seek>(
minialloc: &mut MiniAllocator<F>,
stream_id: u32,
new_stream_len: u64,
) -> io::Result<()> {
let (old_start_sector, old_stream_len) = {
let dir_entry = minialloc.dir_entry(stream_id);
debug_assert_eq!(dir_entry.obj_type, ObjType::Stream);
(dir_entry.start_sector, dir_entry.stream_len)
};
let new_start_sector = if old_start_sector == consts::END_OF_CHAIN {
debug_assert_eq!(old_stream_len, 0);
if new_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
let mut chain = minialloc.open_mini_chain(consts::END_OF_CHAIN)?;
chain.set_len(new_stream_len)?;
chain.start_sector_id()
} else {
let mut chain = minialloc
.open_chain(consts::END_OF_CHAIN, SectorInit::Zero)?;
chain.set_len(new_stream_len)?;
chain.start_sector_id()
}
} else if old_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
if new_stream_len == 0 {
minialloc.free_mini_chain(old_start_sector)?;
consts::END_OF_CHAIN
} else if new_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
let mut chain = minialloc.open_mini_chain(old_start_sector)?;
chain.set_len(new_stream_len)?;
debug_assert_eq!(chain.start_sector_id(), old_start_sector);
old_start_sector
} else {
let mut tmp = vec![0u8; old_stream_len as usize];
let mut chain = minialloc.open_mini_chain(old_start_sector)?;
chain.read_exact(&mut tmp)?;
chain.free()?;
let mut chain = minialloc
.open_chain(consts::END_OF_CHAIN, SectorInit::Zero)?;
chain.write_all(&tmp)?;
chain.set_len(new_stream_len)?;
chain.start_sector_id()
}
} else {
if new_stream_len == 0 {
minialloc.free_chain(old_start_sector)?;
consts::END_OF_CHAIN
} else if new_stream_len < consts::MINI_STREAM_CUTOFF as u64 {
debug_assert!(new_stream_len < old_stream_len);
let mut tmp = vec![0u8; new_stream_len as usize];
let mut chain =
minialloc.open_chain(old_start_sector, SectorInit::Zero)?;
chain.read_exact(&mut tmp)?;
chain.free()?;
let mut chain = minialloc.open_mini_chain(consts::END_OF_CHAIN)?;
chain.write_all(&tmp)?;
chain.start_sector_id()
} else {
let mut chain =
minialloc.open_chain(old_start_sector, SectorInit::Zero)?;
chain.set_len(new_stream_len)?;
debug_assert_eq!(chain.start_sector_id(), old_start_sector);
old_start_sector
}
};
minialloc.with_dir_entry_mut(stream_id, |dir_entry| {
dir_entry.start_sector = new_start_sector;
dir_entry.stream_len = new_stream_len;
})
}