use crate::sim::WeakSimWorld;
use crate::sim::state::FileId;
use async_trait::async_trait;
use moonpool_core::StorageFile;
use std::cell::Cell;
use std::io::{self, SeekFrom};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use super::futures::{SetLenFuture, SyncFuture};
use super::sim_shutdown_error;
#[derive(Debug, Clone, Copy)]
enum SeekState {
Idle,
Seeking(u64),
}
#[derive(Debug)]
pub struct SimStorageFile {
sim: WeakSimWorld,
file_id: FileId,
pending_read: Cell<Option<(u64, u64, usize)>>,
pending_write: Cell<Option<(u64, usize)>>,
seek_state: Cell<SeekState>,
}
impl SimStorageFile {
pub(crate) fn new(sim: WeakSimWorld, file_id: FileId) -> Self {
Self {
sim,
file_id,
pending_read: Cell::new(None),
pending_write: Cell::new(None),
seek_state: Cell::new(SeekState::Idle),
}
}
pub fn file_id(&self) -> FileId {
self.file_id
}
}
#[async_trait(?Send)]
impl StorageFile for SimStorageFile {
async fn sync_all(&self) -> io::Result<()> {
SyncFuture::new(self.sim.clone(), self.file_id).await
}
async fn sync_data(&self) -> io::Result<()> {
SyncFuture::new(self.sim.clone(), self.file_id).await
}
async fn size(&self) -> io::Result<u64> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
sim.file_size(self.file_id)
.map_err(|e| io::Error::other(e.to_string()))
}
async fn set_len(&self, size: u64) -> io::Result<()> {
SetLenFuture::new(self.sim.clone(), self.file_id, size).await
}
}
impl AsyncRead for SimStorageFile {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some((op_seq, offset, len)) = self.pending_read.get() {
if sim.is_storage_op_complete(self.file_id, op_seq) {
self.pending_read.set(None);
let bytes_to_read = buf.remaining().min(len);
if bytes_to_read == 0 {
return Poll::Ready(Ok(()));
}
let mut temp_buf = vec![0u8; bytes_to_read];
let bytes_read = sim.read_from_file(self.file_id, offset, &mut temp_buf)?;
let new_position = offset + bytes_read as u64;
sim.set_file_position(self.file_id, new_position)?;
buf.put_slice(&temp_buf[..bytes_read]);
return Poll::Ready(Ok(()));
}
sim.register_storage_waker(self.file_id, op_seq, cx.waker().clone());
return Poll::Pending;
}
let position = sim.file_position(self.file_id)?;
let file_size = sim.file_size(self.file_id)?;
if position >= file_size {
return Poll::Ready(Ok(())); }
let remaining_in_file = (file_size - position) as usize;
let len = buf.remaining().min(remaining_in_file);
if len == 0 {
return Poll::Ready(Ok(()));
}
let op_seq = sim.schedule_read(self.file_id, position, len)?;
self.pending_read.set(Some((op_seq, position, len)));
sim.register_storage_waker(self.file_id, op_seq, cx.waker().clone());
Poll::Pending
}
}
impl AsyncWrite for SimStorageFile {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
if let Some((op_seq, bytes_written)) = self.pending_write.get() {
if sim.is_storage_op_complete(self.file_id, op_seq) {
self.pending_write.set(None);
let position = sim.file_position(self.file_id)?;
let new_position = position + bytes_written as u64;
sim.set_file_position(self.file_id, new_position)?;
return Poll::Ready(Ok(bytes_written));
}
sim.register_storage_waker(self.file_id, op_seq, cx.waker().clone());
return Poll::Pending;
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let position = sim.file_position(self.file_id)?;
let op_seq = sim.schedule_write(self.file_id, position, buf.to_vec())?;
self.pending_write.set(Some((op_seq, buf.len())));
sim.register_storage_waker(self.file_id, op_seq, cx.waker().clone());
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncSeek for SimStorageFile {
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
let current_position = sim.file_position(self.file_id)?;
let file_size = sim.file_size(self.file_id)?;
let target = match position {
SeekFrom::Start(pos) => pos,
SeekFrom::End(offset) => {
if offset >= 0 {
file_size.saturating_add(offset as u64)
} else {
file_size.saturating_sub((-offset) as u64)
}
}
SeekFrom::Current(offset) => {
if offset >= 0 {
current_position.saturating_add(offset as u64)
} else {
current_position.saturating_sub((-offset) as u64)
}
}
};
self.seek_state.set(SeekState::Seeking(target));
Ok(())
}
fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let sim = self.sim.upgrade().map_err(|_| sim_shutdown_error())?;
match self.seek_state.get() {
SeekState::Idle => {
let position = sim.file_position(self.file_id)?;
Poll::Ready(Ok(position))
}
SeekState::Seeking(target) => {
sim.set_file_position(self.file_id, target)?;
self.seek_state.set(SeekState::Idle);
Poll::Ready(Ok(target))
}
}
}
}