cherry_cast/
lib.rs

1#![allow(clippy::manual_div_ceil)]
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use arrow::{
7    array::{
8        builder, Array, BinaryArray, Decimal256Array, GenericBinaryArray, GenericStringArray,
9        LargeBinaryArray, OffsetSizeTrait, RecordBatch,
10    },
11    compute::CastOptions,
12    datatypes::{DataType, Field, Schema},
13};
14
15/// Casts columns according to given (column name, target data type) pairs.
16///
17/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
18/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
19pub fn cast<S: AsRef<str>>(
20    map: &[(S, DataType)],
21    data: &RecordBatch,
22    allow_cast_fail: bool,
23) -> Result<RecordBatch> {
24    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
25
26    let mut arrays = Vec::with_capacity(data.num_columns());
27
28    let cast_opt = CastOptions {
29        safe: allow_cast_fail,
30        ..Default::default()
31    };
32
33    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
34        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
35
36        let col = match cast_target {
37            Some(tgt) => {
38                // allow precision loss for decimal types into floating point types
39                if matches!(
40                    col.data_type(),
41                    DataType::Decimal256(..) | DataType::Decimal128(..)
42                ) && tgt.1.is_floating()
43                {
44                    let string_col =
45                        arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
46                            .with_context(|| {
47                                format!(
48                            "Failed when casting column '{}' to string as intermediate step",
49                            field.name()
50                        )
51                            })?;
52                    Arc::new(
53                        arrow::compute::cast_with_options(&string_col, &tgt.1, &cast_opt)
54                            .with_context(|| {
55                                format!(
56                                    "Failed when casting column '{}' to {:?}",
57                                    field.name(),
58                                    tgt.1
59                                )
60                            })?,
61                    )
62                } else {
63                    Arc::new(
64                        arrow::compute::cast_with_options(col, &tgt.1, &cast_opt).with_context(
65                            || {
66                                format!(
67                                    "Failed when casting column '{}' from {:?} to {:?}",
68                                    field.name(),
69                                    col.data_type(),
70                                    tgt.1
71                                )
72                            },
73                        )?,
74                    )
75                }
76            }
77            None => col.clone(),
78        };
79
80        arrays.push(col);
81    }
82
83    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
84
85    Ok(batch)
86}
87
88/// Casts column types according to given (column name, target data type) pairs.
89pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
90    let mut fields = schema.fields().to_vec();
91
92    for f in fields.iter_mut() {
93        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
94
95        if let Some(tgt) = cast_target {
96            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
97        }
98    }
99
100    Ok(Schema::new(fields))
101}
102
103/// Casts all columns with from_type to to_type.
104///
105/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
106/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
107pub fn cast_by_type(
108    data: &RecordBatch,
109    from_type: &DataType,
110    to_type: &DataType,
111    allow_cast_fail: bool,
112) -> Result<RecordBatch> {
113    let schema =
114        cast_schema_by_type(data.schema_ref(), from_type, to_type).context("cast schema")?;
115
116    let mut arrays = Vec::with_capacity(data.num_columns());
117
118    let cast_opt = CastOptions {
119        safe: allow_cast_fail,
120        ..Default::default()
121    };
122
123    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
124        let col = if col.data_type() == from_type {
125            // allow precision loss for decimal types into floating point types
126            if matches!(
127                col.data_type(),
128                DataType::Decimal256(..) | DataType::Decimal128(..)
129            ) && to_type.is_floating()
130            {
131                let string_col = arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
132                    .with_context(|| {
133                        format!(
134                            "Failed when casting_by_type column '{}' to string as intermediate step",
135                            field.name()
136                        )
137                    })?;
138                Arc::new(
139                    arrow::compute::cast_with_options(&string_col, to_type, &cast_opt)
140                        .with_context(|| {
141                            format!(
142                                "Failed when casting_by_type column '{}' to {:?}",
143                                field.name(),
144                                to_type
145                            )
146                        })?,
147                )
148            } else {
149                Arc::new(
150                    arrow::compute::cast_with_options(col, to_type, &cast_opt).with_context(
151                        || {
152                            format!(
153                                "Failed when casting_by_type column '{}' to {:?}",
154                                field.name(),
155                                to_type
156                            )
157                        },
158                    )?,
159                )
160            }
161        } else {
162            col.clone()
163        };
164
165        arrays.push(col);
166    }
167
168    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
169
170    Ok(batch)
171}
172
173/// Casts columns with from_type to to_type
174pub fn cast_schema_by_type(
175    schema: &Schema,
176    from_type: &DataType,
177    to_type: &DataType,
178) -> Result<Schema> {
179    let mut fields = schema.fields().to_vec();
180
181    for f in fields.iter_mut() {
182        if f.data_type() == from_type {
183            *f = Arc::new(Field::new(f.name(), to_type.clone(), f.is_nullable()));
184        }
185    }
186
187    Ok(Schema::new(fields))
188}
189
190pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
191    let schema = schema_binary_to_string(data.schema_ref());
192    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
193
194    for col in data.columns().iter() {
195        if col.data_type() == &DataType::Binary {
196            columns.push(Arc::new(base58_encode_column(
197                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
198            )));
199        } else if col.data_type() == &DataType::LargeBinary {
200            columns.push(Arc::new(base58_encode_column(
201                col.as_any().downcast_ref::<LargeBinaryArray>().unwrap(),
202            )));
203        } else {
204            columns.push(col.clone());
205        }
206    }
207
208    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
209}
210
211pub fn base58_encode_column<I: OffsetSizeTrait>(
212    col: &GenericBinaryArray<I>,
213) -> GenericStringArray<I> {
214    let mut arr = builder::GenericStringBuilder::<I>::with_capacity(
215        col.len(),
216        (col.value_data().len() + 2) * 2,
217    );
218
219    for v in col.iter() {
220        match v {
221            Some(v) => {
222                let v = bs58::encode(v)
223                    .with_alphabet(bs58::Alphabet::BITCOIN)
224                    .into_string();
225                arr.append_value(v);
226            }
227            None => arr.append_null(),
228        }
229    }
230
231    arr.finish()
232}
233
234pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
235    let schema = schema_binary_to_string(data.schema_ref());
236    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
237
238    for col in data.columns().iter() {
239        if col.data_type() == &DataType::Binary {
240            columns.push(Arc::new(hex_encode_column::<PREFIXED, i32>(
241                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
242            )));
243        } else if col.data_type() == &DataType::LargeBinary {
244            columns.push(Arc::new(hex_encode_column::<PREFIXED, i64>(
245                col.as_any().downcast_ref::<LargeBinaryArray>().unwrap(),
246            )));
247        } else {
248            columns.push(col.clone());
249        }
250    }
251
252    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
253}
254
255pub fn hex_encode_column<const PREFIXED: bool, I: OffsetSizeTrait>(
256    col: &GenericBinaryArray<I>,
257) -> GenericStringArray<I> {
258    let mut arr = builder::GenericStringBuilder::<I>::with_capacity(
259        col.len(),
260        (col.value_data().len() + 2) * 2,
261    );
262
263    for v in col.iter() {
264        match v {
265            Some(v) => {
266                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
267                // directly somehow.
268                let v = if PREFIXED {
269                    format!("0x{}", faster_hex::hex_string(v))
270                } else {
271                    faster_hex::hex_string(v)
272                };
273
274                arr.append_value(v);
275            }
276            None => arr.append_null(),
277        }
278    }
279
280    arr.finish()
281}
282
283/// Converts binary fields to string in the schema
284///
285/// Intended to be used with encode hex functions
286pub fn schema_binary_to_string(schema: &Schema) -> Schema {
287    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
288
289    for f in schema.fields().iter() {
290        if f.data_type() == &DataType::Binary {
291            fields.push(Arc::new(Field::new(
292                f.name().clone(),
293                DataType::Utf8,
294                f.is_nullable(),
295            )));
296        } else if f.data_type() == &DataType::LargeBinary {
297            fields.push(Arc::new(Field::new(
298                f.name().clone(),
299                DataType::LargeUtf8,
300                f.is_nullable(),
301            )));
302        } else {
303            fields.push(f.clone());
304        }
305    }
306
307    Schema::new(fields)
308}
309
310/// Converts decimal256 fields to binary in the schema
311///
312/// Intended to be used with u256_to_binary function
313pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
314    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
315
316    for f in schema.fields().iter() {
317        if f.data_type() == &DataType::Decimal256(76, 0) {
318            fields.push(Arc::new(Field::new(
319                f.name().clone(),
320                DataType::Binary,
321                f.is_nullable(),
322            )));
323        } else {
324            fields.push(f.clone());
325        }
326    }
327
328    Schema::new(fields)
329}
330
331pub fn base58_decode_column<I: OffsetSizeTrait>(
332    col: &GenericStringArray<I>,
333) -> Result<GenericBinaryArray<I>> {
334    let mut arr =
335        builder::GenericBinaryBuilder::<I>::with_capacity(col.len(), col.value_data().len() / 2);
336
337    for v in col.iter() {
338        match v {
339            // TODO: this should be optimized by removing allocations if needed
340            Some(v) => {
341                let v = bs58::decode(v)
342                    .with_alphabet(bs58::Alphabet::BITCOIN)
343                    .into_vec()
344                    .context("bs58 decode")?;
345                arr.append_value(v);
346            }
347            None => arr.append_null(),
348        }
349    }
350
351    Ok(arr.finish())
352}
353
354pub fn hex_decode_column<const PREFIXED: bool, I: OffsetSizeTrait>(
355    col: &GenericStringArray<I>,
356) -> Result<GenericBinaryArray<I>> {
357    let mut arr =
358        builder::GenericBinaryBuilder::<I>::with_capacity(col.len(), col.value_data().len() / 2);
359
360    for v in col.iter() {
361        match v {
362            // TODO: this should be optimized by removing allocations if needed
363            Some(v) => {
364                let v = v.as_bytes();
365                let v = if PREFIXED {
366                    v.get(2..).context("index into prefix hex encoded value")?
367                } else {
368                    v
369                };
370
371                let len = v.len();
372                let mut dst = vec![0; (len + 1) / 2];
373
374                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
375
376                arr.append_value(dst);
377            }
378            None => arr.append_null(),
379        }
380    }
381
382    Ok(arr.finish())
383}
384
385pub fn u256_column_from_binary<I: OffsetSizeTrait>(
386    col: &GenericBinaryArray<I>,
387) -> Result<Decimal256Array> {
388    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
389
390    for v in col.iter() {
391        match v {
392            Some(v) => {
393                let num = ruint::aliases::U256::try_from_be_slice(v).context("parse ruint u256")?;
394                let num = alloy_primitives::I256::try_from(num)
395                    .with_context(|| format!("u256 to i256. val was {}", num))?;
396
397                let val = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
398                arr.append_value(val);
399            }
400            None => arr.append_null(),
401        }
402    }
403
404    Ok(arr.with_precision_and_scale(76, 0).unwrap().finish())
405}
406
407pub fn u256_column_to_binary(col: &Decimal256Array) -> Result<BinaryArray> {
408    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
409
410    for v in col.iter() {
411        match v {
412            Some(v) => {
413                let num = alloy_primitives::I256::from_be_bytes::<32>(v.to_be_bytes());
414                let num = ruint::aliases::U256::try_from(num).context("convert i256 to u256")?;
415                arr.append_value(num.to_be_bytes_trimmed_vec());
416            }
417            None => {
418                arr.append_null();
419            }
420        }
421    }
422
423    Ok(arr.finish())
424}
425
426/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
427pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
428    let schema = schema_decimal256_to_binary(data.schema_ref());
429    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
430
431    for (i, col) in data.columns().iter().enumerate() {
432        if col.data_type() == &DataType::Decimal256(76, 0) {
433            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
434            let x = u256_column_to_binary(col)
435                .with_context(|| format!("col {} to binary", data.schema().fields()[i].name()))?;
436            columns.push(Arc::new(x));
437        } else {
438            columns.push(col.clone());
439        }
440    }
441
442    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use arrow::datatypes::DataType;
449    use std::fs::File;
450
451    #[test]
452    #[ignore]
453    fn test_cast() {
454        use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
455
456        let builder =
457            ParquetRecordBatchReaderBuilder::try_new(File::open("data.parquet").unwrap()).unwrap();
458        let mut reader = builder.build().unwrap();
459        let table = reader.next().unwrap().unwrap();
460
461        let type_mappings = vec![
462            ("amount0In", DataType::Decimal128(15, 0)),
463            ("amount1In", DataType::Float32),
464            ("amount0Out", DataType::Float64),
465            ("amount1Out", DataType::Decimal128(38, 0)),
466            ("timestamp", DataType::Int64),
467        ];
468
469        let result = cast(&type_mappings, &table, true).unwrap();
470
471        // Save the filtered instructions to a new parquet file
472        let mut file = File::create("result.parquet").unwrap();
473        let mut writer =
474            parquet::arrow::ArrowWriter::try_new(&mut file, result.schema(), None).unwrap();
475        writer.write(&result).unwrap();
476        writer.close().unwrap();
477    }
478}