Skip to main content

lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::FutureExt;
13use futures::future::BoxFuture;
14use object_store::MultipartUpload;
15use object_store::{Error as OSError, ObjectStore, Result as OSResult, path::Path};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use crate::utils::tracking_store::IOTracker;
25use tokio::runtime::Handle;
26
27/// Start at 5MB.
28const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33        std::env::var("LANCE_UPLOAD_CONCURRENCY")
34            .ok()
35            .and_then(|s| s.parse::<usize>().ok())
36            .unwrap_or(10)
37    })
38}
39
40fn max_conn_reset_retries() -> u16 {
41    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43        std::env::var("LANCE_CONN_RESET_RETRIES")
44            .ok()
45            .and_then(|s| s.parse::<u16>().ok())
46            .unwrap_or(20)
47    })
48}
49
50fn initial_upload_size() -> usize {
51    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
52    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
53        std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
54            .ok()
55            .and_then(|s| s.parse::<usize>().ok())
56            .inspect(|size| {
57                if *size < INITIAL_UPLOAD_STEP {
58                    // Minimum part size in GCS and S3
59                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
60                } else if *size > 1024 * 1024 * 1024 * 5 {
61                    // Maximum part size in GCS and S3
62                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
63                }
64            })
65            .unwrap_or(INITIAL_UPLOAD_STEP)
66    })
67}
68
69/// Writer to an object in an object store.
70///
71/// If the object is small enough, the writer will upload the object in a single
72/// PUT request. If the object is larger, the writer will create a multipart
73/// upload and upload parts in parallel.
74///
75/// This implements the `AsyncWrite` trait.
76pub struct ObjectWriter {
77    state: UploadState,
78    path: Arc<Path>,
79    cursor: usize,
80    connection_resets: u16,
81    buffer: Vec<u8>,
82    // TODO: use constant size to support R2
83    use_constant_size_upload_parts: bool,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct WriteResult {
88    pub size: usize,
89    pub e_tag: Option<String>,
90}
91
92enum UploadState {
93    /// The writer has been opened but no data has been written yet. Will be in
94    /// this state until the buffer is full or the writer is shut down.
95    Started(Arc<dyn ObjectStore>),
96    /// The writer is in the process of creating a multipart upload.
97    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
98    /// The writer is in the process of uploading parts.
99    InProgress {
100        part_idx: u16,
101        upload: Box<dyn MultipartUpload>,
102        futures: JoinSet<std::result::Result<(), UploadPutError>>,
103    },
104    /// The writer is in the process of uploading data in a single PUT request.
105    /// This happens when shutdown is called before the buffer is full.
106    PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
107    /// The writer is in the process of completing the multipart upload.
108    Completing(BoxFuture<'static, OSResult<WriteResult>>),
109    /// The writer has been shut down and all data has been written.
110    Done(WriteResult),
111}
112
113/// Methods for state transitions.
114impl UploadState {
115    fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
116        // To get owned self, we temporarily swap with Done.
117        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
118        *self = match this {
119            Self::Started(store) => {
120                let fut = async move {
121                    let size = buffer.len();
122                    let res = store.put(&path, buffer.into()).await?;
123                    Ok(WriteResult {
124                        size,
125                        e_tag: res.e_tag,
126                    })
127                };
128                Self::PuttingSingle(Box::pin(fut))
129            }
130            _ => unreachable!(),
131        }
132    }
133
134    fn in_progress_to_completing(&mut self) {
135        // To get owned self, we temporarily swap with Done.
136        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
137        *self = match this {
138            Self::InProgress {
139                mut upload,
140                futures,
141                ..
142            } => {
143                debug_assert!(futures.is_empty());
144                let fut = async move {
145                    let res = upload.complete().await?;
146                    Ok(WriteResult {
147                        size: 0, // This will be set properly later.
148                        e_tag: res.e_tag,
149                    })
150                };
151                Self::Completing(Box::pin(fut))
152            }
153            _ => unreachable!(),
154        };
155    }
156}
157
158impl ObjectWriter {
159    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
160        Ok(Self {
161            state: UploadState::Started(object_store.inner.clone()),
162            cursor: 0,
163            path: Arc::new(path.clone()),
164            connection_resets: 0,
165            buffer: Vec::with_capacity(initial_upload_size()),
166            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
167        })
168    }
169
170    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
171    /// The new capacity of `buffer` is determined by the current part index.
172    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
173        let new_capacity = if constant_upload_size {
174            // The store does not support variable part sizes, so use the initial size.
175            initial_upload_size()
176        } else {
177            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
178            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
179        };
180        let new_buffer = Vec::with_capacity(new_capacity);
181        let part = std::mem::replace(buffer, new_buffer);
182        Bytes::from(part)
183    }
184
185    fn put_part(
186        upload: &mut dyn MultipartUpload,
187        buffer: Bytes,
188        part_idx: u16,
189        sleep: Option<std::time::Duration>,
190    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
191        log::debug!(
192            "MultipartUpload submitting part with {} bytes",
193            buffer.len()
194        );
195        let fut = upload.put_part(buffer.clone().into());
196        Box::pin(async move {
197            if let Some(sleep) = sleep {
198                tokio::time::sleep(sleep).await;
199            }
200            fut.await.map_err(|source| UploadPutError {
201                part_idx,
202                buffer,
203                source,
204            })?;
205            Ok(())
206        })
207    }
208
209    fn poll_tasks(
210        mut self: Pin<&mut Self>,
211        cx: &mut std::task::Context<'_>,
212    ) -> std::result::Result<(), io::Error> {
213        let mut_self = &mut *self;
214        loop {
215            match &mut mut_self.state {
216                UploadState::Started(_) | UploadState::Done(_) => break,
217                UploadState::CreatingUpload(fut) => match fut.poll_unpin(cx) {
218                    Poll::Ready(Ok(mut upload)) => {
219                        let mut futures = JoinSet::new();
220
221                        let data = Self::next_part_buffer(
222                            &mut mut_self.buffer,
223                            0,
224                            mut_self.use_constant_size_upload_parts,
225                        );
226                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
227
228                        mut_self.state = UploadState::InProgress {
229                            part_idx: 1, // We just used 0
230                            futures,
231                            upload,
232                        };
233                    }
234                    Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
235                    Poll::Pending => break,
236                },
237                UploadState::InProgress {
238                    upload, futures, ..
239                } => {
240                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
241                        match res {
242                            Ok(Ok(())) => {}
243                            Err(err) => return Err(std::io::Error::other(err)),
244                            Ok(Err(UploadPutError {
245                                source: OSError::Generic { source, .. },
246                                part_idx,
247                                buffer,
248                            })) if source
249                                .to_string()
250                                .to_lowercase()
251                                .contains("connection reset by peer") =>
252                            {
253                                if mut_self.connection_resets < max_conn_reset_retries() {
254                                    // Retry, but only up to max_conn_reset_retries of them.
255                                    mut_self.connection_resets += 1;
256
257                                    // Resubmit with random jitter
258                                    let sleep_time_ms = rand::rng().random_range(2_000..8_000);
259                                    let sleep_time =
260                                        std::time::Duration::from_millis(sleep_time_ms);
261
262                                    futures.spawn(Self::put_part(
263                                        upload.as_mut(),
264                                        buffer,
265                                        part_idx,
266                                        Some(sleep_time),
267                                    ));
268                                } else {
269                                    return Err(io::Error::new(
270                                        io::ErrorKind::ConnectionReset,
271                                        Box::new(ConnectionResetError {
272                                            message: format!(
273                                                "Hit max retries ({}) for connection reset",
274                                                max_conn_reset_retries()
275                                            ),
276                                            source,
277                                        }),
278                                    ));
279                                }
280                            }
281                            Ok(Err(err)) => return Err(err.source.into()),
282                        }
283                    }
284                    break;
285                }
286                UploadState::PuttingSingle(fut) | UploadState::Completing(fut) => {
287                    match fut.poll_unpin(cx) {
288                        Poll::Ready(Ok(mut res)) => {
289                            res.size = mut_self.cursor;
290                            mut_self.state = UploadState::Done(res)
291                        }
292                        Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
293                        Poll::Pending => break,
294                    }
295                }
296            }
297        }
298        Ok(())
299    }
300
301    pub async fn abort(&mut self) {
302        let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
303        if let UploadState::InProgress { mut upload, .. } = state {
304            let _ = upload.abort().await;
305        }
306    }
307}
308
309impl Drop for ObjectWriter {
310    fn drop(&mut self) {
311        // If there is a multipart upload started but not finished, we should abort it.
312        if matches!(self.state, UploadState::InProgress { .. }) {
313            // Take ownership of the state.
314            let state =
315                std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
316            if let UploadState::InProgress { mut upload, .. } = state
317                && let Ok(handle) = Handle::try_current()
318            {
319                handle.spawn(async move {
320                    let _ = upload.abort().await;
321                });
322            }
323        }
324    }
325}
326
327/// Returned error from trying to upload a part.
328/// Has the part_idx and buffer so we can pass
329/// them to the retry logic.
330struct UploadPutError {
331    part_idx: u16,
332    buffer: Bytes,
333    source: OSError,
334}
335
336#[derive(Debug)]
337struct ConnectionResetError {
338    message: String,
339    source: Box<dyn std::error::Error + Send + Sync>,
340}
341
342impl std::error::Error for ConnectionResetError {}
343
344impl std::fmt::Display for ConnectionResetError {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(f, "{}: {}", self.message, self.source)
347    }
348}
349
350impl AsyncWrite for ObjectWriter {
351    fn poll_write(
352        mut self: std::pin::Pin<&mut Self>,
353        cx: &mut std::task::Context<'_>,
354        buf: &[u8],
355    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
356        self.as_mut().poll_tasks(cx)?;
357
358        // Fill buffer up to remaining capacity.
359        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
360        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
361        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
362        self.cursor += bytes_to_write;
363
364        // Rust needs a little help to borrow self mutably and immutably at the same time
365        // through a Pin.
366        let mut_self = &mut *self;
367
368        // Instantiate next request, if available.
369        if mut_self.buffer.capacity() == mut_self.buffer.len() {
370            match &mut mut_self.state {
371                UploadState::Started(store) => {
372                    let path = mut_self.path.clone();
373                    let store = store.clone();
374                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
375                    self.state = UploadState::CreatingUpload(fut);
376                }
377                UploadState::InProgress {
378                    upload,
379                    part_idx,
380                    futures,
381                    ..
382                } => {
383                    // TODO: Make max concurrency configurable from storage options.
384                    if futures.len() < max_upload_parallelism() {
385                        let data = Self::next_part_buffer(
386                            &mut mut_self.buffer,
387                            *part_idx,
388                            mut_self.use_constant_size_upload_parts,
389                        );
390                        futures.spawn(
391                            Self::put_part(upload.as_mut(), data, *part_idx, None)
392                                .instrument(tracing::Span::current()),
393                        );
394                        *part_idx += 1;
395                    }
396                }
397                _ => {}
398            }
399        }
400
401        self.poll_tasks(cx)?;
402
403        match bytes_to_write {
404            0 => Poll::Pending,
405            _ => Poll::Ready(Ok(bytes_to_write)),
406        }
407    }
408
409    fn poll_flush(
410        mut self: std::pin::Pin<&mut Self>,
411        cx: &mut std::task::Context<'_>,
412    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
413        self.as_mut().poll_tasks(cx)?;
414
415        match &self.state {
416            UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
417            UploadState::CreatingUpload(_)
418            | UploadState::Completing(_)
419            | UploadState::PuttingSingle(_) => Poll::Pending,
420            UploadState::InProgress { futures, .. } => {
421                if futures.is_empty() {
422                    Poll::Ready(Ok(()))
423                } else {
424                    Poll::Pending
425                }
426            }
427        }
428    }
429
430    fn poll_shutdown(
431        mut self: std::pin::Pin<&mut Self>,
432        cx: &mut std::task::Context<'_>,
433    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
434        loop {
435            self.as_mut().poll_tasks(cx)?;
436
437            // Rust needs a little help to borrow self mutably and immutably at the same time
438            // through a Pin.
439            let mut_self = &mut *self;
440            match &mut mut_self.state {
441                UploadState::Done(_) => return Poll::Ready(Ok(())),
442                UploadState::CreatingUpload(_)
443                | UploadState::PuttingSingle(_)
444                | UploadState::Completing(_) => return Poll::Pending,
445                UploadState::Started(_) => {
446                    // If we didn't start a multipart upload, we can just do a single put.
447                    let part = std::mem::take(&mut mut_self.buffer);
448                    let path = mut_self.path.clone();
449                    self.state.started_to_putting_single(path, part);
450                }
451                UploadState::InProgress {
452                    upload,
453                    futures,
454                    part_idx,
455                } => {
456                    // Flush final batch
457                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
458                        // We can just use `take` since we don't need the buffer anymore.
459                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
460                        futures.spawn(
461                            Self::put_part(upload.as_mut(), data, *part_idx, None)
462                                .instrument(tracing::Span::current()),
463                        );
464                        // We need to go back to beginning of loop to poll the
465                        // new feature and get the waker registered on the ctx.
466                        continue;
467                    }
468
469                    // We handle the transition from in progress to completing here.
470                    if futures.is_empty() {
471                        self.state.in_progress_to_completing();
472                    } else {
473                        return Poll::Pending;
474                    }
475                }
476            }
477        }
478    }
479}
480
481#[async_trait]
482impl Writer for ObjectWriter {
483    async fn tell(&mut self) -> Result<usize> {
484        Ok(self.cursor)
485    }
486
487    async fn shutdown(&mut self) -> Result<WriteResult> {
488        AsyncWriteExt::shutdown(self).await.map_err(|e| {
489            Error::io(format!(
490                "failed to shutdown object writer for {}: {}",
491                self.path, e
492            ))
493        })?;
494        if let UploadState::Done(result) = &self.state {
495            Ok(result.clone())
496        } else {
497            unreachable!()
498        }
499    }
500}
501
502pub struct LocalWriter {
503    path: Path,
504    state: LocalWriteState,
505}
506
507#[derive(Default)]
508enum LocalWriteState {
509    Writing(WritingState),
510    Finishing {
511        size: usize,
512        future: BoxFuture<'static, Result<WriteResult>>,
513    },
514    Done(WriteResult),
515    #[default]
516    Poisoned,
517}
518
519struct WritingState {
520    writer: tokio::io::BufWriter<tokio::fs::File>,
521    cursor: usize,
522    /// Temp path that auto-deletes on drop. Set to `None` after `persist()`.
523    temp_path: tempfile::TempPath,
524    io_tracker: Arc<IOTracker>,
525}
526
527impl LocalWriter {
528    pub fn new(
529        file: tokio::fs::File,
530        path: Path,
531        temp_path: tempfile::TempPath,
532        io_tracker: Arc<IOTracker>,
533    ) -> Self {
534        Self {
535            path,
536            state: LocalWriteState::Writing(WritingState {
537                writer: tokio::io::BufWriter::new(file),
538                cursor: 0,
539                temp_path,
540                io_tracker,
541            }),
542        }
543    }
544
545    fn already_closed_err(path: &Path) -> io::Error {
546        io::Error::other(format!(
547            "cannot write to LocalWriter for {} after shutdown",
548            path
549        ))
550    }
551
552    fn poisoned_err(path: &Path) -> io::Error {
553        io::Error::other(format!("LocalWriter for {} is in poisoned state", path))
554    }
555
556    async fn persist(
557        temp_path: tempfile::TempPath,
558        final_path: Path,
559        size: usize,
560        io_tracker: Arc<IOTracker>,
561    ) -> Result<WriteResult> {
562        let local_path = crate::local::to_local_path(&final_path);
563        let e_tag = tokio::task::spawn_blocking(move || -> Result<String> {
564            temp_path.persist(&local_path).map_err(|e| {
565                Error::io(format!(
566                    "failed to persist temp file to {}: {}",
567                    local_path, e.error
568                ))
569            })?;
570
571            let metadata = std::fs::metadata(&local_path).map_err(|e| {
572                Error::io(format!("failed to read metadata for {}: {}", local_path, e))
573            })?;
574            Ok(get_etag(&metadata))
575        })
576        .await
577        .map_err(|e| Error::io(format!("spawn_blocking failed: {}", e)))??;
578
579        io_tracker.record_write("put", final_path, size as u64);
580
581        Ok(WriteResult {
582            size,
583            e_tag: Some(e_tag),
584        })
585    }
586}
587
588impl AsyncWrite for LocalWriter {
589    fn poll_write(
590        mut self: Pin<&mut Self>,
591        cx: &mut std::task::Context<'_>,
592        buf: &[u8],
593    ) -> Poll<std::result::Result<usize, std::io::Error>> {
594        if let LocalWriteState::Writing(state) = &mut self.state {
595            let poll = Pin::new(&mut state.writer).poll_write(cx, buf);
596            if let Poll::Ready(Ok(n)) = &poll {
597                state.cursor += *n;
598            }
599            poll
600        } else {
601            Poll::Ready(Err(Self::already_closed_err(&self.path)))
602        }
603    }
604
605    fn poll_flush(
606        mut self: Pin<&mut Self>,
607        cx: &mut std::task::Context<'_>,
608    ) -> Poll<std::result::Result<(), std::io::Error>> {
609        if let LocalWriteState::Writing(state) = &mut self.state {
610            Pin::new(&mut state.writer).poll_flush(cx)
611        } else {
612            Poll::Ready(Err(Self::already_closed_err(&self.path)))
613        }
614    }
615
616    fn poll_shutdown(
617        mut self: Pin<&mut Self>,
618        cx: &mut std::task::Context<'_>,
619    ) -> Poll<std::result::Result<(), std::io::Error>> {
620        let mut_self = &mut *self;
621        loop {
622            match &mut mut_self.state {
623                LocalWriteState::Writing(state) => {
624                    if Pin::new(&mut state.writer).poll_shutdown(cx).is_pending() {
625                        return Poll::Pending;
626                    }
627
628                    // Write is complete, we can transition to persisting.
629                    let LocalWriteState::Writing(state) =
630                        std::mem::replace(&mut mut_self.state, LocalWriteState::Poisoned)
631                    else {
632                        unreachable!()
633                    };
634                    let size = state.cursor;
635                    mut_self.state = LocalWriteState::Finishing {
636                        size,
637                        future: Box::pin(Self::persist(
638                            state.temp_path,
639                            mut_self.path.clone(),
640                            size,
641                            state.io_tracker,
642                        )),
643                    };
644                }
645                LocalWriteState::Finishing { future, .. } => match future.poll_unpin(cx) {
646                    Poll::Ready(Ok(result)) => mut_self.state = LocalWriteState::Done(result),
647                    Poll::Ready(Err(e)) => {
648                        return Poll::Ready(Err(io::Error::other(e)));
649                    }
650                    Poll::Pending => return Poll::Pending,
651                },
652                LocalWriteState::Done(_) => return Poll::Ready(Ok(())),
653                LocalWriteState::Poisoned => {
654                    return Poll::Ready(Err(Self::poisoned_err(&self.path)));
655                }
656            }
657        }
658    }
659}
660
661#[async_trait]
662impl Writer for LocalWriter {
663    async fn tell(&mut self) -> Result<usize> {
664        match &mut self.state {
665            LocalWriteState::Writing(state) => Ok(state.cursor),
666            LocalWriteState::Finishing { size, .. } => Ok(*size),
667            LocalWriteState::Done(result) => Ok(result.size),
668            LocalWriteState::Poisoned => Err(Self::poisoned_err(&self.path).into()),
669        }
670    }
671
672    async fn shutdown(&mut self) -> Result<WriteResult> {
673        AsyncWriteExt::shutdown(self).await.map_err(|e| {
674            Error::io(format!(
675                "failed to shutdown local writer for {}: {}",
676                self.path, e
677            ))
678        })?;
679
680        match &self.state {
681            LocalWriteState::Done(result) => Ok(result.clone()),
682            _ => unreachable!(),
683        }
684    }
685}
686
687// Based on object store's implementation.
688pub fn get_etag(metadata: &std::fs::Metadata) -> String {
689    let inode = get_inode(metadata);
690    let size = metadata.len();
691    let mtime = metadata
692        .modified()
693        .ok()
694        .and_then(|mtime| mtime.duration_since(std::time::SystemTime::UNIX_EPOCH).ok())
695        .unwrap_or_default()
696        .as_micros();
697
698    // Use an ETag scheme based on that used by many popular HTTP servers
699    // <https://httpd.apache.org/docs/2.2/mod/core.html#fileetag>
700    format!("{inode:x}-{mtime:x}-{size:x}")
701}
702
703#[cfg(unix)]
704fn get_inode(metadata: &std::fs::Metadata) -> u64 {
705    std::os::unix::fs::MetadataExt::ino(metadata)
706}
707
708#[cfg(not(unix))]
709fn get_inode(_metadata: &std::fs::Metadata) -> u64 {
710    0
711}
712
713#[cfg(test)]
714mod tests {
715    use tokio::io::AsyncWriteExt;
716
717    use super::*;
718
719    #[tokio::test]
720    async fn test_write() {
721        let store = LanceObjectStore::memory();
722
723        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
724            .await
725            .unwrap();
726        assert_eq!(object_writer.tell().await.unwrap(), 0);
727
728        let buf = vec![0; 256];
729        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
730        assert_eq!(object_writer.tell().await.unwrap(), 256);
731
732        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
733        assert_eq!(object_writer.tell().await.unwrap(), 512);
734
735        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
736        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
737
738        let res = Writer::shutdown(&mut object_writer).await.unwrap();
739        assert_eq!(res.size, 256 * 3);
740
741        // Trigger multi part upload
742        let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
743            .await
744            .unwrap();
745        let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
746        for i in 0..5 {
747            // Write more data to trigger the multipart upload
748            // This should be enough to trigger a multipart upload
749            object_writer.write_all(buf.as_slice()).await.unwrap();
750            // Check the cursor
751            assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
752        }
753        let res = Writer::shutdown(&mut object_writer).await.unwrap();
754        assert_eq!(res.size, buf.len() * 5);
755    }
756
757    #[tokio::test]
758    async fn test_abort_write() {
759        let store = LanceObjectStore::memory();
760
761        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
762            .await
763            .unwrap();
764        object_writer.abort().await;
765    }
766
767    #[tokio::test]
768    async fn test_local_writer_shutdown() {
769        let tmp = lance_core::utils::tempfile::TempStdDir::default();
770        let file_path = tmp.join("test_local_writer.bin");
771        let os_path = Path::from_absolute_path(&file_path).unwrap();
772        let io_tracker = Arc::new(IOTracker::default());
773
774        let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
775        let temp_file_path = named_temp.path().to_owned();
776        let (std_file, temp_path) = named_temp.into_parts();
777        let file = tokio::fs::File::from_std(std_file);
778        let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker.clone());
779
780        let data = b"hello local writer";
781        writer.write_all(data).await.unwrap();
782
783        // Before shutdown, the final path should not exist
784        assert!(!file_path.exists());
785        // But the temp file should exist
786        assert!(temp_file_path.exists());
787
788        let result = Writer::shutdown(&mut writer).await.unwrap();
789        assert_eq!(result.size, data.len());
790        assert!(result.e_tag.is_some());
791        assert!(!result.e_tag.as_ref().unwrap().is_empty());
792
793        // After shutdown, the final path should exist and temp should be gone
794        assert!(file_path.exists());
795        assert!(!temp_file_path.exists());
796
797        let stats = io_tracker.stats();
798        assert_eq!(stats.write_iops, 1);
799        assert_eq!(stats.written_bytes, data.len() as u64);
800    }
801
802    #[tokio::test]
803    async fn test_local_writer_drop_cleans_up() {
804        let tmp = lance_core::utils::tempfile::TempStdDir::default();
805        let file_path = tmp.join("test_drop.bin");
806        let os_path = Path::from_absolute_path(&file_path).unwrap();
807        let io_tracker = Arc::new(IOTracker::default());
808
809        let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
810        let temp_file_path = named_temp.path().to_owned();
811        let (std_file, temp_path) = named_temp.into_parts();
812        let file = tokio::fs::File::from_std(std_file);
813        let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker);
814
815        writer.write_all(b"some data").await.unwrap();
816        assert!(temp_file_path.exists());
817
818        // Drop without shutdown should clean up the temp file
819        drop(writer);
820        assert!(!temp_file_path.exists());
821        assert!(!file_path.exists());
822    }
823}