use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use object_store::{path::Path, MultipartId, ObjectStore};
use pin_project::pin_project;
use snafu::{location, Location};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use lance_core::{Error, Result};
use crate::traits::Writer;
#[pin_project]
pub struct ObjectWriter {
#[pin]
writer: Box<dyn AsyncWrite + Send + Unpin>,
pub multipart_id: MultipartId,
cursor: usize,
}
impl ObjectWriter {
pub async fn new(object_store: &dyn ObjectStore, path: &Path) -> Result<Self> {
let (multipart_id, writer) =
object_store
.put_multipart(path)
.await
.map_err(|e| Error::IO {
message: format!("failed to create object writer for {}: {}", path, e),
location: location!(),
})?;
Ok(Self {
writer,
multipart_id,
cursor: 0,
})
}
pub async fn shutdown(&mut self) -> Result<()> {
Ok(self.writer.as_mut().shutdown().await?)
}
}
#[async_trait]
impl Writer for ObjectWriter {
async fn tell(&mut self) -> Result<usize> {
Ok(self.cursor)
}
}
impl AsyncWrite for ObjectWriter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut this = self.project();
this.writer.as_mut().poll_write(cx, buf).map_ok(|n| {
*this.cursor += n;
n
})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().writer.as_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().writer.as_mut().poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use object_store::{memory::InMemory, path::Path};
use tokio::io::AsyncWriteExt;
use super::*;
#[tokio::test]
async fn test_write() {
let store = InMemory::new();
let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
.await
.unwrap();
assert_eq!(object_writer.tell().await.unwrap(), 0);
let buf = vec![0; 256];
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 256);
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 512);
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
object_writer.shutdown().await.unwrap();
}
}