lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25use tokio::runtime::Handle;
26
27/// Start at 5MB.
28const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33        std::env::var("LANCE_UPLOAD_CONCURRENCY")
34            .ok()
35            .and_then(|s| s.parse::<usize>().ok())
36            .unwrap_or(10)
37    })
38}
39
40fn max_conn_reset_retries() -> u16 {
41    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43        std::env::var("LANCE_CONN_RESET_RETRIES")
44            .ok()
45            .and_then(|s| s.parse::<u16>().ok())
46            .unwrap_or(20)
47    })
48}
49
50fn initial_upload_size() -> usize {
51    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
52    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
53        std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
54            .ok()
55            .and_then(|s| s.parse::<usize>().ok())
56            .inspect(|size| {
57                if *size < INITIAL_UPLOAD_STEP {
58                    // Minimum part size in GCS and S3
59                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
60                } else if *size > 1024 * 1024 * 1024 * 5 {
61                    // Maximum part size in GCS and S3
62                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
63                }
64            })
65            .unwrap_or(INITIAL_UPLOAD_STEP)
66    })
67}
68
69/// Writer to an object in an object store.
70///
71/// If the object is small enough, the writer will upload the object in a single
72/// PUT request. If the object is larger, the writer will create a multipart
73/// upload and upload parts in parallel.
74///
75/// This implements the `AsyncWrite` trait.
76pub struct ObjectWriter {
77    state: UploadState,
78    path: Arc<Path>,
79    cursor: usize,
80    connection_resets: u16,
81    buffer: Vec<u8>,
82    // TODO: use constant size to support R2
83    use_constant_size_upload_parts: bool,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct WriteResult {
88    pub size: usize,
89    pub e_tag: Option<String>,
90}
91
92enum UploadState {
93    /// The writer has been opened but no data has been written yet. Will be in
94    /// this state until the buffer is full or the writer is shut down.
95    Started(Arc<dyn ObjectStore>),
96    /// The writer is in the process of creating a multipart upload.
97    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
98    /// The writer is in the process of uploading parts.
99    InProgress {
100        part_idx: u16,
101        upload: Box<dyn MultipartUpload>,
102        futures: JoinSet<std::result::Result<(), UploadPutError>>,
103    },
104    /// The writer is in the process of uploading data in a single PUT request.
105    /// This happens when shutdown is called before the buffer is full.
106    PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
107    /// The writer is in the process of completing the multipart upload.
108    Completing(BoxFuture<'static, OSResult<WriteResult>>),
109    /// The writer has been shut down and all data has been written.
110    Done(WriteResult),
111}
112
113/// Methods for state transitions.
114impl UploadState {
115    fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
116        // To get owned self, we temporarily swap with Done.
117        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
118        *self = match this {
119            Self::Started(store) => {
120                let fut = async move {
121                    let size = buffer.len();
122                    let res = store.put(&path, buffer.into()).await?;
123                    Ok(WriteResult {
124                        size,
125                        e_tag: res.e_tag,
126                    })
127                };
128                Self::PuttingSingle(Box::pin(fut))
129            }
130            _ => unreachable!(),
131        }
132    }
133
134    fn in_progress_to_completing(&mut self) {
135        // To get owned self, we temporarily swap with Done.
136        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
137        *self = match this {
138            Self::InProgress {
139                mut upload,
140                futures,
141                ..
142            } => {
143                debug_assert!(futures.is_empty());
144                let fut = async move {
145                    let res = upload.complete().await?;
146                    Ok(WriteResult {
147                        size: 0, // This will be set properly later.
148                        e_tag: res.e_tag,
149                    })
150                };
151                Self::Completing(Box::pin(fut))
152            }
153            _ => unreachable!(),
154        };
155    }
156}
157
158impl ObjectWriter {
159    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
160        Ok(Self {
161            state: UploadState::Started(object_store.inner.clone()),
162            cursor: 0,
163            path: Arc::new(path.clone()),
164            connection_resets: 0,
165            buffer: Vec::with_capacity(initial_upload_size()),
166            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
167        })
168    }
169
170    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
171    /// The new capacity of `buffer` is determined by the current part index.
172    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
173        let new_capacity = if constant_upload_size {
174            // The store does not support variable part sizes, so use the initial size.
175            initial_upload_size()
176        } else {
177            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
178            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
179        };
180        let new_buffer = Vec::with_capacity(new_capacity);
181        let part = std::mem::replace(buffer, new_buffer);
182        Bytes::from(part)
183    }
184
185    fn put_part(
186        upload: &mut dyn MultipartUpload,
187        buffer: Bytes,
188        part_idx: u16,
189        sleep: Option<std::time::Duration>,
190    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
191        log::debug!(
192            "MultipartUpload submitting part with {} bytes",
193            buffer.len()
194        );
195        let fut = upload.put_part(buffer.clone().into());
196        Box::pin(async move {
197            if let Some(sleep) = sleep {
198                tokio::time::sleep(sleep).await;
199            }
200            fut.await.map_err(|source| UploadPutError {
201                part_idx,
202                buffer,
203                source,
204            })?;
205            Ok(())
206        })
207    }
208
209    fn poll_tasks(
210        mut self: Pin<&mut Self>,
211        cx: &mut std::task::Context<'_>,
212    ) -> std::result::Result<(), io::Error> {
213        let mut_self = &mut *self;
214        loop {
215            match &mut mut_self.state {
216                UploadState::Started(_) | UploadState::Done(_) => break,
217                UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
218                    Poll::Ready(Ok(mut upload)) => {
219                        let mut futures = JoinSet::new();
220
221                        let data = Self::next_part_buffer(
222                            &mut mut_self.buffer,
223                            0,
224                            mut_self.use_constant_size_upload_parts,
225                        );
226                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
227
228                        mut_self.state = UploadState::InProgress {
229                            part_idx: 1, // We just used 0
230                            futures,
231                            upload,
232                        };
233                    }
234                    Poll::Ready(Err(e)) => {
235                        return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
236                    }
237                    Poll::Pending => break,
238                },
239                UploadState::InProgress {
240                    upload, futures, ..
241                } => {
242                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
243                        match res {
244                            Ok(Ok(())) => {}
245                            Err(err) => {
246                                return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
247                            }
248                            Ok(Err(UploadPutError {
249                                source: OSError::Generic { source, .. },
250                                part_idx,
251                                buffer,
252                            })) if source
253                                .to_string()
254                                .to_lowercase()
255                                .contains("connection reset by peer") =>
256                            {
257                                if mut_self.connection_resets < max_conn_reset_retries() {
258                                    // Retry, but only up to max_conn_reset_retries of them.
259                                    mut_self.connection_resets += 1;
260
261                                    // Resubmit with random jitter
262                                    let sleep_time_ms = rand::rng().random_range(2_000..8_000);
263                                    let sleep_time =
264                                        std::time::Duration::from_millis(sleep_time_ms);
265
266                                    futures.spawn(Self::put_part(
267                                        upload.as_mut(),
268                                        buffer,
269                                        part_idx,
270                                        Some(sleep_time),
271                                    ));
272                                } else {
273                                    return Err(io::Error::new(
274                                        io::ErrorKind::ConnectionReset,
275                                        Box::new(ConnectionResetError {
276                                            message: format!(
277                                                "Hit max retries ({}) for connection reset",
278                                                max_conn_reset_retries()
279                                            ),
280                                            source,
281                                        }),
282                                    ));
283                                }
284                            }
285                            Ok(Err(err)) => return Err(err.source.into()),
286                        }
287                    }
288                    break;
289                }
290                UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
291                    match fut.poll_unpin(cx) {
292                        Poll::Ready(Ok(mut res)) => {
293                            res.size = mut_self.cursor;
294                            mut_self.state = UploadState::Done(res)
295                        }
296                        Poll::Ready(Err(e)) => {
297                            return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
298                        }
299                        Poll::Pending => break,
300                    }
301                }
302            }
303        }
304        Ok(())
305    }
306
307    pub async fn shutdown(&mut self) -> Result<WriteResult> {
308        AsyncWriteExt::shutdown(self).await.map_err(|e| {
309            Error::io(
310                format!("failed to shutdown object writer for {}: {}", self.path, e),
311                // and wrap it in here.
312                location!(),
313            )
314        })?;
315        if let UploadState::Done(result) = &self.state {
316            Ok(result.clone())
317        } else {
318            unreachable!()
319        }
320    }
321
322    pub async fn abort(&mut self) {
323        let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
324        if let UploadState::InProgress { mut upload, .. } = state {
325            let _ = upload.abort().await;
326        }
327    }
328}
329
330impl Drop for ObjectWriter {
331    fn drop(&mut self) {
332        // If there is a multipart upload started but not finished, we should abort it.
333        if matches!(self.state, UploadState::InProgress { .. }) {
334            // Take ownership of the state.
335            let state =
336                std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
337            if let UploadState::InProgress { mut upload, .. } = state {
338                if let Ok(handle) = Handle::try_current() {
339                    handle.spawn(async move {
340                        let _ = upload.abort().await;
341                    });
342                }
343            }
344        }
345    }
346}
347
348/// Returned error from trying to upload a part.
349/// Has the part_idx and buffer so we can pass
350/// them to the retry logic.
351struct UploadPutError {
352    part_idx: u16,
353    buffer: Bytes,
354    source: OSError,
355}
356
357#[derive(Debug)]
358struct ConnectionResetError {
359    message: String,
360    source: Box<dyn std::error::Error + Send + Sync>,
361}
362
363impl std::error::Error for ConnectionResetError {}
364
365impl std::fmt::Display for ConnectionResetError {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        write!(f, "{}: {}", self.message, self.source)
368    }
369}
370
371impl AsyncWrite for ObjectWriter {
372    fn poll_write(
373        mut self: std::pin::Pin<&mut Self>,
374        cx: &mut std::task::Context<'_>,
375        buf: &[u8],
376    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
377        self.as_mut().poll_tasks(cx)?;
378
379        // Fill buffer up to remaining capacity.
380        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
381        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
382        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
383        self.cursor += bytes_to_write;
384
385        // Rust needs a little help to borrow self mutably and immutably at the same time
386        // through a Pin.
387        let mut_self = &mut *self;
388
389        // Instantiate next request, if available.
390        if mut_self.buffer.capacity() == mut_self.buffer.len() {
391            match &mut mut_self.state {
392                UploadState::Started(store) => {
393                    let path = mut_self.path.clone();
394                    let store = store.clone();
395                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
396                    self.state = UploadState::CreatingUpload(fut);
397                }
398                UploadState::InProgress {
399                    upload,
400                    part_idx,
401                    futures,
402                    ..
403                } => {
404                    // TODO: Make max concurrency configurable from storage options.
405                    if futures.len() < max_upload_parallelism() {
406                        let data = Self::next_part_buffer(
407                            &mut mut_self.buffer,
408                            *part_idx,
409                            mut_self.use_constant_size_upload_parts,
410                        );
411                        futures.spawn(
412                            Self::put_part(upload.as_mut(), data, *part_idx, None)
413                                .instrument(tracing::Span::current()),
414                        );
415                        *part_idx += 1;
416                    }
417                }
418                _ => {}
419            }
420        }
421
422        self.poll_tasks(cx)?;
423
424        match bytes_to_write {
425            0 => Poll::Pending,
426            _ => Poll::Ready(Ok(bytes_to_write)),
427        }
428    }
429
430    fn poll_flush(
431        mut self: std::pin::Pin<&mut Self>,
432        cx: &mut std::task::Context<'_>,
433    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
434        self.as_mut().poll_tasks(cx)?;
435
436        match &self.state {
437            UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
438            UploadState::CreatingUpload(_)
439            | UploadState::Completing(_)
440            | UploadState::PuttingSingle(_) => Poll::Pending,
441            UploadState::InProgress { futures, .. } => {
442                if futures.is_empty() {
443                    Poll::Ready(Ok(()))
444                } else {
445                    Poll::Pending
446                }
447            }
448        }
449    }
450
451    fn poll_shutdown(
452        mut self: std::pin::Pin<&mut Self>,
453        cx: &mut std::task::Context<'_>,
454    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
455        loop {
456            self.as_mut().poll_tasks(cx)?;
457
458            // Rust needs a little help to borrow self mutably and immutably at the same time
459            // through a Pin.
460            let mut_self = &mut *self;
461            match &mut mut_self.state {
462                UploadState::Done(_) => return Poll::Ready(Ok(())),
463                UploadState::CreatingUpload(_)
464                | UploadState::PuttingSingle(_)
465                | UploadState::Completing(_) => return Poll::Pending,
466                UploadState::Started(_) => {
467                    // If we didn't start a multipart upload, we can just do a single put.
468                    let part = std::mem::take(&mut mut_self.buffer);
469                    let path = mut_self.path.clone();
470                    self.state.started_to_putting_single(path, part);
471                }
472                UploadState::InProgress {
473                    upload,
474                    futures,
475                    part_idx,
476                } => {
477                    // Flush final batch
478                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
479                        // We can just use `take` since we don't need the buffer anymore.
480                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
481                        futures.spawn(
482                            Self::put_part(upload.as_mut(), data, *part_idx, None)
483                                .instrument(tracing::Span::current()),
484                        );
485                        // We need to go back to beginning of loop to poll the
486                        // new feature and get the waker registered on the ctx.
487                        continue;
488                    }
489
490                    // We handle the transition from in progress to completing here.
491                    if futures.is_empty() {
492                        self.state.in_progress_to_completing();
493                    } else {
494                        return Poll::Pending;
495                    }
496                }
497            }
498        }
499    }
500}
501
502#[async_trait]
503impl Writer for ObjectWriter {
504    async fn tell(&mut self) -> Result<usize> {
505        Ok(self.cursor)
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use tokio::io::AsyncWriteExt;
512
513    use super::*;
514
515    #[tokio::test]
516    async fn test_write() {
517        let store = LanceObjectStore::memory();
518
519        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
520            .await
521            .unwrap();
522        assert_eq!(object_writer.tell().await.unwrap(), 0);
523
524        let buf = vec![0; 256];
525        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
526        assert_eq!(object_writer.tell().await.unwrap(), 256);
527
528        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
529        assert_eq!(object_writer.tell().await.unwrap(), 512);
530
531        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
532        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
533
534        let res = object_writer.shutdown().await.unwrap();
535        assert_eq!(res.size, 256 * 3);
536
537        // Trigger multi part upload
538        let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
539            .await
540            .unwrap();
541        let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
542        for i in 0..5 {
543            // Write more data to trigger the multipart upload
544            // This should be enough to trigger a multipart upload
545            object_writer.write_all(buf.as_slice()).await.unwrap();
546            // Check the cursor
547            assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
548        }
549        let res = object_writer.shutdown().await.unwrap();
550        assert_eq!(res.size, buf.len() * 5);
551    }
552
553    #[tokio::test]
554    async fn test_abort_write() {
555        let store = LanceObjectStore::memory();
556
557        let mut object_writer = futures::executor::block_on(async move {
558            ObjectWriter::new(&store, &Path::from("/foo"))
559                .await
560                .unwrap()
561        });
562        object_writer.abort().await;
563    }
564}