use std::sync::Arc;
use futures::Future;
use futures::FutureExt;
use futures::select;
use crate::raw::*;
use crate::*;
pub trait PositionWrite: Send + Sync + Unpin + 'static {
fn write_all_at(
&self,
offset: u64,
buf: Buffer,
) -> impl Future<Output = Result<()>> + MaybeSend;
fn close(&self, size: u64) -> impl Future<Output = Result<Metadata>> + MaybeSend;
fn abort(&self) -> impl Future<Output = Result<()>> + MaybeSend;
}
struct WriteInput<W: PositionWrite> {
w: Arc<W>,
executor: Executor,
offset: u64,
bytes: Buffer,
}
pub struct PositionWriter<W: PositionWrite> {
w: Arc<W>,
executor: Executor,
next_offset: u64,
cache: Option<Buffer>,
tasks: ConcurrentTasks<WriteInput<W>, ()>,
}
#[allow(dead_code)]
impl<W: PositionWrite> PositionWriter<W> {
pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
let executor = info.executor();
Self {
w: Arc::new(inner),
executor: executor.clone(),
next_offset: 0,
cache: None,
tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
Box::pin(async move {
let fut = input.w.write_all_at(input.offset, input.bytes.clone());
match input.executor.timeout() {
None => {
let result = fut.await;
(input, result)
}
Some(timeout) => {
let result = select! {
result = fut.fuse() => {
result
}
_ = timeout.fuse() => {
Err(Error::new(
ErrorKind::Unexpected, "write position timeout")
.with_context("offset", input.offset.to_string())
.set_temporary())
}
};
(input, result)
}
}
})
}),
}
}
fn fill_cache(&mut self, bs: Buffer) -> usize {
let size = bs.len();
assert!(self.cache.is_none());
self.cache = Some(bs);
size
}
}
impl<W: PositionWrite> oio::Write for PositionWriter<W> {
async fn write(&mut self, bs: Buffer) -> Result<()> {
if self.cache.is_none() {
let _ = self.fill_cache(bs);
return Ok(());
}
let bytes = self.cache.clone().expect("pending write must exist");
let length = bytes.len() as u64;
let offset = self.next_offset;
self.tasks
.execute(WriteInput {
w: self.w.clone(),
executor: self.executor.clone(),
offset,
bytes,
})
.await?;
self.cache = None;
self.next_offset += length;
let _ = self.fill_cache(bs);
Ok(())
}
async fn close(&mut self) -> Result<Metadata> {
while self.tasks.next().await.transpose()?.is_some() {}
if let Some(buffer) = self.cache.clone() {
let offset = self.next_offset;
self.w.write_all_at(offset, buffer.clone()).await?;
self.cache = None;
self.next_offset += buffer.len() as u64;
}
let final_size = self.next_offset;
self.w.close(final_size).await
}
async fn abort(&mut self) -> Result<()> {
self.tasks.clear();
self.cache = None;
self.w.abort().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::sync::Mutex;
use pretty_assertions::assert_eq;
use rand::Rng;
use rand::RngCore;
use rand::thread_rng;
use tokio::time::sleep;
use super::*;
use crate::raw::oio::Write;
struct TestWrite {
length: u64,
bytes: HashSet<u64>,
}
impl TestWrite {
pub fn new() -> Arc<Mutex<Self>> {
let v = Self {
bytes: HashSet::new(),
length: 0,
};
Arc::new(Mutex::new(v))
}
}
impl PositionWrite for Arc<Mutex<TestWrite>> {
async fn write_all_at(&self, offset: u64, buf: Buffer) -> Result<()> {
sleep(Duration::from_millis(50)).await;
if thread_rng().gen_bool(1.0 / 10.0) {
return Err(
Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
);
}
let mut test = self.lock().unwrap();
let size = buf.len() as u64;
test.length += size;
let input = (offset..offset + size).collect::<HashSet<_>>();
assert!(
test.bytes.is_disjoint(&input),
"input should not have overlap"
);
test.bytes.extend(input);
Ok(())
}
async fn close(&self, _size: u64) -> Result<Metadata> {
Ok(Metadata::default())
}
async fn abort(&self) -> Result<()> {
Ok(())
}
}
#[tokio::test]
async fn test_position_writer_with_concurrent_errors() {
let mut rng = thread_rng();
let mut w = PositionWriter::new(Arc::default(), TestWrite::new(), 200);
let mut total_size = 0u64;
for _ in 0..1000 {
let size = rng.gen_range(1..1024);
total_size += size as u64;
let mut bs = vec![0; size];
rng.fill_bytes(&mut bs);
loop {
match w.write(bs.clone().into()).await {
Ok(_) => break,
Err(e) => {
println!("write error: {e:?}");
continue;
}
}
}
}
loop {
match w.close().await {
Ok(n) => {
println!("close: {n:?}");
break;
}
Err(e) => {
println!("close error: {e:?}");
continue;
}
}
}
let actual_bytes = w.w.lock().unwrap().bytes.clone();
let expected_bytes: HashSet<_> = (0..total_size).collect();
assert_eq!(actual_bytes, expected_bytes);
let actual_size = w.w.lock().unwrap().length;
assert_eq!(actual_size, total_size);
}
}