use std::ops::Range;
use async_trait::async_trait;
use bytes::Bytes;
use deepsize::DeepSizeOf;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use object_store::path::Path;
use prost::Message;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use lance_core::Result;
use crate::object_writer::WriteResult;
pub trait ProtoStruct {
type Proto: Message;
}
pub type ByteStream = BoxStream<'static, object_store::Result<Bytes>>;
#[async_trait]
pub trait Writer: AsyncWrite + Unpin + Send {
async fn tell(&mut self) -> Result<usize>;
async fn shutdown(&mut self) -> Result<WriteResult>;
}
#[async_trait]
impl Writer for Box<dyn Writer> {
async fn tell(&mut self) -> Result<usize> {
self.as_mut().tell().await
}
async fn shutdown(&mut self) -> Result<WriteResult> {
self.as_mut().shutdown().await
}
}
#[async_trait]
pub trait WriteExt {
async fn write_protobuf(&mut self, msg: &impl Message) -> Result<usize>;
async fn write_struct<
'b,
M: Message + From<&'b T>,
T: ProtoStruct<Proto = M> + Send + Sync + 'b,
>(
&mut self,
obj: &'b T,
) -> Result<usize> {
let msg: M = M::from(obj);
self.write_protobuf(&msg).await
}
async fn write_magics(
&mut self,
pos: usize,
major_version: i16,
minor_version: i16,
magic: &[u8],
) -> Result<()>;
async fn copy_from_reader(&mut self, reader: &dyn Reader) -> Result<usize>;
async fn copy_range_from_reader(
&mut self,
reader: &dyn Reader,
range: Range<usize>,
) -> Result<usize>;
}
#[async_trait]
impl<W: Writer + ?Sized> WriteExt for W {
async fn write_protobuf(&mut self, msg: &impl Message) -> Result<usize> {
let offset = self.tell().await?;
let len = msg.encoded_len();
self.write_u32_le(len as u32).await?;
self.write_all(&msg.encode_to_vec()).await?;
Ok(offset)
}
async fn write_magics(
&mut self,
pos: usize,
major_version: i16,
minor_version: i16,
magic: &[u8],
) -> Result<()> {
self.write_i64_le(pos as i64).await?;
self.write_i16_le(major_version).await?;
self.write_i16_le(minor_version).await?;
self.write_all(magic).await?;
Ok(())
}
async fn copy_from_reader(&mut self, reader: &dyn Reader) -> Result<usize> {
let mut stream = reader.get_stream().await?;
let mut copied = 0usize;
while let Some(chunk) = stream.next().await {
let bytes = chunk?;
copied += bytes.len();
self.write_all(&bytes).await?;
}
Ok(copied)
}
async fn copy_range_from_reader(
&mut self,
reader: &dyn Reader,
range: Range<usize>,
) -> Result<usize> {
let mut stream = reader.get_range_stream(range).await?;
let mut copied = 0usize;
while let Some(chunk) = stream.next().await {
let bytes = chunk?;
copied += bytes.len();
self.write_all(&bytes).await?;
}
Ok(copied)
}
}
pub trait Reader: std::fmt::Debug + Send + Sync + DeepSizeOf {
fn path(&self) -> &Path;
fn block_size(&self) -> usize;
fn io_parallelism(&self) -> usize;
fn size(&self) -> BoxFuture<'_, object_store::Result<usize>>;
fn get_range(&self, range: Range<usize>) -> BoxFuture<'static, object_store::Result<Bytes>>;
fn get_all(&self) -> BoxFuture<'_, object_store::Result<Bytes>>;
fn get_stream(&self) -> BoxFuture<'_, object_store::Result<ByteStream>> {
Box::pin(async move {
let bytes = self.get_all().await?;
Ok(futures::stream::once(async move { Ok(bytes) }).boxed())
})
}
fn get_range_stream(
&self,
range: Range<usize>,
) -> BoxFuture<'_, object_store::Result<ByteStream>> {
Box::pin(async move {
let bytes = self.get_range(range).await?;
Ok(futures::stream::once(async move { Ok(bytes) }).boxed())
})
}
}