proof_of_sql_planner/
util.rs

1use super::{PlannerError, PlannerResult};
2use arrow::datatypes::{Field, Schema};
3use datafusion::{
4    catalog::TableReference,
5    common::{Column, ScalarValue},
6    logical_expr::expr::Placeholder,
7};
8use proof_of_sql::{
9    base::{
10        database::{ColumnField, ColumnRef, ColumnType, LiteralValue, TableRef},
11        math::decimal::Precision,
12        posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
13    },
14    sql::proof_exprs::DynProofExpr,
15};
16use sqlparser::ast::Ident;
17
18/// Parse a placeholder string of the form "$1", "$2", etc. into a `usize`.
19fn parse_placeholder_id(s: &str) -> Option<usize> {
20    s.strip_prefix('$')
21        // Must be all digits
22        .filter(|digits| digits.chars().all(|c| c.is_ascii_digit()) && !digits.starts_with('0'))
23        // Finally, parse
24        .and_then(|digits| digits.parse().ok())
25}
26
27/// Convert a datafusion [`Placeholder`] to a Proof of SQL [`PlaceholderExpr`]
28#[expect(clippy::missing_panics_doc, reason = "can not actually panic")]
29pub(crate) fn placeholder_to_placeholder_expr(
30    placeholder: &Placeholder,
31) -> PlannerResult<DynProofExpr> {
32    let df_id = placeholder.id.clone();
33    let df_type = placeholder.data_type.clone();
34    let posql_id = parse_placeholder_id(&df_id)
35        .ok_or_else(|| PlannerError::InvalidPlaceholderId { id: df_id.clone() })?;
36    let posql_type = df_type
37        .clone()
38        .ok_or(PlannerError::UntypedPlaceholder {
39            placeholder: placeholder.clone(),
40        })?
41        .try_into()
42        .map_err(|_| PlannerError::UnsupportedDataType {
43            data_type: df_type.clone().unwrap(),
44        })?;
45    Ok(DynProofExpr::try_new_placeholder(posql_id, posql_type)?)
46}
47
48/// Convert a [`TableReference`] to a [`TableRef`]
49///
50/// If catalog is provided it errors out
51pub(crate) fn table_reference_to_table_ref(table: &TableReference) -> PlannerResult<TableRef> {
52    match table {
53        TableReference::Bare { table } => Ok(TableRef::from_names(None, table)),
54        TableReference::Partial { schema, table } => Ok(TableRef::from_names(Some(schema), table)),
55        TableReference::Full { .. } => Err(PlannerError::CatalogNotSupported),
56    }
57}
58
59/// Convert a [`ScalarValue`] to a [`LiteralValue`]
60///
61/// TODO: add other types supported in `PoSQL`
62pub(crate) fn scalar_value_to_literal_value(value: ScalarValue) -> PlannerResult<LiteralValue> {
63    match value {
64        ScalarValue::Boolean(Some(v)) => Ok(LiteralValue::Boolean(v)),
65        ScalarValue::Int8(Some(v)) => Ok(LiteralValue::TinyInt(v)),
66        ScalarValue::Int16(Some(v)) => Ok(LiteralValue::SmallInt(v)),
67        ScalarValue::Int32(Some(v)) => Ok(LiteralValue::Int(v)),
68        ScalarValue::Int64(Some(v)) => Ok(LiteralValue::BigInt(v)),
69        ScalarValue::UInt8(Some(v)) => Ok(LiteralValue::Uint8(v)),
70        ScalarValue::Utf8(Some(v)) => Ok(LiteralValue::VarChar(v)),
71        ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => {
72            Ok(LiteralValue::VarBinary(v))
73        }
74        ScalarValue::TimestampSecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
75            PoSQLTimeUnit::Second,
76            PoSQLTimeZone::utc(),
77            v,
78        )),
79        ScalarValue::TimestampMillisecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
80            PoSQLTimeUnit::Millisecond,
81            PoSQLTimeZone::utc(),
82            v,
83        )),
84        ScalarValue::TimestampMicrosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
85            PoSQLTimeUnit::Microsecond,
86            PoSQLTimeZone::utc(),
87            v,
88        )),
89        ScalarValue::TimestampNanosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
90            PoSQLTimeUnit::Nanosecond,
91            PoSQLTimeZone::utc(),
92            v,
93        )),
94        ScalarValue::Decimal128(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
95            Precision::new(precision)?,
96            scale,
97            v.into(),
98        )),
99        ScalarValue::Decimal256(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
100            Precision::new(precision)?,
101            scale,
102            v.into(),
103        )),
104        _ => Err(PlannerError::UnsupportedDataType {
105            data_type: value.data_type().clone(),
106        }),
107    }
108}
109
110/// Find a column in a schema and return its info as a [`ColumnRef`]
111///
112/// Note that the table name must be provided in the column which resolved logical plans do
113/// Otherwise we error out
114pub(crate) fn column_to_column_ref(
115    column: &Column,
116    schema: &[(Ident, ColumnType)],
117) -> PlannerResult<ColumnRef> {
118    let relation = column
119        .relation
120        .as_ref()
121        .ok_or_else(|| PlannerError::UnresolvedLogicalPlan)?;
122    let table_ref = table_reference_to_table_ref(relation)?;
123    let ident: Ident = column.name.as_str().into();
124    let column_type = schema
125        .iter()
126        .find(|(i, _t)| *i == ident)
127        .ok_or(PlannerError::ColumnNotFound)?
128        .1;
129    Ok(ColumnRef::new(table_ref, ident, column_type))
130}
131
132/// Convert a Vec<ColumnField> to a Schema
133#[must_use]
134pub fn column_fields_to_schema(column_fields: Vec<ColumnField>) -> Schema {
135    Schema::new(
136        column_fields
137            .into_iter()
138            .map(|column_field| {
139                //TODO: Make columns nullable
140                let data_type = (&column_field.data_type()).into();
141                Field::new(column_field.name().value.as_str(), data_type, false)
142            })
143            .collect::<Vec<_>>(),
144    )
145}
146
147/// Convert a [`DFSchema`] to a Vec<ColumnField>
148///
149/// Note that this returns an error if any column has an unsupported `DataType`
150pub(crate) fn schema_to_column_fields(schema: Vec<(Ident, ColumnType)>) -> Vec<ColumnField> {
151    schema
152        .into_iter()
153        .map(|(name, column_type)| ColumnField::new(name, column_type))
154        .collect()
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use arrow::datatypes::DataType;
161
162    // parse_placeholder_id
163    #[test]
164    fn we_can_parse_valid_placeholder_id() {
165        // "$1" => Some(1)
166        assert_eq!(parse_placeholder_id("$1"), Some(1));
167        // "$123" => Some(123)
168        assert_eq!(parse_placeholder_id("$123"), Some(123));
169    }
170
171    #[test]
172    fn we_cannot_parse_placeholder_id_without_dollar_sign() {
173        // "" => None
174        assert_eq!(parse_placeholder_id(""), None);
175        // "1" => None
176        assert_eq!(parse_placeholder_id("1"), None);
177    }
178
179    #[test]
180    fn we_cannot_parse_placeholder_id_empty_after_dollar_sign() {
181        // "$" => None
182        assert_eq!(parse_placeholder_id("$"), None);
183    }
184
185    #[test]
186    fn we_cannot_parse_placeholder_id_with_non_digits() {
187        // "$abc" => None
188        assert_eq!(parse_placeholder_id("$abc"), None);
189        // "$1x" => None
190        assert_eq!(parse_placeholder_id("$1x"), None);
191    }
192
193    #[test]
194    fn we_cannot_parse_placeholder_id_with_leading_zero() {
195        // "$0" => None
196        assert_eq!(parse_placeholder_id("$0"), None);
197        // "$01" => None
198        assert_eq!(parse_placeholder_id("$01"), None);
199    }
200
201    // placeholder_to_placeholder_expr
202    #[test]
203    fn we_can_convert_valid_placeholder_to_placeholder_expr() {
204        let placeholder = Placeholder {
205            id: "$42".to_string(),
206            data_type: Some(DataType::Int32),
207        };
208        let expected = DynProofExpr::try_new_placeholder(42, ColumnType::Int).unwrap();
209        let result = placeholder_to_placeholder_expr(&placeholder).unwrap();
210        assert_eq!(result, expected);
211    }
212
213    #[test]
214    fn we_cannot_convert_placeholder_without_type() {
215        let placeholder = Placeholder {
216            id: "$1".to_string(),
217            data_type: None,
218        };
219        assert!(matches!(
220            placeholder_to_placeholder_expr(&placeholder),
221            Err(PlannerError::UntypedPlaceholder { .. })
222        ));
223    }
224
225    #[test]
226    fn we_cannot_convert_placeholder_with_invalid_id() {
227        let placeholder = Placeholder {
228            // Something invalid like "$0" or "1"
229            id: "$0".to_string(),
230            data_type: Some(DataType::Int32),
231        };
232        assert!(matches!(
233            placeholder_to_placeholder_expr(&placeholder),
234            Err(PlannerError::InvalidPlaceholderId { .. })
235        ));
236    }
237
238    // TableReference to TableRef
239    #[test]
240    fn we_can_convert_table_reference_to_table_ref() {
241        // Bare
242        let table = TableReference::bare("table");
243        assert_eq!(
244            table_reference_to_table_ref(&table).unwrap(),
245            TableRef::from_names(None, "table")
246        );
247
248        // Partial
249        let table = TableReference::partial("schema", "table");
250        assert_eq!(
251            table_reference_to_table_ref(&table).unwrap(),
252            TableRef::from_names(Some("schema"), "table")
253        );
254    }
255
256    #[test]
257    fn we_cannot_convert_full_table_reference_to_table_ref() {
258        let table = TableReference::full("catalog", "schema", "table");
259        assert!(matches!(
260            table_reference_to_table_ref(&table),
261            Err(PlannerError::CatalogNotSupported)
262        ));
263    }
264
265    // ScalarValue to LiteralValue
266    #[test]
267    fn we_can_convert_scalar_value_to_literal_value() {
268        // Boolean
269        let value = ScalarValue::Boolean(Some(true));
270        assert_eq!(
271            scalar_value_to_literal_value(value).unwrap(),
272            LiteralValue::Boolean(true)
273        );
274
275        // Int8
276        let value = ScalarValue::Int8(Some(1));
277        assert_eq!(
278            scalar_value_to_literal_value(value).unwrap(),
279            LiteralValue::TinyInt(1)
280        );
281
282        // Int16
283        let value = ScalarValue::Int16(Some(1));
284        assert_eq!(
285            scalar_value_to_literal_value(value).unwrap(),
286            LiteralValue::SmallInt(1)
287        );
288
289        // Int32
290        let value = ScalarValue::Int32(Some(1));
291        assert_eq!(
292            scalar_value_to_literal_value(value).unwrap(),
293            LiteralValue::Int(1)
294        );
295
296        // Int64
297        let value = ScalarValue::Int64(Some(1));
298        assert_eq!(
299            scalar_value_to_literal_value(value).unwrap(),
300            LiteralValue::BigInt(1)
301        );
302
303        // UInt8
304        let value = ScalarValue::UInt8(Some(1));
305        assert_eq!(
306            scalar_value_to_literal_value(value).unwrap(),
307            LiteralValue::Uint8(1)
308        );
309
310        // Utf8
311        let value = ScalarValue::Utf8(Some("value".to_string()));
312        assert_eq!(
313            scalar_value_to_literal_value(value).unwrap(),
314            LiteralValue::VarChar("value".to_string())
315        );
316
317        // Binary
318        let value = ScalarValue::Binary(Some(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104]));
319        assert_eq!(
320            scalar_value_to_literal_value(value).unwrap(),
321            LiteralValue::VarBinary(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104])
322        );
323
324        // TimestampSecond
325        // Thu Mar 06 2025 04:43:12 GMT+0000
326        let value = ScalarValue::TimestampSecond(Some(1_741_236_192_i64), None);
327        assert_eq!(
328            scalar_value_to_literal_value(value).unwrap(),
329            LiteralValue::TimeStampTZ(
330                PoSQLTimeUnit::Second,
331                PoSQLTimeZone::utc(),
332                1_741_236_192_i64
333            )
334        );
335
336        // TimestampMillisecond
337        let value = ScalarValue::TimestampMillisecond(Some(1_741_236_192_004_i64), None);
338        assert_eq!(
339            scalar_value_to_literal_value(value).unwrap(),
340            LiteralValue::TimeStampTZ(
341                PoSQLTimeUnit::Millisecond,
342                PoSQLTimeZone::utc(),
343                1_741_236_192_004_i64
344            )
345        );
346
347        // TimestampMicrosecond
348        let value = ScalarValue::TimestampMicrosecond(Some(1_741_236_192_004_000_i64), None);
349        assert_eq!(
350            scalar_value_to_literal_value(value).unwrap(),
351            LiteralValue::TimeStampTZ(
352                PoSQLTimeUnit::Microsecond,
353                PoSQLTimeZone::utc(),
354                1_741_236_192_004_000_i64
355            )
356        );
357
358        // TimestampNanosecond
359        let value = ScalarValue::TimestampNanosecond(Some(1_741_236_192_123_456_789_i64), None);
360        assert_eq!(
361            scalar_value_to_literal_value(value).unwrap(),
362            LiteralValue::TimeStampTZ(
363                PoSQLTimeUnit::Nanosecond,
364                PoSQLTimeZone::utc(),
365                1_741_236_192_123_456_789_i64
366            )
367        );
368    }
369
370    #[expect(clippy::cast_sign_loss)]
371    #[test]
372    fn we_can_convert_scalar_value_to_literal_value_for_decimals() {
373        // Decimal128
374        let value = ScalarValue::Decimal128(Some(123), 38, 0);
375        assert_eq!(
376            scalar_value_to_literal_value(value).unwrap(),
377            LiteralValue::Decimal75(
378                Precision::new(38).unwrap(),
379                0,
380                proof_of_sql::base::math::i256::I256::from(123i128)
381            )
382        );
383
384        // Test edge cases for Decimal128
385        let value = ScalarValue::Decimal128(Some(i128::MIN), 38, 10);
386        assert_eq!(
387            scalar_value_to_literal_value(value).unwrap(),
388            LiteralValue::Decimal75(
389                Precision::new(38).unwrap(),
390                10,
391                proof_of_sql::base::math::i256::I256::from(i128::MIN)
392            )
393        );
394
395        let value = ScalarValue::Decimal128(Some(i128::MAX), 28, -5);
396        assert_eq!(
397            scalar_value_to_literal_value(value).unwrap(),
398            LiteralValue::Decimal75(
399                Precision::new(28).unwrap(),
400                -5,
401                proof_of_sql::base::math::i256::I256::from(i128::MAX)
402            )
403        );
404
405        let value = ScalarValue::Decimal128(Some(0), 38, 0);
406        assert_eq!(
407            scalar_value_to_literal_value(value).unwrap(),
408            LiteralValue::Decimal75(
409                Precision::new(38).unwrap(),
410                0,
411                proof_of_sql::base::math::i256::I256::from(0i128)
412            )
413        );
414
415        // Decimal256
416        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::from_i128(-456)), 75, 120);
417        assert_eq!(
418            scalar_value_to_literal_value(value).unwrap(),
419            LiteralValue::Decimal75(
420                Precision::new(75).unwrap(),
421                120,
422                proof_of_sql::base::math::i256::I256::from(-456i128)
423            )
424        );
425
426        // Test edge cases for Decimal256
427        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MIN), 75, 127);
428        assert_eq!(
429            scalar_value_to_literal_value(value).unwrap(),
430            LiteralValue::Decimal75(
431                Precision::new(75).unwrap(),
432                127,
433                proof_of_sql::base::math::i256::I256::new([0, 0, 0, i64::MIN as u64])
434            )
435        );
436        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MAX), 75, -128);
437        assert_eq!(
438            scalar_value_to_literal_value(value).unwrap(),
439            LiteralValue::Decimal75(
440                Precision::new(75).unwrap(),
441                -128,
442                proof_of_sql::base::math::i256::I256::new([
443                    u64::MAX,
444                    u64::MAX,
445                    u64::MAX,
446                    i64::MAX as u64
447                ])
448            )
449        );
450        let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::ZERO), 75, 0);
451        assert_eq!(
452            scalar_value_to_literal_value(value).unwrap(),
453            LiteralValue::Decimal75(
454                Precision::new(75).unwrap(),
455                0,
456                proof_of_sql::base::math::i256::I256::from(0i128)
457            )
458        );
459    }
460
461    #[test]
462    fn we_cannot_convert_scalar_value_to_literal_value_if_unsupported() {
463        // Unsupported
464        let value = ScalarValue::Float32(Some(1.0));
465        assert!(matches!(
466            scalar_value_to_literal_value(value),
467            Err(PlannerError::UnsupportedDataType { .. })
468        ));
469    }
470
471    // Column to ColumnRef
472    #[test]
473    fn we_can_convert_column_to_column_ref() {
474        let column = Column::new(Some("namespace.table"), "a");
475        let schema = vec![("a".into(), ColumnType::Int)];
476        assert_eq!(
477            column_to_column_ref(&column, &schema).unwrap(),
478            ColumnRef::new(
479                TableRef::from_names(Some("namespace"), "table"),
480                "a".into(),
481                ColumnType::Int
482            )
483        );
484    }
485
486    #[test]
487    fn we_cannot_convert_column_to_column_ref_without_relation() {
488        let column = Column::new(None::<&str>, "a");
489        let schema = vec![("a".into(), ColumnType::Int)];
490        assert!(matches!(
491            column_to_column_ref(&column, &schema),
492            Err(PlannerError::UnresolvedLogicalPlan)
493        ));
494    }
495
496    #[test]
497    fn we_cannot_convert_column_to_column_ref_with_invalid_column_name() {
498        let column = Column::new(Some("namespace.table"), "b");
499        let schema = vec![("a".into(), ColumnType::Int)];
500        assert!(matches!(
501            column_to_column_ref(&column, &schema),
502            Err(PlannerError::ColumnNotFound)
503        ));
504    }
505
506    // ColumnFields to Schema
507    #[test]
508    fn we_can_convert_column_fields_to_schema() {
509        // Empty
510        let column_fields = vec![];
511        let schema = column_fields_to_schema(column_fields);
512        assert_eq!(schema.all_fields(), Vec::<&Field>::new());
513
514        // Non-empty
515        let column_fields = vec![
516            ColumnField::new("a".into(), ColumnType::SmallInt),
517            ColumnField::new("b".into(), ColumnType::VarChar),
518        ];
519        let schema = column_fields_to_schema(column_fields);
520        assert_eq!(
521            schema.all_fields(),
522            vec![
523                &Field::new("a", DataType::Int16, false),
524                &Field::new("b", DataType::Utf8, false),
525            ]
526        );
527    }
528
529    // DFSchema to Vec<ColumnField>
530    #[test]
531    fn we_can_convert_df_schema_to_column_fields() {
532        // Empty
533        let column_fields = schema_to_column_fields(Vec::new());
534        assert_eq!(column_fields, Vec::<ColumnField>::new());
535
536        // Non-empty
537        let schema = vec![
538            ("a".into(), ColumnType::SmallInt),
539            ("b".into(), ColumnType::VarChar),
540        ];
541        let column_fields = schema_to_column_fields(schema);
542        assert_eq!(
543            column_fields,
544            vec![
545                ColumnField::new("a".into(), ColumnType::SmallInt),
546                ColumnField::new("b".into(), ColumnType::VarChar),
547            ]
548        );
549    }
550}