use std::cmp;
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::ptr;
use block::{AlignedBytes, BlockSize};
use nvm::NonVolatileMemory;
use {ErrorKind, Result};
#[derive(Debug)]
pub struct JournalNvmBuffer<N: NonVolatileMemory> {
inner: N,
position: u64,
write_buf: AlignedBytes,
write_buf_offset: u64,
maybe_dirty: bool,
read_buf: AlignedBytes,
}
impl<N: NonVolatileMemory> JournalNvmBuffer<N> {
pub fn new(nvm: N) -> Self {
let block_size = nvm.block_size();
JournalNvmBuffer {
inner: nvm,
position: 0,
maybe_dirty: false,
write_buf_offset: 0,
write_buf: AlignedBytes::new(0, block_size),
read_buf: AlignedBytes::new(0, block_size),
}
}
#[cfg(test)]
pub fn nvm(&self) -> &N {
&self.inner
}
fn is_dirty_area(&self, offset: u64, length: usize) -> bool {
if !self.maybe_dirty || length == 0 || self.write_buf.is_empty() {
return false;
}
if self.write_buf_offset < offset {
let buf_end = self.write_buf_offset + self.write_buf.len() as u64;
offset < buf_end
} else {
let end = offset + length as u64;
self.write_buf_offset < end
}
}
fn flush_write_buf(&mut self) -> Result<()> {
if self.write_buf.is_empty() || !self.maybe_dirty {
return Ok(());
}
track_io!(self.inner.seek(SeekFrom::Start(self.write_buf_offset)))?;
track_io!(self.inner.write(&self.write_buf))?;
if self.write_buf.len() > self.block_size().as_u16() as usize {
let new_len = self.block_size().as_u16() as usize;
let drop_len = self.write_buf.len() - new_len;
unsafe {
ptr::copy(
self.write_buf.as_ptr().add(drop_len), self.write_buf.as_mut_ptr(), new_len,
);
}
self.write_buf.truncate(new_len);
self.write_buf_offset += drop_len as u64;
}
self.maybe_dirty = false;
Ok(())
}
fn check_overflow(&self, write_len: usize) -> Result<()> {
let next_position = self.position() + write_len as u64;
track_assert!(
next_position <= self.capacity(),
ErrorKind::InconsistentState,
"self.position={}, write_len={}, self.len={}",
self.position(),
write_len,
self.capacity()
);
Ok(())
}
}
impl<N: NonVolatileMemory> NonVolatileMemory for JournalNvmBuffer<N> {
fn sync(&mut self) -> Result<()> {
track!(self.flush_write_buf())?;
self.inner.sync()
}
fn position(&self) -> u64 {
self.position
}
fn capacity(&self) -> u64 {
self.inner.capacity()
}
fn block_size(&self) -> BlockSize {
self.inner.block_size()
}
fn split(self, _: u64) -> Result<(Self, Self)> {
unreachable!()
}
}
impl<N: NonVolatileMemory> Drop for JournalNvmBuffer<N> {
fn drop(&mut self) {
let _ = self.sync();
}
}
impl<N: NonVolatileMemory> Seek for JournalNvmBuffer<N> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let offset = track!(self.convert_to_offset(pos))?;
self.position = offset;
Ok(offset)
}
}
impl<N: NonVolatileMemory> Read for JournalNvmBuffer<N> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.is_dirty_area(self.position, buf.len()) {
track!(self.flush_write_buf())?;
}
let aligned_start = self.block_size().floor_align(self.position);
let aligned_end = self
.block_size()
.ceil_align(self.position + buf.len() as u64);
self.read_buf
.aligned_resize((aligned_end - aligned_start) as usize);
self.inner.seek(SeekFrom::Start(aligned_start))?;
let inner_read_size = self.inner.read(&mut self.read_buf)?;
let start = (self.position - aligned_start) as usize;
let end = cmp::min(inner_read_size, start + buf.len());
let read_size = end - start;
(&mut buf[..read_size]).copy_from_slice(&self.read_buf[start..end]);
self.position += read_size as u64;
Ok(read_size)
}
}
impl<N: NonVolatileMemory> Write for JournalNvmBuffer<N> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
track!(self.check_overflow(buf.len()))?;
let write_buf_start = self.write_buf_offset;
let write_buf_end = write_buf_start + self.write_buf.len() as u64;
if write_buf_start <= self.position && self.position <= write_buf_end {
let start = (self.position - self.write_buf_offset) as usize;
let end = start + buf.len();
self.write_buf.aligned_resize(end);
(&mut self.write_buf[start..end]).copy_from_slice(buf);
self.position += buf.len() as u64;
self.maybe_dirty = true;
Ok(buf.len())
} else {
track!(self.flush_write_buf())?;
if self.block_size().is_aligned(self.position) {
self.write_buf_offset = self.position;
self.write_buf.aligned_resize(0);
} else {
let size = self.block_size().as_u16();
self.write_buf_offset = self.block_size().floor_align(self.position);
self.write_buf.aligned_resize(size as usize);
self.inner.seek(SeekFrom::Start(self.write_buf_offset))?;
self.inner.read_exact(&mut self.write_buf)?;
}
self.write(buf)
}
}
fn flush(&mut self) -> io::Result<()> {
track!(self.flush_write_buf())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io::{Read, Seek, SeekFrom, Write};
use trackable::result::TestResult;
use super::*;
use nvm::MemoryNvm;
#[test]
fn write_write_flush() -> TestResult {
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.write_all(b"bar"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
assert_eq!(&buffer.nvm().as_bytes()[3..6], &[0; 3][..]);
track_io!(buffer.flush())?;
assert_eq!(&buffer.nvm().as_bytes()[0..6], b"foobar");
Ok(())
}
#[test]
fn write_seek_write_flush() -> TestResult {
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.seek(SeekFrom::Current(1)))?;
track_io!(buffer.write_all(b"bar"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
assert_eq!(&buffer.nvm().as_bytes()[4..7], &[0; 3][..]);
track_io!(buffer.flush())?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], b"foo");
assert_eq!(&buffer.nvm().as_bytes()[4..7], b"bar");
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.seek(SeekFrom::Start(512)))?;
track_io!(buffer.write_all(b"bar"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
assert_eq!(&buffer.nvm().as_bytes()[512..515], &[0; 3][..]);
track_io!(buffer.flush())?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], b"foo");
assert_eq!(&buffer.nvm().as_bytes()[512..515], b"bar");
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.seek(SeekFrom::Current(-1)))?;
track_io!(buffer.write_all(b"bar"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
assert_eq!(&buffer.nvm().as_bytes()[2..5], &[0; 3][..]);
track_io!(buffer.flush())?;
assert_eq!(&buffer.nvm().as_bytes()[0..5], b"fobar");
Ok(())
}
#[test]
fn write_seek_write() -> TestResult {
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.seek(SeekFrom::Start(513)))?;
track_io!(buffer.write_all(b"bar"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], b"foo");
assert_eq!(&buffer.nvm().as_bytes()[513..516], &[0; 3][..]);
Ok(())
}
#[test]
fn write_seek_read() -> TestResult {
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.read_exact(&mut [0; 1][..]))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], b"foo");
let mut buffer = new_buffer();
track_io!(buffer.write_all(b"foo"))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
track_io!(buffer.seek(SeekFrom::Start(512)))?;
track_io!(buffer.read_exact(&mut [0; 1][..]))?;
assert_eq!(&buffer.nvm().as_bytes()[0..3], &[0; 3][..]);
Ok(())
}
#[test]
fn overwritten() -> TestResult {
let mut buffer = new_buffer();
track_io!(buffer.write_all(&[b'a'; 512]))?;
track_io!(buffer.flush())?;
assert_eq!(&buffer.nvm().as_bytes()[0..512], &[b'a'; 512][..]);
track_io!(buffer.seek(SeekFrom::Start(256)))?;
track_io!(buffer.write_all(&[b'b'; 1]))?;
track_io!(buffer.flush())?;
assert_eq!(&buffer.nvm().as_bytes()[0..256], &[b'a'; 256][..]);
assert_eq!(buffer.nvm().as_bytes()[256], b'b');
Ok(())
}
fn new_buffer() -> JournalNvmBuffer<MemoryNvm> {
let nvm = MemoryNvm::new(vec![0; 10 * 1024]);
JournalNvmBuffer::new(nvm)
}
}