1use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25use tokio::runtime::Handle;
26
27const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31 static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32 *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33 std::env::var("LANCE_UPLOAD_CONCURRENCY")
34 .ok()
35 .and_then(|s| s.parse::<usize>().ok())
36 .unwrap_or(10)
37 })
38}
39
40fn max_conn_reset_retries() -> u16 {
41 static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42 *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43 std::env::var("LANCE_CONN_RESET_RETRIES")
44 .ok()
45 .and_then(|s| s.parse::<u16>().ok())
46 .unwrap_or(20)
47 })
48}
49
50fn initial_upload_size() -> usize {
51 static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
52 *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
53 std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
54 .ok()
55 .and_then(|s| s.parse::<usize>().ok())
56 .inspect(|size| {
57 if *size < INITIAL_UPLOAD_STEP {
58 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
60 } else if *size > 1024 * 1024 * 1024 * 5 {
61 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
63 }
64 })
65 .unwrap_or(INITIAL_UPLOAD_STEP)
66 })
67}
68
69pub struct ObjectWriter {
77 state: UploadState,
78 path: Arc<Path>,
79 cursor: usize,
80 connection_resets: u16,
81 buffer: Vec<u8>,
82 use_constant_size_upload_parts: bool,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct WriteResult {
88 pub size: usize,
89 pub e_tag: Option<String>,
90}
91
92enum UploadState {
93 Started(Arc<dyn ObjectStore>),
96 CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
98 InProgress {
100 part_idx: u16,
101 upload: Box<dyn MultipartUpload>,
102 futures: JoinSet<std::result::Result<(), UploadPutError>>,
103 },
104 PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
107 Completing(BoxFuture<'static, OSResult<WriteResult>>),
109 Done(WriteResult),
111}
112
113impl UploadState {
115 fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
116 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
118 *self = match this {
119 Self::Started(store) => {
120 let fut = async move {
121 let size = buffer.len();
122 let res = store.put(&path, buffer.into()).await?;
123 Ok(WriteResult {
124 size,
125 e_tag: res.e_tag,
126 })
127 };
128 Self::PuttingSingle(Box::pin(fut))
129 }
130 _ => unreachable!(),
131 }
132 }
133
134 fn in_progress_to_completing(&mut self) {
135 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
137 *self = match this {
138 Self::InProgress {
139 mut upload,
140 futures,
141 ..
142 } => {
143 debug_assert!(futures.is_empty());
144 let fut = async move {
145 let res = upload.complete().await?;
146 Ok(WriteResult {
147 size: 0, e_tag: res.e_tag,
149 })
150 };
151 Self::Completing(Box::pin(fut))
152 }
153 _ => unreachable!(),
154 };
155 }
156}
157
158impl ObjectWriter {
159 pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
160 Ok(Self {
161 state: UploadState::Started(object_store.inner.clone()),
162 cursor: 0,
163 path: Arc::new(path.clone()),
164 connection_resets: 0,
165 buffer: Vec::with_capacity(initial_upload_size()),
166 use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
167 })
168 }
169
170 fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
173 let new_capacity = if constant_upload_size {
174 initial_upload_size()
176 } else {
177 initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
179 };
180 let new_buffer = Vec::with_capacity(new_capacity);
181 let part = std::mem::replace(buffer, new_buffer);
182 Bytes::from(part)
183 }
184
185 fn put_part(
186 upload: &mut dyn MultipartUpload,
187 buffer: Bytes,
188 part_idx: u16,
189 sleep: Option<std::time::Duration>,
190 ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
191 log::debug!(
192 "MultipartUpload submitting part with {} bytes",
193 buffer.len()
194 );
195 let fut = upload.put_part(buffer.clone().into());
196 Box::pin(async move {
197 if let Some(sleep) = sleep {
198 tokio::time::sleep(sleep).await;
199 }
200 fut.await.map_err(|source| UploadPutError {
201 part_idx,
202 buffer,
203 source,
204 })?;
205 Ok(())
206 })
207 }
208
209 fn poll_tasks(
210 mut self: Pin<&mut Self>,
211 cx: &mut std::task::Context<'_>,
212 ) -> std::result::Result<(), io::Error> {
213 let mut_self = &mut *self;
214 loop {
215 match &mut mut_self.state {
216 UploadState::Started(_) | UploadState::Done(_) => break,
217 UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
218 Poll::Ready(Ok(mut upload)) => {
219 let mut futures = JoinSet::new();
220
221 let data = Self::next_part_buffer(
222 &mut mut_self.buffer,
223 0,
224 mut_self.use_constant_size_upload_parts,
225 );
226 futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
227
228 mut_self.state = UploadState::InProgress {
229 part_idx: 1, futures,
231 upload,
232 };
233 }
234 Poll::Ready(Err(e)) => {
235 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
236 }
237 Poll::Pending => break,
238 },
239 UploadState::InProgress {
240 upload, futures, ..
241 } => {
242 while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
243 match res {
244 Ok(Ok(())) => {}
245 Err(err) => {
246 return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
247 }
248 Ok(Err(UploadPutError {
249 source: OSError::Generic { source, .. },
250 part_idx,
251 buffer,
252 })) if source
253 .to_string()
254 .to_lowercase()
255 .contains("connection reset by peer") =>
256 {
257 if mut_self.connection_resets < max_conn_reset_retries() {
258 mut_self.connection_resets += 1;
260
261 let sleep_time_ms = rand::rng().random_range(2_000..8_000);
263 let sleep_time =
264 std::time::Duration::from_millis(sleep_time_ms);
265
266 futures.spawn(Self::put_part(
267 upload.as_mut(),
268 buffer,
269 part_idx,
270 Some(sleep_time),
271 ));
272 } else {
273 return Err(io::Error::new(
274 io::ErrorKind::ConnectionReset,
275 Box::new(ConnectionResetError {
276 message: format!(
277 "Hit max retries ({}) for connection reset",
278 max_conn_reset_retries()
279 ),
280 source,
281 }),
282 ));
283 }
284 }
285 Ok(Err(err)) => return Err(err.source.into()),
286 }
287 }
288 break;
289 }
290 UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
291 match fut.poll_unpin(cx) {
292 Poll::Ready(Ok(mut res)) => {
293 res.size = mut_self.cursor;
294 mut_self.state = UploadState::Done(res)
295 }
296 Poll::Ready(Err(e)) => {
297 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
298 }
299 Poll::Pending => break,
300 }
301 }
302 }
303 }
304 Ok(())
305 }
306
307 pub async fn shutdown(&mut self) -> Result<WriteResult> {
308 AsyncWriteExt::shutdown(self).await.map_err(|e| {
309 Error::io(
310 format!("failed to shutdown object writer for {}: {}", self.path, e),
311 location!(),
313 )
314 })?;
315 if let UploadState::Done(result) = &self.state {
316 Ok(result.clone())
317 } else {
318 unreachable!()
319 }
320 }
321
322 pub async fn abort(&mut self) {
323 let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
324 if let UploadState::InProgress { mut upload, .. } = state {
325 let _ = upload.abort().await;
326 }
327 }
328}
329
330impl Drop for ObjectWriter {
331 fn drop(&mut self) {
332 if matches!(self.state, UploadState::InProgress { .. }) {
334 let state =
336 std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
337 if let UploadState::InProgress { mut upload, .. } = state {
338 if let Ok(handle) = Handle::try_current() {
339 handle.spawn(async move {
340 let _ = upload.abort().await;
341 });
342 }
343 }
344 }
345 }
346}
347
348struct UploadPutError {
352 part_idx: u16,
353 buffer: Bytes,
354 source: OSError,
355}
356
357#[derive(Debug)]
358struct ConnectionResetError {
359 message: String,
360 source: Box<dyn std::error::Error + Send + Sync>,
361}
362
363impl std::error::Error for ConnectionResetError {}
364
365impl std::fmt::Display for ConnectionResetError {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 write!(f, "{}: {}", self.message, self.source)
368 }
369}
370
371impl AsyncWrite for ObjectWriter {
372 fn poll_write(
373 mut self: std::pin::Pin<&mut Self>,
374 cx: &mut std::task::Context<'_>,
375 buf: &[u8],
376 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
377 self.as_mut().poll_tasks(cx)?;
378
379 let remaining_capacity = self.buffer.capacity() - self.buffer.len();
381 let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
382 self.buffer.extend_from_slice(&buf[..bytes_to_write]);
383 self.cursor += bytes_to_write;
384
385 let mut_self = &mut *self;
388
389 if mut_self.buffer.capacity() == mut_self.buffer.len() {
391 match &mut mut_self.state {
392 UploadState::Started(store) => {
393 let path = mut_self.path.clone();
394 let store = store.clone();
395 let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
396 self.state = UploadState::CreatingUpload(fut);
397 }
398 UploadState::InProgress {
399 upload,
400 part_idx,
401 futures,
402 ..
403 } => {
404 if futures.len() < max_upload_parallelism() {
406 let data = Self::next_part_buffer(
407 &mut mut_self.buffer,
408 *part_idx,
409 mut_self.use_constant_size_upload_parts,
410 );
411 futures.spawn(
412 Self::put_part(upload.as_mut(), data, *part_idx, None)
413 .instrument(tracing::Span::current()),
414 );
415 *part_idx += 1;
416 }
417 }
418 _ => {}
419 }
420 }
421
422 self.poll_tasks(cx)?;
423
424 match bytes_to_write {
425 0 => Poll::Pending,
426 _ => Poll::Ready(Ok(bytes_to_write)),
427 }
428 }
429
430 fn poll_flush(
431 mut self: std::pin::Pin<&mut Self>,
432 cx: &mut std::task::Context<'_>,
433 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
434 self.as_mut().poll_tasks(cx)?;
435
436 match &self.state {
437 UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
438 UploadState::CreatingUpload(_)
439 | UploadState::Completing(_)
440 | UploadState::PuttingSingle(_) => Poll::Pending,
441 UploadState::InProgress { futures, .. } => {
442 if futures.is_empty() {
443 Poll::Ready(Ok(()))
444 } else {
445 Poll::Pending
446 }
447 }
448 }
449 }
450
451 fn poll_shutdown(
452 mut self: std::pin::Pin<&mut Self>,
453 cx: &mut std::task::Context<'_>,
454 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
455 loop {
456 self.as_mut().poll_tasks(cx)?;
457
458 let mut_self = &mut *self;
461 match &mut mut_self.state {
462 UploadState::Done(_) => return Poll::Ready(Ok(())),
463 UploadState::CreatingUpload(_)
464 | UploadState::PuttingSingle(_)
465 | UploadState::Completing(_) => return Poll::Pending,
466 UploadState::Started(_) => {
467 let part = std::mem::take(&mut mut_self.buffer);
469 let path = mut_self.path.clone();
470 self.state.started_to_putting_single(path, part);
471 }
472 UploadState::InProgress {
473 upload,
474 futures,
475 part_idx,
476 } => {
477 if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
479 let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
481 futures.spawn(
482 Self::put_part(upload.as_mut(), data, *part_idx, None)
483 .instrument(tracing::Span::current()),
484 );
485 continue;
488 }
489
490 if futures.is_empty() {
492 self.state.in_progress_to_completing();
493 } else {
494 return Poll::Pending;
495 }
496 }
497 }
498 }
499 }
500}
501
502#[async_trait]
503impl Writer for ObjectWriter {
504 async fn tell(&mut self) -> Result<usize> {
505 Ok(self.cursor)
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use tokio::io::AsyncWriteExt;
512
513 use super::*;
514
515 #[tokio::test]
516 async fn test_write() {
517 let store = LanceObjectStore::memory();
518
519 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
520 .await
521 .unwrap();
522 assert_eq!(object_writer.tell().await.unwrap(), 0);
523
524 let buf = vec![0; 256];
525 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
526 assert_eq!(object_writer.tell().await.unwrap(), 256);
527
528 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
529 assert_eq!(object_writer.tell().await.unwrap(), 512);
530
531 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
532 assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
533
534 let res = object_writer.shutdown().await.unwrap();
535 assert_eq!(res.size, 256 * 3);
536
537 let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
539 .await
540 .unwrap();
541 let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
542 for i in 0..5 {
543 object_writer.write_all(buf.as_slice()).await.unwrap();
546 assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
548 }
549 let res = object_writer.shutdown().await.unwrap();
550 assert_eq!(res.size, buf.len() * 5);
551 }
552
553 #[tokio::test]
554 async fn test_abort_write() {
555 let store = LanceObjectStore::memory();
556
557 let mut object_writer = futures::executor::block_on(async move {
558 ObjectWriter::new(&store, &Path::from("/foo"))
559 .await
560 .unwrap()
561 });
562 object_writer.abort().await;
563 }
564}