use std::{
collections::{btree_map, BTreeMap},
io,
sync::{Arc, RwLock, RwLockWriteGuard},
};
use crate::segment::FileLike;
use super::Repo;
type SharedLock<T> = Arc<RwLock<T>>;
type SharedBytes = SharedLock<Vec<u8>>;
#[derive(Clone, Debug, Default)]
pub struct Segment {
pos: u64,
buf: SharedBytes,
}
impl Segment {
pub fn len(&self) -> usize {
self.buf.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn buf_mut(&self) -> RwLockWriteGuard<'_, Vec<u8>> {
self.buf.write().unwrap()
}
}
impl From<SharedBytes> for Segment {
fn from(buf: SharedBytes) -> Self {
Self { pos: 0, buf }
}
}
impl super::SegmentLen for Segment {
fn segment_len(&mut self) -> io::Result<u64> {
Ok(self.len() as u64)
}
}
impl FileLike for Segment {
fn fsync(&mut self) -> io::Result<()> {
Ok(())
}
fn ftruncate(&mut self, _tx_offset: u64, size: u64) -> io::Result<()> {
let mut inner = self.buf.write().unwrap();
inner.resize(size as usize, 0);
Ok(())
}
}
impl io::Write for Segment {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut inner = self.buf.write().unwrap();
inner.extend(buf);
self.pos += buf.len() as u64;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl io::Read for Segment {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let inner = self.buf.read().unwrap();
let pos = self.pos as usize;
if pos > inner.len() {
return Err(io::Error::from_raw_os_error(9));
}
let n = io::Read::read(&mut &inner[pos..], buf)?;
self.pos += n as u64;
Ok(n)
}
}
impl io::Seek for Segment {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let (base_pos, offset) = match pos {
io::SeekFrom::Start(n) => {
self.pos = n;
return Ok(n);
}
io::SeekFrom::End(n) => (self.len() as u64, n),
io::SeekFrom::Current(n) => (self.pos, n),
};
match base_pos.checked_add_signed(offset) {
Some(n) => {
self.pos = n;
Ok(n)
}
None => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to a negative or overflowing position",
)),
}
}
}
#[cfg(feature = "streaming")]
mod async_impls {
use super::*;
use std::{
io::{Seek as _, Write as _},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{self, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use crate::stream::{AsyncFsync, AsyncLen, AsyncRepo, IntoAsyncWriter};
impl AsyncRepo for Memory {
type AsyncSegmentWriter = io::BufWriter<Segment>;
type AsyncSegmentReader = io::BufReader<Segment>;
async fn open_segment_reader_async(&self, offset: u64) -> io::Result<Self::AsyncSegmentReader> {
self.open_segment_writer(offset).map(io::BufReader::new)
}
}
impl IntoAsyncWriter for Segment {
type AsyncWriter = tokio::io::BufWriter<Self>;
fn into_async_writer(self) -> Self::AsyncWriter {
tokio::io::BufWriter::new(self)
}
}
impl AsyncRead for Segment {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let inner = this.buf.read().unwrap();
let pos = this.pos as usize;
if pos > inner.len() {
return Poll::Ready(Err(io::Error::from_raw_os_error(9)));
}
let filled = buf.filled().len();
AsyncRead::poll_read(Pin::new(&mut &inner[pos..]), cx, buf).map_ok(|()| {
this.pos += (buf.filled().len() - filled) as u64;
})
}
}
impl AsyncSeek for Segment {
fn start_seek(self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
self.get_mut().seek(position).map(drop)
}
fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Poll::Ready(self.get_mut().stream_position())
}
}
impl AsyncWrite for Segment {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
Poll::Ready(self.get_mut().write(buf))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
impl AsyncFsync for Segment {
async fn fsync(&self) {}
}
impl AsyncLen for Segment {
async fn segment_len(&mut self) -> io::Result<u64> {
Ok(self.len() as u64)
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Memory(SharedLock<BTreeMap<u64, SharedBytes>>);
impl Memory {
pub fn new() -> Self {
Self::default()
}
}
impl Repo for Memory {
type SegmentWriter = Segment;
type SegmentReader = io::BufReader<Segment>;
fn create_segment(&self, offset: u64) -> io::Result<Self::SegmentWriter> {
let mut inner = self.0.write().unwrap();
match inner.entry(offset) {
btree_map::Entry::Occupied(entry) => {
let entry = entry.get();
let read_guard = entry.read().unwrap();
if read_guard.is_empty() {
Ok(Segment::from(Arc::clone(entry)))
} else {
Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("segment {offset} already exists"),
))
}
}
btree_map::Entry::Vacant(entry) => {
let segment = entry.insert(Default::default());
Ok(Segment::from(Arc::clone(segment)))
}
}
}
fn open_segment_writer(&self, offset: u64) -> io::Result<Self::SegmentWriter> {
let inner = self.0.read().unwrap();
let Some(buf) = inner.get(&offset) else {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("segment {offset} does not exist"),
));
};
Ok(Segment::from(Arc::clone(buf)))
}
fn open_segment_reader(&self, offset: u64) -> io::Result<Self::SegmentReader> {
self.open_segment_writer(offset).map(io::BufReader::new)
}
fn remove_segment(&self, offset: u64) -> io::Result<()> {
let mut inner = self.0.write().unwrap();
if inner.remove(&offset).is_none() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("segment {offset} does not exist"),
));
}
Ok(())
}
fn compress_segment(&self, _offset: u64) -> io::Result<()> {
Ok(())
}
fn existing_offsets(&self) -> io::Result<Vec<u64>> {
Ok(self.0.read().unwrap().keys().copied().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Seek, Write};
#[test]
fn segment_read_write_seek() {
let mut segment = Segment::default();
segment.write_all(b"alonso").unwrap();
segment.seek(io::SeekFrom::Start(0)).unwrap();
let mut buf = [0; 6];
segment.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"alonso");
segment.seek(io::SeekFrom::Start(2)).unwrap();
let n = segment.read(&mut buf).unwrap();
assert_eq!(n, 4);
assert_eq!(&buf[..4], b"onso");
segment.seek(io::SeekFrom::Current(-4)).unwrap();
let n = segment.read(&mut buf).unwrap();
assert_eq!(n, 4);
assert_eq!(&buf[..4], b"onso");
segment.seek(io::SeekFrom::End(-3)).unwrap();
let n = segment.read(&mut buf).unwrap();
assert_eq!(n, 3);
assert_eq!(&buf[0..3], b"nso");
}
}