lance_io/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{cmp::min, num::NonZero, sync::atomic::AtomicU64};
5
6use arrow_array::{
7    types::{BinaryType, LargeBinaryType, LargeUtf8Type, Utf8Type},
8    ArrayRef,
9};
10use arrow_schema::DataType;
11use byteorder::{ByteOrder, LittleEndian};
12use bytes::Bytes;
13use deepsize::DeepSizeOf;
14use lance_arrow::*;
15use prost::Message;
16use serde::{Deserialize, Serialize};
17use snafu::location;
18
19use crate::{
20    encodings::{binary::BinaryDecoder, plain::PlainDecoder, AsyncIndex, Decoder},
21    traits::ProtoStruct,
22};
23use crate::{traits::Reader, ReadBatchParams};
24use lance_core::{Error, Result};
25
26/// Read a binary array from a [Reader].
27///
28pub async fn read_binary_array(
29    reader: &dyn Reader,
30    data_type: &DataType,
31    nullable: bool,
32    position: usize,
33    length: usize,
34    params: impl Into<ReadBatchParams>,
35) -> Result<ArrayRef> {
36    use arrow_schema::DataType::*;
37    let decoder: Box<dyn Decoder<Output = Result<ArrayRef>> + Send> = match data_type {
38        Utf8 => Box::new(BinaryDecoder::<Utf8Type>::new(
39            reader, position, length, nullable,
40        )),
41        Binary => Box::new(BinaryDecoder::<BinaryType>::new(
42            reader, position, length, nullable,
43        )),
44        LargeUtf8 => Box::new(BinaryDecoder::<LargeUtf8Type>::new(
45            reader, position, length, nullable,
46        )),
47        LargeBinary => Box::new(BinaryDecoder::<LargeBinaryType>::new(
48            reader, position, length, nullable,
49        )),
50        _ => {
51            return Err(Error::io(
52                format!("Unsupported binary type: {}", data_type),
53                location!(),
54            ));
55        }
56    };
57    let fut = decoder.as_ref().get(params.into());
58    fut.await
59}
60
61/// Read a fixed stride array from disk.
62///
63pub async fn read_fixed_stride_array(
64    reader: &dyn Reader,
65    data_type: &DataType,
66    position: usize,
67    length: usize,
68    params: impl Into<ReadBatchParams>,
69) -> Result<ArrayRef> {
70    if !data_type.is_fixed_stride() {
71        return Err(Error::Schema {
72            message: format!("{data_type} is not a fixed stride type"),
73            location: location!(),
74        });
75    }
76    // TODO: support more than plain encoding here.
77    let decoder = PlainDecoder::new(reader, data_type, position, length)?;
78    decoder.get(params.into()).await
79}
80
81/// Read a protobuf message at file position 'pos'.
82///
83/// We write protobuf by first writing the length of the message as a u32,
84/// followed by the message itself.
85pub async fn read_message<M: Message + Default>(reader: &dyn Reader, pos: usize) -> Result<M> {
86    let file_size = reader.size().await?;
87    if pos > file_size {
88        return Err(Error::io("file size is too small".to_string(), location!()));
89    }
90
91    let range = pos..min(pos + reader.block_size(), file_size);
92    let buf = reader.get_range(range.clone()).await?;
93    let msg_len = LittleEndian::read_u32(&buf) as usize;
94
95    if msg_len + 4 > buf.len() {
96        let remaining_range = range.end..min(4 + pos + msg_len, file_size);
97        let remaining_bytes = reader.get_range(remaining_range).await?;
98        let buf = [buf, remaining_bytes].concat();
99        assert!(buf.len() >= msg_len + 4);
100        Ok(M::decode(&buf[4..4 + msg_len])?)
101    } else {
102        Ok(M::decode(&buf[4..4 + msg_len])?)
103    }
104}
105
106/// Read a Protobuf-backed struct at file position: `pos`.
107// TODO: pub(crate)
108pub async fn read_struct<
109    M: Message + Default + 'static,
110    T: ProtoStruct<Proto = M> + TryFrom<M, Error = Error>,
111>(
112    reader: &dyn Reader,
113    pos: usize,
114) -> Result<T> {
115    let msg = read_message::<M>(reader, pos).await?;
116    T::try_from(msg)
117}
118
119pub async fn read_last_block(reader: &dyn Reader) -> object_store::Result<Bytes> {
120    let file_size = reader.size().await?;
121    let block_size = reader.block_size();
122    let begin = file_size.saturating_sub(block_size);
123    reader.get_range(begin..file_size).await
124}
125
126pub fn read_metadata_offset(bytes: &Bytes) -> Result<usize> {
127    let len = bytes.len();
128    if len < 16 {
129        return Err(Error::io(
130            format!(
131                "does not have sufficient data, len: {}, bytes: {:?}",
132                len, bytes
133            ),
134            location!(),
135        ));
136    }
137    let offset_bytes = bytes.slice(len - 16..len - 8);
138    Ok(LittleEndian::read_u64(offset_bytes.as_ref()) as usize)
139}
140
141/// Read the version from the footer bytes
142pub fn read_version(bytes: &Bytes) -> Result<(u16, u16)> {
143    let len = bytes.len();
144    if len < 8 {
145        return Err(Error::io(
146            format!(
147                "does not have sufficient data, len: {}, bytes: {:?}",
148                len, bytes
149            ),
150            location!(),
151        ));
152    }
153
154    let major_version = LittleEndian::read_u16(bytes.slice(len - 8..len - 6).as_ref());
155    let minor_version = LittleEndian::read_u16(bytes.slice(len - 6..len - 4).as_ref());
156    Ok((major_version, minor_version))
157}
158
159/// Read protobuf from a buffer.
160pub fn read_message_from_buf<M: Message + Default>(buf: &Bytes) -> Result<M> {
161    let msg_len = LittleEndian::read_u32(buf) as usize;
162    Ok(M::decode(&buf[4..4 + msg_len])?)
163}
164
165/// Read a Protobuf-backed struct from a buffer.
166pub fn read_struct_from_buf<
167    M: Message + Default,
168    T: ProtoStruct<Proto = M> + TryFrom<M, Error = Error>,
169>(
170    buf: &Bytes,
171) -> Result<T> {
172    let msg: M = read_message_from_buf(buf)?;
173    T::try_from(msg)
174}
175
176/// A cached file size.
177///
178/// This wraps an atomic u64 to allow setting the cached file size without
179/// needed a mutable reference.
180///
181/// Zero is interpreted as unknown.
182#[derive(Debug, DeepSizeOf)]
183pub struct CachedFileSize(AtomicU64);
184
185impl<'de> Deserialize<'de> for CachedFileSize {
186    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
187    where
188        D: serde::Deserializer<'de>,
189    {
190        let size = Option::<u64>::deserialize(deserializer)?.unwrap_or(0);
191        Ok(Self::new(size))
192    }
193}
194
195impl Serialize for CachedFileSize {
196    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
197    where
198        S: serde::Serializer,
199    {
200        let size = self.0.load(std::sync::atomic::Ordering::Relaxed);
201        if size == 0 {
202            serializer.serialize_none()
203        } else {
204            serializer.serialize_u64(size)
205        }
206    }
207}
208
209impl From<Option<NonZero<u64>>> for CachedFileSize {
210    fn from(size: Option<NonZero<u64>>) -> Self {
211        match size {
212            Some(size) => Self(AtomicU64::new(size.into())),
213            None => Self(AtomicU64::new(0)),
214        }
215    }
216}
217
218impl Default for CachedFileSize {
219    fn default() -> Self {
220        Self(AtomicU64::new(0))
221    }
222}
223
224impl Clone for CachedFileSize {
225    fn clone(&self) -> Self {
226        Self(AtomicU64::new(
227            self.0.load(std::sync::atomic::Ordering::Relaxed),
228        ))
229    }
230}
231
232impl PartialEq for CachedFileSize {
233    fn eq(&self, other: &Self) -> bool {
234        self.0.load(std::sync::atomic::Ordering::Relaxed)
235            == other.0.load(std::sync::atomic::Ordering::Relaxed)
236    }
237}
238
239impl Eq for CachedFileSize {}
240
241impl CachedFileSize {
242    pub fn new(size: u64) -> Self {
243        Self(AtomicU64::new(size))
244    }
245
246    pub fn unknown() -> Self {
247        Self(AtomicU64::new(0))
248    }
249
250    pub fn get(&self) -> Option<NonZero<u64>> {
251        NonZero::new(self.0.load(std::sync::atomic::Ordering::Relaxed))
252    }
253
254    pub fn set(&self, size: NonZero<u64>) {
255        self.0
256            .store(size.into(), std::sync::atomic::Ordering::Relaxed);
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use bytes::Bytes;
263    use object_store::path::Path;
264
265    use crate::{
266        object_reader::CloudObjectReader,
267        object_store::{ObjectStore, DEFAULT_DOWNLOAD_RETRY_COUNT},
268        object_writer::ObjectWriter,
269        traits::{ProtoStruct, WriteExt, Writer},
270        utils::read_struct,
271        Error, Result,
272    };
273
274    // Bytes is a prost::Message, since we don't have any .proto files in this crate we
275    // can use it to simulate a real message object.
276    #[derive(Debug, PartialEq)]
277    struct BytesWrapper(Bytes);
278
279    impl ProtoStruct for BytesWrapper {
280        type Proto = Bytes;
281    }
282
283    impl From<&BytesWrapper> for Bytes {
284        fn from(value: &BytesWrapper) -> Self {
285            value.0.clone()
286        }
287    }
288
289    impl TryFrom<Bytes> for BytesWrapper {
290        type Error = Error;
291        fn try_from(value: Bytes) -> Result<Self> {
292            Ok(Self(value))
293        }
294    }
295
296    #[tokio::test]
297    async fn test_write_proto_structs() {
298        let store = ObjectStore::memory();
299        let path = Path::from("/foo");
300
301        let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap();
302        assert_eq!(object_writer.tell().await.unwrap(), 0);
303
304        let some_message = BytesWrapper(Bytes::from(vec![10, 20, 30]));
305
306        let pos = object_writer.write_struct(&some_message).await.unwrap();
307        assert_eq!(pos, 0);
308        object_writer.shutdown().await.unwrap();
309
310        let object_reader =
311            CloudObjectReader::new(store.inner, path, 1024, None, DEFAULT_DOWNLOAD_RETRY_COUNT)
312                .unwrap();
313        let actual: BytesWrapper = read_struct(&object_reader, pos).await.unwrap();
314        assert_eq!(some_message, actual);
315    }
316}