1use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21use tracing::Instrument;
22
23use crate::traits::Writer;
24use snafu::location;
25
26const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
28
29fn max_upload_parallelism() -> usize {
30 static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
31 *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
32 std::env::var("LANCE_UPLOAD_CONCURRENCY")
33 .ok()
34 .and_then(|s| s.parse::<usize>().ok())
35 .unwrap_or(10)
36 })
37}
38
39fn max_conn_reset_retries() -> u16 {
40 static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
41 *MAX_CONN_RESET_RETRIES.get_or_init(|| {
42 std::env::var("LANCE_CONN_RESET_RETRIES")
43 .ok()
44 .and_then(|s| s.parse::<u16>().ok())
45 .unwrap_or(20)
46 })
47}
48
49fn initial_upload_size() -> usize {
50 static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
51 *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
52 std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
53 .ok()
54 .and_then(|s| s.parse::<usize>().ok())
55 .inspect(|size| {
56 if *size < INITIAL_UPLOAD_STEP {
57 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
59 } else if *size > 1024 * 1024 * 1024 * 5 {
60 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
62 }
63 })
64 .unwrap_or(INITIAL_UPLOAD_STEP)
65 })
66}
67
68pub struct ObjectWriter {
76 state: UploadState,
77 path: Arc<Path>,
78 cursor: usize,
79 connection_resets: u16,
80 buffer: Vec<u8>,
81 use_constant_size_upload_parts: bool,
83}
84
85enum UploadState {
86 Started(Arc<dyn ObjectStore>),
89 CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
91 InProgress {
93 part_idx: u16,
94 upload: Box<dyn MultipartUpload>,
95 futures: JoinSet<std::result::Result<(), UploadPutError>>,
96 },
97 PuttingSingle(BoxFuture<'static, OSResult<()>>),
100 Completing(BoxFuture<'static, OSResult<()>>),
102 Done,
104}
105
106impl UploadState {
108 fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
109 let this = std::mem::replace(self, Self::Done);
111 *self = match this {
112 Self::Started(store) => {
113 let fut = async move {
114 store.put(&path, buffer.into()).await?;
115 Ok(())
116 };
117 Self::PuttingSingle(Box::pin(fut))
118 }
119 _ => unreachable!(),
120 }
121 }
122
123 fn in_progress_to_completing(&mut self) {
124 let this = std::mem::replace(self, Self::Done);
126 *self = match this {
127 Self::InProgress {
128 mut upload,
129 futures,
130 ..
131 } => {
132 debug_assert!(futures.is_empty());
133 let fut = async move {
134 upload.complete().await?;
135 Ok(())
136 };
137 Self::Completing(Box::pin(fut))
138 }
139 _ => unreachable!(),
140 };
141 }
142}
143
144impl ObjectWriter {
145 pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
146 Ok(Self {
147 state: UploadState::Started(object_store.inner.clone()),
148 cursor: 0,
149 path: Arc::new(path.clone()),
150 connection_resets: 0,
151 buffer: Vec::with_capacity(initial_upload_size()),
152 use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
153 })
154 }
155
156 fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
159 let new_capacity = if constant_upload_size {
160 initial_upload_size()
162 } else {
163 initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
165 };
166 let new_buffer = Vec::with_capacity(new_capacity);
167 let part = std::mem::replace(buffer, new_buffer);
168 Bytes::from(part)
169 }
170
171 fn put_part(
172 upload: &mut dyn MultipartUpload,
173 buffer: Bytes,
174 part_idx: u16,
175 sleep: Option<std::time::Duration>,
176 ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
177 log::debug!(
178 "MultipartUpload submitting part with {} bytes",
179 buffer.len()
180 );
181 let fut = upload.put_part(buffer.clone().into());
182 Box::pin(async move {
183 if let Some(sleep) = sleep {
184 tokio::time::sleep(sleep).await;
185 }
186 fut.await.map_err(|source| UploadPutError {
187 part_idx,
188 buffer,
189 source,
190 })?;
191 Ok(())
192 })
193 }
194
195 fn poll_tasks(
196 mut self: Pin<&mut Self>,
197 cx: &mut std::task::Context<'_>,
198 ) -> std::result::Result<(), io::Error> {
199 let mut_self = &mut *self;
200 loop {
201 match &mut mut_self.state {
202 UploadState::Started(_) | UploadState::Done => break,
203 UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
204 Poll::Ready(Ok(mut upload)) => {
205 let mut futures = JoinSet::new();
206
207 let data = Self::next_part_buffer(
208 &mut mut_self.buffer,
209 0,
210 mut_self.use_constant_size_upload_parts,
211 );
212 futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
213
214 mut_self.state = UploadState::InProgress {
215 part_idx: 1, futures,
217 upload,
218 };
219 }
220 Poll::Ready(Err(e)) => {
221 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
222 }
223 Poll::Pending => break,
224 },
225 UploadState::InProgress {
226 upload, futures, ..
227 } => {
228 while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
229 match res {
230 Ok(Ok(())) => {}
231 Err(err) => {
232 return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
233 }
234 Ok(Err(UploadPutError {
235 source: OSError::Generic { source, .. },
236 part_idx,
237 buffer,
238 })) if source
239 .to_string()
240 .to_lowercase()
241 .contains("connection reset by peer") =>
242 {
243 if mut_self.connection_resets < max_conn_reset_retries() {
244 mut_self.connection_resets += 1;
246
247 let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
249 let sleep_time =
250 std::time::Duration::from_millis(sleep_time_ms);
251
252 futures.spawn(Self::put_part(
253 upload.as_mut(),
254 buffer,
255 part_idx,
256 Some(sleep_time),
257 ));
258 } else {
259 return Err(io::Error::new(
260 io::ErrorKind::ConnectionReset,
261 Box::new(ConnectionResetError {
262 message: format!(
263 "Hit max retries ({}) for connection reset",
264 max_conn_reset_retries()
265 ),
266 source,
267 }),
268 ));
269 }
270 }
271 Ok(Err(err)) => return Err(err.source.into()),
272 }
273 }
274 break;
275 }
276 UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
277 match fut.poll_unpin(cx) {
278 Poll::Ready(Ok(())) => mut_self.state = UploadState::Done,
279 Poll::Ready(Err(e)) => {
280 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
281 }
282 Poll::Pending => break,
283 }
284 }
285 }
286 }
287 Ok(())
288 }
289
290 pub async fn shutdown(&mut self) -> Result<()> {
291 AsyncWriteExt::shutdown(self).await.map_err(|e| {
292 Error::io(
293 format!("failed to shutdown object writer for {}: {}", self.path, e),
294 location!(),
296 )
297 })
298 }
299}
300
301impl Drop for ObjectWriter {
302 fn drop(&mut self) {
303 if matches!(self.state, UploadState::InProgress { .. }) {
305 let state = std::mem::replace(&mut self.state, UploadState::Done);
307 if let UploadState::InProgress { mut upload, .. } = state {
308 tokio::task::spawn(async move {
309 let _ = upload.abort().await;
310 });
311 }
312 }
313 }
314}
315
316struct UploadPutError {
320 part_idx: u16,
321 buffer: Bytes,
322 source: OSError,
323}
324
325#[derive(Debug)]
326struct ConnectionResetError {
327 message: String,
328 source: Box<dyn std::error::Error + Send + Sync>,
329}
330
331impl std::error::Error for ConnectionResetError {}
332
333impl std::fmt::Display for ConnectionResetError {
334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 write!(f, "{}: {}", self.message, self.source)
336 }
337}
338
339impl AsyncWrite for ObjectWriter {
340 fn poll_write(
341 mut self: std::pin::Pin<&mut Self>,
342 cx: &mut std::task::Context<'_>,
343 buf: &[u8],
344 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
345 self.as_mut().poll_tasks(cx)?;
346
347 let remaining_capacity = self.buffer.capacity() - self.buffer.len();
349 let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
350 self.buffer.extend_from_slice(&buf[..bytes_to_write]);
351 self.cursor += bytes_to_write;
352
353 let mut_self = &mut *self;
356
357 if mut_self.buffer.capacity() == mut_self.buffer.len() {
359 match &mut mut_self.state {
360 UploadState::Started(store) => {
361 let path = mut_self.path.clone();
362 let store = store.clone();
363 let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
364 self.state = UploadState::CreatingUpload(fut);
365 }
366 UploadState::InProgress {
367 upload,
368 part_idx,
369 futures,
370 ..
371 } => {
372 if futures.len() < max_upload_parallelism() {
374 let data = Self::next_part_buffer(
375 &mut mut_self.buffer,
376 *part_idx,
377 mut_self.use_constant_size_upload_parts,
378 );
379 futures.spawn(
380 Self::put_part(upload.as_mut(), data, *part_idx, None)
381 .instrument(tracing::Span::current()),
382 );
383 *part_idx += 1;
384 }
385 }
386 _ => {}
387 }
388 }
389
390 self.poll_tasks(cx)?;
391
392 match bytes_to_write {
393 0 => Poll::Pending,
394 _ => Poll::Ready(Ok(bytes_to_write)),
395 }
396 }
397
398 fn poll_flush(
399 mut self: std::pin::Pin<&mut Self>,
400 cx: &mut std::task::Context<'_>,
401 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
402 self.as_mut().poll_tasks(cx)?;
403
404 match &self.state {
405 UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())),
406 UploadState::CreatingUpload(_)
407 | UploadState::Completing(_)
408 | UploadState::PuttingSingle(_) => Poll::Pending,
409 UploadState::InProgress { futures, .. } => {
410 if futures.is_empty() {
411 Poll::Ready(Ok(()))
412 } else {
413 Poll::Pending
414 }
415 }
416 }
417 }
418
419 fn poll_shutdown(
420 mut self: std::pin::Pin<&mut Self>,
421 cx: &mut std::task::Context<'_>,
422 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
423 loop {
424 self.as_mut().poll_tasks(cx)?;
425
426 let mut_self = &mut *self;
429 match &mut mut_self.state {
430 UploadState::Done => return Poll::Ready(Ok(())),
431 UploadState::CreatingUpload(_)
432 | UploadState::PuttingSingle(_)
433 | UploadState::Completing(_) => return Poll::Pending,
434 UploadState::Started(_) => {
435 let part = std::mem::take(&mut mut_self.buffer);
437 let path = mut_self.path.clone();
438 self.state.started_to_completing(path, part);
439 }
440 UploadState::InProgress {
441 upload,
442 futures,
443 part_idx,
444 } => {
445 if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
447 let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
449 futures.spawn(
450 Self::put_part(upload.as_mut(), data, *part_idx, None)
451 .instrument(tracing::Span::current()),
452 );
453 continue;
456 }
457
458 if futures.is_empty() {
460 self.state.in_progress_to_completing();
461 } else {
462 return Poll::Pending;
463 }
464 }
465 }
466 }
467 }
468}
469
470#[async_trait]
471impl Writer for ObjectWriter {
472 async fn tell(&mut self) -> Result<usize> {
473 Ok(self.cursor)
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use tokio::io::AsyncWriteExt;
480
481 use super::*;
482
483 #[tokio::test]
484 async fn test_write() {
485 let store = LanceObjectStore::memory();
486
487 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
488 .await
489 .unwrap();
490 assert_eq!(object_writer.tell().await.unwrap(), 0);
491
492 let buf = vec![0; 256];
493 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
494 assert_eq!(object_writer.tell().await.unwrap(), 256);
495
496 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
497 assert_eq!(object_writer.tell().await.unwrap(), 512);
498
499 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
500 assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
501
502 object_writer.shutdown().await.unwrap();
503 }
504}