proof_of_sql_planner/
util.rs

1use super::{PlannerError, PlannerResult};
2use arrow::datatypes::{Field, Schema};
3use datafusion::{
4    catalog::TableReference,
5    common::{Column, ScalarValue},
6    logical_expr::expr::Placeholder,
7};
8use proof_of_sql::{
9    base::{
10        database::{ColumnField, ColumnRef, ColumnType, LiteralValue, TableRef},
11        math::decimal::Precision,
12        posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
13    },
14    sql::proof_exprs::DynProofExpr,
15};
16use sqlparser::ast::Ident;
17
18/// Parse a placeholder string of the form "$1", "$2", etc. into a `usize`.
19fn parse_placeholder_id(s: &str) -> Option<usize> {
20    s.strip_prefix('$')
21        // Must be all digits
22        .filter(|digits| digits.chars().all(|c| c.is_ascii_digit()) && !digits.starts_with('0'))
23        // Finally, parse
24        .and_then(|digits| digits.parse().ok())
25}
26
27/// Convert a datafusion [`Placeholder`] to a Proof of SQL [`PlaceholderExpr`]
28#[expect(clippy::missing_panics_doc, reason = "can not actually panic")]
29pub(crate) fn placeholder_to_placeholder_expr(
30    placeholder: &Placeholder,
31) -> PlannerResult<DynProofExpr> {
32    let df_id = placeholder.id.clone();
33    let df_type = placeholder.data_type.clone();
34    let posql_id = parse_placeholder_id(&df_id)
35        .ok_or_else(|| PlannerError::InvalidPlaceholderId { id: df_id.clone() })?;
36    let posql_type = df_type
37        .clone()
38        .ok_or(PlannerError::UntypedPlaceholder {
39            placeholder: placeholder.clone(),
40        })?
41        .try_into()
42        .map_err(|_| PlannerError::UnsupportedDataType {
43            data_type: df_type.clone().unwrap(),
44        })?;
45    Ok(DynProofExpr::try_new_placeholder(posql_id, posql_type)?)
46}
47
48/// Convert a [`TableReference`] to a [`TableRef`]
49///
50/// If catalog is provided it errors out
51pub(crate) fn table_reference_to_table_ref(table: &TableReference) -> PlannerResult<TableRef> {
52    match table {
53        TableReference::Bare { table } => Ok(TableRef::from_names(None, table)),
54        TableReference::Partial { schema, table } => Ok(TableRef::from_names(Some(schema), table)),
55        TableReference::Full { .. } => Err(PlannerError::CatalogNotSupported),
56    }
57}
58
59/// Convert a [`ScalarValue`] to a [`LiteralValue`]
60///
61/// TODO: add other types supported in `PoSQL`
62pub(crate) fn scalar_value_to_literal_value(value: ScalarValue) -> PlannerResult<LiteralValue> {
63    match value {
64        ScalarValue::Boolean(Some(v)) => Ok(LiteralValue::Boolean(v)),
65        ScalarValue::Int8(Some(v)) => Ok(LiteralValue::TinyInt(v)),
66        ScalarValue::Int16(Some(v)) => Ok(LiteralValue::SmallInt(v)),
67        ScalarValue::Int32(Some(v)) => Ok(LiteralValue::Int(v)),
68        ScalarValue::Int64(Some(v)) => Ok(LiteralValue::BigInt(v)),
69        ScalarValue::UInt8(Some(v)) => Ok(LiteralValue::Uint8(v)),
70        ScalarValue::Utf8(Some(v)) => Ok(LiteralValue::VarChar(v)),
71        ScalarValue::Binary(Some(v)) => Ok(LiteralValue::VarBinary(v)),
72        ScalarValue::TimestampSecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
73            PoSQLTimeUnit::Second,
74            PoSQLTimeZone::utc(),
75            v,
76        )),
77        ScalarValue::TimestampMillisecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
78            PoSQLTimeUnit::Millisecond,
79            PoSQLTimeZone::utc(),
80            v,
81        )),
82        ScalarValue::TimestampMicrosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
83            PoSQLTimeUnit::Microsecond,
84            PoSQLTimeZone::utc(),
85            v,
86        )),
87        ScalarValue::TimestampNanosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
88            PoSQLTimeUnit::Nanosecond,
89            PoSQLTimeZone::utc(),
90            v,
91        )),
92        ScalarValue::Decimal128(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
93            Precision::new(precision)?,
94            scale,
95            v.into(),
96        )),
97        ScalarValue::Decimal256(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
98            Precision::new(precision)?,
99            scale,
100            v.into(),
101        )),
102        _ => Err(PlannerError::UnsupportedDataType {
103            data_type: value.data_type().clone(),
104        }),
105    }
106}
107
108/// Find a column in a schema and return its info as a [`ColumnRef`]
109///
110/// Note that the table name must be provided in the column which resolved logical plans do
111/// Otherwise we error out
112pub(crate) fn column_to_column_ref(
113    column: &Column,
114    schema: &[(Ident, ColumnType)],
115) -> PlannerResult<ColumnRef> {
116    let relation = column
117        .relation
118        .as_ref()
119        .ok_or_else(|| PlannerError::UnresolvedLogicalPlan)?;
120    let table_ref = table_reference_to_table_ref(relation)?;
121    let ident: Ident = column.name.as_str().into();
122    let column_type = schema
123        .iter()
124        .find(|(i, _t)| *i == ident)
125        .ok_or(PlannerError::ColumnNotFound)?
126        .1;
127    Ok(ColumnRef::new(table_ref, ident, column_type))
128}
129
130/// Convert a Vec<ColumnField> to a Schema
131#[must_use]
132pub fn column_fields_to_schema(column_fields: Vec<ColumnField>) -> Schema {
133    Schema::new(
134        column_fields
135            .into_iter()
136            .map(|column_field| {
137                //TODO: Make columns nullable
138                let data_type = (&column_field.data_type()).into();
139                Field::new(column_field.name().value.as_str(), data_type, false)
140            })
141            .collect::<Vec<_>>(),
142    )
143}
144
145/// Convert a [`DFSchema`] to a Vec<ColumnField>
146///
147/// Note that this returns an error if any column has an unsupported `DataType`
148pub(crate) fn schema_to_column_fields(schema: Vec<(Ident, ColumnType)>) -> Vec<ColumnField> {
149    schema
150        .into_iter()
151        .map(|(name, column_type)| ColumnField::new(name, column_type))
152        .collect()
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use arrow::datatypes::DataType;
159
160    // parse_placeholder_id
161    #[test]
162    fn we_can_parse_valid_placeholder_id() {
163        // "$1" => Some(1)
164        assert_eq!(parse_placeholder_id("$1"), Some(1));
165        // "$123" => Some(123)
166        assert_eq!(parse_placeholder_id("$123"), Some(123));
167    }
168
169    #[test]
170    fn we_cannot_parse_placeholder_id_without_dollar_sign() {
171        // "" => None
172        assert_eq!(parse_placeholder_id(""), None);
173        // "1" => None
174        assert_eq!(parse_placeholder_id("1"), None);
175    }
176
177    #[test]
178    fn we_cannot_parse_placeholder_id_empty_after_dollar_sign() {
179        // "$" => None
180        assert_eq!(parse_placeholder_id("$"), None);
181    }
182
183    #[test]
184    fn we_cannot_parse_placeholder_id_with_non_digits() {
185        // "$abc" => None
186        assert_eq!(parse_placeholder_id("$abc"), None);
187        // "$1x" => None
188        assert_eq!(parse_placeholder_id("$1x"), None);
189    }
190
191    #[test]
192    fn we_cannot_parse_placeholder_id_with_leading_zero() {
193        // "$0" => None
194        assert_eq!(parse_placeholder_id("$0"), None);
195        // "$01" => None
196        assert_eq!(parse_placeholder_id("$01"), None);
197    }
198
199    // placeholder_to_placeholder_expr
200    #[test]
201    fn we_can_convert_valid_placeholder_to_placeholder_expr() {
202        let placeholder = Placeholder {
203            id: "$42".to_string(),
204            data_type: Some(DataType::Int32),
205        };
206        let expected = DynProofExpr::try_new_placeholder(42, ColumnType::Int).unwrap();
207        let result = placeholder_to_placeholder_expr(&placeholder).unwrap();
208        assert_eq!(result, expected);
209    }
210
211    #[test]
212    fn we_cannot_convert_placeholder_without_type() {
213        let placeholder = Placeholder {
214            id: "$1".to_string(),
215            data_type: None,
216        };
217        assert!(matches!(
218            placeholder_to_placeholder_expr(&placeholder),
219            Err(PlannerError::UntypedPlaceholder { .. })
220        ));
221    }
222
223    #[test]
224    fn we_cannot_convert_placeholder_with_invalid_id() {
225        let placeholder = Placeholder {
226            // Something invalid like "$0" or "1"
227            id: "$0".to_string(),
228            data_type: Some(DataType::Int32),
229        };
230        assert!(matches!(
231            placeholder_to_placeholder_expr(&placeholder),
232            Err(PlannerError::InvalidPlaceholderId { .. })
233        ));
234    }
235
236    // TableReference to TableRef
237    #[test]
238    fn we_can_convert_table_reference_to_table_ref() {
239        // Bare
240        let table = TableReference::bare("table");
241        assert_eq!(
242            table_reference_to_table_ref(&table).unwrap(),
243            TableRef::from_names(None, "table")
244        );
245
246        // Partial
247        let table = TableReference::partial("schema", "table");
248        assert_eq!(
249            table_reference_to_table_ref(&table).unwrap(),
250            TableRef::from_names(Some("schema"), "table")
251        );
252    }
253
254    #[test]
255    fn we_cannot_convert_full_table_reference_to_table_ref() {
256        let table = TableReference::full("catalog", "schema", "table");
257        assert!(matches!(
258            table_reference_to_table_ref(&table),
259            Err(PlannerError::CatalogNotSupported)
260        ));
261    }
262
263    // ScalarValue to LiteralValue
264    #[test]
265    fn we_can_convert_scalar_value_to_literal_value() {
266        // Boolean
267        let value = ScalarValue::Boolean(Some(true));
268        assert_eq!(
269            scalar_value_to_literal_value(value).unwrap(),
270            LiteralValue::Boolean(true)
271        );
272
273        // Int8
274        let value = ScalarValue::Int8(Some(1));
275        assert_eq!(
276            scalar_value_to_literal_value(value).unwrap(),
277            LiteralValue::TinyInt(1)
278        );
279
280        // Int16
281        let value = ScalarValue::Int16(Some(1));
282        assert_eq!(
283            scalar_value_to_literal_value(value).unwrap(),
284            LiteralValue::SmallInt(1)
285        );
286
287        // Int32
288        let value = ScalarValue::Int32(Some(1));
289        assert_eq!(
290            scalar_value_to_literal_value(value).unwrap(),
291            LiteralValue::Int(1)
292        );
293
294        // Int64
295        let value = ScalarValue::Int64(Some(1));
296        assert_eq!(
297            scalar_value_to_literal_value(value).unwrap(),
298            LiteralValue::BigInt(1)
299        );
300
301        // UInt8
302        let value = ScalarValue::UInt8(Some(1));
303        assert_eq!(
304            scalar_value_to_literal_value(value).unwrap(),
305            LiteralValue::Uint8(1)
306        );
307
308        // Utf8
309        let value = ScalarValue::Utf8(Some("value".to_string()));
310        assert_eq!(
311            scalar_value_to_literal_value(value).unwrap(),
312            LiteralValue::VarChar("value".to_string())
313        );
314
315        // Binary
316        let value = ScalarValue::Binary(Some(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104]));
317        assert_eq!(
318            scalar_value_to_literal_value(value).unwrap(),
319            LiteralValue::VarBinary(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104])
320        );
321
322        // TimestampSecond
323        // Thu Mar 06 2025 04:43:12 GMT+0000
324        let value = ScalarValue::TimestampSecond(Some(1_741_236_192_i64), None);
325        assert_eq!(
326            scalar_value_to_literal_value(value).unwrap(),
327            LiteralValue::TimeStampTZ(
328                PoSQLTimeUnit::Second,
329                PoSQLTimeZone::utc(),
330                1_741_236_192_i64
331            )
332        );
333
334        // TimestampMillisecond
335        let value = ScalarValue::TimestampMillisecond(Some(1_741_236_192_004_i64), None);
336        assert_eq!(
337            scalar_value_to_literal_value(value).unwrap(),
338            LiteralValue::TimeStampTZ(
339                PoSQLTimeUnit::Millisecond,
340                PoSQLTimeZone::utc(),
341                1_741_236_192_004_i64
342            )
343        );
344
345        // TimestampMicrosecond
346        let value = ScalarValue::TimestampMicrosecond(Some(1_741_236_192_004_000_i64), None);
347        assert_eq!(
348            scalar_value_to_literal_value(value).unwrap(),
349            LiteralValue::TimeStampTZ(
350                PoSQLTimeUnit::Microsecond,
351                PoSQLTimeZone::utc(),
352                1_741_236_192_004_000_i64
353            )
354        );
355
356        // TimestampNanosecond
357        let value = ScalarValue::TimestampNanosecond(Some(1_741_236_192_123_456_789_i64), None);
358        assert_eq!(
359            scalar_value_to_literal_value(value).unwrap(),
360            LiteralValue::TimeStampTZ(
361                PoSQLTimeUnit::Nanosecond,
362                PoSQLTimeZone::utc(),
363                1_741_236_192_123_456_789_i64
364            )
365        );
366    }
367
368    #[expect(clippy::cast_sign_loss)]
369    #[test]
370    fn we_can_convert_scalar_value_to_literal_value_for_decimals() {
371        // Decimal128
372        let value = ScalarValue::Decimal128(Some(123), 38, 0);
373        assert_eq!(
374            scalar_value_to_literal_value(value).unwrap(),
375            LiteralValue::Decimal75(
376                Precision::new(38).unwrap(),
377                0,
378                proof_of_sql::base::math::i256::I256::from(123i128)
379            )
380        );
381
382        // Test edge cases for Decimal128
383        let value = ScalarValue::Decimal128(Some(i128::MIN), 38, 10);
384        assert_eq!(
385            scalar_value_to_literal_value(value).unwrap(),
386            LiteralValue::Decimal75(
387                Precision::new(38).unwrap(),
388                10,
389                proof_of_sql::base::math::i256::I256::from(i128::MIN)
390            )
391        );
392
393        let value = ScalarValue::Decimal128(Some(i128::MAX), 28, -5);
394        assert_eq!(
395            scalar_value_to_literal_value(value).unwrap(),
396            LiteralValue::Decimal75(
397                Precision::new(28).unwrap(),
398                -5,
399                proof_of_sql::base::math::i256::I256::from(i128::MAX)
400            )
401        );
402
403        let value = ScalarValue::Decimal128(Some(0), 38, 0);
404        assert_eq!(
405            scalar_value_to_literal_value(value).unwrap(),
406            LiteralValue::Decimal75(
407                Precision::new(38).unwrap(),
408                0,
409                proof_of_sql::base::math::i256::I256::from(0i128)
410            )
411        );
412
413        // Decimal256
414        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::from_i128(-456)), 75, 120);
415        assert_eq!(
416            scalar_value_to_literal_value(value).unwrap(),
417            LiteralValue::Decimal75(
418                Precision::new(75).unwrap(),
419                120,
420                proof_of_sql::base::math::i256::I256::from(-456i128)
421            )
422        );
423
424        // Test edge cases for Decimal256
425        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MIN), 75, 127);
426        assert_eq!(
427            scalar_value_to_literal_value(value).unwrap(),
428            LiteralValue::Decimal75(
429                Precision::new(75).unwrap(),
430                127,
431                proof_of_sql::base::math::i256::I256::new([0, 0, 0, i64::MIN as u64])
432            )
433        );
434        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MAX), 75, -128);
435        assert_eq!(
436            scalar_value_to_literal_value(value).unwrap(),
437            LiteralValue::Decimal75(
438                Precision::new(75).unwrap(),
439                -128,
440                proof_of_sql::base::math::i256::I256::new([
441                    u64::MAX,
442                    u64::MAX,
443                    u64::MAX,
444                    i64::MAX as u64
445                ])
446            )
447        );
448        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::ZERO), 75, 0);
449        assert_eq!(
450            scalar_value_to_literal_value(value).unwrap(),
451            LiteralValue::Decimal75(
452                Precision::new(75).unwrap(),
453                0,
454                proof_of_sql::base::math::i256::I256::from(0i128)
455            )
456        );
457    }
458
459    #[test]
460    fn we_cannot_convert_scalar_value_to_literal_value_if_unsupported() {
461        // Unsupported
462        let value = ScalarValue::Float32(Some(1.0));
463        assert!(matches!(
464            scalar_value_to_literal_value(value),
465            Err(PlannerError::UnsupportedDataType { .. })
466        ));
467    }
468
469    // Column to ColumnRef
470    #[test]
471    fn we_can_convert_column_to_column_ref() {
472        let column = Column::new(Some("namespace.table"), "a");
473        let schema = vec![("a".into(), ColumnType::Int)];
474        assert_eq!(
475            column_to_column_ref(&column, &schema).unwrap(),
476            ColumnRef::new(
477                TableRef::from_names(Some("namespace"), "table"),
478                "a".into(),
479                ColumnType::Int
480            )
481        );
482    }
483
484    #[test]
485    fn we_cannot_convert_column_to_column_ref_without_relation() {
486        let column = Column::new(None::<&str>, "a");
487        let schema = vec![("a".into(), ColumnType::Int)];
488        assert!(matches!(
489            column_to_column_ref(&column, &schema),
490            Err(PlannerError::UnresolvedLogicalPlan)
491        ));
492    }
493
494    #[test]
495    fn we_cannot_convert_column_to_column_ref_with_invalid_column_name() {
496        let column = Column::new(Some("namespace.table"), "b");
497        let schema = vec![("a".into(), ColumnType::Int)];
498        assert!(matches!(
499            column_to_column_ref(&column, &schema),
500            Err(PlannerError::ColumnNotFound)
501        ));
502    }
503
504    // ColumnFields to Schema
505    #[test]
506    fn we_can_convert_column_fields_to_schema() {
507        // Empty
508        let column_fields = vec![];
509        let schema = column_fields_to_schema(column_fields);
510        assert_eq!(schema.all_fields(), Vec::<&Field>::new());
511
512        // Non-empty
513        let column_fields = vec![
514            ColumnField::new("a".into(), ColumnType::SmallInt),
515            ColumnField::new("b".into(), ColumnType::VarChar),
516        ];
517        let schema = column_fields_to_schema(column_fields);
518        assert_eq!(
519            schema.all_fields(),
520            vec![
521                &Field::new("a", DataType::Int16, false),
522                &Field::new("b", DataType::Utf8, false),
523            ]
524        );
525    }
526
527    // DFSchema to Vec<ColumnField>
528    #[test]
529    fn we_can_convert_df_schema_to_column_fields() {
530        // Empty
531        let column_fields = schema_to_column_fields(Vec::new());
532        assert_eq!(column_fields, Vec::<ColumnField>::new());
533
534        // Non-empty
535        let schema = vec![
536            ("a".into(), ColumnType::SmallInt),
537            ("b".into(), ColumnType::VarChar),
538        ];
539        let column_fields = schema_to_column_fields(schema);
540        assert_eq!(
541            column_fields,
542            vec![
543                ColumnField::new("a".into(), ColumnType::SmallInt),
544                ColumnField::new("b".into(), ColumnType::VarChar),
545            ]
546        );
547    }
548}