cherry_cast/
lib.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use arrow::{
5    array::{builder, Array, BinaryArray, Decimal256Array, RecordBatch, StringArray},
6    compute::CastOptions,
7    datatypes::{DataType, Field, Schema},
8};
9use ruint::aliases::U256;
10
11/// Casts columns according to given (column name, target data type) pairs.
12///
13/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
14/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
15pub fn cast<S: AsRef<str>>(
16    map: &[(S, DataType)],
17    data: &RecordBatch,
18    allow_cast_fail: bool,
19) -> Result<RecordBatch> {
20    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
21
22    let mut arrays = Vec::with_capacity(data.num_columns());
23
24    let cast_opt = CastOptions {
25        safe: !allow_cast_fail,
26        ..Default::default()
27    };
28
29    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
30        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
31
32        let col = match cast_target {
33            Some(tgt) => Arc::new(
34                arrow::compute::cast_with_options(col, &tgt.1, &cast_opt)
35                    .with_context(|| format!("Failed when casting column '{}'", field.name()))?,
36            ),
37            None => col.clone(),
38        };
39
40        arrays.push(col);
41    }
42
43    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
44
45    Ok(batch)
46}
47
48/// Casts column types according to given (column name, target data type) pairs.
49pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
50    let mut fields = schema.fields().to_vec();
51
52    for f in fields.iter_mut() {
53        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
54
55        if let Some(tgt) = cast_target {
56            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
57        }
58    }
59
60    Ok(Schema::new(fields))
61}
62
63pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
64    let schema = schema_binary_to_string(data.schema_ref());
65    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
66
67    for col in data.columns().iter() {
68        if col.data_type() == &DataType::Binary {
69            columns.push(Arc::new(hex_encode_column::<PREFIXED>(
70                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
71            )));
72        } else {
73            columns.push(col.clone());
74        }
75    }
76
77    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
78}
79
80pub fn hex_encode_column<const PREFIXED: bool>(col: &BinaryArray) -> StringArray {
81    let mut arr =
82        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
83
84    for v in col.iter() {
85        match v {
86            Some(v) => {
87                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
88                // directly somehow.
89                let v = if PREFIXED {
90                    format!("0x{}", faster_hex::hex_string(v))
91                } else {
92                    faster_hex::hex_string(v)
93                };
94
95                arr.append_value(v);
96            }
97            None => arr.append_null(),
98        }
99    }
100
101    arr.finish()
102}
103
104/// Converts binary fields to string in the schema
105///
106/// Intended to be used with encode hex functions
107pub fn schema_binary_to_string(schema: &Schema) -> Schema {
108    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
109
110    for f in schema.fields().iter() {
111        if f.data_type() == &DataType::Binary {
112            fields.push(Arc::new(Field::new(
113                f.name().clone(),
114                DataType::Utf8,
115                f.is_nullable(),
116            )));
117        } else {
118            fields.push(f.clone());
119        }
120    }
121
122    Schema::new(fields)
123}
124
125/// Converts decimal256 fields to binary in the schema
126///
127/// Intended to be used with u256_to_binary function
128pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
129    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
130
131    for f in schema.fields().iter() {
132        if f.data_type() == &DataType::Decimal256(76, 0) {
133            fields.push(Arc::new(Field::new(
134                f.name().clone(),
135                DataType::Binary,
136                f.is_nullable(),
137            )));
138        } else {
139            fields.push(f.clone());
140        }
141    }
142
143    Schema::new(fields)
144}
145
146pub fn hex_decode_column<const PREFIXED: bool>(col: &StringArray) -> Result<BinaryArray> {
147    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
148
149    for v in col.iter() {
150        match v {
151            // TODO: this should be optimized by removing allocations if needed
152            Some(v) => {
153                let v = v.as_bytes();
154                let v = if PREFIXED {
155                    v.get(2..).context("index into prefix hex encoded value")?
156                } else {
157                    v
158                };
159
160                let len = v.len();
161                let mut dst = vec![0; (len + 1) / 2];
162
163                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
164
165                arr.append_value(dst);
166            }
167            None => arr.append_null(),
168        }
169    }
170
171    Ok(arr.finish())
172}
173
174pub fn u256_column_from_binary(col: &BinaryArray) -> Result<Decimal256Array> {
175    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
176
177    for v in col.iter() {
178        match v {
179            Some(v) => {
180                let num = U256::try_from_be_slice(v).context("parse u256")?;
181                let num = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
182                arr.append_value(num);
183            }
184            None => arr.append_null(),
185        }
186    }
187
188    Ok(arr.with_precision_and_scale(76, 0).unwrap().finish())
189}
190
191pub fn u256_column_to_binary(col: &Decimal256Array) -> BinaryArray {
192    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
193
194    for v in col.iter() {
195        match v {
196            Some(v) => {
197                let num = U256::from_be_bytes::<32>(v.to_be_bytes());
198                arr.append_value(num.to_be_bytes_trimmed_vec());
199            }
200            None => {
201                arr.append_null();
202            }
203        }
204    }
205
206    arr.finish()
207}
208
209/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
210pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
211    let schema = schema_binary_to_string(data.schema_ref());
212    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
213
214    for col in data.columns().iter() {
215        if col.data_type() == &DataType::Decimal256(76, 0) {
216            let mut arr = builder::BinaryBuilder::new();
217
218            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
219
220            for val in col.iter() {
221                arr.append_option(val.map(|v| v.to_be_bytes()));
222            }
223
224            columns.push(Arc::new(arr.finish()));
225        } else {
226            columns.push(col.clone());
227        }
228    }
229
230    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
231}