1use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13
14use ailake_catalog::{read_equality_delete_values, EqualityDeleteFile};
15use ailake_core::{AilakeError, AilakeResult};
16use ailake_store::Store;
17use arrow_array::{
18 Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch,
19 StringArray,
20};
21use arrow_schema::DataType;
22
23pub struct EqualityDeleteFilter {
29 filters: HashMap<String, HashSet<String>>,
31}
32
33impl EqualityDeleteFilter {
34 pub async fn from_files(
39 store: &Arc<dyn Store>,
40 files: &[EqualityDeleteFile],
41 ) -> AilakeResult<Self> {
42 let mut filters: HashMap<String, HashSet<String>> = HashMap::new();
43 for edf in files {
44 let bytes = store.get(&edf.path).await?;
45 let pairs = read_equality_delete_values(&bytes)
46 .map_err(|e| AilakeError::Catalog(e.to_string()))?;
47 for (col, val) in pairs {
48 filters.entry(col).or_default().insert(val);
49 }
50 }
51 Ok(Self { filters })
52 }
53
54 pub fn empty() -> Self {
55 Self {
56 filters: HashMap::new(),
57 }
58 }
59
60 pub fn is_empty(&self) -> bool {
61 self.filters.is_empty()
62 }
63
64 pub fn should_delete_row(&self, batch: &RecordBatch, row_idx: usize) -> bool {
70 if self.filters.is_empty() {
71 return false;
72 }
73 for (col_name, delete_values) in &self.filters {
74 let col_idx = match batch.schema().index_of(col_name.as_str()) {
75 Ok(i) => i,
76 Err(_) => continue,
77 };
78 let array = batch.column(col_idx);
79 if array.is_null(row_idx) {
80 continue;
81 }
82 let val_str: Option<String> = match array.data_type() {
83 DataType::Utf8 => array
84 .as_any()
85 .downcast_ref::<StringArray>()
86 .map(|a| a.value(row_idx).to_string()),
87 DataType::LargeUtf8 => array
88 .as_any()
89 .downcast_ref::<arrow_array::LargeStringArray>()
90 .map(|a| a.value(row_idx).to_string()),
91 DataType::Int32 => array
92 .as_any()
93 .downcast_ref::<Int32Array>()
94 .map(|a| a.value(row_idx).to_string()),
95 DataType::Int64 => array
96 .as_any()
97 .downcast_ref::<Int64Array>()
98 .map(|a| a.value(row_idx).to_string()),
99 DataType::Float32 => array
100 .as_any()
101 .downcast_ref::<Float32Array>()
102 .map(|a| a.value(row_idx).to_string()),
103 DataType::Float64 => array
104 .as_any()
105 .downcast_ref::<Float64Array>()
106 .map(|a| a.value(row_idx).to_string()),
107 _ => None,
108 };
109 if let Some(s) = val_str {
110 if delete_values.contains(&s) {
111 return true;
112 }
113 }
114 }
115 false
116 }
117
118 pub fn apply(&self, batch: RecordBatch) -> AilakeResult<RecordBatch> {
124 if self.filters.is_empty() {
125 return Ok(batch);
126 }
127
128 let n = batch.num_rows();
129 let mut keep = vec![true; n];
131
132 for (col_name, delete_values) in &self.filters {
133 let col_idx = match batch.schema().index_of(col_name.as_str()) {
134 Ok(i) => i,
135 Err(_) => continue, };
137 let array = batch.column(col_idx);
138 let dtype = array.data_type();
139
140 for (i, keep_slot) in keep.iter_mut().enumerate().take(n) {
141 if !*keep_slot {
142 continue;
143 }
144 if array.is_null(i) {
145 continue; }
147 let val_str: Option<String> = match dtype {
148 DataType::Utf8 => Some(
149 array
150 .as_any()
151 .downcast_ref::<StringArray>()
152 .map(|a| a.value(i).to_string())
153 .unwrap_or_default(),
154 ),
155 DataType::LargeUtf8 => Some(
156 array
157 .as_any()
158 .downcast_ref::<arrow_array::LargeStringArray>()
159 .map(|a| a.value(i).to_string())
160 .unwrap_or_default(),
161 ),
162 DataType::Int32 => Some(
163 array
164 .as_any()
165 .downcast_ref::<Int32Array>()
166 .map(|a| a.value(i).to_string())
167 .unwrap_or_default(),
168 ),
169 DataType::Int64 => Some(
170 array
171 .as_any()
172 .downcast_ref::<Int64Array>()
173 .map(|a| a.value(i).to_string())
174 .unwrap_or_default(),
175 ),
176 DataType::Float32 => Some(
177 array
178 .as_any()
179 .downcast_ref::<Float32Array>()
180 .map(|a| a.value(i).to_string())
181 .unwrap_or_default(),
182 ),
183 DataType::Float64 => Some(
184 array
185 .as_any()
186 .downcast_ref::<Float64Array>()
187 .map(|a| a.value(i).to_string())
188 .unwrap_or_default(),
189 ),
190 _ => None,
191 };
192 if let Some(s) = val_str {
193 if delete_values.contains(&s) {
194 *keep_slot = false;
195 }
196 }
197 }
198 }
199
200 let mask = BooleanArray::from(keep);
201 let filtered = arrow_select::filter::filter_record_batch(&batch, &mask)
202 .map_err(|e| AilakeError::Arrow(e.to_string()))?;
203 Ok(filtered)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use std::sync::Arc;
210
211 use arrow_array::{Int32Array, RecordBatch, StringArray};
212 use arrow_schema::{DataType, Field, Schema};
213
214 use super::EqualityDeleteFilter;
215 use std::collections::{HashMap, HashSet};
216
217 fn make_batch() -> RecordBatch {
218 let schema = Arc::new(Schema::new(vec![
219 Field::new("doc_id", DataType::Utf8, true),
220 Field::new("score", DataType::Int32, true),
221 ]));
222 RecordBatch::try_new(
223 schema,
224 vec![
225 Arc::new(StringArray::from(vec!["doc-a", "doc-b", "doc-c", "doc-d"])),
226 Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
227 ],
228 )
229 .unwrap()
230 }
231
232 fn filter_with(filters: HashMap<String, HashSet<String>>) -> EqualityDeleteFilter {
233 EqualityDeleteFilter { filters }
234 }
235
236 #[test]
237 fn empty_filter_is_no_op() {
238 let batch = make_batch();
239 let f = filter_with(HashMap::new());
240 let result = f.apply(batch.clone()).unwrap();
241 assert_eq!(result.num_rows(), 4);
242 }
243
244 #[test]
245 fn single_value_deleted() {
246 let mut filters = HashMap::new();
247 filters.insert("doc_id".into(), ["doc-b".to_string()].into());
248 let f = filter_with(filters);
249 let result = f.apply(make_batch()).unwrap();
250 assert_eq!(result.num_rows(), 3);
251 let ids = result
252 .column(0)
253 .as_any()
254 .downcast_ref::<StringArray>()
255 .unwrap();
256 assert_eq!(ids.value(0), "doc-a");
257 assert_eq!(ids.value(1), "doc-c");
258 assert_eq!(ids.value(2), "doc-d");
259 }
260
261 #[test]
262 fn multiple_values_deleted() {
263 let mut filters = HashMap::new();
264 filters.insert(
265 "doc_id".into(),
266 ["doc-a".to_string(), "doc-c".to_string()].into(),
267 );
268 let f = filter_with(filters);
269 let result = f.apply(make_batch()).unwrap();
270 assert_eq!(result.num_rows(), 2);
271 let ids = result
272 .column(0)
273 .as_any()
274 .downcast_ref::<StringArray>()
275 .unwrap();
276 assert_eq!(ids.value(0), "doc-b");
277 assert_eq!(ids.value(1), "doc-d");
278 }
279
280 #[test]
281 fn column_absent_from_batch_is_skipped() {
282 let mut filters = HashMap::new();
283 filters.insert("nonexistent_col".into(), ["x".to_string()].into());
284 let f = filter_with(filters);
285 let result = f.apply(make_batch()).unwrap();
286 assert_eq!(result.num_rows(), 4); }
288
289 #[test]
290 fn numeric_column_deletion() {
291 let mut filters = HashMap::new();
292 filters.insert("score".into(), ["2".to_string(), "4".to_string()].into());
293 let f = filter_with(filters);
294 let result = f.apply(make_batch()).unwrap();
295 assert_eq!(result.num_rows(), 2);
296 let ids = result
297 .column(0)
298 .as_any()
299 .downcast_ref::<StringArray>()
300 .unwrap();
301 assert_eq!(ids.value(0), "doc-a");
302 assert_eq!(ids.value(1), "doc-c");
303 }
304}