pack_it/
repack.rs

1use std::io::{Read, Seek, Write};
2
3use anyhow::{anyhow, bail, ensure, Context, Result};
4use arrow2::array::{
5    Array, BooleanArray, MutableBooleanArray, MutablePrimitiveArray, MutableUtf8Array,
6    PrimitiveArray, TryExtend, Utf8Array,
7};
8use arrow2::datatypes::{DataType, Field, Schema};
9use arrow2::io::parquet::read;
10use arrow2::io::parquet::read::RowGroupMetaData;
11use arrow2::io::parquet::write::Encoding;
12use log::info;
13
14use crate::table::VarArray;
15use crate::{Kind, Packer, TableField};
16
17#[derive(Clone)]
18pub struct OutField {
19    pub name: String,
20    pub data_type: DataType,
21    pub nullable: bool,
22
23    pub encoding: Encoding,
24}
25
26// struct Transform {
27//     input: String,
28//     output: OutField,
29//     func: Box<dyn FnMut(Box<dyn Array>, &mut Table, usize) -> Result<()>>,
30// }
31
32pub struct Split {
33    pub output: Vec<OutField>,
34    pub func: Box<dyn Send + FnMut(Box<dyn Array>, &mut [&mut VarArray]) -> Result<()>>,
35}
36
37pub enum Action {
38    ErrorOut,
39    Drop,
40    Copy,
41    // Transform(Transform),
42    Split(Split),
43}
44
45pub struct Op {
46    pub input: String,
47    pub action: Action,
48}
49
50pub struct Repack {
51    pub ops: Vec<Op>,
52}
53
54#[inline]
55pub fn find_field<'f>(schema: &'f Schema, name: &str) -> Option<(usize, &'f Field)> {
56    schema
57        .fields
58        .iter()
59        .enumerate()
60        .find(|(_, f)| f.name == name)
61}
62
63pub enum LoopDecision {
64    Include,
65    Skip,
66    Break,
67}
68
69pub fn read_single_column(
70    mut f: impl Read + Seek,
71    rg_meta: &RowGroupMetaData,
72    field_meta: Field,
73) -> Result<Box<dyn Array>> {
74    let name = field_meta.name.to_string();
75    let col = read::read_columns(&mut f, rg_meta.columns(), &name)?;
76    let mut des = read::to_deserializer(
77        col,
78        field_meta,
79        rg_meta
80            .num_rows()
81            .try_into()
82            .expect("row count fits in memory"),
83        None,
84        None,
85    )?;
86
87    let ret = des
88        .next()
89        .ok_or_else(|| anyhow!("expected at least one column"))??;
90    ensure!(des.next().is_none(), "expected exactly one column");
91
92    Ok(ret)
93}
94
95pub fn transform<W: Write + Send + 'static>(
96    mut f: impl Read + Seek,
97    out: W,
98    repack: &mut Repack,
99    mut rg_filter: impl FnMut(usize, &RowGroupMetaData) -> LoopDecision,
100) -> Result<W> {
101    let metadata = read::read_metadata(&mut f)?;
102    let in_schema = read::infer_schema(&metadata)?;
103
104    let out_schema = repack
105        .ops
106        .iter()
107        .flat_map(|op| -> Vec<Result<OutField>> {
108            match &op.action {
109                Action::Drop | Action::ErrorOut => Vec::new(),
110                Action::Copy => vec![
111                    try {
112                        let (_, x) = find_field(&in_schema, &op.input)
113                            .ok_or_else(|| anyhow!("field has gone missing?"))?;
114                        OutField {
115                            name: x.name.to_string(),
116                            data_type: x.data_type.clone(),
117                            nullable: x.is_nullable,
118                            encoding: Encoding::Plain,
119                        }
120                    },
121                ],
122
123                Action::Split(split) => split.output.iter().cloned().map(Ok).collect(),
124            }
125        })
126        .collect::<Result<Vec<OutField>>>()?;
127
128    let table_schema = out_schema
129        .iter()
130        .map(|v| -> Result<TableField> {
131            Ok(TableField {
132                name: v.name.to_string(),
133                kind: Kind::from_arrow(&v.data_type)
134                    .with_context(|| anyhow!("converting {:?} to a Kind", v.name))?,
135                nullable: false,
136                encoding: Encoding::Plain,
137            })
138        })
139        .collect::<Result<Vec<_>>>()
140        .with_context(|| anyhow!("generating an internal schema for the output"))?;
141
142    let mut writer = Packer::new(out, &table_schema)?;
143
144    for (rg, rg_meta) in metadata.row_groups.iter().enumerate() {
145        info!(
146            "handling rg {}/{} ({} rows)",
147            rg,
148            metadata.row_groups.len(),
149            rg_meta.num_rows()
150        );
151
152        match rg_filter(rg, rg_meta) {
153            LoopDecision::Include => (),
154            LoopDecision::Skip => continue,
155            LoopDecision::Break => break,
156        };
157
158        for op in &mut repack.ops {
159            let (_field, field_meta) = find_field(&in_schema, &op.input)
160                .ok_or_else(|| anyhow!("looking up input field {:?}", op.input))?;
161
162            let arr = read_single_column(&mut f, rg_meta, field_meta.clone())?;
163
164            match &mut op.action {
165                Action::ErrorOut => bail!("asked to error out after loading {:?}", field_meta.name),
166                Action::Drop => unimplemented!("drop"),
167                Action::Copy => {
168                    let (output, _) = writer.find_field(&op.input).expect("created above");
169
170                    let output = writer.table().get(output);
171
172                    if let Some(output) = output.downcast_mut::<MutableUtf8Array<i32>>() {
173                        output
174                            .try_extend(
175                                arr.as_any()
176                                    .downcast_ref::<Utf8Array<i32>>()
177                                    .expect("input=output")
178                                    .iter(),
179                            )
180                            .with_context(|| {
181                                anyhow!("copying {} rows of {:?}", metadata.num_rows, op.input)
182                            })?;
183                    } else if let Some(output) = output.downcast_mut::<MutablePrimitiveArray<i64>>()
184                    {
185                        output.extend(
186                            arr.as_any()
187                                .downcast_ref::<PrimitiveArray<i64>>()
188                                .expect("input=output")
189                                .iter()
190                                .map(|v| v.map(|x| *x)),
191                        );
192                    } else if let Some(output) = output.downcast_mut::<MutablePrimitiveArray<i32>>()
193                    {
194                        output.extend(
195                            arr.as_any()
196                                .downcast_ref::<PrimitiveArray<i32>>()
197                                .expect("input=output")
198                                .iter()
199                                .map(|v| v.map(|x| *x)),
200                        );
201                    } else if let Some(output) = output.downcast_mut::<MutableBooleanArray>() {
202                        output.extend(
203                            arr.as_any()
204                                .downcast_ref::<BooleanArray>()
205                                .expect("input=output")
206                                .iter(),
207                        );
208                    } else if let Some(output) = output.downcast_mut::<MutablePrimitiveArray<f64>>()
209                    {
210                        output.extend(
211                            arr.as_any()
212                                .downcast_ref::<PrimitiveArray<f64>>()
213                                .expect("input=output")
214                                .iter()
215                                .map(|v| v.map(|x| *x)),
216                        );
217                    } else {
218                        bail!(
219                            "copy for {:?} columns ({:?})",
220                            field_meta.data_type,
221                            field_meta.name
222                        )
223                    }
224                }
225                Action::Split(s) => {
226                    let fields: Vec<usize> = s
227                        .output
228                        .iter()
229                        .map(|f| {
230                            writer
231                                .find_field(&f.name)
232                                .expect("created based on input")
233                                .0
234                        })
235                        .collect();
236                    (s.func)(arr, &mut writer.table().get_many(&fields))?;
237                }
238            }
239        }
240
241        writer.table().finish_bulk_push()?;
242        writer.consider_flushing()?;
243    }
244
245    writer.finish()
246}