1use std::fmt;
8use std::sync::Arc;
9
10use arrow_array::cast::{as_dictionary_array, as_primitive_array};
11use arrow_array::types::{
12 ArrowDictionaryKeyType, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
13 UInt32Type, UInt64Type,
14};
15use arrow_array::{Array, ArrayRef, DictionaryArray, PrimitiveArray, UInt32Array};
16use arrow_schema::DataType;
17use async_trait::async_trait;
18
19use crate::{
20 ReadBatchParams,
21 traits::{Reader, Writer},
22};
23use lance_core::{Error, Result};
24
25use super::AsyncIndex;
26use super::plain::PlainEncoder;
27use crate::encodings::plain::PlainDecoder;
28use crate::encodings::{Decoder, Encoder};
29
30pub struct DictionaryEncoder<'a> {
32 writer: &'a mut dyn Writer,
33 key_type: &'a DataType,
34}
35
36impl<'a> DictionaryEncoder<'a> {
37 pub fn new(writer: &'a mut dyn Writer, key_type: &'a DataType) -> Self {
38 Self { writer, key_type }
39 }
40
41 async fn write_typed_array<T: ArrowDictionaryKeyType>(
42 &mut self,
43 arrs: &[&dyn Array],
44 ) -> Result<usize> {
45 assert!(!arrs.is_empty());
46 let data_type = arrs[0].data_type();
47 let pos = self.writer.tell().await?;
48 let mut plain_encoder = PlainEncoder::new(self.writer, data_type);
49
50 let keys = arrs
51 .iter()
52 .map(|a| {
53 let dict_arr = as_dictionary_array::<T>(*a);
54 dict_arr.keys() as &dyn Array
55 })
56 .collect::<Vec<_>>();
57
58 plain_encoder.encode(keys.as_slice()).await?;
59 Ok(pos)
60 }
61}
62
63#[async_trait]
64impl Encoder for DictionaryEncoder<'_> {
65 async fn encode(&mut self, array: &[&dyn Array]) -> Result<usize> {
66 use DataType::*;
67
68 match self.key_type {
69 UInt8 => self.write_typed_array::<UInt8Type>(array).await,
70 UInt16 => self.write_typed_array::<UInt16Type>(array).await,
71 UInt32 => self.write_typed_array::<UInt32Type>(array).await,
72 UInt64 => self.write_typed_array::<UInt64Type>(array).await,
73 Int8 => self.write_typed_array::<Int8Type>(array).await,
74 Int16 => self.write_typed_array::<Int16Type>(array).await,
75 Int32 => self.write_typed_array::<Int32Type>(array).await,
76 Int64 => self.write_typed_array::<Int64Type>(array).await,
77 _ => Err(Error::schema(format!(
78 "DictionaryEncoder: unsupported key type: {:?}",
79 self.key_type
80 ))),
81 }
82 }
83}
84
85pub struct DictionaryDecoder<'a> {
87 reader: &'a dyn Reader,
88 position: usize,
90 length: usize,
92 data_type: &'a DataType,
94 value_arr: ArrayRef,
96}
97
98impl fmt::Debug for DictionaryDecoder<'_> {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("DictionaryDecoder")
101 .field("position", &self.position)
102 .field("length", &self.length)
103 .field("data_type", &self.data_type)
104 .field("value_arr", &self.value_arr)
105 .finish()
106 }
107}
108
109impl<'a> DictionaryDecoder<'a> {
110 pub fn new(
111 reader: &'a dyn Reader,
112 position: usize,
113 length: usize,
114 data_type: &'a DataType,
115 value_arr: ArrayRef,
116 ) -> Self {
117 assert!(matches!(data_type, DataType::Dictionary(_, _)));
118 Self {
119 reader,
120 position,
121 length,
122 data_type,
123 value_arr,
124 }
125 }
126
127 async fn decode_impl(&self, params: impl Into<ReadBatchParams>) -> Result<ArrayRef> {
128 let index_type = if let DataType::Dictionary(key_type, _) = &self.data_type {
129 assert!(key_type.as_ref().is_dictionary_key_type());
130 key_type.as_ref()
131 } else {
132 return Err(Error::arrow(format!(
133 "Not a dictionary type: {}",
134 self.data_type
135 )));
136 };
137
138 let decoder = PlainDecoder::new(self.reader, index_type, self.position, self.length)?;
139 let keys = decoder.get(params.into()).await?;
140
141 match index_type {
142 DataType::Int8 => self.make_dict_array::<Int8Type>(keys).await,
143 DataType::Int16 => self.make_dict_array::<Int16Type>(keys).await,
144 DataType::Int32 => self.make_dict_array::<Int32Type>(keys).await,
145 DataType::Int64 => self.make_dict_array::<Int64Type>(keys).await,
146 DataType::UInt8 => self.make_dict_array::<UInt8Type>(keys).await,
147 DataType::UInt16 => self.make_dict_array::<UInt16Type>(keys).await,
148 DataType::UInt32 => self.make_dict_array::<UInt32Type>(keys).await,
149 DataType::UInt64 => self.make_dict_array::<UInt64Type>(keys).await,
150 _ => Err(Error::arrow(format!(
151 "Dictionary encoding does not support index type: {index_type}",
152 ))),
153 }
154 }
155
156 async fn make_dict_array<T: ArrowDictionaryKeyType + Sync + Send>(
157 &self,
158 index_array: ArrayRef,
159 ) -> Result<ArrayRef> {
160 let keys: PrimitiveArray<T> = as_primitive_array(index_array.as_ref()).clone();
161 Ok(Arc::new(DictionaryArray::try_new(
162 keys,
163 self.value_arr.clone(),
164 )?))
165 }
166}
167
168#[async_trait]
169impl Decoder for DictionaryDecoder<'_> {
170 async fn decode(&self) -> Result<ArrayRef> {
171 self.decode_impl(..).await
172 }
173
174 async fn take(&self, indices: &UInt32Array) -> Result<ArrayRef> {
175 self.decode_impl(indices.clone()).await
176 }
177}
178
179#[async_trait]
180impl AsyncIndex<usize> for DictionaryDecoder<'_> {
181 type Output = Result<ArrayRef>;
182
183 async fn get(&self, _index: usize) -> Self::Output {
184 Err(Error::not_supported_source(
185 "DictionaryDecoder does not support get()"
186 .to_string()
187 .into(),
188 ))
189 }
190}
191
192#[async_trait]
193impl AsyncIndex<ReadBatchParams> for DictionaryDecoder<'_> {
194 type Output = Result<ArrayRef>;
195
196 async fn get(&self, params: ReadBatchParams) -> Self::Output {
197 self.decode_impl(params.clone()).await
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 use crate::local::LocalObjectReader;
206 use arrow_array::StringArray;
207 use arrow_buffer::ArrowNativeType;
208 use lance_core::utils::tempfile::TempStdFile;
209 use tokio::io::AsyncWriteExt;
210
211 async fn test_dict_decoder_for_type<T: ArrowDictionaryKeyType>() {
212 let value_array: StringArray = vec![Some("a"), Some("b"), Some("c"), Some("d")]
213 .into_iter()
214 .collect();
215 let value_array_ref = Arc::new(value_array) as ArrayRef;
216
217 let keys1: PrimitiveArray<T> = vec![T::Native::from_usize(0), T::Native::from_usize(1)]
218 .into_iter()
219 .collect();
220 let arr1: DictionaryArray<T> =
221 DictionaryArray::try_new(keys1, value_array_ref.clone()).unwrap();
222
223 let keys2: PrimitiveArray<T> = vec![T::Native::from_usize(1), T::Native::from_usize(3)]
224 .into_iter()
225 .collect();
226 let arr2: DictionaryArray<T> =
227 DictionaryArray::try_new(keys2, value_array_ref.clone()).unwrap();
228
229 let keys1_ref = arr1.keys() as &dyn Array;
230 let keys2_ref = arr2.keys() as &dyn Array;
231 let arrs: Vec<&dyn Array> = vec![keys1_ref, keys2_ref];
232
233 let path = TempStdFile::default();
234
235 let pos;
236 {
237 let mut object_writer = tokio::fs::File::create(&path).await.unwrap();
238 let mut encoder = PlainEncoder::new(&mut object_writer, arr1.keys().data_type());
239 pos = encoder.encode(arrs.as_slice()).await.unwrap();
240 AsyncWriteExt::shutdown(&mut object_writer).await.unwrap();
241 }
242
243 let reader = LocalObjectReader::open_local_path(&path, 2048, None)
244 .await
245 .unwrap();
246 let decoder = DictionaryDecoder::new(
247 reader.as_ref(),
248 pos,
249 arr1.len() + arr2.len(),
250 arr1.data_type(),
251 value_array_ref.clone(),
252 );
253
254 let decoded_data = decoder.decode().await.unwrap();
255 let expected_data: DictionaryArray<T> = vec!["a", "b", "b", "d"].into_iter().collect();
256 assert_eq!(
257 &expected_data,
258 decoded_data
259 .as_any()
260 .downcast_ref::<DictionaryArray<T>>()
261 .unwrap()
262 );
263 }
264
265 #[tokio::test]
266 async fn test_dict_decoder() {
267 test_dict_decoder_for_type::<Int8Type>().await;
268 test_dict_decoder_for_type::<Int16Type>().await;
269 test_dict_decoder_for_type::<Int32Type>().await;
270 test_dict_decoder_for_type::<Int64Type>().await;
271
272 test_dict_decoder_for_type::<UInt8Type>().await;
273 test_dict_decoder_for_type::<UInt16Type>().await;
274 test_dict_decoder_for_type::<UInt32Type>().await;
275 test_dict_decoder_for_type::<UInt64Type>().await;
276 }
277}