1use super::{PlannerError, PlannerResult};
2use arrow::datatypes::{Field, Schema};
3use datafusion::{
4 catalog::TableReference,
5 common::{Column, ScalarValue},
6 logical_expr::expr::Placeholder,
7};
8use proof_of_sql::{
9 base::{
10 database::{ColumnField, ColumnRef, ColumnType, LiteralValue, TableRef},
11 math::decimal::Precision,
12 posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
13 },
14 sql::proof_exprs::DynProofExpr,
15};
16use sqlparser::ast::Ident;
17
18fn parse_placeholder_id(s: &str) -> Option<usize> {
20 s.strip_prefix('$')
21 .filter(|digits| digits.chars().all(|c| c.is_ascii_digit()) && !digits.starts_with('0'))
23 .and_then(|digits| digits.parse().ok())
25}
26
27#[expect(clippy::missing_panics_doc, reason = "can not actually panic")]
29pub(crate) fn placeholder_to_placeholder_expr(
30 placeholder: &Placeholder,
31) -> PlannerResult<DynProofExpr> {
32 let df_id = placeholder.id.clone();
33 let df_type = placeholder.data_type.clone();
34 let posql_id = parse_placeholder_id(&df_id)
35 .ok_or_else(|| PlannerError::InvalidPlaceholderId { id: df_id.clone() })?;
36 let posql_type = df_type
37 .clone()
38 .ok_or(PlannerError::UntypedPlaceholder {
39 placeholder: placeholder.clone(),
40 })?
41 .try_into()
42 .map_err(|_| PlannerError::UnsupportedDataType {
43 data_type: df_type.clone().unwrap(),
44 })?;
45 Ok(DynProofExpr::try_new_placeholder(posql_id, posql_type)?)
46}
47
48pub(crate) fn table_reference_to_table_ref(table: &TableReference) -> PlannerResult<TableRef> {
52 match table {
53 TableReference::Bare { table } => Ok(TableRef::from_names(None, table)),
54 TableReference::Partial { schema, table } => Ok(TableRef::from_names(Some(schema), table)),
55 TableReference::Full { .. } => Err(PlannerError::CatalogNotSupported),
56 }
57}
58
59pub(crate) fn scalar_value_to_literal_value(value: ScalarValue) -> PlannerResult<LiteralValue> {
63 match value {
64 ScalarValue::Boolean(Some(v)) => Ok(LiteralValue::Boolean(v)),
65 ScalarValue::Int8(Some(v)) => Ok(LiteralValue::TinyInt(v)),
66 ScalarValue::Int16(Some(v)) => Ok(LiteralValue::SmallInt(v)),
67 ScalarValue::Int32(Some(v)) => Ok(LiteralValue::Int(v)),
68 ScalarValue::Int64(Some(v)) => Ok(LiteralValue::BigInt(v)),
69 ScalarValue::UInt8(Some(v)) => Ok(LiteralValue::Uint8(v)),
70 ScalarValue::Utf8(Some(v)) => Ok(LiteralValue::VarChar(v)),
71 ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => {
72 Ok(LiteralValue::VarBinary(v))
73 }
74 ScalarValue::TimestampSecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
75 PoSQLTimeUnit::Second,
76 PoSQLTimeZone::utc(),
77 v,
78 )),
79 ScalarValue::TimestampMillisecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
80 PoSQLTimeUnit::Millisecond,
81 PoSQLTimeZone::utc(),
82 v,
83 )),
84 ScalarValue::TimestampMicrosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
85 PoSQLTimeUnit::Microsecond,
86 PoSQLTimeZone::utc(),
87 v,
88 )),
89 ScalarValue::TimestampNanosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
90 PoSQLTimeUnit::Nanosecond,
91 PoSQLTimeZone::utc(),
92 v,
93 )),
94 ScalarValue::Decimal128(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
95 Precision::new(precision)?,
96 scale,
97 v.into(),
98 )),
99 ScalarValue::Decimal256(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
100 Precision::new(precision)?,
101 scale,
102 v.into(),
103 )),
104 _ => Err(PlannerError::UnsupportedDataType {
105 data_type: value.data_type().clone(),
106 }),
107 }
108}
109
110pub(crate) fn column_to_column_ref(
115 column: &Column,
116 schema: &[(Ident, ColumnType)],
117) -> PlannerResult<ColumnRef> {
118 let relation = column
119 .relation
120 .as_ref()
121 .ok_or_else(|| PlannerError::UnresolvedLogicalPlan)?;
122 let table_ref = table_reference_to_table_ref(relation)?;
123 let ident: Ident = column.name.as_str().into();
124 let column_type = schema
125 .iter()
126 .find(|(i, _t)| *i == ident)
127 .ok_or(PlannerError::ColumnNotFound)?
128 .1;
129 Ok(ColumnRef::new(table_ref, ident, column_type))
130}
131
132#[must_use]
134pub fn column_fields_to_schema(column_fields: Vec<ColumnField>) -> Schema {
135 Schema::new(
136 column_fields
137 .into_iter()
138 .map(|column_field| {
139 let data_type = (&column_field.data_type()).into();
141 Field::new(column_field.name().value.as_str(), data_type, false)
142 })
143 .collect::<Vec<_>>(),
144 )
145}
146
147pub(crate) fn schema_to_column_fields(schema: Vec<(Ident, ColumnType)>) -> Vec<ColumnField> {
151 schema
152 .into_iter()
153 .map(|(name, column_type)| ColumnField::new(name, column_type))
154 .collect()
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use arrow::datatypes::DataType;
161
162 #[test]
164 fn we_can_parse_valid_placeholder_id() {
165 assert_eq!(parse_placeholder_id("$1"), Some(1));
167 assert_eq!(parse_placeholder_id("$123"), Some(123));
169 }
170
171 #[test]
172 fn we_cannot_parse_placeholder_id_without_dollar_sign() {
173 assert_eq!(parse_placeholder_id(""), None);
175 assert_eq!(parse_placeholder_id("1"), None);
177 }
178
179 #[test]
180 fn we_cannot_parse_placeholder_id_empty_after_dollar_sign() {
181 assert_eq!(parse_placeholder_id("$"), None);
183 }
184
185 #[test]
186 fn we_cannot_parse_placeholder_id_with_non_digits() {
187 assert_eq!(parse_placeholder_id("$abc"), None);
189 assert_eq!(parse_placeholder_id("$1x"), None);
191 }
192
193 #[test]
194 fn we_cannot_parse_placeholder_id_with_leading_zero() {
195 assert_eq!(parse_placeholder_id("$0"), None);
197 assert_eq!(parse_placeholder_id("$01"), None);
199 }
200
201 #[test]
203 fn we_can_convert_valid_placeholder_to_placeholder_expr() {
204 let placeholder = Placeholder {
205 id: "$42".to_string(),
206 data_type: Some(DataType::Int32),
207 };
208 let expected = DynProofExpr::try_new_placeholder(42, ColumnType::Int).unwrap();
209 let result = placeholder_to_placeholder_expr(&placeholder).unwrap();
210 assert_eq!(result, expected);
211 }
212
213 #[test]
214 fn we_cannot_convert_placeholder_without_type() {
215 let placeholder = Placeholder {
216 id: "$1".to_string(),
217 data_type: None,
218 };
219 assert!(matches!(
220 placeholder_to_placeholder_expr(&placeholder),
221 Err(PlannerError::UntypedPlaceholder { .. })
222 ));
223 }
224
225 #[test]
226 fn we_cannot_convert_placeholder_with_invalid_id() {
227 let placeholder = Placeholder {
228 id: "$0".to_string(),
230 data_type: Some(DataType::Int32),
231 };
232 assert!(matches!(
233 placeholder_to_placeholder_expr(&placeholder),
234 Err(PlannerError::InvalidPlaceholderId { .. })
235 ));
236 }
237
238 #[test]
240 fn we_can_convert_table_reference_to_table_ref() {
241 let table = TableReference::bare("table");
243 assert_eq!(
244 table_reference_to_table_ref(&table).unwrap(),
245 TableRef::from_names(None, "table")
246 );
247
248 let table = TableReference::partial("schema", "table");
250 assert_eq!(
251 table_reference_to_table_ref(&table).unwrap(),
252 TableRef::from_names(Some("schema"), "table")
253 );
254 }
255
256 #[test]
257 fn we_cannot_convert_full_table_reference_to_table_ref() {
258 let table = TableReference::full("catalog", "schema", "table");
259 assert!(matches!(
260 table_reference_to_table_ref(&table),
261 Err(PlannerError::CatalogNotSupported)
262 ));
263 }
264
265 #[test]
267 fn we_can_convert_scalar_value_to_literal_value() {
268 let value = ScalarValue::Boolean(Some(true));
270 assert_eq!(
271 scalar_value_to_literal_value(value).unwrap(),
272 LiteralValue::Boolean(true)
273 );
274
275 let value = ScalarValue::Int8(Some(1));
277 assert_eq!(
278 scalar_value_to_literal_value(value).unwrap(),
279 LiteralValue::TinyInt(1)
280 );
281
282 let value = ScalarValue::Int16(Some(1));
284 assert_eq!(
285 scalar_value_to_literal_value(value).unwrap(),
286 LiteralValue::SmallInt(1)
287 );
288
289 let value = ScalarValue::Int32(Some(1));
291 assert_eq!(
292 scalar_value_to_literal_value(value).unwrap(),
293 LiteralValue::Int(1)
294 );
295
296 let value = ScalarValue::Int64(Some(1));
298 assert_eq!(
299 scalar_value_to_literal_value(value).unwrap(),
300 LiteralValue::BigInt(1)
301 );
302
303 let value = ScalarValue::UInt8(Some(1));
305 assert_eq!(
306 scalar_value_to_literal_value(value).unwrap(),
307 LiteralValue::Uint8(1)
308 );
309
310 let value = ScalarValue::Utf8(Some("value".to_string()));
312 assert_eq!(
313 scalar_value_to_literal_value(value).unwrap(),
314 LiteralValue::VarChar("value".to_string())
315 );
316
317 let value = ScalarValue::Binary(Some(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104]));
319 assert_eq!(
320 scalar_value_to_literal_value(value).unwrap(),
321 LiteralValue::VarBinary(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104])
322 );
323
324 let value = ScalarValue::TimestampSecond(Some(1_741_236_192_i64), None);
327 assert_eq!(
328 scalar_value_to_literal_value(value).unwrap(),
329 LiteralValue::TimeStampTZ(
330 PoSQLTimeUnit::Second,
331 PoSQLTimeZone::utc(),
332 1_741_236_192_i64
333 )
334 );
335
336 let value = ScalarValue::TimestampMillisecond(Some(1_741_236_192_004_i64), None);
338 assert_eq!(
339 scalar_value_to_literal_value(value).unwrap(),
340 LiteralValue::TimeStampTZ(
341 PoSQLTimeUnit::Millisecond,
342 PoSQLTimeZone::utc(),
343 1_741_236_192_004_i64
344 )
345 );
346
347 let value = ScalarValue::TimestampMicrosecond(Some(1_741_236_192_004_000_i64), None);
349 assert_eq!(
350 scalar_value_to_literal_value(value).unwrap(),
351 LiteralValue::TimeStampTZ(
352 PoSQLTimeUnit::Microsecond,
353 PoSQLTimeZone::utc(),
354 1_741_236_192_004_000_i64
355 )
356 );
357
358 let value = ScalarValue::TimestampNanosecond(Some(1_741_236_192_123_456_789_i64), None);
360 assert_eq!(
361 scalar_value_to_literal_value(value).unwrap(),
362 LiteralValue::TimeStampTZ(
363 PoSQLTimeUnit::Nanosecond,
364 PoSQLTimeZone::utc(),
365 1_741_236_192_123_456_789_i64
366 )
367 );
368 }
369
370 #[expect(clippy::cast_sign_loss)]
371 #[test]
372 fn we_can_convert_scalar_value_to_literal_value_for_decimals() {
373 let value = ScalarValue::Decimal128(Some(123), 38, 0);
375 assert_eq!(
376 scalar_value_to_literal_value(value).unwrap(),
377 LiteralValue::Decimal75(
378 Precision::new(38).unwrap(),
379 0,
380 proof_of_sql::base::math::i256::I256::from(123i128)
381 )
382 );
383
384 let value = ScalarValue::Decimal128(Some(i128::MIN), 38, 10);
386 assert_eq!(
387 scalar_value_to_literal_value(value).unwrap(),
388 LiteralValue::Decimal75(
389 Precision::new(38).unwrap(),
390 10,
391 proof_of_sql::base::math::i256::I256::from(i128::MIN)
392 )
393 );
394
395 let value = ScalarValue::Decimal128(Some(i128::MAX), 28, -5);
396 assert_eq!(
397 scalar_value_to_literal_value(value).unwrap(),
398 LiteralValue::Decimal75(
399 Precision::new(28).unwrap(),
400 -5,
401 proof_of_sql::base::math::i256::I256::from(i128::MAX)
402 )
403 );
404
405 let value = ScalarValue::Decimal128(Some(0), 38, 0);
406 assert_eq!(
407 scalar_value_to_literal_value(value).unwrap(),
408 LiteralValue::Decimal75(
409 Precision::new(38).unwrap(),
410 0,
411 proof_of_sql::base::math::i256::I256::from(0i128)
412 )
413 );
414
415 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::from_i128(-456)), 75, 120);
417 assert_eq!(
418 scalar_value_to_literal_value(value).unwrap(),
419 LiteralValue::Decimal75(
420 Precision::new(75).unwrap(),
421 120,
422 proof_of_sql::base::math::i256::I256::from(-456i128)
423 )
424 );
425
426 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MIN), 75, 127);
428 assert_eq!(
429 scalar_value_to_literal_value(value).unwrap(),
430 LiteralValue::Decimal75(
431 Precision::new(75).unwrap(),
432 127,
433 proof_of_sql::base::math::i256::I256::new([0, 0, 0, i64::MIN as u64])
434 )
435 );
436 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MAX), 75, -128);
437 assert_eq!(
438 scalar_value_to_literal_value(value).unwrap(),
439 LiteralValue::Decimal75(
440 Precision::new(75).unwrap(),
441 -128,
442 proof_of_sql::base::math::i256::I256::new([
443 u64::MAX,
444 u64::MAX,
445 u64::MAX,
446 i64::MAX as u64
447 ])
448 )
449 );
450 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::ZERO), 75, 0);
451 assert_eq!(
452 scalar_value_to_literal_value(value).unwrap(),
453 LiteralValue::Decimal75(
454 Precision::new(75).unwrap(),
455 0,
456 proof_of_sql::base::math::i256::I256::from(0i128)
457 )
458 );
459 }
460
461 #[test]
462 fn we_cannot_convert_scalar_value_to_literal_value_if_unsupported() {
463 let value = ScalarValue::Float32(Some(1.0));
465 assert!(matches!(
466 scalar_value_to_literal_value(value),
467 Err(PlannerError::UnsupportedDataType { .. })
468 ));
469 }
470
471 #[test]
473 fn we_can_convert_column_to_column_ref() {
474 let column = Column::new(Some("namespace.table"), "a");
475 let schema = vec![("a".into(), ColumnType::Int)];
476 assert_eq!(
477 column_to_column_ref(&column, &schema).unwrap(),
478 ColumnRef::new(
479 TableRef::from_names(Some("namespace"), "table"),
480 "a".into(),
481 ColumnType::Int
482 )
483 );
484 }
485
486 #[test]
487 fn we_cannot_convert_column_to_column_ref_without_relation() {
488 let column = Column::new(None::<&str>, "a");
489 let schema = vec![("a".into(), ColumnType::Int)];
490 assert!(matches!(
491 column_to_column_ref(&column, &schema),
492 Err(PlannerError::UnresolvedLogicalPlan)
493 ));
494 }
495
496 #[test]
497 fn we_cannot_convert_column_to_column_ref_with_invalid_column_name() {
498 let column = Column::new(Some("namespace.table"), "b");
499 let schema = vec![("a".into(), ColumnType::Int)];
500 assert!(matches!(
501 column_to_column_ref(&column, &schema),
502 Err(PlannerError::ColumnNotFound)
503 ));
504 }
505
506 #[test]
508 fn we_can_convert_column_fields_to_schema() {
509 let column_fields = vec![];
511 let schema = column_fields_to_schema(column_fields);
512 assert_eq!(schema.all_fields(), Vec::<&Field>::new());
513
514 let column_fields = vec![
516 ColumnField::new("a".into(), ColumnType::SmallInt),
517 ColumnField::new("b".into(), ColumnType::VarChar),
518 ];
519 let schema = column_fields_to_schema(column_fields);
520 assert_eq!(
521 schema.all_fields(),
522 vec![
523 &Field::new("a", DataType::Int16, false),
524 &Field::new("b", DataType::Utf8, false),
525 ]
526 );
527 }
528
529 #[test]
531 fn we_can_convert_df_schema_to_column_fields() {
532 let column_fields = schema_to_column_fields(Vec::new());
534 assert_eq!(column_fields, Vec::<ColumnField>::new());
535
536 let schema = vec![
538 ("a".into(), ColumnType::SmallInt),
539 ("b".into(), ColumnType::VarChar),
540 ];
541 let column_fields = schema_to_column_fields(schema);
542 assert_eq!(
543 column_fields,
544 vec![
545 ColumnField::new("a".into(), ColumnType::SmallInt),
546 ColumnField::new("b".into(), ColumnType::VarChar),
547 ]
548 );
549 }
550}