Skip to main content

lance_io/encodings/
dictionary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Dictionary encoding.
5//!
6
7use std::fmt;
8use std::sync::Arc;
9
10use arrow_array::cast::{as_dictionary_array, as_primitive_array};
11use arrow_array::types::{
12    ArrowDictionaryKeyType, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
13    UInt32Type, UInt64Type,
14};
15use arrow_array::{Array, ArrayRef, DictionaryArray, PrimitiveArray, UInt32Array};
16use arrow_schema::DataType;
17use async_trait::async_trait;
18
19use crate::{
20    ReadBatchParams,
21    traits::{Reader, Writer},
22};
23use lance_core::{Error, Result};
24
25use super::AsyncIndex;
26use super::plain::PlainEncoder;
27use crate::encodings::plain::PlainDecoder;
28use crate::encodings::{Decoder, Encoder};
29
30/// Encoder for Dictionary encoding.
31pub struct DictionaryEncoder<'a> {
32    writer: &'a mut dyn Writer,
33    key_type: &'a DataType,
34}
35
36impl<'a> DictionaryEncoder<'a> {
37    pub fn new(writer: &'a mut dyn Writer, key_type: &'a DataType) -> Self {
38        Self { writer, key_type }
39    }
40
41    async fn write_typed_array<T: ArrowDictionaryKeyType>(
42        &mut self,
43        arrs: &[&dyn Array],
44    ) -> Result<usize> {
45        assert!(!arrs.is_empty());
46        let data_type = arrs[0].data_type();
47        let pos = self.writer.tell().await?;
48        let mut plain_encoder = PlainEncoder::new(self.writer, data_type);
49
50        let keys = arrs
51            .iter()
52            .map(|a| {
53                let dict_arr = as_dictionary_array::<T>(*a);
54                dict_arr.keys() as &dyn Array
55            })
56            .collect::<Vec<_>>();
57
58        plain_encoder.encode(keys.as_slice()).await?;
59        Ok(pos)
60    }
61}
62
63#[async_trait]
64impl Encoder for DictionaryEncoder<'_> {
65    async fn encode(&mut self, array: &[&dyn Array]) -> Result<usize> {
66        use DataType::*;
67
68        match self.key_type {
69            UInt8 => self.write_typed_array::<UInt8Type>(array).await,
70            UInt16 => self.write_typed_array::<UInt16Type>(array).await,
71            UInt32 => self.write_typed_array::<UInt32Type>(array).await,
72            UInt64 => self.write_typed_array::<UInt64Type>(array).await,
73            Int8 => self.write_typed_array::<Int8Type>(array).await,
74            Int16 => self.write_typed_array::<Int16Type>(array).await,
75            Int32 => self.write_typed_array::<Int32Type>(array).await,
76            Int64 => self.write_typed_array::<Int64Type>(array).await,
77            _ => Err(Error::schema(format!(
78                "DictionaryEncoder: unsupported key type: {:?}",
79                self.key_type
80            ))),
81        }
82    }
83}
84
85/// Decoder for Dictionary encoding.
86pub struct DictionaryDecoder<'a> {
87    reader: &'a dyn Reader,
88    /// The start position of the key array in the file.
89    position: usize,
90    /// Number of the rows in this batch.
91    length: usize,
92    /// The dictionary data type
93    data_type: &'a DataType,
94    /// Value array,
95    value_arr: ArrayRef,
96}
97
98impl fmt::Debug for DictionaryDecoder<'_> {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.debug_struct("DictionaryDecoder")
101            .field("position", &self.position)
102            .field("length", &self.length)
103            .field("data_type", &self.data_type)
104            .field("value_arr", &self.value_arr)
105            .finish()
106    }
107}
108
109impl<'a> DictionaryDecoder<'a> {
110    pub fn new(
111        reader: &'a dyn Reader,
112        position: usize,
113        length: usize,
114        data_type: &'a DataType,
115        value_arr: ArrayRef,
116    ) -> Self {
117        assert!(matches!(data_type, DataType::Dictionary(_, _)));
118        Self {
119            reader,
120            position,
121            length,
122            data_type,
123            value_arr,
124        }
125    }
126
127    async fn decode_impl(&self, params: impl Into<ReadBatchParams>) -> Result<ArrayRef> {
128        let index_type = if let DataType::Dictionary(key_type, _) = &self.data_type {
129            assert!(key_type.as_ref().is_dictionary_key_type());
130            key_type.as_ref()
131        } else {
132            return Err(Error::arrow(format!(
133                "Not a dictionary type: {}",
134                self.data_type
135            )));
136        };
137
138        let decoder = PlainDecoder::new(self.reader, index_type, self.position, self.length)?;
139        let keys = decoder.get(params.into()).await?;
140
141        match index_type {
142            DataType::Int8 => self.make_dict_array::<Int8Type>(keys).await,
143            DataType::Int16 => self.make_dict_array::<Int16Type>(keys).await,
144            DataType::Int32 => self.make_dict_array::<Int32Type>(keys).await,
145            DataType::Int64 => self.make_dict_array::<Int64Type>(keys).await,
146            DataType::UInt8 => self.make_dict_array::<UInt8Type>(keys).await,
147            DataType::UInt16 => self.make_dict_array::<UInt16Type>(keys).await,
148            DataType::UInt32 => self.make_dict_array::<UInt32Type>(keys).await,
149            DataType::UInt64 => self.make_dict_array::<UInt64Type>(keys).await,
150            _ => Err(Error::arrow(format!(
151                "Dictionary encoding does not support index type: {index_type}",
152            ))),
153        }
154    }
155
156    async fn make_dict_array<T: ArrowDictionaryKeyType + Sync + Send>(
157        &self,
158        index_array: ArrayRef,
159    ) -> Result<ArrayRef> {
160        let keys: PrimitiveArray<T> = as_primitive_array(index_array.as_ref()).clone();
161        Ok(Arc::new(DictionaryArray::try_new(
162            keys,
163            self.value_arr.clone(),
164        )?))
165    }
166}
167
168#[async_trait]
169impl Decoder for DictionaryDecoder<'_> {
170    async fn decode(&self) -> Result<ArrayRef> {
171        self.decode_impl(..).await
172    }
173
174    async fn take(&self, indices: &UInt32Array) -> Result<ArrayRef> {
175        self.decode_impl(indices.clone()).await
176    }
177}
178
179#[async_trait]
180impl AsyncIndex<usize> for DictionaryDecoder<'_> {
181    type Output = Result<ArrayRef>;
182
183    async fn get(&self, _index: usize) -> Self::Output {
184        Err(Error::not_supported_source(
185            "DictionaryDecoder does not support get()"
186                .to_string()
187                .into(),
188        ))
189    }
190}
191
192#[async_trait]
193impl AsyncIndex<ReadBatchParams> for DictionaryDecoder<'_> {
194    type Output = Result<ArrayRef>;
195
196    async fn get(&self, params: ReadBatchParams) -> Self::Output {
197        self.decode_impl(params.clone()).await
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    use crate::local::LocalObjectReader;
206    use arrow_array::StringArray;
207    use arrow_buffer::ArrowNativeType;
208    use lance_core::utils::tempfile::TempStdFile;
209    use tokio::io::AsyncWriteExt;
210
211    async fn test_dict_decoder_for_type<T: ArrowDictionaryKeyType>() {
212        let value_array: StringArray = vec![Some("a"), Some("b"), Some("c"), Some("d")]
213            .into_iter()
214            .collect();
215        let value_array_ref = Arc::new(value_array) as ArrayRef;
216
217        let keys1: PrimitiveArray<T> = vec![T::Native::from_usize(0), T::Native::from_usize(1)]
218            .into_iter()
219            .collect();
220        let arr1: DictionaryArray<T> =
221            DictionaryArray::try_new(keys1, value_array_ref.clone()).unwrap();
222
223        let keys2: PrimitiveArray<T> = vec![T::Native::from_usize(1), T::Native::from_usize(3)]
224            .into_iter()
225            .collect();
226        let arr2: DictionaryArray<T> =
227            DictionaryArray::try_new(keys2, value_array_ref.clone()).unwrap();
228
229        let keys1_ref = arr1.keys() as &dyn Array;
230        let keys2_ref = arr2.keys() as &dyn Array;
231        let arrs: Vec<&dyn Array> = vec![keys1_ref, keys2_ref];
232
233        let path = TempStdFile::default();
234
235        let pos;
236        {
237            let mut object_writer = tokio::fs::File::create(&path).await.unwrap();
238            let mut encoder = PlainEncoder::new(&mut object_writer, arr1.keys().data_type());
239            pos = encoder.encode(arrs.as_slice()).await.unwrap();
240            AsyncWriteExt::shutdown(&mut object_writer).await.unwrap();
241        }
242
243        let reader = LocalObjectReader::open_local_path(&path, 2048, None)
244            .await
245            .unwrap();
246        let decoder = DictionaryDecoder::new(
247            reader.as_ref(),
248            pos,
249            arr1.len() + arr2.len(),
250            arr1.data_type(),
251            value_array_ref.clone(),
252        );
253
254        let decoded_data = decoder.decode().await.unwrap();
255        let expected_data: DictionaryArray<T> = vec!["a", "b", "b", "d"].into_iter().collect();
256        assert_eq!(
257            &expected_data,
258            decoded_data
259                .as_any()
260                .downcast_ref::<DictionaryArray<T>>()
261                .unwrap()
262        );
263    }
264
265    #[tokio::test]
266    async fn test_dict_decoder() {
267        test_dict_decoder_for_type::<Int8Type>().await;
268        test_dict_decoder_for_type::<Int16Type>().await;
269        test_dict_decoder_for_type::<Int32Type>().await;
270        test_dict_decoder_for_type::<Int64Type>().await;
271
272        test_dict_decoder_for_type::<UInt8Type>().await;
273        test_dict_decoder_for_type::<UInt16Type>().await;
274        test_dict_decoder_for_type::<UInt32Type>().await;
275        test_dict_decoder_for_type::<UInt64Type>().await;
276    }
277}