cherry_cast/
lib.rs

1#![allow(clippy::manual_div_ceil)]
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use arrow::{
7    array::{builder, Array, BinaryArray, Decimal256Array, RecordBatch, StringArray},
8    compute::CastOptions,
9    datatypes::{DataType, Field, Schema},
10};
11
12/// Casts columns according to given (column name, target data type) pairs.
13///
14/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
15/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
16pub fn cast<S: AsRef<str>>(
17    map: &[(S, DataType)],
18    data: &RecordBatch,
19    allow_cast_fail: bool,
20) -> Result<RecordBatch> {
21    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
22
23    let mut arrays = Vec::with_capacity(data.num_columns());
24
25    let cast_opt = CastOptions {
26        safe: allow_cast_fail,
27        ..Default::default()
28    };
29
30    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
31        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
32
33        let col = match cast_target {
34            Some(tgt) => {
35                // allow precision loss for decimal types into floating point types
36                if matches!(
37                    col.data_type(),
38                    DataType::Decimal256(..) | DataType::Decimal128(..)
39                ) && tgt.1.is_floating()
40                {
41                    let string_col =
42                        arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
43                            .with_context(|| {
44                                format!(
45                            "Failed when casting column '{}' to string as intermediate step",
46                            field.name()
47                        )
48                            })?;
49                    Arc::new(
50                        arrow::compute::cast_with_options(&string_col, &tgt.1, &cast_opt)
51                            .with_context(|| {
52                                format!(
53                                    "Failed when casting column '{}' to {:?}",
54                                    field.name(),
55                                    tgt.1
56                                )
57                            })?,
58                    )
59                } else {
60                    Arc::new(
61                        arrow::compute::cast_with_options(col, &tgt.1, &cast_opt).with_context(
62                            || {
63                                format!(
64                                    "Failed when casting column '{}' from {:?} to {:?}",
65                                    field.name(),
66                                    col.data_type(),
67                                    tgt.1
68                                )
69                            },
70                        )?,
71                    )
72                }
73            }
74            None => col.clone(),
75        };
76
77        arrays.push(col);
78    }
79
80    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
81
82    Ok(batch)
83}
84
85/// Casts column types according to given (column name, target data type) pairs.
86pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
87    let mut fields = schema.fields().to_vec();
88
89    for f in fields.iter_mut() {
90        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
91
92        if let Some(tgt) = cast_target {
93            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
94        }
95    }
96
97    Ok(Schema::new(fields))
98}
99
100/// Casts all columns with from_type to to_type.
101///
102/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
103/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
104pub fn cast_by_type(
105    data: &RecordBatch,
106    from_type: &DataType,
107    to_type: &DataType,
108    allow_cast_fail: bool,
109) -> Result<RecordBatch> {
110    let schema =
111        cast_schema_by_type(data.schema_ref(), from_type, to_type).context("cast schema")?;
112
113    let mut arrays = Vec::with_capacity(data.num_columns());
114
115    let cast_opt = CastOptions {
116        safe: allow_cast_fail,
117        ..Default::default()
118    };
119
120    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
121        let col = if col.data_type() == from_type {
122            // allow precision loss for decimal types into floating point types
123            if matches!(
124                col.data_type(),
125                DataType::Decimal256(..) | DataType::Decimal128(..)
126            ) && to_type.is_floating()
127            {
128                let string_col = arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
129                    .with_context(|| {
130                        format!(
131                            "Failed when casting_by_type column '{}' to string as intermediate step",
132                            field.name()
133                        )
134                    })?;
135                Arc::new(
136                    arrow::compute::cast_with_options(&string_col, to_type, &cast_opt)
137                        .with_context(|| {
138                            format!(
139                                "Failed when casting_by_type column '{}' to {:?}",
140                                field.name(),
141                                to_type
142                            )
143                        })?,
144                )
145            } else {
146                Arc::new(
147                    arrow::compute::cast_with_options(col, to_type, &cast_opt).with_context(
148                        || {
149                            format!(
150                                "Failed when casting_by_type column '{}' to {:?}",
151                                field.name(),
152                                to_type
153                            )
154                        },
155                    )?,
156                )
157            }
158        } else {
159            col.clone()
160        };
161
162        arrays.push(col);
163    }
164
165    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
166
167    Ok(batch)
168}
169
170/// Casts columns with from_type to to_type
171pub fn cast_schema_by_type(
172    schema: &Schema,
173    from_type: &DataType,
174    to_type: &DataType,
175) -> Result<Schema> {
176    let mut fields = schema.fields().to_vec();
177
178    for f in fields.iter_mut() {
179        if f.data_type() == from_type {
180            *f = Arc::new(Field::new(f.name(), to_type.clone(), f.is_nullable()));
181        }
182    }
183
184    Ok(Schema::new(fields))
185}
186
187pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
188    let schema = schema_binary_to_string(data.schema_ref());
189    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
190
191    for col in data.columns().iter() {
192        if col.data_type() == &DataType::Binary {
193            columns.push(Arc::new(base58_encode_column(
194                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
195            )));
196        } else {
197            columns.push(col.clone());
198        }
199    }
200
201    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
202}
203
204pub fn base58_encode_column(col: &BinaryArray) -> StringArray {
205    let mut arr =
206        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
207
208    for v in col.iter() {
209        match v {
210            Some(v) => {
211                let v = bs58::encode(v)
212                    .with_alphabet(bs58::Alphabet::BITCOIN)
213                    .into_string();
214                arr.append_value(v);
215            }
216            None => arr.append_null(),
217        }
218    }
219
220    arr.finish()
221}
222
223pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
224    let schema = schema_binary_to_string(data.schema_ref());
225    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
226
227    for col in data.columns().iter() {
228        if col.data_type() == &DataType::Binary {
229            columns.push(Arc::new(hex_encode_column::<PREFIXED>(
230                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
231            )));
232        } else {
233            columns.push(col.clone());
234        }
235    }
236
237    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
238}
239
240pub fn hex_encode_column<const PREFIXED: bool>(col: &BinaryArray) -> StringArray {
241    let mut arr =
242        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
243
244    for v in col.iter() {
245        match v {
246            Some(v) => {
247                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
248                // directly somehow.
249                let v = if PREFIXED {
250                    format!("0x{}", faster_hex::hex_string(v))
251                } else {
252                    faster_hex::hex_string(v)
253                };
254
255                arr.append_value(v);
256            }
257            None => arr.append_null(),
258        }
259    }
260
261    arr.finish()
262}
263
264/// Converts binary fields to string in the schema
265///
266/// Intended to be used with encode hex functions
267pub fn schema_binary_to_string(schema: &Schema) -> Schema {
268    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
269
270    for f in schema.fields().iter() {
271        if f.data_type() == &DataType::Binary {
272            fields.push(Arc::new(Field::new(
273                f.name().clone(),
274                DataType::Utf8,
275                f.is_nullable(),
276            )));
277        } else {
278            fields.push(f.clone());
279        }
280    }
281
282    Schema::new(fields)
283}
284
285/// Converts decimal256 fields to binary in the schema
286///
287/// Intended to be used with u256_to_binary function
288pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
289    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
290
291    for f in schema.fields().iter() {
292        if f.data_type() == &DataType::Decimal256(76, 0) {
293            fields.push(Arc::new(Field::new(
294                f.name().clone(),
295                DataType::Binary,
296                f.is_nullable(),
297            )));
298        } else {
299            fields.push(f.clone());
300        }
301    }
302
303    Schema::new(fields)
304}
305
306pub fn base58_decode_column(col: &StringArray) -> Result<BinaryArray> {
307    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
308
309    for v in col.iter() {
310        match v {
311            // TODO: this should be optimized by removing allocations if needed
312            Some(v) => {
313                let v = bs58::decode(v)
314                    .with_alphabet(bs58::Alphabet::BITCOIN)
315                    .into_vec()
316                    .context("bs58 decode")?;
317                arr.append_value(v);
318            }
319            None => arr.append_null(),
320        }
321    }
322
323    Ok(arr.finish())
324}
325
326pub fn hex_decode_column<const PREFIXED: bool>(col: &StringArray) -> Result<BinaryArray> {
327    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
328
329    for v in col.iter() {
330        match v {
331            // TODO: this should be optimized by removing allocations if needed
332            Some(v) => {
333                let v = v.as_bytes();
334                let v = if PREFIXED {
335                    v.get(2..).context("index into prefix hex encoded value")?
336                } else {
337                    v
338                };
339
340                let len = v.len();
341                let mut dst = vec![0; (len + 1) / 2];
342
343                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
344
345                arr.append_value(dst);
346            }
347            None => arr.append_null(),
348        }
349    }
350
351    Ok(arr.finish())
352}
353
354pub fn u256_column_from_binary(col: &BinaryArray) -> Result<Decimal256Array> {
355    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
356
357    for v in col.iter() {
358        match v {
359            Some(v) => {
360                let num = ruint::aliases::U256::try_from_be_slice(v).context("parse ruint u256")?;
361                let num = alloy_primitives::I256::try_from(num)
362                    .with_context(|| format!("u256 to i256. val was {}", num))?;
363
364                let val = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
365                arr.append_value(val);
366            }
367            None => arr.append_null(),
368        }
369    }
370
371    Ok(arr.with_precision_and_scale(76, 0).unwrap().finish())
372}
373
374pub fn u256_column_to_binary(col: &Decimal256Array) -> Result<BinaryArray> {
375    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
376
377    for v in col.iter() {
378        match v {
379            Some(v) => {
380                let num = alloy_primitives::I256::from_be_bytes::<32>(v.to_be_bytes());
381                let num = ruint::aliases::U256::try_from(num).context("convert i256 to u256")?;
382                arr.append_value(num.to_be_bytes_trimmed_vec());
383            }
384            None => {
385                arr.append_null();
386            }
387        }
388    }
389
390    Ok(arr.finish())
391}
392
393/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
394pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
395    let schema = schema_decimal256_to_binary(data.schema_ref());
396    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
397
398    for (i, col) in data.columns().iter().enumerate() {
399        if col.data_type() == &DataType::Decimal256(76, 0) {
400            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
401            let x = u256_column_to_binary(col)
402                .with_context(|| format!("col {} to binary", data.schema().fields()[i].name()))?;
403            columns.push(Arc::new(x));
404        } else {
405            columns.push(col.clone());
406        }
407    }
408
409    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use arrow::datatypes::DataType;
416    use std::fs::File;
417
418    #[test]
419    #[ignore]
420    fn test_cast() {
421        use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
422
423        let builder =
424            ParquetRecordBatchReaderBuilder::try_new(File::open("data.parquet").unwrap()).unwrap();
425        let mut reader = builder.build().unwrap();
426        let table = reader.next().unwrap().unwrap();
427
428        let type_mappings = vec![
429            ("amount0In", DataType::Decimal128(15, 0)),
430            ("amount1In", DataType::Float32),
431            ("amount0Out", DataType::Float64),
432            ("amount1Out", DataType::Decimal128(38, 0)),
433            ("timestamp", DataType::Int64),
434        ];
435
436        let result = cast(&type_mappings, &table, true).unwrap();
437
438        // Save the filtered instructions to a new parquet file
439        let mut file = File::create("result.parquet").unwrap();
440        let mut writer =
441            parquet::arrow::ArrowWriter::try_new(&mut file, result.schema(), None).unwrap();
442        writer.write(&result).unwrap();
443        writer.close().unwrap();
444    }
445}