use std::io;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use std::task::Poll;
use crate::object_store::ObjectStore as LanceObjectStore;
use async_trait::async_trait;
use bytes::Bytes;
use futures::FutureExt;
use futures::future::BoxFuture;
use object_store::MultipartUpload;
use object_store::{Error as OSError, ObjectStore, Result as OSResult, path::Path};
use rand::Rng;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::task::JoinSet;
use lance_core::{Error, Result};
use tracing::Instrument;
use crate::traits::Writer;
use crate::utils::tracking_store::IOTracker;
use tokio::runtime::Handle;
const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
fn max_upload_parallelism() -> usize {
static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
*MAX_UPLOAD_PARALLELISM.get_or_init(|| {
std::env::var("LANCE_UPLOAD_CONCURRENCY")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(10)
})
}
fn max_conn_reset_retries() -> u16 {
static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
*MAX_CONN_RESET_RETRIES.get_or_init(|| {
std::env::var("LANCE_CONN_RESET_RETRIES")
.ok()
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(20)
})
}
fn initial_upload_size() -> usize {
static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
*LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.inspect(|size| {
if *size < INITIAL_UPLOAD_STEP {
panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
} else if *size > 1024 * 1024 * 1024 * 5 {
panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
}
})
.unwrap_or(INITIAL_UPLOAD_STEP)
})
}
pub struct ObjectWriter {
state: UploadState,
path: Arc<Path>,
cursor: usize,
connection_resets: u16,
buffer: Vec<u8>,
use_constant_size_upload_parts: bool,
}
#[derive(Debug, Clone, Default)]
pub struct WriteResult {
pub size: usize,
pub e_tag: Option<String>,
}
enum UploadState {
Started(Arc<dyn ObjectStore>),
CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
InProgress {
part_idx: u16,
upload: Box<dyn MultipartUpload>,
futures: JoinSet<std::result::Result<(), UploadPutError>>,
},
PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
Completing(BoxFuture<'static, OSResult<WriteResult>>),
Done(WriteResult),
}
impl UploadState {
fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
let this = std::mem::replace(self, Self::Done(WriteResult::default()));
*self = match this {
Self::Started(store) => {
let fut = async move {
let size = buffer.len();
let res = store.put(&path, buffer.into()).await?;
Ok(WriteResult {
size,
e_tag: res.e_tag,
})
};
Self::PuttingSingle(Box::pin(fut))
}
_ => unreachable!(),
}
}
fn in_progress_to_completing(&mut self) {
let this = std::mem::replace(self, Self::Done(WriteResult::default()));
*self = match this {
Self::InProgress {
mut upload,
futures,
..
} => {
debug_assert!(futures.is_empty());
let fut = async move {
let res = upload.complete().await?;
Ok(WriteResult {
size: 0, e_tag: res.e_tag,
})
};
Self::Completing(Box::pin(fut))
}
_ => unreachable!(),
};
}
}
impl ObjectWriter {
pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
Ok(Self {
state: UploadState::Started(object_store.inner.clone()),
cursor: 0,
path: Arc::new(path.clone()),
connection_resets: 0,
buffer: Vec::with_capacity(initial_upload_size()),
use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
})
}
fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
let new_capacity = if constant_upload_size {
initial_upload_size()
} else {
initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
};
let new_buffer = Vec::with_capacity(new_capacity);
let part = std::mem::replace(buffer, new_buffer);
Bytes::from(part)
}
fn put_part(
upload: &mut dyn MultipartUpload,
buffer: Bytes,
part_idx: u16,
sleep: Option<std::time::Duration>,
) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
log::debug!(
"MultipartUpload submitting part with {} bytes",
buffer.len()
);
let fut = upload.put_part(buffer.clone().into());
Box::pin(async move {
if let Some(sleep) = sleep {
tokio::time::sleep(sleep).await;
}
fut.await.map_err(|source| UploadPutError {
part_idx,
buffer,
source,
})?;
Ok(())
})
}
fn poll_tasks(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::result::Result<(), io::Error> {
let mut_self = &mut *self;
loop {
match &mut mut_self.state {
UploadState::Started(_) | UploadState::Done(_) => break,
UploadState::CreatingUpload(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(mut upload)) => {
let mut futures = JoinSet::new();
let data = Self::next_part_buffer(
&mut mut_self.buffer,
0,
mut_self.use_constant_size_upload_parts,
);
futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
mut_self.state = UploadState::InProgress {
part_idx: 1, futures,
upload,
};
}
Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
Poll::Pending => break,
},
UploadState::InProgress {
upload, futures, ..
} => {
while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
match res {
Ok(Ok(())) => {}
Err(err) => return Err(std::io::Error::other(err)),
Ok(Err(UploadPutError {
source: OSError::Generic { source, .. },
part_idx,
buffer,
})) if source
.to_string()
.to_lowercase()
.contains("connection reset by peer") =>
{
if mut_self.connection_resets < max_conn_reset_retries() {
mut_self.connection_resets += 1;
let sleep_time_ms = rand::rng().random_range(2_000..8_000);
let sleep_time =
std::time::Duration::from_millis(sleep_time_ms);
futures.spawn(Self::put_part(
upload.as_mut(),
buffer,
part_idx,
Some(sleep_time),
));
} else {
return Err(io::Error::new(
io::ErrorKind::ConnectionReset,
Box::new(ConnectionResetError {
message: format!(
"Hit max retries ({}) for connection reset",
max_conn_reset_retries()
),
source,
}),
));
}
}
Ok(Err(err)) => return Err(err.source.into()),
}
}
break;
}
UploadState::PuttingSingle(fut) | UploadState::Completing(fut) => {
match fut.poll_unpin(cx) {
Poll::Ready(Ok(mut res)) => {
res.size = mut_self.cursor;
mut_self.state = UploadState::Done(res)
}
Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
Poll::Pending => break,
}
}
}
}
Ok(())
}
pub async fn abort(&mut self) {
let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
if let UploadState::InProgress { mut upload, .. } = state {
let _ = upload.abort().await;
}
}
}
impl Drop for ObjectWriter {
fn drop(&mut self) {
if matches!(self.state, UploadState::InProgress { .. }) {
let state =
std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
if let UploadState::InProgress { mut upload, .. } = state
&& let Ok(handle) = Handle::try_current()
{
handle.spawn(async move {
let _ = upload.abort().await;
});
}
}
}
}
struct UploadPutError {
part_idx: u16,
buffer: Bytes,
source: OSError,
}
#[derive(Debug)]
struct ConnectionResetError {
message: String,
source: Box<dyn std::error::Error + Send + Sync>,
}
impl std::error::Error for ConnectionResetError {}
impl std::fmt::Display for ConnectionResetError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.message, self.source)
}
}
impl AsyncWrite for ObjectWriter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
self.as_mut().poll_tasks(cx)?;
let remaining_capacity = self.buffer.capacity() - self.buffer.len();
let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
self.buffer.extend_from_slice(&buf[..bytes_to_write]);
self.cursor += bytes_to_write;
let mut_self = &mut *self;
if mut_self.buffer.capacity() == mut_self.buffer.len() {
match &mut mut_self.state {
UploadState::Started(store) => {
let path = mut_self.path.clone();
let store = store.clone();
let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
self.state = UploadState::CreatingUpload(fut);
}
UploadState::InProgress {
upload,
part_idx,
futures,
..
} => {
if futures.len() < max_upload_parallelism() {
let data = Self::next_part_buffer(
&mut mut_self.buffer,
*part_idx,
mut_self.use_constant_size_upload_parts,
);
futures.spawn(
Self::put_part(upload.as_mut(), data, *part_idx, None)
.instrument(tracing::Span::current()),
);
*part_idx += 1;
}
}
_ => {}
}
}
self.poll_tasks(cx)?;
match bytes_to_write {
0 => Poll::Pending,
_ => Poll::Ready(Ok(bytes_to_write)),
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
self.as_mut().poll_tasks(cx)?;
match &self.state {
UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
UploadState::CreatingUpload(_)
| UploadState::Completing(_)
| UploadState::PuttingSingle(_) => Poll::Pending,
UploadState::InProgress { futures, .. } => {
if futures.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
loop {
self.as_mut().poll_tasks(cx)?;
let mut_self = &mut *self;
match &mut mut_self.state {
UploadState::Done(_) => return Poll::Ready(Ok(())),
UploadState::CreatingUpload(_)
| UploadState::PuttingSingle(_)
| UploadState::Completing(_) => return Poll::Pending,
UploadState::Started(_) => {
let part = std::mem::take(&mut mut_self.buffer);
let path = mut_self.path.clone();
self.state.started_to_putting_single(path, part);
}
UploadState::InProgress {
upload,
futures,
part_idx,
} => {
if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
futures.spawn(
Self::put_part(upload.as_mut(), data, *part_idx, None)
.instrument(tracing::Span::current()),
);
continue;
}
if futures.is_empty() {
self.state.in_progress_to_completing();
} else {
return Poll::Pending;
}
}
}
}
}
}
#[async_trait]
impl Writer for ObjectWriter {
async fn tell(&mut self) -> Result<usize> {
Ok(self.cursor)
}
async fn shutdown(&mut self) -> Result<WriteResult> {
AsyncWriteExt::shutdown(self).await.map_err(|e| {
Error::io(format!(
"failed to shutdown object writer for {}: {}",
self.path, e
))
})?;
if let UploadState::Done(result) = &self.state {
Ok(result.clone())
} else {
unreachable!()
}
}
}
pub struct LocalWriter {
path: Path,
state: LocalWriteState,
}
#[derive(Default)]
enum LocalWriteState {
Writing(WritingState),
Finishing {
size: usize,
future: BoxFuture<'static, Result<WriteResult>>,
},
Done(WriteResult),
#[default]
Poisoned,
}
struct WritingState {
writer: tokio::io::BufWriter<tokio::fs::File>,
cursor: usize,
temp_path: tempfile::TempPath,
io_tracker: Arc<IOTracker>,
}
impl LocalWriter {
pub fn new(
file: tokio::fs::File,
path: Path,
temp_path: tempfile::TempPath,
io_tracker: Arc<IOTracker>,
) -> Self {
Self {
path,
state: LocalWriteState::Writing(WritingState {
writer: tokio::io::BufWriter::new(file),
cursor: 0,
temp_path,
io_tracker,
}),
}
}
fn already_closed_err(path: &Path) -> io::Error {
io::Error::other(format!(
"cannot write to LocalWriter for {} after shutdown",
path
))
}
fn poisoned_err(path: &Path) -> io::Error {
io::Error::other(format!("LocalWriter for {} is in poisoned state", path))
}
async fn persist(
temp_path: tempfile::TempPath,
final_path: Path,
size: usize,
io_tracker: Arc<IOTracker>,
) -> Result<WriteResult> {
let local_path = crate::local::to_local_path(&final_path);
let e_tag = tokio::task::spawn_blocking(move || -> Result<String> {
temp_path.persist(&local_path).map_err(|e| {
Error::io(format!(
"failed to persist temp file to {}: {}",
local_path, e.error
))
})?;
let metadata = std::fs::metadata(&local_path).map_err(|e| {
Error::io(format!("failed to read metadata for {}: {}", local_path, e))
})?;
Ok(get_etag(&metadata))
})
.await
.map_err(|e| Error::io(format!("spawn_blocking failed: {}", e)))??;
io_tracker.record_write("put", final_path, size as u64);
Ok(WriteResult {
size,
e_tag: Some(e_tag),
})
}
}
impl AsyncWrite for LocalWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
if let LocalWriteState::Writing(state) = &mut self.state {
let poll = Pin::new(&mut state.writer).poll_write(cx, buf);
if let Poll::Ready(Ok(n)) = &poll {
state.cursor += *n;
}
poll
} else {
Poll::Ready(Err(Self::already_closed_err(&self.path)))
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
if let LocalWriteState::Writing(state) = &mut self.state {
Pin::new(&mut state.writer).poll_flush(cx)
} else {
Poll::Ready(Err(Self::already_closed_err(&self.path)))
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
let mut_self = &mut *self;
loop {
match &mut mut_self.state {
LocalWriteState::Writing(state) => {
if Pin::new(&mut state.writer).poll_shutdown(cx).is_pending() {
return Poll::Pending;
}
let LocalWriteState::Writing(state) =
std::mem::replace(&mut mut_self.state, LocalWriteState::Poisoned)
else {
unreachable!()
};
let size = state.cursor;
mut_self.state = LocalWriteState::Finishing {
size,
future: Box::pin(Self::persist(
state.temp_path,
mut_self.path.clone(),
size,
state.io_tracker,
)),
};
}
LocalWriteState::Finishing { future, .. } => match future.poll_unpin(cx) {
Poll::Ready(Ok(result)) => mut_self.state = LocalWriteState::Done(result),
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(io::Error::other(e)));
}
Poll::Pending => return Poll::Pending,
},
LocalWriteState::Done(_) => return Poll::Ready(Ok(())),
LocalWriteState::Poisoned => {
return Poll::Ready(Err(Self::poisoned_err(&self.path)));
}
}
}
}
}
#[async_trait]
impl Writer for LocalWriter {
async fn tell(&mut self) -> Result<usize> {
match &mut self.state {
LocalWriteState::Writing(state) => Ok(state.cursor),
LocalWriteState::Finishing { size, .. } => Ok(*size),
LocalWriteState::Done(result) => Ok(result.size),
LocalWriteState::Poisoned => Err(Self::poisoned_err(&self.path).into()),
}
}
async fn shutdown(&mut self) -> Result<WriteResult> {
AsyncWriteExt::shutdown(self).await.map_err(|e| {
Error::io(format!(
"failed to shutdown local writer for {}: {}",
self.path, e
))
})?;
match &self.state {
LocalWriteState::Done(result) => Ok(result.clone()),
_ => unreachable!(),
}
}
}
pub fn get_etag(metadata: &std::fs::Metadata) -> String {
let inode = get_inode(metadata);
let size = metadata.len();
let mtime = metadata
.modified()
.ok()
.and_then(|mtime| mtime.duration_since(std::time::SystemTime::UNIX_EPOCH).ok())
.unwrap_or_default()
.as_micros();
format!("{inode:x}-{mtime:x}-{size:x}")
}
#[cfg(unix)]
fn get_inode(metadata: &std::fs::Metadata) -> u64 {
std::os::unix::fs::MetadataExt::ino(metadata)
}
#[cfg(not(unix))]
fn get_inode(_metadata: &std::fs::Metadata) -> u64 {
0
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncWriteExt;
use super::*;
#[tokio::test]
async fn test_write() {
let store = LanceObjectStore::memory();
let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
.await
.unwrap();
assert_eq!(object_writer.tell().await.unwrap(), 0);
let buf = vec![0; 256];
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 256);
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 512);
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
let res = Writer::shutdown(&mut object_writer).await.unwrap();
assert_eq!(res.size, 256 * 3);
let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
.await
.unwrap();
let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
for i in 0..5 {
object_writer.write_all(buf.as_slice()).await.unwrap();
assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
}
let res = Writer::shutdown(&mut object_writer).await.unwrap();
assert_eq!(res.size, buf.len() * 5);
}
#[tokio::test]
async fn test_abort_write() {
let store = LanceObjectStore::memory();
let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
.await
.unwrap();
object_writer.abort().await;
}
#[tokio::test]
async fn test_local_writer_shutdown() {
let tmp = lance_core::utils::tempfile::TempStdDir::default();
let file_path = tmp.join("test_local_writer.bin");
let os_path = Path::from_absolute_path(&file_path).unwrap();
let io_tracker = Arc::new(IOTracker::default());
let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
let temp_file_path = named_temp.path().to_owned();
let (std_file, temp_path) = named_temp.into_parts();
let file = tokio::fs::File::from_std(std_file);
let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker.clone());
let data = b"hello local writer";
writer.write_all(data).await.unwrap();
assert!(!file_path.exists());
assert!(temp_file_path.exists());
let result = Writer::shutdown(&mut writer).await.unwrap();
assert_eq!(result.size, data.len());
assert!(result.e_tag.is_some());
assert!(!result.e_tag.as_ref().unwrap().is_empty());
assert!(file_path.exists());
assert!(!temp_file_path.exists());
let stats = io_tracker.stats();
assert_eq!(stats.write_iops, 1);
assert_eq!(stats.written_bytes, data.len() as u64);
}
#[tokio::test]
async fn test_local_writer_drop_cleans_up() {
let tmp = lance_core::utils::tempfile::TempStdDir::default();
let file_path = tmp.join("test_drop.bin");
let os_path = Path::from_absolute_path(&file_path).unwrap();
let io_tracker = Arc::new(IOTracker::default());
let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
let temp_file_path = named_temp.path().to_owned();
let (std_file, temp_path) = named_temp.into_parts();
let file = tokio::fs::File::from_std(std_file);
let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker);
writer.write_all(b"some data").await.unwrap();
assert!(temp_file_path.exists());
drop(writer);
assert!(!temp_file_path.exists());
assert!(!file_path.exists());
}
}