cherry_cast/
lib.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use arrow::{
5    array::{builder, Array, BinaryArray, Decimal256Array, RecordBatch, StringArray},
6    compute::CastOptions,
7    datatypes::{DataType, Field, Schema},
8};
9
10/// Casts columns according to given (column name, target data type) pairs.
11///
12/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
13/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
14pub fn cast<S: AsRef<str>>(
15    map: &[(S, DataType)],
16    data: &RecordBatch,
17    allow_cast_fail: bool,
18) -> Result<RecordBatch> {
19    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
20
21    let mut arrays = Vec::with_capacity(data.num_columns());
22
23    let cast_opt = CastOptions {
24        safe: allow_cast_fail,
25        ..Default::default()
26    };
27
28    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
29        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
30
31        let col = match cast_target {
32            Some(tgt) => Arc::new(
33                arrow::compute::cast_with_options(col, &tgt.1, &cast_opt)
34                    .with_context(|| format!("Failed when casting column '{}'", field.name()))?,
35            ),
36            None => col.clone(),
37        };
38
39        arrays.push(col);
40    }
41
42    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
43
44    Ok(batch)
45}
46
47/// Casts column types according to given (column name, target data type) pairs.
48pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
49    let mut fields = schema.fields().to_vec();
50
51    for f in fields.iter_mut() {
52        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
53
54        if let Some(tgt) = cast_target {
55            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
56        }
57    }
58
59    Ok(Schema::new(fields))
60}
61
62/// Casts all columns with from_type to to_type.
63///
64/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
65/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
66pub fn cast_by_type(
67    data: &RecordBatch,
68    from_type: &DataType,
69    to_type: &DataType,
70    allow_cast_fail: bool,
71) -> Result<RecordBatch> {
72    let schema =
73        cast_schema_by_type(data.schema_ref(), from_type, to_type).context("cast schema")?;
74
75    let mut arrays = Vec::with_capacity(data.num_columns());
76
77    let cast_opt = CastOptions {
78        safe: allow_cast_fail,
79        ..Default::default()
80    };
81
82    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
83        let col = if col.data_type() == from_type {
84            Arc::new(
85                arrow::compute::cast_with_options(col, to_type, &cast_opt)
86                    .with_context(|| format!("Failed when casting column '{}'", field.name()))?,
87            )
88        } else {
89            col.clone()
90        };
91
92        arrays.push(col);
93    }
94
95    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
96
97    Ok(batch)
98}
99
100/// Casts columns with from_type to to_type
101pub fn cast_schema_by_type(
102    schema: &Schema,
103    from_type: &DataType,
104    to_type: &DataType,
105) -> Result<Schema> {
106    let mut fields = schema.fields().to_vec();
107
108    for f in fields.iter_mut() {
109        if f.data_type() == from_type {
110            *f = Arc::new(Field::new(f.name(), to_type.clone(), f.is_nullable()));
111        }
112    }
113
114    Ok(Schema::new(fields))
115}
116
117pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
118    let schema = schema_binary_to_string(data.schema_ref());
119    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
120
121    for col in data.columns().iter() {
122        if col.data_type() == &DataType::Binary {
123            columns.push(Arc::new(base58_encode_column(
124                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
125            )));
126        } else {
127            columns.push(col.clone());
128        }
129    }
130
131    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
132}
133
134pub fn base58_encode_column(col: &BinaryArray) -> StringArray {
135    let mut arr =
136        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
137
138    for v in col.iter() {
139        match v {
140            Some(v) => {
141                let v = bs58::encode(v)
142                    .with_alphabet(bs58::Alphabet::BITCOIN)
143                    .into_string();
144                arr.append_value(v);
145            }
146            None => arr.append_null(),
147        }
148    }
149
150    arr.finish()
151}
152
153pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
154    let schema = schema_binary_to_string(data.schema_ref());
155    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
156
157    for col in data.columns().iter() {
158        if col.data_type() == &DataType::Binary {
159            columns.push(Arc::new(hex_encode_column::<PREFIXED>(
160                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
161            )));
162        } else {
163            columns.push(col.clone());
164        }
165    }
166
167    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
168}
169
170pub fn hex_encode_column<const PREFIXED: bool>(col: &BinaryArray) -> StringArray {
171    let mut arr =
172        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
173
174    for v in col.iter() {
175        match v {
176            Some(v) => {
177                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
178                // directly somehow.
179                let v = if PREFIXED {
180                    format!("0x{}", faster_hex::hex_string(v))
181                } else {
182                    faster_hex::hex_string(v)
183                };
184
185                arr.append_value(v);
186            }
187            None => arr.append_null(),
188        }
189    }
190
191    arr.finish()
192}
193
194/// Converts binary fields to string in the schema
195///
196/// Intended to be used with encode hex functions
197pub fn schema_binary_to_string(schema: &Schema) -> Schema {
198    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
199
200    for f in schema.fields().iter() {
201        if f.data_type() == &DataType::Binary {
202            fields.push(Arc::new(Field::new(
203                f.name().clone(),
204                DataType::Utf8,
205                f.is_nullable(),
206            )));
207        } else {
208            fields.push(f.clone());
209        }
210    }
211
212    Schema::new(fields)
213}
214
215/// Converts decimal256 fields to binary in the schema
216///
217/// Intended to be used with u256_to_binary function
218pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
219    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
220
221    for f in schema.fields().iter() {
222        if f.data_type() == &DataType::Decimal256(76, 0) {
223            fields.push(Arc::new(Field::new(
224                f.name().clone(),
225                DataType::Binary,
226                f.is_nullable(),
227            )));
228        } else {
229            fields.push(f.clone());
230        }
231    }
232
233    Schema::new(fields)
234}
235
236pub fn base58_decode_column(col: &StringArray) -> Result<BinaryArray> {
237    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
238
239    for v in col.iter() {
240        match v {
241            // TODO: this should be optimized by removing allocations if needed
242            Some(v) => {
243                let v = bs58::decode(v)
244                    .with_alphabet(bs58::Alphabet::BITCOIN)
245                    .into_vec()
246                    .context("bs58 decode")?;
247                arr.append_value(v);
248            }
249            None => arr.append_null(),
250        }
251    }
252
253    Ok(arr.finish())
254}
255
256pub fn hex_decode_column<const PREFIXED: bool>(col: &StringArray) -> Result<BinaryArray> {
257    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
258
259    for v in col.iter() {
260        match v {
261            // TODO: this should be optimized by removing allocations if needed
262            Some(v) => {
263                let v = v.as_bytes();
264                let v = if PREFIXED {
265                    v.get(2..).context("index into prefix hex encoded value")?
266                } else {
267                    v
268                };
269
270                let len = v.len();
271                let mut dst = vec![0; (len + 1) / 2];
272
273                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
274
275                arr.append_value(dst);
276            }
277            None => arr.append_null(),
278        }
279    }
280
281    Ok(arr.finish())
282}
283
284pub fn u256_column_from_binary(col: &BinaryArray) -> Result<Decimal256Array> {
285    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
286
287    for v in col.iter() {
288        match v {
289            Some(v) => {
290                let num = ruint::aliases::U256::try_from_be_slice(v).context("parse ruint u256")?;
291                let num = alloy_primitives::I256::try_from(num)
292                    .with_context(|| format!("u256 to i256. val was {}", num))?;
293
294                let val = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
295                arr.append_value(val);
296            }
297            None => arr.append_null(),
298        }
299    }
300
301    Ok(arr.with_precision_and_scale(76, 0).unwrap().finish())
302}
303
304pub fn u256_column_to_binary(col: &Decimal256Array) -> Result<BinaryArray> {
305    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
306
307    for v in col.iter() {
308        match v {
309            Some(v) => {
310                let num = alloy_primitives::I256::from_be_bytes::<32>(v.to_be_bytes());
311                let num = ruint::aliases::U256::try_from(num).context("convert i256 to u256")?;
312                arr.append_value(num.to_be_bytes_trimmed_vec());
313            }
314            None => {
315                arr.append_null();
316            }
317        }
318    }
319
320    Ok(arr.finish())
321}
322
323/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
324pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
325    let schema = schema_decimal256_to_binary(data.schema_ref());
326    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
327
328    for (i, col) in data.columns().iter().enumerate() {
329        if col.data_type() == &DataType::Decimal256(76, 0) {
330            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
331            let x = u256_column_to_binary(col)
332                .with_context(|| format!("col {} to binary", data.schema().fields()[i].name()))?;
333            columns.push(Arc::new(x));
334        } else {
335            columns.push(col.clone());
336        }
337    }
338
339    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
340}