cherry_cast/
lib.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use arrow::{
5    array::{builder, Array, BinaryArray, Decimal256Array, RecordBatch, StringArray},
6    compute::CastOptions,
7    datatypes::{DataType, Field, Schema},
8};
9
10/// Casts columns according to given (column name, target data type) pairs.
11///
12/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
13/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
14pub fn cast<S: AsRef<str>>(
15    map: &[(S, DataType)],
16    data: &RecordBatch,
17    allow_cast_fail: bool,
18) -> Result<RecordBatch> {
19    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
20
21    let mut arrays = Vec::with_capacity(data.num_columns());
22
23    let cast_opt = CastOptions {
24        safe: !allow_cast_fail,
25        ..Default::default()
26    };
27
28    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
29        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
30
31        let col = match cast_target {
32            Some(tgt) => Arc::new(
33                arrow::compute::cast_with_options(col, &tgt.1, &cast_opt)
34                    .with_context(|| format!("Failed when casting column '{}'", field.name()))?,
35            ),
36            None => col.clone(),
37        };
38
39        arrays.push(col);
40    }
41
42    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
43
44    Ok(batch)
45}
46
47/// Casts column types according to given (column name, target data type) pairs.
48pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
49    let mut fields = schema.fields().to_vec();
50
51    for f in fields.iter_mut() {
52        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
53
54        if let Some(tgt) = cast_target {
55            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
56        }
57    }
58
59    Ok(Schema::new(fields))
60}
61
62pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
63    let schema = schema_binary_to_string(data.schema_ref());
64    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
65
66    for col in data.columns().iter() {
67        if col.data_type() == &DataType::Binary {
68            columns.push(Arc::new(base58_encode_column(
69                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
70            )));
71        } else {
72            columns.push(col.clone());
73        }
74    }
75
76    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
77}
78
79pub fn base58_encode_column(col: &BinaryArray) -> StringArray {
80    let mut arr =
81        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
82
83    for v in col.iter() {
84        match v {
85            Some(v) => {
86                let v = bs58::encode(v)
87                    .with_alphabet(bs58::Alphabet::BITCOIN)
88                    .into_string();
89                arr.append_value(v);
90            }
91            None => arr.append_null(),
92        }
93    }
94
95    arr.finish()
96}
97
98pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
99    let schema = schema_binary_to_string(data.schema_ref());
100    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
101
102    for col in data.columns().iter() {
103        if col.data_type() == &DataType::Binary {
104            columns.push(Arc::new(hex_encode_column::<PREFIXED>(
105                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
106            )));
107        } else {
108            columns.push(col.clone());
109        }
110    }
111
112    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
113}
114
115pub fn hex_encode_column<const PREFIXED: bool>(col: &BinaryArray) -> StringArray {
116    let mut arr =
117        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
118
119    for v in col.iter() {
120        match v {
121            Some(v) => {
122                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
123                // directly somehow.
124                let v = if PREFIXED {
125                    format!("0x{}", faster_hex::hex_string(v))
126                } else {
127                    faster_hex::hex_string(v)
128                };
129
130                arr.append_value(v);
131            }
132            None => arr.append_null(),
133        }
134    }
135
136    arr.finish()
137}
138
139/// Converts binary fields to string in the schema
140///
141/// Intended to be used with encode hex functions
142pub fn schema_binary_to_string(schema: &Schema) -> Schema {
143    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
144
145    for f in schema.fields().iter() {
146        if f.data_type() == &DataType::Binary {
147            fields.push(Arc::new(Field::new(
148                f.name().clone(),
149                DataType::Utf8,
150                f.is_nullable(),
151            )));
152        } else {
153            fields.push(f.clone());
154        }
155    }
156
157    Schema::new(fields)
158}
159
160/// Converts decimal256 fields to binary in the schema
161///
162/// Intended to be used with u256_to_binary function
163pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
164    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
165
166    for f in schema.fields().iter() {
167        if f.data_type() == &DataType::Decimal256(76, 0) {
168            fields.push(Arc::new(Field::new(
169                f.name().clone(),
170                DataType::Binary,
171                f.is_nullable(),
172            )));
173        } else {
174            fields.push(f.clone());
175        }
176    }
177
178    Schema::new(fields)
179}
180
181pub fn base58_decode_column(col: &StringArray) -> Result<BinaryArray> {
182    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
183
184    for v in col.iter() {
185        match v {
186            // TODO: this should be optimized by removing allocations if needed
187            Some(v) => {
188                let v = bs58::decode(v)
189                    .with_alphabet(bs58::Alphabet::BITCOIN)
190                    .into_vec()
191                    .context("bs58 decode")?;
192                arr.append_value(v);
193            }
194            None => arr.append_null(),
195        }
196    }
197
198    Ok(arr.finish())
199}
200
201pub fn hex_decode_column<const PREFIXED: bool>(col: &StringArray) -> Result<BinaryArray> {
202    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
203
204    for v in col.iter() {
205        match v {
206            // TODO: this should be optimized by removing allocations if needed
207            Some(v) => {
208                let v = v.as_bytes();
209                let v = if PREFIXED {
210                    v.get(2..).context("index into prefix hex encoded value")?
211                } else {
212                    v
213                };
214
215                let len = v.len();
216                let mut dst = vec![0; (len + 1) / 2];
217
218                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
219
220                arr.append_value(dst);
221            }
222            None => arr.append_null(),
223        }
224    }
225
226    Ok(arr.finish())
227}
228
229pub fn u256_column_from_binary(col: &BinaryArray) -> Result<Decimal256Array> {
230    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
231
232    for v in col.iter() {
233        match v {
234            Some(v) => {
235                let num = ruint::aliases::U256::try_from_be_slice(v).context("parse ruint u256")?;
236                let num = alloy_primitives::I256::try_from(num)
237                    .with_context(|| format!("u256 to i256. val was {}", num))?;
238
239                let val = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
240                arr.append_value(val);
241            }
242            None => arr.append_null(),
243        }
244    }
245
246    Ok(arr.with_precision_and_scale(76, 0).unwrap().finish())
247}
248
249pub fn u256_column_to_binary(col: &Decimal256Array) -> Result<BinaryArray> {
250    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
251
252    for v in col.iter() {
253        match v {
254            Some(v) => {
255                let num = alloy_primitives::I256::from_be_bytes::<32>(v.to_be_bytes());
256                let num = ruint::aliases::U256::try_from(num).context("convert i256 to u256")?;
257                arr.append_value(num.to_be_bytes_trimmed_vec());
258            }
259            None => {
260                arr.append_null();
261            }
262        }
263    }
264
265    Ok(arr.finish())
266}
267
268/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
269pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
270    let schema = schema_decimal256_to_binary(data.schema_ref());
271    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
272
273    for (i, col) in data.columns().iter().enumerate() {
274        if col.data_type() == &DataType::Decimal256(76, 0) {
275            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
276            let x = u256_column_to_binary(col)
277                .with_context(|| format!("col {} to binary", data.schema().fields()[i].name()))?;
278            columns.push(Arc::new(x));
279        } else {
280            columns.push(col.clone());
281        }
282    }
283
284    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
285}