1use std::ops::Range;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use bytes::Bytes;
9use deepsize::DeepSizeOf;
10use futures::{
11 future::{BoxFuture, Shared},
12 FutureExt,
13};
14use lance_core::{error::CloneableError, Error, Result};
15use object_store::{path::Path, GetOptions, GetResult, ObjectStore, Result as OSResult};
16use tokio::sync::OnceCell;
17use tracing::instrument;
18
19use crate::{object_store::DEFAULT_CLOUD_IO_PARALLELISM, traits::Reader};
20
21#[derive(Debug)]
25pub struct CloudObjectReader {
26 pub object_store: Arc<dyn ObjectStore>,
28 pub path: Path,
30 size: OnceCell<usize>,
32
33 block_size: usize,
34 download_retry_count: usize,
35}
36
37impl DeepSizeOf for CloudObjectReader {
38 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
39 self.path.as_ref().deep_size_of_children(context)
41 }
42}
43
44impl CloudObjectReader {
45 pub fn new(
47 object_store: Arc<dyn ObjectStore>,
48 path: Path,
49 block_size: usize,
50 known_size: Option<usize>,
51 download_retry_count: usize,
52 ) -> Result<Self> {
53 Ok(Self {
54 object_store,
55 path,
56 size: OnceCell::new_with(known_size),
57 block_size,
58 download_retry_count,
59 })
60 }
61
62 async fn do_with_retry<'a, O>(
66 &self,
67 f: impl Fn() -> BoxFuture<'a, OSResult<O>>,
68 ) -> OSResult<O> {
69 let mut retries = 3;
70 loop {
71 match f().await {
72 Ok(val) => return Ok(val),
73 Err(err) => {
74 if retries == 0 {
75 return Err(err);
76 }
77 retries -= 1;
78 }
79 }
80 }
81 }
82
83 async fn do_get_with_outer_retry<'a>(
90 &self,
91 f: impl Fn() -> BoxFuture<'a, OSResult<GetResult>> + Copy,
92 desc: impl Fn() -> String,
93 ) -> OSResult<Bytes> {
94 let mut retries = self.download_retry_count;
95 loop {
96 let get_result = self.do_with_retry(f).await?;
97 match get_result.bytes().await {
98 Ok(bytes) => return Ok(bytes),
99 Err(err) => {
100 if retries == 0 {
101 log::warn!("Failed to download {} from {} after {} attempts. This may indicate that cloud storage is overloaded or your timeout settings are too restrictive. Error details: {:?}", desc(), self.path, self.download_retry_count, err);
102 return Err(err);
103 }
104 log::debug!(
105 "Retrying {} from {} (remaining retries: {}). Error details: {:?}",
106 desc(),
107 self.path,
108 retries,
109 err
110 );
111 retries -= 1;
112 }
113 }
114 }
115 }
116}
117
118#[async_trait]
119impl Reader for CloudObjectReader {
120 fn path(&self) -> &Path {
121 &self.path
122 }
123
124 fn block_size(&self) -> usize {
125 self.block_size
126 }
127
128 fn io_parallelism(&self) -> usize {
129 DEFAULT_CLOUD_IO_PARALLELISM
130 }
131
132 async fn size(&self) -> object_store::Result<usize> {
134 self.size
135 .get_or_try_init(|| async move {
136 let meta = self
137 .do_with_retry(|| self.object_store.head(&self.path))
138 .await?;
139 Ok(meta.size as usize)
140 })
141 .await
142 .cloned()
143 }
144
145 #[instrument(level = "debug", skip(self))]
146 async fn get_range(&self, range: Range<usize>) -> OSResult<Bytes> {
147 self.do_get_with_outer_retry(
148 || {
149 let options = GetOptions {
150 range: Some(
151 Range {
152 start: range.start as u64,
153 end: range.end as u64,
154 }
155 .into(),
156 ),
157 ..Default::default()
158 };
159 self.object_store.get_opts(&self.path, options)
160 },
161 || format!("range {:?}", range),
162 )
163 .await
164 }
165
166 #[instrument(level = "debug", skip_all)]
167 async fn get_all(&self) -> OSResult<Bytes> {
168 self.do_get_with_outer_retry(
169 || {
170 self.object_store
171 .get_opts(&self.path, GetOptions::default())
172 },
173 || "read_all".to_string(),
174 )
175 .await
176 }
177}
178
179#[derive(Debug)]
187pub struct SmallReader {
188 path: Path,
189 size: usize,
190 state: Arc<std::sync::Mutex<SmallReaderState>>,
191}
192
193enum SmallReaderState {
194 Loading(Shared<BoxFuture<'static, std::result::Result<Bytes, CloneableError>>>),
195 Finished(std::result::Result<Bytes, CloneableError>),
196}
197
198impl std::fmt::Debug for SmallReaderState {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 match self {
201 Self::Loading(_) => write!(f, "Loading"),
202 Self::Finished(Ok(data)) => {
203 write!(f, "Finished({} bytes)", data.len())
204 }
205 Self::Finished(Err(err)) => {
206 write!(f, "Finished({})", err.0)
207 }
208 }
209 }
210}
211
212impl SmallReader {
213 pub fn new(
214 store: Arc<dyn ObjectStore>,
215 path: Path,
216 download_retry_count: usize,
217 size: usize,
218 ) -> Self {
219 let path_ref = path.clone();
220 let state = SmallReaderState::Loading(
221 Box::pin(async move {
222 let object_reader =
223 CloudObjectReader::new(store, path_ref, 0, None, download_retry_count)
224 .map_err(CloneableError)?;
225 object_reader
226 .get_all()
227 .await
228 .map_err(|err| CloneableError(Error::from(err)))
229 })
230 .boxed()
231 .shared(),
232 );
233 Self {
234 path,
235 size,
236 state: Arc::new(std::sync::Mutex::new(state)),
237 }
238 }
239
240 async fn wait(&self) -> OSResult<Bytes> {
241 let future = {
242 let state = self.state.lock().unwrap();
243 match &*state {
244 SmallReaderState::Loading(future) => future.clone(),
245 SmallReaderState::Finished(result) => {
246 return result.clone().map_err(|err| err.0.into());
247 }
248 }
249 };
250
251 let result = future.await;
252 let result_to_return = result.clone().map_err(|err| err.0.into());
253 let mut state = self.state.lock().unwrap();
254 if matches!(*state, SmallReaderState::Loading(_)) {
255 *state = SmallReaderState::Finished(result);
256 }
257 result_to_return
258 }
259}
260
261#[async_trait]
262impl Reader for SmallReader {
263 fn path(&self) -> &Path {
264 &self.path
265 }
266
267 fn block_size(&self) -> usize {
268 64 * 1024
269 }
270
271 fn io_parallelism(&self) -> usize {
272 1024
273 }
274
275 async fn size(&self) -> OSResult<usize> {
277 Ok(self.size)
278 }
279
280 async fn get_range(&self, range: Range<usize>) -> OSResult<Bytes> {
281 self.wait().await.and_then(|bytes| {
282 let start = range.start;
283 let end = range.end;
284 if start >= bytes.len() || end > bytes.len() {
285 return Err(object_store::Error::Generic {
286 store: "memory",
287 source: format!(
288 "Invalid range {}..{} for object of size {} bytes",
289 start,
290 end,
291 bytes.len()
292 )
293 .into(),
294 });
295 }
296 Ok(bytes.slice(range))
297 })
298 }
299
300 async fn get_all(&self) -> OSResult<Bytes> {
301 self.wait().await
302 }
303}
304
305impl DeepSizeOf for SmallReader {
306 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
307 let mut size = self.path.as_ref().deep_size_of_children(context);
308
309 if let Ok(guard) = self.state.try_lock() {
310 if let SmallReaderState::Finished(Ok(data)) = &*guard {
311 size += data.len();
312 }
313 }
314
315 size
316 }
317}