1use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25use tokio::runtime::Handle;
26
27const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
29
30fn max_upload_parallelism() -> usize {
31 static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
32 *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
33 std::env::var("LANCE_UPLOAD_CONCURRENCY")
34 .ok()
35 .and_then(|s| s.parse::<usize>().ok())
36 .unwrap_or(10)
37 })
38}
39
40fn max_conn_reset_retries() -> u16 {
41 static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
42 *MAX_CONN_RESET_RETRIES.get_or_init(|| {
43 std::env::var("LANCE_CONN_RESET_RETRIES")
44 .ok()
45 .and_then(|s| s.parse::<u16>().ok())
46 .unwrap_or(20)
47 })
48}
49
50fn initial_upload_size() -> usize {
51 static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
52 *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
53 std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
54 .ok()
55 .and_then(|s| s.parse::<usize>().ok())
56 .inspect(|size| {
57 if *size < INITIAL_UPLOAD_STEP {
58 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
60 } else if *size > 1024 * 1024 * 1024 * 5 {
61 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
63 }
64 })
65 .unwrap_or(INITIAL_UPLOAD_STEP)
66 })
67}
68
69pub struct ObjectWriter {
77 state: UploadState,
78 path: Arc<Path>,
79 cursor: usize,
80 connection_resets: u16,
81 buffer: Vec<u8>,
82 use_constant_size_upload_parts: bool,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct WriteResult {
88 pub size: usize,
89 pub e_tag: Option<String>,
90}
91
92enum UploadState {
93 Started(Arc<dyn ObjectStore>),
96 CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
98 InProgress {
100 part_idx: u16,
101 upload: Box<dyn MultipartUpload>,
102 futures: JoinSet<std::result::Result<(), UploadPutError>>,
103 },
104 PuttingSingle(BoxFuture<'static, OSResult<WriteResult>>),
107 Completing(BoxFuture<'static, OSResult<WriteResult>>),
109 Done(WriteResult),
111}
112
113impl UploadState {
115 fn started_to_putting_single(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
116 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
118 *self = match this {
119 Self::Started(store) => {
120 let fut = async move {
121 let size = buffer.len();
122 let res = store.put(&path, buffer.into()).await?;
123 Ok(WriteResult {
124 size,
125 e_tag: res.e_tag,
126 })
127 };
128 Self::PuttingSingle(Box::pin(fut))
129 }
130 _ => unreachable!(),
131 }
132 }
133
134 fn in_progress_to_completing(&mut self) {
135 let this = std::mem::replace(self, Self::Done(WriteResult::default()));
137 *self = match this {
138 Self::InProgress {
139 mut upload,
140 futures,
141 ..
142 } => {
143 debug_assert!(futures.is_empty());
144 let fut = async move {
145 let res = upload.complete().await?;
146 Ok(WriteResult {
147 size: 0, e_tag: res.e_tag,
149 })
150 };
151 Self::Completing(Box::pin(fut))
152 }
153 _ => unreachable!(),
154 };
155 }
156}
157
158impl ObjectWriter {
159 pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
160 Ok(Self {
161 state: UploadState::Started(object_store.inner.clone()),
162 cursor: 0,
163 path: Arc::new(path.clone()),
164 connection_resets: 0,
165 buffer: Vec::with_capacity(initial_upload_size()),
166 use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
167 })
168 }
169
170 fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
173 let new_capacity = if constant_upload_size {
174 initial_upload_size()
176 } else {
177 initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
179 };
180 let new_buffer = Vec::with_capacity(new_capacity);
181 let part = std::mem::replace(buffer, new_buffer);
182 Bytes::from(part)
183 }
184
185 fn put_part(
186 upload: &mut dyn MultipartUpload,
187 buffer: Bytes,
188 part_idx: u16,
189 sleep: Option<std::time::Duration>,
190 ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
191 log::debug!(
192 "MultipartUpload submitting part with {} bytes",
193 buffer.len()
194 );
195 let fut = upload.put_part(buffer.clone().into());
196 Box::pin(async move {
197 if let Some(sleep) = sleep {
198 tokio::time::sleep(sleep).await;
199 }
200 fut.await.map_err(|source| UploadPutError {
201 part_idx,
202 buffer,
203 source,
204 })?;
205 Ok(())
206 })
207 }
208
209 fn poll_tasks(
210 mut self: Pin<&mut Self>,
211 cx: &mut std::task::Context<'_>,
212 ) -> std::result::Result<(), io::Error> {
213 let mut_self = &mut *self;
214 loop {
215 match &mut mut_self.state {
216 UploadState::Started(_) | UploadState::Done(_) => break,
217 UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
218 Poll::Ready(Ok(mut upload)) => {
219 let mut futures = JoinSet::new();
220
221 let data = Self::next_part_buffer(
222 &mut mut_self.buffer,
223 0,
224 mut_self.use_constant_size_upload_parts,
225 );
226 futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
227
228 mut_self.state = UploadState::InProgress {
229 part_idx: 1, futures,
231 upload,
232 };
233 }
234 Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
235 Poll::Pending => break,
236 },
237 UploadState::InProgress {
238 upload, futures, ..
239 } => {
240 while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
241 match res {
242 Ok(Ok(())) => {}
243 Err(err) => return Err(std::io::Error::other(err)),
244 Ok(Err(UploadPutError {
245 source: OSError::Generic { source, .. },
246 part_idx,
247 buffer,
248 })) if source
249 .to_string()
250 .to_lowercase()
251 .contains("connection reset by peer") =>
252 {
253 if mut_self.connection_resets < max_conn_reset_retries() {
254 mut_self.connection_resets += 1;
256
257 let sleep_time_ms = rand::rng().random_range(2_000..8_000);
259 let sleep_time =
260 std::time::Duration::from_millis(sleep_time_ms);
261
262 futures.spawn(Self::put_part(
263 upload.as_mut(),
264 buffer,
265 part_idx,
266 Some(sleep_time),
267 ));
268 } else {
269 return Err(io::Error::new(
270 io::ErrorKind::ConnectionReset,
271 Box::new(ConnectionResetError {
272 message: format!(
273 "Hit max retries ({}) for connection reset",
274 max_conn_reset_retries()
275 ),
276 source,
277 }),
278 ));
279 }
280 }
281 Ok(Err(err)) => return Err(err.source.into()),
282 }
283 }
284 break;
285 }
286 UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
287 match fut.poll_unpin(cx) {
288 Poll::Ready(Ok(mut res)) => {
289 res.size = mut_self.cursor;
290 mut_self.state = UploadState::Done(res)
291 }
292 Poll::Ready(Err(e)) => return Err(std::io::Error::other(e)),
293 Poll::Pending => break,
294 }
295 }
296 }
297 }
298 Ok(())
299 }
300
301 pub async fn shutdown(&mut self) -> Result<WriteResult> {
302 AsyncWriteExt::shutdown(self).await.map_err(|e| {
303 Error::io(
304 format!("failed to shutdown object writer for {}: {}", self.path, e),
305 location!(),
307 )
308 })?;
309 if let UploadState::Done(result) = &self.state {
310 Ok(result.clone())
311 } else {
312 unreachable!()
313 }
314 }
315
316 pub async fn abort(&mut self) {
317 let state = std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
318 if let UploadState::InProgress { mut upload, .. } = state {
319 let _ = upload.abort().await;
320 }
321 }
322}
323
324impl Drop for ObjectWriter {
325 fn drop(&mut self) {
326 if matches!(self.state, UploadState::InProgress { .. }) {
328 let state =
330 std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default()));
331 if let UploadState::InProgress { mut upload, .. } = state {
332 if let Ok(handle) = Handle::try_current() {
333 handle.spawn(async move {
334 let _ = upload.abort().await;
335 });
336 }
337 }
338 }
339 }
340}
341
342struct UploadPutError {
346 part_idx: u16,
347 buffer: Bytes,
348 source: OSError,
349}
350
351#[derive(Debug)]
352struct ConnectionResetError {
353 message: String,
354 source: Box<dyn std::error::Error + Send + Sync>,
355}
356
357impl std::error::Error for ConnectionResetError {}
358
359impl std::fmt::Display for ConnectionResetError {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 write!(f, "{}: {}", self.message, self.source)
362 }
363}
364
365impl AsyncWrite for ObjectWriter {
366 fn poll_write(
367 mut self: std::pin::Pin<&mut Self>,
368 cx: &mut std::task::Context<'_>,
369 buf: &[u8],
370 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
371 self.as_mut().poll_tasks(cx)?;
372
373 let remaining_capacity = self.buffer.capacity() - self.buffer.len();
375 let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
376 self.buffer.extend_from_slice(&buf[..bytes_to_write]);
377 self.cursor += bytes_to_write;
378
379 let mut_self = &mut *self;
382
383 if mut_self.buffer.capacity() == mut_self.buffer.len() {
385 match &mut mut_self.state {
386 UploadState::Started(store) => {
387 let path = mut_self.path.clone();
388 let store = store.clone();
389 let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
390 self.state = UploadState::CreatingUpload(fut);
391 }
392 UploadState::InProgress {
393 upload,
394 part_idx,
395 futures,
396 ..
397 } => {
398 if futures.len() < max_upload_parallelism() {
400 let data = Self::next_part_buffer(
401 &mut mut_self.buffer,
402 *part_idx,
403 mut_self.use_constant_size_upload_parts,
404 );
405 futures.spawn(
406 Self::put_part(upload.as_mut(), data, *part_idx, None)
407 .instrument(tracing::Span::current()),
408 );
409 *part_idx += 1;
410 }
411 }
412 _ => {}
413 }
414 }
415
416 self.poll_tasks(cx)?;
417
418 match bytes_to_write {
419 0 => Poll::Pending,
420 _ => Poll::Ready(Ok(bytes_to_write)),
421 }
422 }
423
424 fn poll_flush(
425 mut self: std::pin::Pin<&mut Self>,
426 cx: &mut std::task::Context<'_>,
427 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
428 self.as_mut().poll_tasks(cx)?;
429
430 match &self.state {
431 UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())),
432 UploadState::CreatingUpload(_)
433 | UploadState::Completing(_)
434 | UploadState::PuttingSingle(_) => Poll::Pending,
435 UploadState::InProgress { futures, .. } => {
436 if futures.is_empty() {
437 Poll::Ready(Ok(()))
438 } else {
439 Poll::Pending
440 }
441 }
442 }
443 }
444
445 fn poll_shutdown(
446 mut self: std::pin::Pin<&mut Self>,
447 cx: &mut std::task::Context<'_>,
448 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
449 loop {
450 self.as_mut().poll_tasks(cx)?;
451
452 let mut_self = &mut *self;
455 match &mut mut_self.state {
456 UploadState::Done(_) => return Poll::Ready(Ok(())),
457 UploadState::CreatingUpload(_)
458 | UploadState::PuttingSingle(_)
459 | UploadState::Completing(_) => return Poll::Pending,
460 UploadState::Started(_) => {
461 let part = std::mem::take(&mut mut_self.buffer);
463 let path = mut_self.path.clone();
464 self.state.started_to_putting_single(path, part);
465 }
466 UploadState::InProgress {
467 upload,
468 futures,
469 part_idx,
470 } => {
471 if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
473 let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
475 futures.spawn(
476 Self::put_part(upload.as_mut(), data, *part_idx, None)
477 .instrument(tracing::Span::current()),
478 );
479 continue;
482 }
483
484 if futures.is_empty() {
486 self.state.in_progress_to_completing();
487 } else {
488 return Poll::Pending;
489 }
490 }
491 }
492 }
493 }
494}
495
496#[async_trait]
497impl Writer for ObjectWriter {
498 async fn tell(&mut self) -> Result<usize> {
499 Ok(self.cursor)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use tokio::io::AsyncWriteExt;
506
507 use super::*;
508
509 #[tokio::test]
510 async fn test_write() {
511 let store = LanceObjectStore::memory();
512
513 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
514 .await
515 .unwrap();
516 assert_eq!(object_writer.tell().await.unwrap(), 0);
517
518 let buf = vec![0; 256];
519 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
520 assert_eq!(object_writer.tell().await.unwrap(), 256);
521
522 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
523 assert_eq!(object_writer.tell().await.unwrap(), 512);
524
525 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
526 assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
527
528 let res = object_writer.shutdown().await.unwrap();
529 assert_eq!(res.size, 256 * 3);
530
531 let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar"))
533 .await
534 .unwrap();
535 let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2];
536 for i in 0..5 {
537 object_writer.write_all(buf.as_slice()).await.unwrap();
540 assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len());
542 }
543 let res = object_writer.shutdown().await.unwrap();
544 assert_eq!(res.size, buf.len() * 5);
545 }
546
547 #[tokio::test]
548 async fn test_abort_write() {
549 let store = LanceObjectStore::memory();
550
551 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
552 .await
553 .unwrap();
554 object_writer.abort().await;
555 }
556}