1use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25
26const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
28
29fn max_upload_parallelism() -> usize {
30 static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
31 *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
32 std::env::var("LANCE_UPLOAD_CONCURRENCY")
33 .ok()
34 .and_then(|s| s.parse::<usize>().ok())
35 .unwrap_or(10)
36 })
37}
38
39fn max_conn_reset_retries() -> u16 {
40 static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
41 *MAX_CONN_RESET_RETRIES.get_or_init(|| {
42 std::env::var("LANCE_CONN_RESET_RETRIES")
43 .ok()
44 .and_then(|s| s.parse::<u16>().ok())
45 .unwrap_or(20)
46 })
47}
48
49fn initial_upload_size() -> usize {
50 static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
51 *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
52 std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
53 .ok()
54 .and_then(|s| s.parse::<usize>().ok())
55 .inspect(|size| {
56 if *size < INITIAL_UPLOAD_STEP {
57 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
59 } else if *size > 1024 * 1024 * 1024 * 5 {
60 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
62 }
63 })
64 .unwrap_or(INITIAL_UPLOAD_STEP)
65 })
66}
67
68pub struct ObjectWriter {
76 state: UploadState,
77 path: Arc<Path>,
78 cursor: usize,
79 connection_resets: u16,
80 buffer: Vec<u8>,
81 use_constant_size_upload_parts: bool,
83}
84
85#[derive(Debug, Clone, Default)]
86pub struct WriteResult {
87 pub size: usize,
88 pub e_tag: Option<String>,
89}
90
91enum UploadState {
92 Started(Arc<dyn ObjectStore>),
95 CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
97 InProgress {
99 part_idx: u16,
100 upload: Box<dyn MultipartUpload>,
101 futures: JoinSet<std::result::Result<(), UploadPutError>>,
102 },
103 PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
106 Completing(BoxFuture<'static, OSResult<WriteResult>>),
108 Done(WriteResult),
110}
111
112impl UploadState {
114 fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
115 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
117 *self = match this {
118 Self::Started(store) => {
119 let fut = async move {
120 let size = buffer.len();
121 let res = store.put(&path, buffer.into()).await?;
122 Ok(WriteResult {
123 size,
124 e_tag: res.e_tag,
125 })
126 };
127 Self::PuttingSingle(Box::pin(fut))
128 }
129 _ => unreachable!(),
130 }
131 }
132
133 fn in_progress_to_completing(&mut self) {
134 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
136 *self = match this {
137 Self::InProgress {
138 mut upload,
139 futures,
140 ..
141 } => {
142 debug_assert!(futures.is_empty());
143 let fut = async move {
144 let res = upload.complete().await?;
145 Ok(WriteResult {
146 size: 0, e_tag: res.e_tag,
148 })
149 };
150 Self::Completing(Box::pin(fut))
151 }
152 _ => unreachable!(),
153 };
154 }
155}
156
157impl ObjectWriter {
158 pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
159 Ok(Self {
160 state: UploadState::Started(object_store.inner.clone()),
161 cursor: 0,
162 path: Arc::new(path.clone()),
163 connection_resets: 0,
164 buffer: Vec::with_capacity(initial_upload_size()),
165 use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
166 })
167 }
168
169 fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
172 let new_capacity = if constant_upload_size {
173 initial_upload_size()
175 } else {
176 initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
178 };
179 let new_buffer = Vec::with_capacity(new_capacity);
180 let part = std::mem::replace(buffer, new_buffer);
181 Bytes::from(part)
182 }
183
184 fn put_part(
185 upload: &mut dyn MultipartUpload,
186 buffer: Bytes,
187 part_idx: u16,
188 sleep: Option<std::time::Duration>,
189 ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
190 log::debug!(
191 "MultipartUpload submitting part with {} bytes",
192 buffer.len()
193 );
194 let fut = upload.put_part(buffer.clone().into());
195 Box::pin(async move {
196 if let Some(sleep) = sleep {
197 tokio::time::sleep(sleep).await;
198 }
199 fut.await.map_err(|source| UploadPutError {
200 part_idx,
201 buffer,
202 source,
203 })?;
204 Ok(())
205 })
206 }
207
208 fn poll_tasks(
209 mut self: Pin<&mut Self>,
210 cx: &mut std::task::Context<'_>,
211 ) -> std::result::Result<(), io::Error> {
212 let mut_self = &mut *self;
213 loop {
214 match &mut mut_self.state {
215 UploadState::Started(_) | UploadState::Done(_) => break,
216 UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
217 Poll::Ready(Ok(mut upload)) => {
218 let mut futures = JoinSet::new();
219
220 let data = Self::next_part_buffer(
221 &mut mut_self.buffer,
222 0,
223 mut_self.use_constant_size_upload_parts,
224 );
225 futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
226
227 mut_self.state = UploadState::InProgress {
228 part_idx: 1, futures,
230 upload,
231 };
232 }
233 Poll::Ready(Err(e)) => {
234 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
235 }
236 Poll::Pending => break,
237 },
238 UploadState::InProgress {
239 upload, futures, ..
240 } => {
241 while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
242 match res {
243 Ok(Ok(())) => {}
244 Err(err) => {
245 return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
246 }
247 Ok(Err(UploadPutError {
248 source: OSError::Generic { source, .. },
249 part_idx,
250 buffer,
251 })) if source
252 .to_string()
253 .to_lowercase()
254 .contains("connection reset by peer") =>
255 {
256 if mut_self.connection_resets < max_conn_reset_retries() {
257 mut_self.connection_resets += 1;
259
260 let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
262 let sleep_time =
263 std::time::Duration::from_millis(sleep_time_ms);
264
265 futures.spawn(Self::put_part(
266 upload.as_mut(),
267 buffer,
268 part_idx,
269 Some(sleep_time),
270 ));
271 } else {
272 return Err(io::Error::new(
273 io::ErrorKind::ConnectionReset,
274 Box::new(ConnectionResetError {
275 message: format!(
276 "Hit max retries ({}) for connection reset",
277 max_conn_reset_retries()
278 ),
279 source,
280 }),
281 ));
282 }
283 }
284 Ok(Err(err)) => return Err(err.source.into()),
285 }
286 }
287 break;
288 }
289 UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
290 match fut.poll_unpin(cx) {
291 Poll::Ready(Ok(mut res)) => {
292 res.size = mut_self.cursor;
293 mut_self.state = UploadState::Done(res)
294 }
295 Poll::Ready(Err(e)) => {
296 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
297 }
298 Poll::Pending => break,
299 }
300 }
301 }
302 }
303 Ok(())
304 }
305
306 pub async fn shutdown(&mut self) -> Result<WriteResult> {
307 AsyncWriteExt::shutdown(self).await.map_err(|e| {
308 Error::io(
309 format!("failed to shutdown object writer for {}: {}", self.path, e),
310 location!(),
312 )
313 })?;
314 if let UploadState::Done(result) = &self.state {
315 Ok(result.clone())
316 } else {
317 unreachable!()
318 }
319 }
320}
321
322impl Drop for ObjectWriter {
323 fn drop(&mut self) {
324 if matches!(self.state, UploadState::InProgress { .. }) {
326 let state =
328 std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
329 if let UploadState::InProgress { mut upload, .. } = state {
330 tokio::task::spawn(async move {
331 let _ = upload.abort().await;
332 });
333 }
334 }
335 }
336}
337
338struct UploadPutError {
342 part_idx: u16,
343 buffer: Bytes,
344 source: OSError,
345}
346
347#[derive(Debug)]
348struct ConnectionResetError {
349 message: String,
350 source: Box<dyn std::error::Error + Send + Sync>,
351}
352
353impl std::error::Error for ConnectionResetError {}
354
355impl std::fmt::Display for ConnectionResetError {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 write!(f, "{}: {}", self.message, self.source)
358 }
359}
360
361impl AsyncWrite for ObjectWriter {
362 fn poll_write(
363 mut self: std::pin::Pin<&mut Self>,
364 cx: &mut std::task::Context<'_>,
365 buf: &[u8],
366 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
367 self.as_mut().poll_tasks(cx)?;
368
369 let remaining_capacity = self.buffer.capacity() - self.buffer.len();
371 let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
372 self.buffer.extend_from_slice(&buf[..bytes_to_write]);
373 self.cursor += bytes_to_write;
374
375 let mut_self = &mut *self;
378
379 if mut_self.buffer.capacity() == mut_self.buffer.len() {
381 match &mut mut_self.state {
382 UploadState::Started(store) => {
383 let path = mut_self.path.clone();
384 let store = store.clone();
385 let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
386 self.state = UploadState::CreatingUpload(fut);
387 }
388 UploadState::InProgress {
389 upload,
390 part_idx,
391 futures,
392 ..
393 } => {
394 if futures.len() < max_upload_parallelism() {
396 let data = Self::next_part_buffer(
397 &mut mut_self.buffer,
398 *part_idx,
399 mut_self.use_constant_size_upload_parts,
400 );
401 futures.spawn(
402 Self::put_part(upload.as_mut(), data, *part_idx, None)
403 .instrument(tracing::Span::current()),
404 );
405 *part_idx += 1;
406 }
407 }
408 _ => {}
409 }
410 }
411
412 self.poll_tasks(cx)?;
413
414 match bytes_to_write {
415 0 => Poll::Pending,
416 _ => Poll::Ready(Ok(bytes_to_write)),
417 }
418 }
419
420 fn poll_flush(
421 mut self: std::pin::Pin<&mut Self>,
422 cx: &mut std::task::Context<'_>,
423 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
424 self.as_mut().poll_tasks(cx)?;
425
426 match &self.state {
427 UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
428 UploadState::CreatingUpload(_)
429 | UploadState::Completing(_)
430 | UploadState::PuttingSingle(_) => Poll::Pending,
431 UploadState::InProgress { futures, .. } => {
432 if futures.is_empty() {
433 Poll::Ready(Ok(()))
434 } else {
435 Poll::Pending
436 }
437 }
438 }
439 }
440
441 fn poll_shutdown(
442 mut self: std::pin::Pin<&mut Self>,
443 cx: &mut std::task::Context<'_>,
444 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
445 loop {
446 self.as_mut().poll_tasks(cx)?;
447
448 let mut_self = &mut *self;
451 match &mut mut_self.state {
452 UploadState::Done(_) => return Poll::Ready(Ok(())),
453 UploadState::CreatingUpload(_)
454 | UploadState::PuttingSingle(_)
455 | UploadState::Completing(_) => return Poll::Pending,
456 UploadState::Started(_) => {
457 let part = std::mem::take(&mut mut_self.buffer);
459 let path = mut_self.path.clone();
460 self.state.started_to_completing(path, part);
461 }
462 UploadState::InProgress {
463 upload,
464 futures,
465 part_idx,
466 } => {
467 if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
469 let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
471 futures.spawn(
472 Self::put_part(upload.as_mut(), data, *part_idx, None)
473 .instrument(tracing::Span::current()),
474 );
475 continue;
478 }
479
480 if futures.is_empty() {
482 self.state.in_progress_to_completing();
483 } else {
484 return Poll::Pending;
485 }
486 }
487 }
488 }
489 }
490}
491
492#[async_trait]
493impl Writer for ObjectWriter {
494 async fn tell(&mut self) -> Result<usize> {
495 Ok(self.cursor)
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use tokio::io::AsyncWriteExt;
502
503 use super::*;
504
505 #[tokio::test]
506 async fn test_write() {
507 let store = LanceObjectStore::memory();
508
509 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
510 .await
511 .unwrap();
512 assert_eq!(object_writer.tell().await.unwrap(), 0);
513
514 let buf = vec![0; 256];
515 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
516 assert_eq!(object_writer.tell().await.unwrap(), 256);
517
518 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
519 assert_eq!(object_writer.tell().await.unwrap(), 512);
520
521 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
522 assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
523
524 let res = object_writer.shutdown().await.unwrap();
525 assert_eq!(res.size, 256 * 3);
526
527 let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
529 .await
530 .unwrap();
531 let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
532 for i in 0..5 {
533 object_writer.write_all(buf.as_slice()).await.unwrap();
536 assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
538 }
539 let res = object_writer.shutdown().await.unwrap();
540 assert_eq!(res.size, buf.len() * 5);
541 }
542}