lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25
26/// Start at 5MB.
27const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
28
29fn max_upload_parallelism() -> usize {
30    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
31    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
32        std::env::var("LANCE_UPLOAD_CONCURRENCY")
33            .ok()
34            .and_then(|s| s.parse::<usize>().ok())
35            .unwrap_or(10)
36    })
37}
38
39fn max_conn_reset_retries() -> u16 {
40    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
41    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
42        std::env::var("LANCE_CONN_RESET_RETRIES")
43            .ok()
44            .and_then(|s| s.parse::<u16>().ok())
45            .unwrap_or(20)
46    })
47}
48
49fn initial_upload_size() -> usize {
50    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
51    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
52        std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
53            .ok()
54            .and_then(|s| s.parse::<usize>().ok())
55            .inspect(|size| {
56                if *size < INITIAL_UPLOAD_STEP {
57                    // Minimum part size in GCS and S3
58                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
59                } else if *size > 1024 * 1024 * 1024 * 5 {
60                    // Maximum part size in GCS and S3
61                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
62                }
63            })
64            .unwrap_or(INITIAL_UPLOAD_STEP)
65    })
66}
67
68/// Writer to an object in an object store.
69///
70/// If the object is small enough, the writer will upload the object in a single
71/// PUT request. If the object is larger, the writer will create a multipart
72/// upload and upload parts in parallel.
73///
74/// This implements the `AsyncWrite` trait.
75pub struct ObjectWriter {
76    state: UploadState,
77    path: Arc<Path>,
78    cursor: usize,
79    connection_resets: u16,
80    buffer: Vec<u8>,
81    // TODO: use constant size to support R2
82    use_constant_size_upload_parts: bool,
83}
84
85enum UploadState {
86    /// The writer has been opened but no data has been written yet. Will be in
87    /// this state until the buffer is full or the writer is shut down.
88    Started(Arc<dyn ObjectStore>),
89    /// The writer is in the process of creating a multipart upload.
90    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
91    /// The writer is in the process of uploading parts.
92    InProgress {
93        part_idx: u16,
94        upload: Box<dyn MultipartUpload>,
95        futures: JoinSet<std::result::Result<(), UploadPutError>>,
96    },
97    /// The writer is in the process of uploading data in a single PUT request.
98    /// This happens when shutdown is called before the buffer is full.
99    PuttingSingle(BoxFuture<'static, OSResult<()>>),
100    /// The writer is in the process of completing the multipart upload.
101    Completing(BoxFuture<'static, OSResult<()>>),
102    /// The writer has been shut down and all data has been written.
103    Done,
104}
105
106/// Methods for state transitions.
107impl UploadState {
108    fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
109        // To get owned self, we temporarily swap with Done.
110        let this = std::mem::replace(self, Self::Done);
111        *self = match this {
112            Self::Started(store) => {
113                let fut = async move {
114                    store.put(&path, buffer.into()).await?;
115                    Ok(())
116                };
117                Self::PuttingSingle(Box::pin(fut))
118            }
119            _ => unreachable!(),
120        }
121    }
122
123    fn in_progress_to_completing(&mut self) {
124        // To get owned self, we temporarily swap with Done.
125        let this = std::mem::replace(self, Self::Done);
126        *self = match this {
127            Self::InProgress {
128                mut upload,
129                futures,
130                ..
131            } => {
132                debug_assert!(futures.is_empty());
133                let fut = async move {
134                    upload.complete().await?;
135                    Ok(())
136                };
137                Self::Completing(Box::pin(fut))
138            }
139            _ => unreachable!(),
140        };
141    }
142}
143
144impl ObjectWriter {
145    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
146        Ok(Self {
147            state: UploadState::Started(object_store.inner.clone()),
148            cursor: 0,
149            path: Arc::new(path.clone()),
150            connection_resets: 0,
151            buffer: Vec::with_capacity(initial_upload_size()),
152            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
153        })
154    }
155
156    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
157    /// The new capacity of `buffer` is determined by the current part index.
158    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
159        let new_capacity = if constant_upload_size {
160            // The store does not support variable part sizes, so use the initial size.
161            initial_upload_size()
162        } else {
163            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
164            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
165        };
166        let new_buffer = Vec::with_capacity(new_capacity);
167        let part = std::mem::replace(buffer, new_buffer);
168        Bytes::from(part)
169    }
170
171    fn put_part(
172        upload: &mut dyn MultipartUpload,
173        buffer: Bytes,
174        part_idx: u16,
175        sleep: Option<std::time::Duration>,
176    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
177        log::debug!(
178            "MultipartUpload submitting part with {} bytes",
179            buffer.len()
180        );
181        let fut = upload.put_part(buffer.clone().into());
182        Box::pin(async move {
183            if let Some(sleep) = sleep {
184                tokio::time::sleep(sleep).await;
185            }
186            fut.await.map_err(|source| UploadPutError {
187                part_idx,
188                buffer,
189                source,
190            })?;
191            Ok(())
192        })
193    }
194
195    fn poll_tasks(
196        mut self: Pin<&mut Self>,
197        cx: &mut std::task::Context<'_>,
198    ) -> std::result::Result<(), io::Error> {
199        let mut_self = &mut *self;
200        loop {
201            match &mut mut_self.state {
202                UploadState::Started(_) | UploadState::Done => break,
203                UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
204                    Poll::Ready(Ok(mut upload)) => {
205                        let mut futures = JoinSet::new();
206
207                        let data = Self::next_part_buffer(
208                            &mut mut_self.buffer,
209                            0,
210                            mut_self.use_constant_size_upload_parts,
211                        );
212                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
213
214                        mut_self.state = UploadState::InProgress {
215                            part_idx: 1, // We just used 0
216                            futures,
217                            upload,
218                        };
219                    }
220                    Poll::Ready(Err(e)) => {
221                        return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
222                    }
223                    Poll::Pending => break,
224                },
225                UploadState::InProgress {
226                    upload, futures, ..
227                } => {
228                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
229                        match res {
230                            Ok(Ok(())) => {}
231                            Err(err) => {
232                                return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
233                            }
234                            Ok(Err(UploadPutError {
235                                source: OSError::Generic { source, .. },
236                                part_idx,
237                                buffer,
238                            })) if source
239                                .to_string()
240                                .to_lowercase()
241                                .contains("connection reset by peer") =>
242                            {
243                                if mut_self.connection_resets < max_conn_reset_retries() {
244                                    // Retry, but only up to max_conn_reset_retries of them.
245                                    mut_self.connection_resets += 1;
246
247                                    // Resubmit with random jitter
248                                    let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
249                                    let sleep_time =
250                                        std::time::Duration::from_millis(sleep_time_ms);
251
252                                    futures.spawn(Self::put_part(
253                                        upload.as_mut(),
254                                        buffer,
255                                        part_idx,
256                                        Some(sleep_time),
257                                    ));
258                                } else {
259                                    return Err(io::Error::new(
260                                        io::ErrorKind::ConnectionReset,
261                                        Box::new(ConnectionResetError {
262                                            message: format!(
263                                                "Hit max retries ({}) for connection reset",
264                                                max_conn_reset_retries()
265                                            ),
266                                            source,
267                                        }),
268                                    ));
269                                }
270                            }
271                            Ok(Err(err)) => return Err(err.source.into()),
272                        }
273                    }
274                    break;
275                }
276                UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
277                    match fut.poll_unpin(cx) {
278                        Poll::Ready(Ok(())) => mut_self.state = UploadState::Done,
279                        Poll::Ready(Err(e)) => {
280                            return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
281                        }
282                        Poll::Pending => break,
283                    }
284                }
285            }
286        }
287        Ok(())
288    }
289
290    pub async fn shutdown(&mut self) -> Result<()> {
291        AsyncWriteExt::shutdown(self).await.map_err(|e| {
292            Error::io(
293                format!("failed to shutdown object writer for {}: {}", self.path, e),
294                // and wrap it in here.
295                location!(),
296            )
297        })
298    }
299}
300
301impl Drop for ObjectWriter {
302    fn drop(&mut self) {
303        // If there is a multipart upload started but not finished, we should abort it.
304        if matches!(self.state, UploadState::InProgress { .. }) {
305            // Take ownership of the state.
306            let state = std::mem::replace(&mut self.state, UploadState::Done);
307            if let UploadState::InProgress { mut upload, .. } = state {
308                tokio::task::spawn(async move {
309                    let _ = upload.abort().await;
310                });
311            }
312        }
313    }
314}
315
316/// Returned error from trying to upload a part.
317/// Has the part_idx and buffer so we can pass
318/// them to the retry logic.
319struct UploadPutError {
320    part_idx: u16,
321    buffer: Bytes,
322    source: OSError,
323}
324
325#[derive(Debug)]
326struct ConnectionResetError {
327    message: String,
328    source: Box<dyn std::error::Error + Send + Sync>,
329}
330
331impl std::error::Error for ConnectionResetError {}
332
333impl std::fmt::Display for ConnectionResetError {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        write!(f, "{}: {}", self.message, self.source)
336    }
337}
338
339impl AsyncWrite for ObjectWriter {
340    fn poll_write(
341        mut self: std::pin::Pin<&mut Self>,
342        cx: &mut std::task::Context<'_>,
343        buf: &[u8],
344    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
345        self.as_mut().poll_tasks(cx)?;
346
347        // Fill buffer up to remaining capacity.
348        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
349        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
350        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
351        self.cursor += bytes_to_write;
352
353        // Rust needs a little help to borrow self mutably and immutably at the same time
354        // through a Pin.
355        let mut_self = &mut *self;
356
357        // Instantiate next request, if available.
358        if mut_self.buffer.capacity() == mut_self.buffer.len() {
359            match &mut mut_self.state {
360                UploadState::Started(store) => {
361                    let path = mut_self.path.clone();
362                    let store = store.clone();
363                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
364                    self.state = UploadState::CreatingUpload(fut);
365                }
366                UploadState::InProgress {
367                    upload,
368                    part_idx,
369                    futures,
370                    ..
371                } => {
372                    // TODO: Make max concurrency configurable from storage options.
373                    if futures.len() < max_upload_parallelism() {
374                        let data = Self::next_part_buffer(
375                            &mut mut_self.buffer,
376                            *part_idx,
377                            mut_self.use_constant_size_upload_parts,
378                        );
379                        futures.spawn(
380                            Self::put_part(upload.as_mut(), data, *part_idx, None)
381                                .instrument(tracing::Span::current()),
382                        );
383                        *part_idx += 1;
384                    }
385                }
386                _ => {}
387            }
388        }
389
390        self.poll_tasks(cx)?;
391
392        match bytes_to_write {
393            0 => Poll::Pending,
394            _ => Poll::Ready(Ok(bytes_to_write)),
395        }
396    }
397
398    fn poll_flush(
399        mut self: std::pin::Pin<&mut Self>,
400        cx: &mut std::task::Context<'_>,
401    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
402        self.as_mut().poll_tasks(cx)?;
403
404        match &self.state {
405            UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())),
406            UploadState::CreatingUpload(_)
407            | UploadState::Completing(_)
408            | UploadState::PuttingSingle(_) => Poll::Pending,
409            UploadState::InProgress { futures, .. } => {
410                if futures.is_empty() {
411                    Poll::Ready(Ok(()))
412                } else {
413                    Poll::Pending
414                }
415            }
416        }
417    }
418
419    fn poll_shutdown(
420        mut self: std::pin::Pin<&mut Self>,
421        cx: &mut std::task::Context<'_>,
422    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
423        loop {
424            self.as_mut().poll_tasks(cx)?;
425
426            // Rust needs a little help to borrow self mutably and immutably at the same time
427            // through a Pin.
428            let mut_self = &mut *self;
429            match &mut mut_self.state {
430                UploadState::Done => return Poll::Ready(Ok(())),
431                UploadState::CreatingUpload(_)
432                | UploadState::PuttingSingle(_)
433                | UploadState::Completing(_) => return Poll::Pending,
434                UploadState::Started(_) => {
435                    // If we didn't start a multipart upload, we can just do a single put.
436                    let part = std::mem::take(&mut mut_self.buffer);
437                    let path = mut_self.path.clone();
438                    self.state.started_to_completing(path, part);
439                }
440                UploadState::InProgress {
441                    upload,
442                    futures,
443                    part_idx,
444                } => {
445                    // Flush final batch
446                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
447                        // We can just use `take` since we don't need the buffer anymore.
448                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
449                        futures.spawn(
450                            Self::put_part(upload.as_mut(), data, *part_idx, None)
451                                .instrument(tracing::Span::current()),
452                        );
453                        // We need to go back to beginning of loop to poll the
454                        // new feature and get the waker registered on the ctx.
455                        continue;
456                    }
457
458                    // We handle the transition from in progress to completing here.
459                    if futures.is_empty() {
460                        self.state.in_progress_to_completing();
461                    } else {
462                        return Poll::Pending;
463                    }
464                }
465            }
466        }
467    }
468}
469
470#[async_trait]
471impl Writer for ObjectWriter {
472    async fn tell(&mut self) -> Result<usize> {
473        Ok(self.cursor)
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use tokio::io::AsyncWriteExt;
480
481    use super::*;
482
483    #[tokio::test]
484    async fn test_write() {
485        let store = LanceObjectStore::memory();
486
487        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
488            .await
489            .unwrap();
490        assert_eq!(object_writer.tell().await.unwrap(), 0);
491
492        let buf = vec![0; 256];
493        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
494        assert_eq!(object_writer.tell().await.unwrap(), 256);
495
496        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
497        assert_eq!(object_writer.tell().await.unwrap(), 512);
498
499        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
500        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
501
502        object_writer.shutdown().await.unwrap();
503    }
504}