Skip to main content

lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::FutureExt;
13use futures::future::BoxFuture;
14use object_store::{Error as OSError, ObjectStore, Result as OSResult, path::Path};
15use object_store::{MultipartUpload, ObjectStoreExt};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use crate::utils::tracking_store::IOTracker;
25use tokio::runtime::Handle;
26
27/// Start at 5MB.
28const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33        std::env::var("LANCE_UPLOAD_CONCURRENCY")
34            .ok()
35            .and_then(|s| s.parse::<usize>().ok())
36            .unwrap_or(10)
37    })
38}
39
40fn max_conn_reset_retries() -> u16 {
41    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43        std::env::var("LANCE_CONN_RESET_RETRIES")
44            .ok()
45            .and_then(|s| s.parse::<u16>().ok())
46            .unwrap_or(20)
47    })
48}
49
50/// Maximum part size in GCS and S3: 5GB.
51const MAX_UPLOAD_PART_SIZE: usize = 1024 * 1024 * 1024 * 5;
52
53/// Clamps a requested upload part size to the valid [5MB, 5GB] range.
54/// Returns the clamped value and whether clamping was necessary.
55fn clamp_initial_upload_size(raw: usize) -> (usize, bool) {
56    let clamped = raw.clamp(INITIAL_UPLOAD_STEP, MAX_UPLOAD_PART_SIZE);
57    (clamped, clamped != raw)
58}
59
60fn initial_upload_size() -> usize {
61    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
62    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
63        let Some(raw) = std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
64            .ok()
65            .and_then(|s| s.parse::<usize>().ok())
66        else {
67            return INITIAL_UPLOAD_STEP;
68        };
69        let (clamped, was_clamped) = clamp_initial_upload_size(raw);
70        if was_clamped {
71            // OnceLock caches the result, so this warning fires at most once per process.
72            tracing::warn!(
73                requested = raw,
74                clamped,
75                "LANCE_INITIAL_UPLOAD_SIZE must be between 5MB and 5GB; clamping to valid range"
76            );
77        }
78        clamped
79    })
80}
81
82/// Writer to an object in an object store.
83///
84/// If the object is small enough, the writer will upload the object in a single
85/// PUT request. If the object is larger, the writer will create a multipart
86/// upload and upload parts in parallel.
87///
88/// This implements the `AsyncWrite` trait.
89pub struct ObjectWriter {
90    state: UploadState,
91    path: Arc<Path>,
92    cursor: usize,
93    connection_resets: u16,
94    buffer: Vec<u8>,
95    // TODO: use constant size to support R2
96    use_constant_size_upload_parts: bool,
97}
98
99#[derive(Debug, Clone, Default)]
100pub struct WriteResult {
101    pub size: usize,
102    pub e_tag: Option<String>,
103}
104
105enum UploadState {
106    /// The writer has been opened but no data has been written yet. Will be in
107    /// this state until the buffer is full or the writer is shut down.
108    Started(Arc<dyn ObjectStore>),
109    /// The writer is in the process of creating a multipart upload.
110    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
111    /// The writer is in the process of uploading parts.
112    InProgress {
113        part_idx: u16,
114        upload: Box<dyn MultipartUpload>,
115        futures: JoinSet<std::result::Result<(), UploadPutError>>,
116    },
117    /// The writer is in the process of uploading data in a single PUT request.
118    /// This happens when shutdown is called before the buffer is full.
119    PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
120    /// The writer is in the process of completing the multipart upload.
121    Completing(BoxFuture<'static, OSResult<WriteResult>>),
122    /// The writer has been shut down and all data has been written.
123    Done(WriteResult),
124}
125
126/// Methods for state transitions.
127impl UploadState {
128    fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
129        // To get owned self, we temporarily swap with Done.
130        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
131        *self = match this {
132            Self::Started(store) => {
133                let fut = async move {
134                    let size = buffer.len();
135                    let res = store.put(&path, buffer.into()).await?;
136                    Ok(WriteResult {
137                        size,
138                        e_tag: res.e_tag,
139                    })
140                };
141                Self::PuttingSingle(Box::pin(fut))
142            }
143            _ => unreachable!(),
144        }
145    }
146
147    fn in_progress_to_completing(&mut self) {
148        // To get owned self, we temporarily swap with Done.
149        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
150        *self = match this {
151            Self::InProgress {
152                mut upload,
153                futures,
154                ..
155            } => {
156                debug_assert!(futures.is_empty());
157                let fut = async move {
158                    let res = upload.complete().await?;
159                    Ok(WriteResult {
160                        size: 0, // This will be set properly later.
161                        e_tag: res.e_tag,
162                    })
163                };
164                Self::Completing(Box::pin(fut))
165            }
166            _ => unreachable!(),
167        };
168    }
169}
170
171impl ObjectWriter {
172    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
173        Ok(Self {
174            state: UploadState::Started(object_store.inner.clone()),
175            cursor: 0,
176            path: Arc::new(path.clone()),
177            connection_resets: 0,
178            buffer: Vec::with_capacity(initial_upload_size()),
179            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
180        })
181    }
182
183    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
184    /// The new capacity of `buffer` is determined by the current part index.
185    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
186        let new_capacity = if constant_upload_size {
187            // The store does not support variable part sizes, so use the initial size.
188            initial_upload_size()
189        } else {
190            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
191            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
192        };
193        let new_buffer = Vec::with_capacity(new_capacity);
194        let part = std::mem::replace(buffer, new_buffer);
195        Bytes::from(part)
196    }
197
198    fn put_part(
199        upload: &mut dyn MultipartUpload,
200        buffer: Bytes,
201        part_idx: u16,
202        sleep: Option<std::time::Duration>,
203    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
204        log::debug!(
205            "MultipartUpload submitting part with {} bytes",
206            buffer.len()
207        );
208        let fut = upload.put_part(buffer.clone().into());
209        Box::pin(async move {
210            if let Some(sleep) = sleep {
211                tokio::time::sleep(sleep).await;
212            }
213            fut.await.map_err(|source| UploadPutError {
214                part_idx,
215                buffer,
216                source,
217            })?;
218            Ok(())
219        })
220    }
221
222    fn poll_tasks(
223        mut self: Pin<&mut Self>,
224        cx: &mut std::task::Context<'_>,
225    ) -> std::result::Result<(), io::Error> {
226        let mut_self = &mut *self;
227        loop {
228            match &mut mut_self.state {
229                UploadState::Started(_) | UploadState::Done(_) => break,
230                UploadState::CreatingUpload(fut) => match fut.poll_unpin(cx) {
231                    Poll::Ready(Ok(mut upload)) => {
232                        let mut futures = JoinSet::new();
233
234                        let data = Self::next_part_buffer(
235                            &mut mut_self.buffer,
236                            0,
237                            mut_self.use_constant_size_upload_parts,
238                        );
239                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
240
241                        mut_self.state = UploadState::InProgress {
242                            part_idx: 1, // We just used 0
243                            futures,
244                            upload,
245                        };
246                    }
247                    Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
248                    Poll::Pending => break,
249                },
250                UploadState::InProgress {
251                    upload, futures, ..
252                } => {
253                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
254                        match res {
255                            Ok(Ok(())) => {}
256                            Err(err) => return Err(std::io::Error::other(err)),
257                            Ok(Err(err)) if should_retry_upload_put(&err.source) => {
258                                if mut_self.connection_resets < max_conn_reset_retries() {
259                                    // Retry, but only up to max_conn_reset_retries of them.
260                                    mut_self.connection_resets += 1;
261
262                                    // Resubmit with random jitter
263                                    let sleep_time_ms = rand::rng().random_range(2_000..8_000);
264                                    let sleep_time =
265                                        std::time::Duration::from_millis(sleep_time_ms);
266
267                                    futures.spawn(Self::put_part(
268                                        upload.as_mut(),
269                                        err.buffer,
270                                        err.part_idx,
271                                        Some(sleep_time),
272                                    ));
273                                } else {
274                                    return Err(io::Error::new(
275                                        io::ErrorKind::ConnectionReset,
276                                        Box::new(ConnectionResetError {
277                                            message: format!(
278                                                "Hit max retries ({}) for retryable upload error",
279                                                max_conn_reset_retries()
280                                            ),
281                                            source: Box::new(err.source),
282                                        }),
283                                    ));
284                                }
285                            }
286                            Ok(Err(err)) => return Err(err.source.into()),
287                        }
288                    }
289                    break;
290                }
291                UploadState::PuttingSingle(fut) | UploadState::Completing(fut) => {
292                    match fut.poll_unpin(cx) {
293                        Poll::Ready(Ok(mut res)) => {
294                            res.size = mut_self.cursor;
295                            mut_self.state = UploadState::Done(res)
296                        }
297                        Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
298                        Poll::Pending => break,
299                    }
300                }
301            }
302        }
303        Ok(())
304    }
305
306    pub async fn abort(&mut self) {
307        let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
308        if let UploadState::InProgress { mut upload, .. } = state {
309            let _ = upload.abort().await;
310        }
311    }
312}
313
314impl Drop for ObjectWriter {
315    fn drop(&mut self) {
316        // If there is a multipart upload started but not finished, we should abort it.
317        if matches!(self.state, UploadState::InProgress { .. }) {
318            // Take ownership of the state.
319            let state =
320                std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
321            if let UploadState::InProgress { mut upload, .. } = state
322                && let Ok(handle) = Handle::try_current()
323            {
324                handle.spawn(async move {
325                    let _ = upload.abort().await;
326                });
327            }
328        }
329    }
330}
331
332/// Returned error from trying to upload a part.
333/// Has the part_idx and buffer so we can pass
334/// them to the retry logic.
335struct UploadPutError {
336    part_idx: u16,
337    buffer: Bytes,
338    source: OSError,
339}
340
341fn should_retry_upload_put(source: &OSError) -> bool {
342    let OSError::Generic { source, .. } = source else {
343        return false;
344    };
345
346    let message = source.to_string().to_ascii_lowercase();
347    message.contains("connection reset by peer") || message.contains("requesttimeout")
348}
349
350#[derive(Debug)]
351struct ConnectionResetError {
352    message: String,
353    source: Box<dyn std::error::Error + Send + Sync>,
354}
355
356impl std::error::Error for ConnectionResetError {}
357
358impl std::fmt::Display for ConnectionResetError {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        write!(f, "{}: {}", self.message, self.source)
361    }
362}
363
364impl AsyncWrite for ObjectWriter {
365    fn poll_write(
366        mut self: std::pin::Pin<&mut Self>,
367        cx: &mut std::task::Context<'_>,
368        buf: &[u8],
369    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
370        self.as_mut().poll_tasks(cx)?;
371
372        // Fill buffer up to remaining capacity.
373        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
374        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
375        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
376        self.cursor += bytes_to_write;
377
378        // Rust needs a little help to borrow self mutably and immutably at the same time
379        // through a Pin.
380        let mut_self = &mut *self;
381
382        // Instantiate next request, if available.
383        if mut_self.buffer.capacity() == mut_self.buffer.len() {
384            match &mut mut_self.state {
385                UploadState::Started(store) => {
386                    let path = mut_self.path.clone();
387                    let store = store.clone();
388                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
389                    self.state = UploadState::CreatingUpload(fut);
390                }
391                UploadState::InProgress {
392                    upload,
393                    part_idx,
394                    futures,
395                    ..
396                } => {
397                    // TODO: Make max concurrency configurable from storage options.
398                    if futures.len() < max_upload_parallelism() {
399                        let data = Self::next_part_buffer(
400                            &mut mut_self.buffer,
401                            *part_idx,
402                            mut_self.use_constant_size_upload_parts,
403                        );
404                        futures.spawn(
405                            Self::put_part(upload.as_mut(), data, *part_idx, None)
406                                .instrument(tracing::Span::current()),
407                        );
408                        *part_idx += 1;
409                    }
410                }
411                _ => {}
412            }
413        }
414
415        self.poll_tasks(cx)?;
416
417        match bytes_to_write {
418            0 => Poll::Pending,
419            _ => Poll::Ready(Ok(bytes_to_write)),
420        }
421    }
422
423    fn poll_flush(
424        mut self: std::pin::Pin<&mut Self>,
425        cx: &mut std::task::Context<'_>,
426    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
427        self.as_mut().poll_tasks(cx)?;
428
429        match &self.state {
430            UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
431            UploadState::CreatingUpload(_)
432            | UploadState::Completing(_)
433            | UploadState::PuttingSingle(_) => Poll::Pending,
434            UploadState::InProgress { futures, .. } => {
435                if futures.is_empty() {
436                    Poll::Ready(Ok(()))
437                } else {
438                    Poll::Pending
439                }
440            }
441        }
442    }
443
444    fn poll_shutdown(
445        mut self: std::pin::Pin<&mut Self>,
446        cx: &mut std::task::Context<'_>,
447    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
448        loop {
449            self.as_mut().poll_tasks(cx)?;
450
451            // Rust needs a little help to borrow self mutably and immutably at the same time
452            // through a Pin.
453            let mut_self = &mut *self;
454            match &mut mut_self.state {
455                UploadState::Done(_) => return Poll::Ready(Ok(())),
456                UploadState::CreatingUpload(_)
457                | UploadState::PuttingSingle(_)
458                | UploadState::Completing(_) => return Poll::Pending,
459                UploadState::Started(_) => {
460                    // If we didn't start a multipart upload, we can just do a single put.
461                    let part = std::mem::take(&mut mut_self.buffer);
462                    let path = mut_self.path.clone();
463                    self.state.started_to_putting_single(path, part);
464                }
465                UploadState::InProgress {
466                    upload,
467                    futures,
468                    part_idx,
469                } => {
470                    // Flush final batch
471                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
472                        // We can just use `take` since we don't need the buffer anymore.
473                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
474                        futures.spawn(
475                            Self::put_part(upload.as_mut(), data, *part_idx, None)
476                                .instrument(tracing::Span::current()),
477                        );
478                        // We need to go back to beginning of loop to poll the
479                        // new feature and get the waker registered on the ctx.
480                        continue;
481                    }
482
483                    // We handle the transition from in progress to completing here.
484                    if futures.is_empty() {
485                        self.state.in_progress_to_completing();
486                    } else {
487                        return Poll::Pending;
488                    }
489                }
490            }
491        }
492    }
493}
494
495#[async_trait]
496impl Writer for ObjectWriter {
497    async fn tell(&mut self) -> Result<usize> {
498        Ok(self.cursor)
499    }
500
501    async fn shutdown(&mut self) -> Result<WriteResult> {
502        AsyncWriteExt::shutdown(self).await.map_err(|e| {
503            Error::io(format!(
504                "failed to shutdown object writer for {}: {}",
505                self.path, e
506            ))
507        })?;
508        if let UploadState::Done(result) = &self.state {
509            Ok(result.clone())
510        } else {
511            unreachable!()
512        }
513    }
514}
515
516pub struct LocalWriter {
517    path: Path,
518    state: LocalWriteState,
519}
520
521#[derive(Default)]
522enum LocalWriteState {
523    Writing(WritingState),
524    Finishing {
525        size: usize,
526        future: BoxFuture<'static, Result<WriteResult>>,
527    },
528    Done(WriteResult),
529    #[default]
530    Poisoned,
531}
532
533struct WritingState {
534    writer: tokio::io::BufWriter<tokio::fs::File>,
535    cursor: usize,
536    /// Temp path that auto-deletes on drop. Set to `None` after `persist()`.
537    temp_path: tempfile::TempPath,
538    io_tracker: Arc<IOTracker>,
539}
540
541impl LocalWriter {
542    pub fn new(
543        file: tokio::fs::File,
544        path: Path,
545        temp_path: tempfile::TempPath,
546        io_tracker: Arc<IOTracker>,
547    ) -> Self {
548        Self {
549            path,
550            state: LocalWriteState::Writing(WritingState {
551                writer: tokio::io::BufWriter::new(file),
552                cursor: 0,
553                temp_path,
554                io_tracker,
555            }),
556        }
557    }
558
559    fn already_closed_err(path: &Path) -> io::Error {
560        io::Error::other(format!(
561            "cannot write to LocalWriter for {} after shutdown",
562            path
563        ))
564    }
565
566    fn poisoned_err(path: &Path) -> io::Error {
567        io::Error::other(format!("LocalWriter for {} is in poisoned state", path))
568    }
569
570    async fn persist(
571        temp_path: tempfile::TempPath,
572        final_path: Path,
573        size: usize,
574        io_tracker: Arc<IOTracker>,
575    ) -> Result<WriteResult> {
576        let local_path = crate::local::to_local_path(&final_path);
577        let e_tag = tokio::task::spawn_blocking(move || -> Result<String> {
578            temp_path.persist(&local_path).map_err(|e| {
579                Error::io(format!(
580                    "failed to persist temp file to {}: {}",
581                    local_path, e.error
582                ))
583            })?;
584
585            let metadata = std::fs::metadata(&local_path).map_err(|e| {
586                Error::io(format!("failed to read metadata for {}: {}", local_path, e))
587            })?;
588            Ok(get_etag(&metadata))
589        })
590        .await
591        .map_err(|e| Error::io(format!("spawn_blocking failed: {}", e)))??;
592
593        io_tracker.record_write("put", final_path, size as u64);
594
595        Ok(WriteResult {
596            size,
597            e_tag: Some(e_tag),
598        })
599    }
600}
601
602impl AsyncWrite for LocalWriter {
603    fn poll_write(
604        mut self: Pin<&mut Self>,
605        cx: &mut std::task::Context<'_>,
606        buf: &[u8],
607    ) -> Poll<std::result::Result<usize, std::io::Error>> {
608        if let LocalWriteState::Writing(state) = &mut self.state {
609            let poll = Pin::new(&mut state.writer).poll_write(cx, buf);
610            if let Poll::Ready(Ok(n)) = &poll {
611                state.cursor += *n;
612            }
613            poll
614        } else {
615            Poll::Ready(Err(Self::already_closed_err(&self.path)))
616        }
617    }
618
619    fn poll_flush(
620        mut self: Pin<&mut Self>,
621        cx: &mut std::task::Context<'_>,
622    ) -> Poll<std::result::Result<(), std::io::Error>> {
623        if let LocalWriteState::Writing(state) = &mut self.state {
624            Pin::new(&mut state.writer).poll_flush(cx)
625        } else {
626            Poll::Ready(Err(Self::already_closed_err(&self.path)))
627        }
628    }
629
630    fn poll_shutdown(
631        mut self: Pin<&mut Self>,
632        cx: &mut std::task::Context<'_>,
633    ) -> Poll<std::result::Result<(), std::io::Error>> {
634        let mut_self = &mut *self;
635        loop {
636            match &mut mut_self.state {
637                LocalWriteState::Writing(state) => {
638                    if Pin::new(&mut state.writer).poll_shutdown(cx).is_pending() {
639                        return Poll::Pending;
640                    }
641
642                    // Write is complete, we can transition to persisting.
643                    let LocalWriteState::Writing(state) =
644                        std::mem::replace(&mut mut_self.state, LocalWriteState::Poisoned)
645                    else {
646                        unreachable!()
647                    };
648                    let size = state.cursor;
649                    mut_self.state = LocalWriteState::Finishing {
650                        size,
651                        future: Box::pin(Self::persist(
652                            state.temp_path,
653                            mut_self.path.clone(),
654                            size,
655                            state.io_tracker,
656                        )),
657                    };
658                }
659                LocalWriteState::Finishing { future, .. } => match future.poll_unpin(cx) {
660                    Poll::Ready(Ok(result)) => mut_self.state = LocalWriteState::Done(result),
661                    Poll::Ready(Err(e)) => {
662                        return Poll::Ready(Err(io::Error::other(e)));
663                    }
664                    Poll::Pending => return Poll::Pending,
665                },
666                LocalWriteState::Done(_) => return Poll::Ready(Ok(())),
667                LocalWriteState::Poisoned => {
668                    return Poll::Ready(Err(Self::poisoned_err(&self.path)));
669                }
670            }
671        }
672    }
673}
674
675#[async_trait]
676impl Writer for LocalWriter {
677    async fn tell(&mut self) -> Result<usize> {
678        match &mut self.state {
679            LocalWriteState::Writing(state) => Ok(state.cursor),
680            LocalWriteState::Finishing { size, .. } => Ok(*size),
681            LocalWriteState::Done(result) => Ok(result.size),
682            LocalWriteState::Poisoned => Err(Self::poisoned_err(&self.path).into()),
683        }
684    }
685
686    async fn shutdown(&mut self) -> Result<WriteResult> {
687        AsyncWriteExt::shutdown(self).await.map_err(|e| {
688            Error::io(format!(
689                "failed to shutdown local writer for {}: {}",
690                self.path, e
691            ))
692        })?;
693
694        match &self.state {
695            LocalWriteState::Done(result) => Ok(result.clone()),
696            _ => unreachable!(),
697        }
698    }
699}
700
701// Based on object store's implementation.
702pub fn get_etag(metadata: &std::fs::Metadata) -> String {
703    let inode = get_inode(metadata);
704    let size = metadata.len();
705    let mtime = metadata
706        .modified()
707        .ok()
708        .and_then(|mtime| mtime.duration_since(std::time::SystemTime::UNIX_EPOCH).ok())
709        .unwrap_or_default()
710        .as_micros();
711
712    // Use an ETag scheme based on that used by many popular HTTP servers
713    // <https://httpd.apache.org/docs/2.2/mod/core.html#fileetag>
714    format!("{inode:x}-{mtime:x}-{size:x}")
715}
716
717#[cfg(unix)]
718fn get_inode(metadata: &std::fs::Metadata) -> u64 {
719    std::os::unix::fs::MetadataExt::ino(metadata)
720}
721
722#[cfg(not(unix))]
723fn get_inode(_metadata: &std::fs::Metadata) -> u64 {
724    0
725}
726
727#[cfg(test)]
728mod tests {
729    use tokio::io::AsyncWriteExt;
730
731    use super::*;
732
733    #[tokio::test]
734    async fn test_write() {
735        let store = LanceObjectStore::memory();
736
737        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
738            .await
739            .unwrap();
740        assert_eq!(object_writer.tell().await.unwrap(), 0);
741
742        let buf = vec![0; 256];
743        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
744        assert_eq!(object_writer.tell().await.unwrap(), 256);
745
746        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
747        assert_eq!(object_writer.tell().await.unwrap(), 512);
748
749        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
750        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
751
752        let res = Writer::shutdown(&mut object_writer).await.unwrap();
753        assert_eq!(res.size, 256 * 3);
754
755        // Trigger multi part upload
756        let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
757            .await
758            .unwrap();
759        let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
760        for i in 0..5 {
761            // Write more data to trigger the multipart upload
762            // This should be enough to trigger a multipart upload
763            object_writer.write_all(buf.as_slice()).await.unwrap();
764            // Check the cursor
765            assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
766        }
767        let res = Writer::shutdown(&mut object_writer).await.unwrap();
768        assert_eq!(res.size, buf.len() * 5);
769    }
770
771    #[tokio::test]
772    async fn test_abort_write() {
773        let store = LanceObjectStore::memory();
774
775        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
776            .await
777            .unwrap();
778        object_writer.abort().await;
779    }
780
781    #[tokio::test]
782    async fn test_local_writer_shutdown() {
783        let tmp = lance_core::utils::tempfile::TempStdDir::default();
784        let file_path = tmp.join("test_local_writer.bin");
785        let os_path = Path::from_absolute_path(&file_path).unwrap();
786        let io_tracker = Arc::new(IOTracker::default());
787
788        let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
789        let temp_file_path = named_temp.path().to_owned();
790        let (std_file, temp_path) = named_temp.into_parts();
791        let file = tokio::fs::File::from_std(std_file);
792        let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker.clone());
793
794        let data = b"hello local writer";
795        writer.write_all(data).await.unwrap();
796
797        // Before shutdown, the final path should not exist
798        assert!(!file_path.exists());
799        // But the temp file should exist
800        assert!(temp_file_path.exists());
801
802        let result = Writer::shutdown(&mut writer).await.unwrap();
803        assert_eq!(result.size, data.len());
804        assert!(result.e_tag.is_some());
805        assert!(!result.e_tag.as_ref().unwrap().is_empty());
806
807        // After shutdown, the final path should exist and temp should be gone
808        assert!(file_path.exists());
809        assert!(!temp_file_path.exists());
810
811        let stats = io_tracker.stats();
812        assert_eq!(stats.write_iops, 1);
813        assert_eq!(stats.written_bytes, data.len() as u64);
814    }
815
816    #[tokio::test]
817    async fn test_local_writer_drop_cleans_up() {
818        let tmp = lance_core::utils::tempfile::TempStdDir::default();
819        let file_path = tmp.join("test_drop.bin");
820        let os_path = Path::from_absolute_path(&file_path).unwrap();
821        let io_tracker = Arc::new(IOTracker::default());
822
823        let named_temp = tempfile::NamedTempFile::new_in(&*tmp).unwrap();
824        let temp_file_path = named_temp.path().to_owned();
825        let (std_file, temp_path) = named_temp.into_parts();
826        let file = tokio::fs::File::from_std(std_file);
827        let mut writer = LocalWriter::new(file, os_path, temp_path, io_tracker);
828
829        writer.write_all(b"some data").await.unwrap();
830        assert!(temp_file_path.exists());
831
832        // Drop without shutdown should clean up the temp file
833        drop(writer);
834        assert!(!temp_file_path.exists());
835        assert!(!file_path.exists());
836    }
837
838    #[test]
839    fn clamp_initial_upload_size_below_min_is_clamped_up() {
840        assert_eq!(clamp_initial_upload_size(0), (INITIAL_UPLOAD_STEP, true));
841        assert_eq!(
842            clamp_initial_upload_size(INITIAL_UPLOAD_STEP - 1),
843            (INITIAL_UPLOAD_STEP, true)
844        );
845    }
846
847    #[test]
848    fn clamp_initial_upload_size_within_range_is_unchanged() {
849        assert_eq!(
850            clamp_initial_upload_size(INITIAL_UPLOAD_STEP),
851            (INITIAL_UPLOAD_STEP, false)
852        );
853        assert_eq!(
854            clamp_initial_upload_size(MAX_UPLOAD_PART_SIZE),
855            (MAX_UPLOAD_PART_SIZE, false)
856        );
857        let mid = INITIAL_UPLOAD_STEP * 8; // 40MB, in range
858        assert_eq!(clamp_initial_upload_size(mid), (mid, false));
859    }
860
861    #[test]
862    fn should_retry_upload_put_detects_transient_errors() {
863        let request_timeout = OSError::Generic {
864            store: "S3",
865            source: Box::new(io::Error::other(
866                "Server returned non-2xx status code: 400 Bad Request: \
867                 <Error><Code>RequestTimeout</Code><Message>Your socket connection to the server \
868                 was not read from or written to within the timeout period. Idle connections will \
869                 be closed.</Message></Error>",
870            )),
871        };
872        assert!(should_retry_upload_put(&request_timeout));
873
874        let connection_reset = OSError::Generic {
875            store: "S3",
876            source: Box::new(io::Error::new(
877                io::ErrorKind::ConnectionReset,
878                "connection reset by peer",
879            )),
880        };
881        assert!(should_retry_upload_put(&connection_reset));
882
883        let not_retryable = OSError::Generic {
884            store: "S3",
885            source: Box::new(io::Error::other("access denied")),
886        };
887        assert!(!should_retry_upload_put(&not_retryable));
888    }
889
890    #[test]
891    fn clamp_initial_upload_size_above_max_is_clamped_down() {
892        assert_eq!(
893            clamp_initial_upload_size(MAX_UPLOAD_PART_SIZE + 1),
894            (MAX_UPLOAD_PART_SIZE, true)
895        );
896        assert_eq!(
897            clamp_initial_upload_size(usize::MAX),
898            (MAX_UPLOAD_PART_SIZE, true)
899        );
900    }
901}