1use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::FutureExt;
13use futures::future::BoxFuture;
14use object_store::MultipartUpload;
15use object_store::{Error as OSError, ObjectStore, Result as OSResult, path::Path};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use crate::utils::tracking_store::IOTracker;
25use tokio::runtime::Handle;
26
27const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31 static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32 *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33 std::env::var("LANCE_UPLOAD_CONCURRENCY")
34 .ok()
35 .and_then(|s| s.parse::<usize>().ok())
36 .unwrap_or(10)
37 })
38}
39
40fn max_conn_reset_retries() -> u16 {
41 static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42 *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43 std::env::var("LANCE_CONN_RESET_RETRIES")
44 .ok()
45 .and_then(|s| s.parse::<u16>().ok())
46 .unwrap_or(20)
47 })
48}
49
50fn initial_upload_size() -> usize {
51 static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
52 *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
53 std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
54 .ok()
55 .and_then(|s| s.parse::<usize>().ok())
56 .inspect(|size| {
57 if *size < INITIAL_UPLOAD_STEP {
58 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
60 } else if *size > 1024 * 1024 * 1024 * 5 {
61 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
63 }
64 })
65 .unwrap_or(INITIAL_UPLOAD_STEP)
66 })
67}
68
69pub struct ObjectWriter {
77 state: UploadState,
78 path: Arc<Path>,
79 cursor: usize,
80 connection_resets: u16,
81 buffer: Vec<u8>,
82 use_constant_size_upload_parts: bool,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct WriteResult {
88 pub size: usize,
89 pub e_tag: Option<String>,
90}
91
92enum UploadState {
93 Started(Arc<dyn ObjectStore>),
96 CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
98 InProgress {
100 part_idx: u16,
101 upload: Box<dyn MultipartUpload>,
102 futures: JoinSet<std::result::Result<(), UploadPutError>>,
103 },
104 PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
107 Completing(BoxFuture<'static, OSResult<WriteResult>>),
109 Done(WriteResult),
111}
112
113impl UploadState {
115 fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
116 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
118 *self = match this {
119 Self::Started(store) => {
120 let fut = async move {
121 let size = buffer.len();
122 let res = store.put(&path, buffer.into()).await?;
123 Ok(WriteResult {
124 size,
125 e_tag: res.e_tag,
126 })
127 };
128 Self::PuttingSingle(Box::pin(fut))
129 }
130 _ => unreachable!(),
131 }
132 }
133
134 fn in_progress_to_completing(&mut self) {
135 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
137 *self = match this {
138 Self::InProgress {
139 mut upload,
140 futures,
141 ..
142 } => {
143 debug_assert!(futures.is_empty());
144 let fut = async move {
145 let res = upload.complete().await?;
146 Ok(WriteResult {
147 size: 0, e_tag: res.e_tag,
149 })
150 };
151 Self::Completing(Box::pin(fut))
152 }
153 _ => unreachable!(),
154 };
155 }
156}
157
158impl ObjectWriter {
159 pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
160 Ok(Self {
161 state: UploadState::Started(object_store.inner.clone()),
162 cursor: 0,
163 path: Arc::new(path.clone()),
164 connection_resets: 0,
165 buffer: Vec::with_capacity(initial_upload_size()),
166 use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
167 })
168 }
169
170 fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
173 let new_capacity = if constant_upload_size {
174 initial_upload_size()
176 } else {
177 initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
179 };
180 let new_buffer = Vec::with_capacity(new_capacity);
181 let part = std::mem::replace(buffer, new_buffer);
182 Bytes::from(part)
183 }
184
185 fn put_part(
186 upload: &mut dyn MultipartUpload,
187 buffer: Bytes,
188 part_idx: u16,
189 sleep: Option<std::time::Duration>,
190 ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
191 log::debug!(
192 "MultipartUpload submitting part with {} bytes",
193 buffer.len()
194 );
195 let fut = upload.put_part(buffer.clone().into());
196 Box::pin(async move {
197 if let Some(sleep) = sleep {
198 tokio::time::sleep(sleep).await;
199 }
200 fut.await.map_err(|source| UploadPutError {
201 part_idx,
202 buffer,
203 source,
204 })?;
205 Ok(())
206 })
207 }
208
209 fn poll_tasks(
210 mut self: Pin<&mut Self>,
211 cx: &mut std::task::Context<'_>,
212 ) -> std::result::Result<(), io::Error> {
213 let mut_self = &mut *self;
214 loop {
215 match &mut mut_self.state {
216 UploadState::Started(_) | UploadState::Done(_) => break,
217 UploadState::CreatingUpload(fut) => match fut.poll_unpin(cx) {
218 Poll::Ready(Ok(mut upload)) => {
219 let mut futures = JoinSet::new();
220
221 let data = Self::next_part_buffer(
222 &mut mut_self.buffer,
223 0,
224 mut_self.use_constant_size_upload_parts,
225 );
226 futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
227
228 mut_self.state = UploadState::InProgress {
229 part_idx: 1, futures,
231 upload,
232 };
233 }
234 Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
235 Poll::Pending => break,
236 },
237 UploadState::InProgress {
238 upload, futures, ..
239 } => {
240 while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
241 match res {
242 Ok(Ok(())) => {}
243 Err(err) => return Err(std::io::Error::other(err)),
244 Ok(Err(UploadPutError {
245 source: OSError::Generic { source, .. },
246 part_idx,
247 buffer,
248 })) if source
249 .to_string()
250 .to_lowercase()
251 .contains("connection reset by peer") =>
252 {
253 if mut_self.connection_resets < max_conn_reset_retries() {
254 mut_self.connection_resets += 1;
256
257 let sleep_time_ms = rand::rng().random_range(2_000..8_000);
259 let sleep_time =
260 std::time::Duration::from_millis(sleep_time_ms);
261
262 futures.spawn(Self::put_part(
263 upload.as_mut(),
264 buffer,
265 part_idx,
266 Some(sleep_time),
267 ));
268 } else {
269 return Err(io::Error::new(
270 io::ErrorKind::ConnectionReset,
271 Box::new(ConnectionResetError {
272 message: format!(
273 "Hit max retries ({}) for connection reset",
274 max_conn_reset_retries()
275 ),
276 source,
277 }),
278 ));
279 }
280 }
281 Ok(Err(err)) => return Err(err.source.into()),
282 }
283 }
284 break;
285 }
286 UploadState::PuttingSingle(fut) | UploadState::Completing(fut) => {
287 match fut.poll_unpin(cx) {
288 Poll::Ready(Ok(mut res)) => {
289 res.size = mut_self.cursor;
290 mut_self.state = UploadState::Done(res)
291 }
292 Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
293 Poll::Pending => break,
294 }
295 }
296 }
297 }
298 Ok(())
299 }
300
301 pub async fn abort(&mut self) {
302 let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
303 if let UploadState::InProgress { mut upload, .. } = state {
304 let _ = upload.abort().await;
305 }
306 }
307}
308
309impl Drop for ObjectWriter {
310 fn drop(&mut self) {
311 if matches!(self.state, UploadState::InProgress { .. }) {
313 let state =
315 std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
316 if let UploadState::InProgress { mut upload, .. } = state
317 && let Ok(handle) = Handle::try_current()
318 {
319 handle.spawn(async move {
320 let _ = upload.abort().await;
321 });
322 }
323 }
324 }
325}
326
327struct UploadPutError {
331 part_idx: u16,
332 buffer: Bytes,
333 source: OSError,
334}
335
336#[derive(Debug)]
337struct ConnectionResetError {
338 message: String,
339 source: Box<dyn std::error::Error + Send + Sync>,
340}
341
342impl std::error::Error for ConnectionResetError {}
343
344impl std::fmt::Display for ConnectionResetError {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 write!(f, "{}: {}", self.message, self.source)
347 }
348}
349
350impl AsyncWrite for ObjectWriter {
351 fn poll_write(
352 mut self: std::pin::Pin<&mut Self>,
353 cx: &mut std::task::Context<'_>,
354 buf: &[u8],
355 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
356 self.as_mut().poll_tasks(cx)?;
357
358 let remaining_capacity = self.buffer.capacity() - self.buffer.len();
360 let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
361 self.buffer.extend_from_slice(&buf[..bytes_to_write]);
362 self.cursor += bytes_to_write;
363
364 let mut_self = &mut *self;
367
368 if mut_self.buffer.capacity() == mut_self.buffer.len() {
370 match &mut mut_self.state {
371 UploadState::Started(store) => {
372 let path = mut_self.path.clone();
373 let store = store.clone();
374 let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
375 self.state = UploadState::CreatingUpload(fut);
376 }
377 UploadState::InProgress {
378 upload,
379 part_idx,
380 futures,
381 ..
382 } => {
383 if futures.len() < max_upload_parallelism() {
385 let data = Self::next_part_buffer(
386 &mut mut_self.buffer,
387 *part_idx,
388 mut_self.use_constant_size_upload_parts,
389 );
390 futures.spawn(
391 Self::put_part(upload.as_mut(), data, *part_idx, None)
392 .instrument(tracing::Span::current()),
393 );
394 *part_idx += 1;
395 }
396 }
397 _ => {}
398 }
399 }
400
401 self.poll_tasks(cx)?;
402
403 match bytes_to_write {
404 0 => Poll::Pending,
405 _ => Poll::Ready(Ok(bytes_to_write)),
406 }
407 }
408
409 fn poll_flush(
410 mut self: std::pin::Pin<&mut Self>,
411 cx: &mut std::task::Context<'_>,
412 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
413 self.as_mut().poll_tasks(cx)?;
414
415 match &self.state {
416 UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
417 UploadState::CreatingUpload(_)
418 | UploadState::Completing(_)
419 | UploadState::PuttingSingle(_) => Poll::Pending,
420 UploadState::InProgress { futures, .. } => {
421 if futures.is_empty() {
422 Poll::Ready(Ok(()))
423 } else {
424 Poll::Pending
425 }
426 }
427 }
428 }
429
430 fn poll_shutdown(
431 mut self: std::pin::Pin<&mut Self>,
432 cx: &mut std::task::Context<'_>,
433 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
434 loop {
435 self.as_mut().poll_tasks(cx)?;
436
437 let mut_self = &mut *self;
440 match &mut mut_self.state {
441 UploadState::Done(_) => return Poll::Ready(Ok(())),
442 UploadState::CreatingUpload(_)
443 | UploadState::PuttingSingle(_)
444 | UploadState::Completing(_) => return Poll::Pending,
445 UploadState::Started(_) => {
446 let part = std::mem::take(&mut mut_self.buffer);
448 let path = mut_self.path.clone();
449 self.state.started_to_putting_single(path, part);
450 }
451 UploadState::InProgress {
452 upload,
453 futures,
454 part_idx,
455 } => {
456 if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
458 let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
460 futures.spawn(
461 Self::put_part(upload.as_mut(), data, *part_idx, None)
462 .instrument(tracing::Span::current()),
463 );
464 continue;
467 }
468
469 if futures.is_empty() {
471 self.state.in_progress_to_completing();
472 } else {
473 return Poll::Pending;
474 }
475 }
476 }
477 }
478 }
479}
480
481#[async_trait]
482impl Writer for ObjectWriter {
483 async fn tell(&mut self) -> Result<usize> {
484 Ok(self.cursor)
485 }
486
487 async fn shutdown(&mut self) -> Result<WriteResult> {
488 AsyncWriteExt::shutdown(self).await.map_err(|e| {
489 Error::io(format!(
490 "failed to shutdown object writer for {}: {}",
491 self.path, e
492 ))
493 })?;
494 if let UploadState::Done(result) = &self.state {
495 Ok(result.clone())
496 } else {
497 unreachable!()
498 }
499 }
500}
501
502pub struct LocalWriter {
503 inner: tokio::io::BufWriter<tokio::fs::File>,
504 cursor: usize,
505 path: Path,
506 temp_path: Option<tempfile::TempPath>,
508 io_tracker: Arc<IOTracker>,
509}
510
511impl LocalWriter {
512 pub fn new(
513 file: tokio::fs::File,
514 path: Path,
515 temp_path: tempfile::TempPath,
516 io_tracker: Arc<IOTracker>,
517 ) -> Self {
518 Self {
519 inner: tokio::io::BufWriter::new(file),
520 cursor: 0,
521 path,
522 temp_path: Some(temp_path),
523 io_tracker,
524 }
525 }
526}
527
528impl AsyncWrite for LocalWriter {
529 fn poll_write(
530 mut self: Pin<&mut Self>,
531 cx: &mut std::task::Context<'_>,
532 buf: &[u8],
533 ) -> Poll<std::result::Result<usize, std::io::Error>> {
534 let poll = Pin::new(&mut self.inner).poll_write(cx, buf);
535 if let Poll::Ready(Ok(n)) = &poll {
536 self.cursor += *n;
537 }
538 poll
539 }
540
541 fn poll_flush(
542 mut self: Pin<&mut Self>,
543 cx: &mut std::task::Context<'_>,
544 ) -> Poll<std::result::Result<(), std::io::Error>> {
545 Pin::new(&mut self.inner).poll_flush(cx)
546 }
547
548 fn poll_shutdown(
549 mut self: Pin<&mut Self>,
550 cx: &mut std::task::Context<'_>,
551 ) -> Poll<std::result::Result<(), std::io::Error>> {
552 Pin::new(&mut self.inner).poll_shutdown(cx)
553 }
554}
555
556#[async_trait]
557impl Writer for LocalWriter {
558 async fn tell(&mut self) -> Result<usize> {
559 Ok(self.cursor)
560 }
561
562 async fn shutdown(&mut self) -> Result<WriteResult> {
563 AsyncWriteExt::shutdown(self).await.map_err(|e| {
564 Error::io(format!(
565 "failed to shutdown local writer for {}: {}",
566 self.path, e
567 ))
568 })?;
569
570 let final_path = crate::local::to_local_path(&self.path);
571 let temp_path = self.temp_path.take().ok_or_else(|| {
572 Error::io(format!("local writer for {} already shut down", self.path))
573 })?;
574 let path_clone = self.path.clone();
575 let e_tag = tokio::task::spawn_blocking(move || -> Result<String> {
576 temp_path.persist(&final_path).map_err(|e| {
577 Error::io(format!(
578 "failed to persist temp file to {}: {}",
579 final_path, e.error
580 ))
581 })?;
582
583 let metadata = std::fs::metadata(&final_path).map_err(|e| {
584 Error::io(format!("failed to read metadata for {}: {}", path_clone, e))
585 })?;
586 Ok(get_etag(&metadata))
587 })
588 .await
589 .map_err(|e| Error::io(format!("spawn_blocking failed: {}", e)))??;
590
591 self.io_tracker
592 .record_write("put", self.path.clone(), self.cursor as u64);
593
594 Ok(WriteResult {
595 size: self.cursor,
596 e_tag: Some(e_tag),
597 })
598 }
599}
600
601pub fn get_etag(metadata: &std::fs::Metadata) -> String {
603 let inode = get_inode(metadata);
604 let size = metadata.len();
605 let mtime = metadata
606 .modified()
607 .ok()
608 .and_then(|mtime| mtime.duration_since(std::time::SystemTime::UNIX_EPOCH).ok())
609 .unwrap_or_default()
610 .as_micros();
611
612 format!("{inode:x}-{mtime:x}-{size:x}")
615}
616
617#[cfg(unix)]
618fn get_inode(metadata: &std::fs::Metadata) -> u64 {
619 std::os::unix::fs::MetadataExt::ino(metadata)
620}
621
622#[cfg(not(unix))]
623fn get_inode(_metadata: &std::fs::Metadata) -> u64 {
624 0
625}
626
627#[cfg(test)]
628mod tests {
629 use tokio::io::AsyncWriteExt;
630
631 use super::*;
632
633 #[tokio::test]
634 async fn test_write() {
635 let store = LanceObjectStore::memory();
636
637 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
638 .await
639 .unwrap();
640 assert_eq!(object_writer.tell().await.unwrap(), 0);
641
642 let buf = vec![0; 256];
643 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
644 assert_eq!(object_writer.tell().await.unwrap(), 256);
645
646 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
647 assert_eq!(object_writer.tell().await.unwrap(), 512);
648
649 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
650 assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
651
652 let res = Writer::shutdown(&mut object_writer).await.unwrap();
653 assert_eq!(res.size, 256 * 3);
654
655 let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
657 .await
658 .unwrap();
659 let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
660 for i in 0..5 {
661 object_writer.write_all(buf.as_slice()).await.unwrap();
664 assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
666 }
667 let res = Writer::shutdown(&mut object_writer).await.unwrap();
668 assert_eq!(res.size, buf.len() * 5);
669 }
670
671 #[tokio::test]
672 async fn test_abort_write() {
673 let store = LanceObjectStore::memory();
674
675 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
676 .await
677 .unwrap();
678 object_writer.abort().await;
679 }
680
681 #[tokio::test]
682 async fn test_local_writer_shutdown() {
683 let tmp = lance_core::utils::tempfile::TempStdDir::default();
684 let file_path = tmp.join("test_local_writer.bin");
685 let os_path = Path::from_absolute_path(&file_path).unwrap();
686 let io_tracker = Arc::new(IOTracker::default());
687
688 let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
689 let temp_file_path = named_temp.path().to_owned();
690 let (std_file, temp_path) = named_temp.into_parts();
691 let file = tokio::fs::File::from_std(std_file);
692 let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker.clone());
693
694 let data = b"hello local writer";
695 writer.write_all(data).await.unwrap();
696
697 assert!(!file_path.exists());
699 assert!(temp_file_path.exists());
701
702 let result = Writer::shutdown(&mut writer).await.unwrap();
703 assert_eq!(result.size, data.len());
704 assert!(result.e_tag.is_some());
705 assert!(!result.e_tag.as_ref().unwrap().is_empty());
706
707 assert!(file_path.exists());
709 assert!(!temp_file_path.exists());
710
711 let stats = io_tracker.stats();
712 assert_eq!(stats.write_iops, 1);
713 assert_eq!(stats.written_bytes, data.len() as u64);
714 }
715
716 #[tokio::test]
717 async fn test_local_writer_drop_cleans_up() {
718 let tmp = lance_core::utils::tempfile::TempStdDir::default();
719 let file_path = tmp.join("test_drop.bin");
720 let os_path = Path::from_absolute_path(&file_path).unwrap();
721 let io_tracker = Arc::new(IOTracker::default());
722
723 let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
724 let temp_file_path = named_temp.path().to_owned();
725 let (std_file, temp_path) = named_temp.into_parts();
726 let file = tokio::fs::File::from_std(std_file);
727 let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker);
728
729 writer.write_all(b"some data").await.unwrap();
730 assert!(temp_file_path.exists());
731
732 drop(writer);
734 assert!(!temp_file_path.exists());
735 assert!(!file_path.exists());
736 }
737}