lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25use tokio::runtime::Handle;
26
27/// Start at 5MB.
28const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33        std::env::var("LANCE_UPLOAD_CONCURRENCY")
34            .ok()
35            .and_then(|s| s.parse::<usize>().ok())
36            .unwrap_or(10)
37    })
38}
39
40fn max_conn_reset_retries() -> u16 {
41    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43        std::env::var("LANCE_CONN_RESET_RETRIES")
44            .ok()
45            .and_then(|s| s.parse::<u16>().ok())
46            .unwrap_or(20)
47    })
48}
49
50fn initial_upload_size() -> usize {
51    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
52    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
53        std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
54            .ok()
55            .and_then(|s| s.parse::<usize>().ok())
56            .inspect(|size| {
57                if *size < INITIAL_UPLOAD_STEP {
58                    // Minimum part size in GCS and S3
59                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
60                } else if *size > 1024 * 1024 * 1024 * 5 {
61                    // Maximum part size in GCS and S3
62                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
63                }
64            })
65            .unwrap_or(INITIAL_UPLOAD_STEP)
66    })
67}
68
69/// Writer to an object in an object store.
70///
71/// If the object is small enough, the writer will upload the object in a single
72/// PUT request. If the object is larger, the writer will create a multipart
73/// upload and upload parts in parallel.
74///
75/// This implements the `AsyncWrite` trait.
76pub struct ObjectWriter {
77    state: UploadState,
78    path: Arc<Path>,
79    cursor: usize,
80    connection_resets: u16,
81    buffer: Vec<u8>,
82    // TODO: use constant size to support R2
83    use_constant_size_upload_parts: bool,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct WriteResult {
88    pub size: usize,
89    pub e_tag: Option<String>,
90}
91
92enum UploadState {
93    /// The writer has been opened but no data has been written yet. Will be in
94    /// this state until the buffer is full or the writer is shut down.
95    Started(Arc<dyn ObjectStore>),
96    /// The writer is in the process of creating a multipart upload.
97    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
98    /// The writer is in the process of uploading parts.
99    InProgress {
100        part_idx: u16,
101        upload: Box<dyn MultipartUpload>,
102        futures: JoinSet<std::result::Result<(), UploadPutError>>,
103    },
104    /// The writer is in the process of uploading data in a single PUT request.
105    /// This happens when shutdown is called before the buffer is full.
106    PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
107    /// The writer is in the process of completing the multipart upload.
108    Completing(BoxFuture<'static, OSResult<WriteResult>>),
109    /// The writer has been shut down and all data has been written.
110    Done(WriteResult),
111}
112
113/// Methods for state transitions.
114impl UploadState {
115    fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
116        // To get owned self, we temporarily swap with Done.
117        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
118        *self = match this {
119            Self::Started(store) => {
120                let fut = async move {
121                    let size = buffer.len();
122                    let res = store.put(&path, buffer.into()).await?;
123                    Ok(WriteResult {
124                        size,
125                        e_tag: res.e_tag,
126                    })
127                };
128                Self::PuttingSingle(Box::pin(fut))
129            }
130            _ => unreachable!(),
131        }
132    }
133
134    fn in_progress_to_completing(&mut self) {
135        // To get owned self, we temporarily swap with Done.
136        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
137        *self = match this {
138            Self::InProgress {
139                mut upload,
140                futures,
141                ..
142            } => {
143                debug_assert!(futures.is_empty());
144                let fut = async move {
145                    let res = upload.complete().await?;
146                    Ok(WriteResult {
147                        size: 0, // This will be set properly later.
148                        e_tag: res.e_tag,
149                    })
150                };
151                Self::Completing(Box::pin(fut))
152            }
153            _ => unreachable!(),
154        };
155    }
156}
157
158impl ObjectWriter {
159    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
160        Ok(Self {
161            state: UploadState::Started(object_store.inner.clone()),
162            cursor: 0,
163            path: Arc::new(path.clone()),
164            connection_resets: 0,
165            buffer: Vec::with_capacity(initial_upload_size()),
166            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
167        })
168    }
169
170    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
171    /// The new capacity of `buffer` is determined by the current part index.
172    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
173        let new_capacity = if constant_upload_size {
174            // The store does not support variable part sizes, so use the initial size.
175            initial_upload_size()
176        } else {
177            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
178            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
179        };
180        let new_buffer = Vec::with_capacity(new_capacity);
181        let part = std::mem::replace(buffer, new_buffer);
182        Bytes::from(part)
183    }
184
185    fn put_part(
186        upload: &mut dyn MultipartUpload,
187        buffer: Bytes,
188        part_idx: u16,
189        sleep: Option<std::time::Duration>,
190    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
191        log::debug!(
192            "MultipartUpload submitting part with {} bytes",
193            buffer.len()
194        );
195        let fut = upload.put_part(buffer.clone().into());
196        Box::pin(async move {
197            if let Some(sleep) = sleep {
198                tokio::time::sleep(sleep).await;
199            }
200            fut.await.map_err(|source| UploadPutError {
201                part_idx,
202                buffer,
203                source,
204            })?;
205            Ok(())
206        })
207    }
208
209    fn poll_tasks(
210        mut self: Pin<&mut Self>,
211        cx: &mut std::task::Context<'_>,
212    ) -> std::result::Result<(), io::Error> {
213        let mut_self = &mut *self;
214        loop {
215            match &mut mut_self.state {
216                UploadState::Started(_) | UploadState::Done(_) => break,
217                UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
218                    Poll::Ready(Ok(mut upload)) => {
219                        let mut futures = JoinSet::new();
220
221                        let data = Self::next_part_buffer(
222                            &mut mut_self.buffer,
223                            0,
224                            mut_self.use_constant_size_upload_parts,
225                        );
226                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
227
228                        mut_self.state = UploadState::InProgress {
229                            part_idx: 1, // We just used 0
230                            futures,
231                            upload,
232                        };
233                    }
234                    Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
235                    Poll::Pending => break,
236                },
237                UploadState::InProgress {
238                    upload, futures, ..
239                } => {
240                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
241                        match res {
242                            Ok(Ok(())) => {}
243                            Err(err) => return Err(std::io::Error::other(err)),
244                            Ok(Err(UploadPutError {
245                                source: OSError::Generic { source, .. },
246                                part_idx,
247                                buffer,
248                            })) if source
249                                .to_string()
250                                .to_lowercase()
251                                .contains("connection reset by peer") =>
252                            {
253                                if mut_self.connection_resets < max_conn_reset_retries() {
254                                    // Retry, but only up to max_conn_reset_retries of them.
255                                    mut_self.connection_resets += 1;
256
257                                    // Resubmit with random jitter
258                                    let sleep_time_ms = rand::rng().random_range(2_000..8_000);
259                                    let sleep_time =
260                                        std::time::Duration::from_millis(sleep_time_ms);
261
262                                    futures.spawn(Self::put_part(
263                                        upload.as_mut(),
264                                        buffer,
265                                        part_idx,
266                                        Some(sleep_time),
267                                    ));
268                                } else {
269                                    return Err(io::Error::new(
270                                        io::ErrorKind::ConnectionReset,
271                                        Box::new(ConnectionResetError {
272                                            message: format!(
273                                                "Hit max retries ({}) for connection reset",
274                                                max_conn_reset_retries()
275                                            ),
276                                            source,
277                                        }),
278                                    ));
279                                }
280                            }
281                            Ok(Err(err)) => return Err(err.source.into()),
282                        }
283                    }
284                    break;
285                }
286                UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
287                    match fut.poll_unpin(cx) {
288                        Poll::Ready(Ok(mut res)) => {
289                            res.size = mut_self.cursor;
290                            mut_self.state = UploadState::Done(res)
291                        }
292                        Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
293                        Poll::Pending => break,
294                    }
295                }
296            }
297        }
298        Ok(())
299    }
300
301    pub async fn shutdown(&mut self) -> Result<WriteResult> {
302        AsyncWriteExt::shutdown(self).await.map_err(|e| {
303            Error::io(
304                format!("failed to shutdown object writer for {}: {}", self.path, e),
305                // and wrap it in here.
306                location!(),
307            )
308        })?;
309        if let UploadState::Done(result) = &self.state {
310            Ok(result.clone())
311        } else {
312            unreachable!()
313        }
314    }
315
316    pub async fn abort(&mut self) {
317        let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
318        if let UploadState::InProgress { mut upload, .. } = state {
319            let _ = upload.abort().await;
320        }
321    }
322}
323
324impl Drop for ObjectWriter {
325    fn drop(&mut self) {
326        // If there is a multipart upload started but not finished, we should abort it.
327        if matches!(self.state, UploadState::InProgress { .. }) {
328            // Take ownership of the state.
329            let state =
330                std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
331            if let UploadState::InProgress { mut upload, .. } = state {
332                if let Ok(handle) = Handle::try_current() {
333                    handle.spawn(async move {
334                        let _ = upload.abort().await;
335                    });
336                }
337            }
338        }
339    }
340}
341
342/// Returned error from trying to upload a part.
343/// Has the part_idx and buffer so we can pass
344/// them to the retry logic.
345struct UploadPutError {
346    part_idx: u16,
347    buffer: Bytes,
348    source: OSError,
349}
350
351#[derive(Debug)]
352struct ConnectionResetError {
353    message: String,
354    source: Box<dyn std::error::Error + Send + Sync>,
355}
356
357impl std::error::Error for ConnectionResetError {}
358
359impl std::fmt::Display for ConnectionResetError {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        write!(f, "{}: {}", self.message, self.source)
362    }
363}
364
365impl AsyncWrite for ObjectWriter {
366    fn poll_write(
367        mut self: std::pin::Pin<&mut Self>,
368        cx: &mut std::task::Context<'_>,
369        buf: &[u8],
370    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
371        self.as_mut().poll_tasks(cx)?;
372
373        // Fill buffer up to remaining capacity.
374        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
375        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
376        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
377        self.cursor += bytes_to_write;
378
379        // Rust needs a little help to borrow self mutably and immutably at the same time
380        // through a Pin.
381        let mut_self = &mut *self;
382
383        // Instantiate next request, if available.
384        if mut_self.buffer.capacity() == mut_self.buffer.len() {
385            match &mut mut_self.state {
386                UploadState::Started(store) => {
387                    let path = mut_self.path.clone();
388                    let store = store.clone();
389                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
390                    self.state = UploadState::CreatingUpload(fut);
391                }
392                UploadState::InProgress {
393                    upload,
394                    part_idx,
395                    futures,
396                    ..
397                } => {
398                    // TODO: Make max concurrency configurable from storage options.
399                    if futures.len() < max_upload_parallelism() {
400                        let data = Self::next_part_buffer(
401                            &mut mut_self.buffer,
402                            *part_idx,
403                            mut_self.use_constant_size_upload_parts,
404                        );
405                        futures.spawn(
406                            Self::put_part(upload.as_mut(), data, *part_idx, None)
407                                .instrument(tracing::Span::current()),
408                        );
409                        *part_idx += 1;
410                    }
411                }
412                _ => {}
413            }
414        }
415
416        self.poll_tasks(cx)?;
417
418        match bytes_to_write {
419            0 => Poll::Pending,
420            _ => Poll::Ready(Ok(bytes_to_write)),
421        }
422    }
423
424    fn poll_flush(
425        mut self: std::pin::Pin<&mut Self>,
426        cx: &mut std::task::Context<'_>,
427    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
428        self.as_mut().poll_tasks(cx)?;
429
430        match &self.state {
431            UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
432            UploadState::CreatingUpload(_)
433            | UploadState::Completing(_)
434            | UploadState::PuttingSingle(_) => Poll::Pending,
435            UploadState::InProgress { futures, .. } => {
436                if futures.is_empty() {
437                    Poll::Ready(Ok(()))
438                } else {
439                    Poll::Pending
440                }
441            }
442        }
443    }
444
445    fn poll_shutdown(
446        mut self: std::pin::Pin<&mut Self>,
447        cx: &mut std::task::Context<'_>,
448    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
449        loop {
450            self.as_mut().poll_tasks(cx)?;
451
452            // Rust needs a little help to borrow self mutably and immutably at the same time
453            // through a Pin.
454            let mut_self = &mut *self;
455            match &mut mut_self.state {
456                UploadState::Done(_) => return Poll::Ready(Ok(())),
457                UploadState::CreatingUpload(_)
458                | UploadState::PuttingSingle(_)
459                | UploadState::Completing(_) => return Poll::Pending,
460                UploadState::Started(_) => {
461                    // If we didn't start a multipart upload, we can just do a single put.
462                    let part = std::mem::take(&mut mut_self.buffer);
463                    let path = mut_self.path.clone();
464                    self.state.started_to_putting_single(path, part);
465                }
466                UploadState::InProgress {
467                    upload,
468                    futures,
469                    part_idx,
470                } => {
471                    // Flush final batch
472                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
473                        // We can just use `take` since we don't need the buffer anymore.
474                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
475                        futures.spawn(
476                            Self::put_part(upload.as_mut(), data, *part_idx, None)
477                                .instrument(tracing::Span::current()),
478                        );
479                        // We need to go back to beginning of loop to poll the
480                        // new feature and get the waker registered on the ctx.
481                        continue;
482                    }
483
484                    // We handle the transition from in progress to completing here.
485                    if futures.is_empty() {
486                        self.state.in_progress_to_completing();
487                    } else {
488                        return Poll::Pending;
489                    }
490                }
491            }
492        }
493    }
494}
495
496#[async_trait]
497impl Writer for ObjectWriter {
498    async fn tell(&mut self) -> Result<usize> {
499        Ok(self.cursor)
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use tokio::io::AsyncWriteExt;
506
507    use super::*;
508
509    #[tokio::test]
510    async fn test_write() {
511        let store = LanceObjectStore::memory();
512
513        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
514            .await
515            .unwrap();
516        assert_eq!(object_writer.tell().await.unwrap(), 0);
517
518        let buf = vec![0; 256];
519        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
520        assert_eq!(object_writer.tell().await.unwrap(), 256);
521
522        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
523        assert_eq!(object_writer.tell().await.unwrap(), 512);
524
525        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
526        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
527
528        let res = object_writer.shutdown().await.unwrap();
529        assert_eq!(res.size, 256 * 3);
530
531        // Trigger multi part upload
532        let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
533            .await
534            .unwrap();
535        let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
536        for i in 0..5 {
537            // Write more data to trigger the multipart upload
538            // This should be enough to trigger a multipart upload
539            object_writer.write_all(buf.as_slice()).await.unwrap();
540            // Check the cursor
541            assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
542        }
543        let res = object_writer.shutdown().await.unwrap();
544        assert_eq!(res.size, buf.len() * 5);
545    }
546
547    #[tokio::test]
548    async fn test_abort_write() {
549        let store = LanceObjectStore::memory();
550
551        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
552            .await
553            .unwrap();
554        object_writer.abort().await;
555    }
556}