cherry_cast/
lib.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use arrow::{
5    array::{builder, Array, BinaryArray, Decimal256Array, RecordBatch, StringArray},
6    compute::CastOptions,
7    datatypes::{DataType, Field, Schema},
8};
9use ruint::aliases::U256;
10
11/// Casts columns according to given (column name, target data type) pairs.
12///
13/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
14/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
15pub fn cast<S: AsRef<str>>(
16    map: &[(S, DataType)],
17    data: &RecordBatch,
18    allow_cast_fail: bool,
19) -> Result<RecordBatch> {
20    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
21
22    let mut arrays = Vec::with_capacity(data.num_columns());
23
24    let cast_opt = CastOptions {
25        safe: !allow_cast_fail,
26        ..Default::default()
27    };
28
29    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
30        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
31
32        let col = match cast_target {
33            Some(tgt) => Arc::new(
34                arrow::compute::cast_with_options(col, &tgt.1, &cast_opt)
35                    .with_context(|| format!("Failed when casting column '{}'", field.name()))?,
36            ),
37            None => col.clone(),
38        };
39
40        arrays.push(col);
41    }
42
43    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
44
45    Ok(batch)
46}
47
48/// Casts column types according to given (column name, target data type) pairs.
49pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
50    let mut fields = schema.fields().to_vec();
51
52    for f in fields.iter_mut() {
53        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
54
55        if let Some(tgt) = cast_target {
56            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
57        }
58    }
59
60    Ok(Schema::new(fields))
61}
62
63pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
64    let schema = schema_binary_to_string(data.schema_ref());
65    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
66
67    for col in data.columns().iter() {
68        if col.data_type() == &DataType::Binary {
69            columns.push(Arc::new(base58_encode_column(
70                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
71            )));
72        } else {
73            columns.push(col.clone());
74        }
75    }
76
77    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
78}
79
80pub fn base58_encode_column(col: &BinaryArray) -> StringArray {
81    let mut arr =
82        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
83
84    for v in col.iter() {
85        match v {
86            Some(v) => {
87                let v = bs58::encode(v)
88                    .with_alphabet(bs58::Alphabet::BITCOIN)
89                    .into_string();
90                arr.append_value(v);
91            }
92            None => arr.append_null(),
93        }
94    }
95
96    arr.finish()
97}
98
99pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
100    let schema = schema_binary_to_string(data.schema_ref());
101    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
102
103    for col in data.columns().iter() {
104        if col.data_type() == &DataType::Binary {
105            columns.push(Arc::new(hex_encode_column::<PREFIXED>(
106                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
107            )));
108        } else {
109            columns.push(col.clone());
110        }
111    }
112
113    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
114}
115
116pub fn hex_encode_column<const PREFIXED: bool>(col: &BinaryArray) -> StringArray {
117    let mut arr =
118        builder::StringBuilder::with_capacity(col.len(), (col.value_data().len() + 2) * 2);
119
120    for v in col.iter() {
121        match v {
122            Some(v) => {
123                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
124                // directly somehow.
125                let v = if PREFIXED {
126                    format!("0x{}", faster_hex::hex_string(v))
127                } else {
128                    faster_hex::hex_string(v)
129                };
130
131                arr.append_value(v);
132            }
133            None => arr.append_null(),
134        }
135    }
136
137    arr.finish()
138}
139
140/// Converts binary fields to string in the schema
141///
142/// Intended to be used with encode hex functions
143pub fn schema_binary_to_string(schema: &Schema) -> Schema {
144    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
145
146    for f in schema.fields().iter() {
147        if f.data_type() == &DataType::Binary {
148            fields.push(Arc::new(Field::new(
149                f.name().clone(),
150                DataType::Utf8,
151                f.is_nullable(),
152            )));
153        } else {
154            fields.push(f.clone());
155        }
156    }
157
158    Schema::new(fields)
159}
160
161/// Converts decimal256 fields to binary in the schema
162///
163/// Intended to be used with u256_to_binary function
164pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
165    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
166
167    for f in schema.fields().iter() {
168        if f.data_type() == &DataType::Decimal256(76, 0) {
169            fields.push(Arc::new(Field::new(
170                f.name().clone(),
171                DataType::Binary,
172                f.is_nullable(),
173            )));
174        } else {
175            fields.push(f.clone());
176        }
177    }
178
179    Schema::new(fields)
180}
181
182pub fn base58_decode_column(col: &StringArray) -> Result<BinaryArray> {
183    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
184
185    for v in col.iter() {
186        match v {
187            // TODO: this should be optimized by removing allocations if needed
188            Some(v) => {
189                let v = bs58::decode(v)
190                    .with_alphabet(bs58::Alphabet::BITCOIN)
191                    .into_vec()
192                    .context("bs58 decode")?;
193                arr.append_value(v);
194            }
195            None => arr.append_null(),
196        }
197    }
198
199    Ok(arr.finish())
200}
201
202pub fn hex_decode_column<const PREFIXED: bool>(col: &StringArray) -> Result<BinaryArray> {
203    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.value_data().len() / 2);
204
205    for v in col.iter() {
206        match v {
207            // TODO: this should be optimized by removing allocations if needed
208            Some(v) => {
209                let v = v.as_bytes();
210                let v = if PREFIXED {
211                    v.get(2..).context("index into prefix hex encoded value")?
212                } else {
213                    v
214                };
215
216                let len = v.len();
217                let mut dst = vec![0; (len + 1) / 2];
218
219                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
220
221                arr.append_value(dst);
222            }
223            None => arr.append_null(),
224        }
225    }
226
227    Ok(arr.finish())
228}
229
230pub fn u256_column_from_binary(col: &BinaryArray) -> Result<Decimal256Array> {
231    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
232
233    for v in col.iter() {
234        match v {
235            Some(v) => {
236                let num = U256::try_from_be_slice(v).context("parse u256")?;
237                let num = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
238                arr.append_value(num);
239            }
240            None => arr.append_null(),
241        }
242    }
243
244    Ok(arr.with_precision_and_scale(76, 0).unwrap().finish())
245}
246
247pub fn u256_column_to_binary(col: &Decimal256Array) -> BinaryArray {
248    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
249
250    for v in col.iter() {
251        match v {
252            Some(v) => {
253                let num = U256::from_be_bytes::<32>(v.to_be_bytes());
254                arr.append_value(num.to_be_bytes_trimmed_vec());
255            }
256            None => {
257                arr.append_null();
258            }
259        }
260    }
261
262    arr.finish()
263}
264
265/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
266pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
267    let schema = schema_binary_to_string(data.schema_ref());
268    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
269
270    for col in data.columns().iter() {
271        if col.data_type() == &DataType::Decimal256(76, 0) {
272            let mut arr = builder::BinaryBuilder::new();
273
274            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
275
276            for val in col.iter() {
277                arr.append_option(val.map(|v| v.to_be_bytes()));
278            }
279
280            columns.push(Arc::new(arr.finish()));
281        } else {
282            columns.push(col.clone());
283        }
284    }
285
286    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
287}