lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25
26/// Start at 5MB.
27const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
28
29fn max_upload_parallelism() -> usize {
30    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
31    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
32        std::env::var("LANCE_UPLOAD_CONCURRENCY")
33            .ok()
34            .and_then(|s| s.parse::<usize>().ok())
35            .unwrap_or(10)
36    })
37}
38
39fn max_conn_reset_retries() -> u16 {
40    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
41    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
42        std::env::var("LANCE_CONN_RESET_RETRIES")
43            .ok()
44            .and_then(|s| s.parse::<u16>().ok())
45            .unwrap_or(20)
46    })
47}
48
49fn initial_upload_size() -> usize {
50    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
51    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
52        std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
53            .ok()
54            .and_then(|s| s.parse::<usize>().ok())
55            .inspect(|size| {
56                if *size < INITIAL_UPLOAD_STEP {
57                    // Minimum part size in GCS and S3
58                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
59                } else if *size > 1024 * 1024 * 1024 * 5 {
60                    // Maximum part size in GCS and S3
61                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
62                }
63            })
64            .unwrap_or(INITIAL_UPLOAD_STEP)
65    })
66}
67
68/// Writer to an object in an object store.
69///
70/// If the object is small enough, the writer will upload the object in a single
71/// PUT request. If the object is larger, the writer will create a multipart
72/// upload and upload parts in parallel.
73///
74/// This implements the `AsyncWrite` trait.
75pub struct ObjectWriter {
76    state: UploadState,
77    path: Arc<Path>,
78    cursor: usize,
79    connection_resets: u16,
80    buffer: Vec<u8>,
81    // TODO: use constant size to support R2
82    use_constant_size_upload_parts: bool,
83}
84
85#[derive(Debug, Clone, Default)]
86pub struct WriteResult {
87    pub size: usize,
88    pub e_tag: Option<String>,
89}
90
91enum UploadState {
92    /// The writer has been opened but no data has been written yet. Will be in
93    /// this state until the buffer is full or the writer is shut down.
94    Started(Arc<dyn ObjectStore>),
95    /// The writer is in the process of creating a multipart upload.
96    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
97    /// The writer is in the process of uploading parts.
98    InProgress {
99        part_idx: u16,
100        upload: Box<dyn MultipartUpload>,
101        futures: JoinSet<std::result::Result<(), UploadPutError>>,
102    },
103    /// The writer is in the process of uploading data in a single PUT request.
104    /// This happens when shutdown is called before the buffer is full.
105    PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
106    /// The writer is in the process of completing the multipart upload.
107    Completing(BoxFuture<'static, OSResult<WriteResult>>),
108    /// The writer has been shut down and all data has been written.
109    Done(WriteResult),
110}
111
112/// Methods for state transitions.
113impl UploadState {
114    fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
115        // To get owned self, we temporarily swap with Done.
116        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
117        *self = match this {
118            Self::Started(store) => {
119                let fut = async move {
120                    let size = buffer.len();
121                    let res = store.put(&path, buffer.into()).await?;
122                    Ok(WriteResult {
123                        size,
124                        e_tag: res.e_tag,
125                    })
126                };
127                Self::PuttingSingle(Box::pin(fut))
128            }
129            _ => unreachable!(),
130        }
131    }
132
133    fn in_progress_to_completing(&mut self) {
134        // To get owned self, we temporarily swap with Done.
135        let this = std::mem::replace(self, Self::Done(WriteResult::default()));
136        *self = match this {
137            Self::InProgress {
138                mut upload,
139                futures,
140                ..
141            } => {
142                debug_assert!(futures.is_empty());
143                let fut = async move {
144                    let res = upload.complete().await?;
145                    Ok(WriteResult {
146                        size: 0, // This will be set properly later.
147                        e_tag: res.e_tag,
148                    })
149                };
150                Self::Completing(Box::pin(fut))
151            }
152            _ => unreachable!(),
153        };
154    }
155}
156
157impl ObjectWriter {
158    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
159        Ok(Self {
160            state: UploadState::Started(object_store.inner.clone()),
161            cursor: 0,
162            path: Arc::new(path.clone()),
163            connection_resets: 0,
164            buffer: Vec::with_capacity(initial_upload_size()),
165            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
166        })
167    }
168
169    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
170    /// The new capacity of `buffer` is determined by the current part index.
171    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
172        let new_capacity = if constant_upload_size {
173            // The store does not support variable part sizes, so use the initial size.
174            initial_upload_size()
175        } else {
176            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
177            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
178        };
179        let new_buffer = Vec::with_capacity(new_capacity);
180        let part = std::mem::replace(buffer, new_buffer);
181        Bytes::from(part)
182    }
183
184    fn put_part(
185        upload: &mut dyn MultipartUpload,
186        buffer: Bytes,
187        part_idx: u16,
188        sleep: Option<std::time::Duration>,
189    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
190        log::debug!(
191            "MultipartUpload submitting part with {} bytes",
192            buffer.len()
193        );
194        let fut = upload.put_part(buffer.clone().into());
195        Box::pin(async move {
196            if let Some(sleep) = sleep {
197                tokio::time::sleep(sleep).await;
198            }
199            fut.await.map_err(|source| UploadPutError {
200                part_idx,
201                buffer,
202                source,
203            })?;
204            Ok(())
205        })
206    }
207
208    fn poll_tasks(
209        mut self: Pin<&mut Self>,
210        cx: &mut std::task::Context<'_>,
211    ) -> std::result::Result<(), io::Error> {
212        let mut_self = &mut *self;
213        loop {
214            match &mut mut_self.state {
215                UploadState::Started(_) | UploadState::Done(_) => break,
216                UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
217                    Poll::Ready(Ok(mut upload)) => {
218                        let mut futures = JoinSet::new();
219
220                        let data = Self::next_part_buffer(
221                            &mut mut_self.buffer,
222                            0,
223                            mut_self.use_constant_size_upload_parts,
224                        );
225                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
226
227                        mut_self.state = UploadState::InProgress {
228                            part_idx: 1, // We just used 0
229                            futures,
230                            upload,
231                        };
232                    }
233                    Poll::Ready(Err(e)) => {
234                        return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
235                    }
236                    Poll::Pending => break,
237                },
238                UploadState::InProgress {
239                    upload, futures, ..
240                } => {
241                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
242                        match res {
243                            Ok(Ok(())) => {}
244                            Err(err) => {
245                                return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
246                            }
247                            Ok(Err(UploadPutError {
248                                source: OSError::Generic { source, .. },
249                                part_idx,
250                                buffer,
251                            })) if source
252                                .to_string()
253                                .to_lowercase()
254                                .contains("connection reset by peer") =>
255                            {
256                                if mut_self.connection_resets < max_conn_reset_retries() {
257                                    // Retry, but only up to max_conn_reset_retries of them.
258                                    mut_self.connection_resets += 1;
259
260                                    // Resubmit with random jitter
261                                    let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
262                                    let sleep_time =
263                                        std::time::Duration::from_millis(sleep_time_ms);
264
265                                    futures.spawn(Self::put_part(
266                                        upload.as_mut(),
267                                        buffer,
268                                        part_idx,
269                                        Some(sleep_time),
270                                    ));
271                                } else {
272                                    return Err(io::Error::new(
273                                        io::ErrorKind::ConnectionReset,
274                                        Box::new(ConnectionResetError {
275                                            message: format!(
276                                                "Hit max retries ({}) for connection reset",
277                                                max_conn_reset_retries()
278                                            ),
279                                            source,
280                                        }),
281                                    ));
282                                }
283                            }
284                            Ok(Err(err)) => return Err(err.source.into()),
285                        }
286                    }
287                    break;
288                }
289                UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
290                    match fut.poll_unpin(cx) {
291                        Poll::Ready(Ok(mut res)) => {
292                            res.size = mut_self.cursor;
293                            mut_self.state = UploadState::Done(res)
294                        }
295                        Poll::Ready(Err(e)) => {
296                            return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
297                        }
298                        Poll::Pending => break,
299                    }
300                }
301            }
302        }
303        Ok(())
304    }
305
306    pub async fn shutdown(&mut self) -> Result<WriteResult> {
307        AsyncWriteExt::shutdown(self).await.map_err(|e| {
308            Error::io(
309                format!("failed to shutdown object writer for {}: {}", self.path, e),
310                // and wrap it in here.
311                location!(),
312            )
313        })?;
314        if let UploadState::Done(result) = &self.state {
315            Ok(result.clone())
316        } else {
317            unreachable!()
318        }
319    }
320}
321
322impl Drop for ObjectWriter {
323    fn drop(&mut self) {
324        // If there is a multipart upload started but not finished, we should abort it.
325        if matches!(self.state, UploadState::InProgress { .. }) {
326            // Take ownership of the state.
327            let state =
328                std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
329            if let UploadState::InProgress { mut upload, .. } = state {
330                tokio::task::spawn(async move {
331                    let _ = upload.abort().await;
332                });
333            }
334        }
335    }
336}
337
338/// Returned error from trying to upload a part.
339/// Has the part_idx and buffer so we can pass
340/// them to the retry logic.
341struct UploadPutError {
342    part_idx: u16,
343    buffer: Bytes,
344    source: OSError,
345}
346
347#[derive(Debug)]
348struct ConnectionResetError {
349    message: String,
350    source: Box<dyn std::error::Error + Send + Sync>,
351}
352
353impl std::error::Error for ConnectionResetError {}
354
355impl std::fmt::Display for ConnectionResetError {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        write!(f, "{}: {}", self.message, self.source)
358    }
359}
360
361impl AsyncWrite for ObjectWriter {
362    fn poll_write(
363        mut self: std::pin::Pin<&mut Self>,
364        cx: &mut std::task::Context<'_>,
365        buf: &[u8],
366    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
367        self.as_mut().poll_tasks(cx)?;
368
369        // Fill buffer up to remaining capacity.
370        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
371        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
372        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
373        self.cursor += bytes_to_write;
374
375        // Rust needs a little help to borrow self mutably and immutably at the same time
376        // through a Pin.
377        let mut_self = &mut *self;
378
379        // Instantiate next request, if available.
380        if mut_self.buffer.capacity() == mut_self.buffer.len() {
381            match &mut mut_self.state {
382                UploadState::Started(store) => {
383                    let path = mut_self.path.clone();
384                    let store = store.clone();
385                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
386                    self.state = UploadState::CreatingUpload(fut);
387                }
388                UploadState::InProgress {
389                    upload,
390                    part_idx,
391                    futures,
392                    ..
393                } => {
394                    // TODO: Make max concurrency configurable from storage options.
395                    if futures.len() < max_upload_parallelism() {
396                        let data = Self::next_part_buffer(
397                            &mut mut_self.buffer,
398                            *part_idx,
399                            mut_self.use_constant_size_upload_parts,
400                        );
401                        futures.spawn(
402                            Self::put_part(upload.as_mut(), data, *part_idx, None)
403                                .instrument(tracing::Span::current()),
404                        );
405                        *part_idx += 1;
406                    }
407                }
408                _ => {}
409            }
410        }
411
412        self.poll_tasks(cx)?;
413
414        match bytes_to_write {
415            0 => Poll::Pending,
416            _ => Poll::Ready(Ok(bytes_to_write)),
417        }
418    }
419
420    fn poll_flush(
421        mut self: std::pin::Pin<&mut Self>,
422        cx: &mut std::task::Context<'_>,
423    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
424        self.as_mut().poll_tasks(cx)?;
425
426        match &self.state {
427            UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
428            UploadState::CreatingUpload(_)
429            | UploadState::Completing(_)
430            | UploadState::PuttingSingle(_) => Poll::Pending,
431            UploadState::InProgress { futures, .. } => {
432                if futures.is_empty() {
433                    Poll::Ready(Ok(()))
434                } else {
435                    Poll::Pending
436                }
437            }
438        }
439    }
440
441    fn poll_shutdown(
442        mut self: std::pin::Pin<&mut Self>,
443        cx: &mut std::task::Context<'_>,
444    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
445        loop {
446            self.as_mut().poll_tasks(cx)?;
447
448            // Rust needs a little help to borrow self mutably and immutably at the same time
449            // through a Pin.
450            let mut_self = &mut *self;
451            match &mut mut_self.state {
452                UploadState::Done(_) => return Poll::Ready(Ok(())),
453                UploadState::CreatingUpload(_)
454                | UploadState::PuttingSingle(_)
455                | UploadState::Completing(_) => return Poll::Pending,
456                UploadState::Started(_) => {
457                    // If we didn't start a multipart upload, we can just do a single put.
458                    let part = std::mem::take(&mut mut_self.buffer);
459                    let path = mut_self.path.clone();
460                    self.state.started_to_completing(path, part);
461                }
462                UploadState::InProgress {
463                    upload,
464                    futures,
465                    part_idx,
466                } => {
467                    // Flush final batch
468                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
469                        // We can just use `take` since we don't need the buffer anymore.
470                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
471                        futures.spawn(
472                            Self::put_part(upload.as_mut(), data, *part_idx, None)
473                                .instrument(tracing::Span::current()),
474                        );
475                        // We need to go back to beginning of loop to poll the
476                        // new feature and get the waker registered on the ctx.
477                        continue;
478                    }
479
480                    // We handle the transition from in progress to completing here.
481                    if futures.is_empty() {
482                        self.state.in_progress_to_completing();
483                    } else {
484                        return Poll::Pending;
485                    }
486                }
487            }
488        }
489    }
490}
491
492#[async_trait]
493impl Writer for ObjectWriter {
494    async fn tell(&mut self) -> Result<usize> {
495        Ok(self.cursor)
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use tokio::io::AsyncWriteExt;
502
503    use super::*;
504
505    #[tokio::test]
506    async fn test_write() {
507        let store = LanceObjectStore::memory();
508
509        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
510            .await
511            .unwrap();
512        assert_eq!(object_writer.tell().await.unwrap(), 0);
513
514        let buf = vec![0; 256];
515        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
516        assert_eq!(object_writer.tell().await.unwrap(), 256);
517
518        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
519        assert_eq!(object_writer.tell().await.unwrap(), 512);
520
521        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
522        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
523
524        let res = object_writer.shutdown().await.unwrap();
525        assert_eq!(res.size, 256 * 3);
526
527        // Trigger multi part upload
528        let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
529            .await
530            .unwrap();
531        let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
532        for i in 0..5 {
533            // Write more data to trigger the multipart upload
534            // This should be enough to trigger a multipart upload
535            object_writer.write_all(buf.as_slice()).await.unwrap();
536            // Check the cursor
537            assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
538        }
539        let res = object_writer.shutdown().await.unwrap();
540        assert_eq!(res.size, buf.len() * 5);
541    }
542}