1use super::{PlannerError, PlannerResult};
2use arrow::datatypes::{Field, Schema};
3use datafusion::{
4 catalog::TableReference,
5 common::{Column, ScalarValue},
6 logical_expr::expr::Placeholder,
7};
8use proof_of_sql::{
9 base::{
10 database::{ColumnField, ColumnRef, ColumnType, LiteralValue, TableRef},
11 math::decimal::Precision,
12 posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
13 },
14 sql::proof_exprs::DynProofExpr,
15};
16use sqlparser::ast::Ident;
17
18fn parse_placeholder_id(s: &str) -> Option<usize> {
20 s.strip_prefix('$')
21 .filter(|digits| digits.chars().all(|c| c.is_ascii_digit()) && !digits.starts_with('0'))
23 .and_then(|digits| digits.parse().ok())
25}
26
27#[expect(clippy::missing_panics_doc, reason = "can not actually panic")]
29pub(crate) fn placeholder_to_placeholder_expr(
30 placeholder: &Placeholder,
31) -> PlannerResult<DynProofExpr> {
32 let df_id = placeholder.id.clone();
33 let df_type = placeholder.data_type.clone();
34 let posql_id = parse_placeholder_id(&df_id)
35 .ok_or_else(|| PlannerError::InvalidPlaceholderId { id: df_id.clone() })?;
36 let posql_type = df_type
37 .clone()
38 .ok_or(PlannerError::UntypedPlaceholder {
39 placeholder: placeholder.clone(),
40 })?
41 .try_into()
42 .map_err(|_| PlannerError::UnsupportedDataType {
43 data_type: df_type.clone().unwrap(),
44 })?;
45 Ok(DynProofExpr::try_new_placeholder(posql_id, posql_type)?)
46}
47
48pub(crate) fn table_reference_to_table_ref(table: &TableReference) -> PlannerResult<TableRef> {
52 match table {
53 TableReference::Bare { table } => Ok(TableRef::from_names(None, table)),
54 TableReference::Partial { schema, table } => Ok(TableRef::from_names(Some(schema), table)),
55 TableReference::Full { .. } => Err(PlannerError::CatalogNotSupported),
56 }
57}
58
59pub(crate) fn scalar_value_to_literal_value(value: ScalarValue) -> PlannerResult<LiteralValue> {
63 match value {
64 ScalarValue::Boolean(Some(v)) => Ok(LiteralValue::Boolean(v)),
65 ScalarValue::Int8(Some(v)) => Ok(LiteralValue::TinyInt(v)),
66 ScalarValue::Int16(Some(v)) => Ok(LiteralValue::SmallInt(v)),
67 ScalarValue::Int32(Some(v)) => Ok(LiteralValue::Int(v)),
68 ScalarValue::Int64(Some(v)) => Ok(LiteralValue::BigInt(v)),
69 ScalarValue::UInt8(Some(v)) => Ok(LiteralValue::Uint8(v)),
70 ScalarValue::Utf8(Some(v)) => Ok(LiteralValue::VarChar(v)),
71 ScalarValue::Binary(Some(v)) => Ok(LiteralValue::VarBinary(v)),
72 ScalarValue::TimestampSecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
73 PoSQLTimeUnit::Second,
74 PoSQLTimeZone::utc(),
75 v,
76 )),
77 ScalarValue::TimestampMillisecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
78 PoSQLTimeUnit::Millisecond,
79 PoSQLTimeZone::utc(),
80 v,
81 )),
82 ScalarValue::TimestampMicrosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
83 PoSQLTimeUnit::Microsecond,
84 PoSQLTimeZone::utc(),
85 v,
86 )),
87 ScalarValue::TimestampNanosecond(Some(v), None) => Ok(LiteralValue::TimeStampTZ(
88 PoSQLTimeUnit::Nanosecond,
89 PoSQLTimeZone::utc(),
90 v,
91 )),
92 ScalarValue::Decimal128(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
93 Precision::new(precision)?,
94 scale,
95 v.into(),
96 )),
97 ScalarValue::Decimal256(Some(v), precision, scale) => Ok(LiteralValue::Decimal75(
98 Precision::new(precision)?,
99 scale,
100 v.into(),
101 )),
102 _ => Err(PlannerError::UnsupportedDataType {
103 data_type: value.data_type().clone(),
104 }),
105 }
106}
107
108pub(crate) fn column_to_column_ref(
113 column: &Column,
114 schema: &[(Ident, ColumnType)],
115) -> PlannerResult<ColumnRef> {
116 let relation = column
117 .relation
118 .as_ref()
119 .ok_or_else(|| PlannerError::UnresolvedLogicalPlan)?;
120 let table_ref = table_reference_to_table_ref(relation)?;
121 let ident: Ident = column.name.as_str().into();
122 let column_type = schema
123 .iter()
124 .find(|(i, _t)| *i == ident)
125 .ok_or(PlannerError::ColumnNotFound)?
126 .1;
127 Ok(ColumnRef::new(table_ref, ident, column_type))
128}
129
130#[must_use]
132pub fn column_fields_to_schema(column_fields: Vec<ColumnField>) -> Schema {
133 Schema::new(
134 column_fields
135 .into_iter()
136 .map(|column_field| {
137 let data_type = (&column_field.data_type()).into();
139 Field::new(column_field.name().value.as_str(), data_type, false)
140 })
141 .collect::<Vec<_>>(),
142 )
143}
144
145pub(crate) fn schema_to_column_fields(schema: Vec<(Ident, ColumnType)>) -> Vec<ColumnField> {
149 schema
150 .into_iter()
151 .map(|(name, column_type)| ColumnField::new(name, column_type))
152 .collect()
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use arrow::datatypes::DataType;
159
160 #[test]
162 fn we_can_parse_valid_placeholder_id() {
163 assert_eq!(parse_placeholder_id("$1"), Some(1));
165 assert_eq!(parse_placeholder_id("$123"), Some(123));
167 }
168
169 #[test]
170 fn we_cannot_parse_placeholder_id_without_dollar_sign() {
171 assert_eq!(parse_placeholder_id(""), None);
173 assert_eq!(parse_placeholder_id("1"), None);
175 }
176
177 #[test]
178 fn we_cannot_parse_placeholder_id_empty_after_dollar_sign() {
179 assert_eq!(parse_placeholder_id("$"), None);
181 }
182
183 #[test]
184 fn we_cannot_parse_placeholder_id_with_non_digits() {
185 assert_eq!(parse_placeholder_id("$abc"), None);
187 assert_eq!(parse_placeholder_id("$1x"), None);
189 }
190
191 #[test]
192 fn we_cannot_parse_placeholder_id_with_leading_zero() {
193 assert_eq!(parse_placeholder_id("$0"), None);
195 assert_eq!(parse_placeholder_id("$01"), None);
197 }
198
199 #[test]
201 fn we_can_convert_valid_placeholder_to_placeholder_expr() {
202 let placeholder = Placeholder {
203 id: "$42".to_string(),
204 data_type: Some(DataType::Int32),
205 };
206 let expected = DynProofExpr::try_new_placeholder(42, ColumnType::Int).unwrap();
207 let result = placeholder_to_placeholder_expr(&placeholder).unwrap();
208 assert_eq!(result, expected);
209 }
210
211 #[test]
212 fn we_cannot_convert_placeholder_without_type() {
213 let placeholder = Placeholder {
214 id: "$1".to_string(),
215 data_type: None,
216 };
217 assert!(matches!(
218 placeholder_to_placeholder_expr(&placeholder),
219 Err(PlannerError::UntypedPlaceholder { .. })
220 ));
221 }
222
223 #[test]
224 fn we_cannot_convert_placeholder_with_invalid_id() {
225 let placeholder = Placeholder {
226 id: "$0".to_string(),
228 data_type: Some(DataType::Int32),
229 };
230 assert!(matches!(
231 placeholder_to_placeholder_expr(&placeholder),
232 Err(PlannerError::InvalidPlaceholderId { .. })
233 ));
234 }
235
236 #[test]
238 fn we_can_convert_table_reference_to_table_ref() {
239 let table = TableReference::bare("table");
241 assert_eq!(
242 table_reference_to_table_ref(&table).unwrap(),
243 TableRef::from_names(None, "table")
244 );
245
246 let table = TableReference::partial("schema", "table");
248 assert_eq!(
249 table_reference_to_table_ref(&table).unwrap(),
250 TableRef::from_names(Some("schema"), "table")
251 );
252 }
253
254 #[test]
255 fn we_cannot_convert_full_table_reference_to_table_ref() {
256 let table = TableReference::full("catalog", "schema", "table");
257 assert!(matches!(
258 table_reference_to_table_ref(&table),
259 Err(PlannerError::CatalogNotSupported)
260 ));
261 }
262
263 #[test]
265 fn we_can_convert_scalar_value_to_literal_value() {
266 let value = ScalarValue::Boolean(Some(true));
268 assert_eq!(
269 scalar_value_to_literal_value(value).unwrap(),
270 LiteralValue::Boolean(true)
271 );
272
273 let value = ScalarValue::Int8(Some(1));
275 assert_eq!(
276 scalar_value_to_literal_value(value).unwrap(),
277 LiteralValue::TinyInt(1)
278 );
279
280 let value = ScalarValue::Int16(Some(1));
282 assert_eq!(
283 scalar_value_to_literal_value(value).unwrap(),
284 LiteralValue::SmallInt(1)
285 );
286
287 let value = ScalarValue::Int32(Some(1));
289 assert_eq!(
290 scalar_value_to_literal_value(value).unwrap(),
291 LiteralValue::Int(1)
292 );
293
294 let value = ScalarValue::Int64(Some(1));
296 assert_eq!(
297 scalar_value_to_literal_value(value).unwrap(),
298 LiteralValue::BigInt(1)
299 );
300
301 let value = ScalarValue::UInt8(Some(1));
303 assert_eq!(
304 scalar_value_to_literal_value(value).unwrap(),
305 LiteralValue::Uint8(1)
306 );
307
308 let value = ScalarValue::Utf8(Some("value".to_string()));
310 assert_eq!(
311 scalar_value_to_literal_value(value).unwrap(),
312 LiteralValue::VarChar("value".to_string())
313 );
314
315 let value = ScalarValue::Binary(Some(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104]));
317 assert_eq!(
318 scalar_value_to_literal_value(value).unwrap(),
319 LiteralValue::VarBinary(vec![72, 97, 108, 108, 101, 108, 117, 106, 97, 104])
320 );
321
322 let value = ScalarValue::TimestampSecond(Some(1_741_236_192_i64), None);
325 assert_eq!(
326 scalar_value_to_literal_value(value).unwrap(),
327 LiteralValue::TimeStampTZ(
328 PoSQLTimeUnit::Second,
329 PoSQLTimeZone::utc(),
330 1_741_236_192_i64
331 )
332 );
333
334 let value = ScalarValue::TimestampMillisecond(Some(1_741_236_192_004_i64), None);
336 assert_eq!(
337 scalar_value_to_literal_value(value).unwrap(),
338 LiteralValue::TimeStampTZ(
339 PoSQLTimeUnit::Millisecond,
340 PoSQLTimeZone::utc(),
341 1_741_236_192_004_i64
342 )
343 );
344
345 let value = ScalarValue::TimestampMicrosecond(Some(1_741_236_192_004_000_i64), None);
347 assert_eq!(
348 scalar_value_to_literal_value(value).unwrap(),
349 LiteralValue::TimeStampTZ(
350 PoSQLTimeUnit::Microsecond,
351 PoSQLTimeZone::utc(),
352 1_741_236_192_004_000_i64
353 )
354 );
355
356 let value = ScalarValue::TimestampNanosecond(Some(1_741_236_192_123_456_789_i64), None);
358 assert_eq!(
359 scalar_value_to_literal_value(value).unwrap(),
360 LiteralValue::TimeStampTZ(
361 PoSQLTimeUnit::Nanosecond,
362 PoSQLTimeZone::utc(),
363 1_741_236_192_123_456_789_i64
364 )
365 );
366 }
367
368 #[expect(clippy::cast_sign_loss)]
369 #[test]
370 fn we_can_convert_scalar_value_to_literal_value_for_decimals() {
371 let value = ScalarValue::Decimal128(Some(123), 38, 0);
373 assert_eq!(
374 scalar_value_to_literal_value(value).unwrap(),
375 LiteralValue::Decimal75(
376 Precision::new(38).unwrap(),
377 0,
378 proof_of_sql::base::math::i256::I256::from(123i128)
379 )
380 );
381
382 let value = ScalarValue::Decimal128(Some(i128::MIN), 38, 10);
384 assert_eq!(
385 scalar_value_to_literal_value(value).unwrap(),
386 LiteralValue::Decimal75(
387 Precision::new(38).unwrap(),
388 10,
389 proof_of_sql::base::math::i256::I256::from(i128::MIN)
390 )
391 );
392
393 let value = ScalarValue::Decimal128(Some(i128::MAX), 28, -5);
394 assert_eq!(
395 scalar_value_to_literal_value(value).unwrap(),
396 LiteralValue::Decimal75(
397 Precision::new(28).unwrap(),
398 -5,
399 proof_of_sql::base::math::i256::I256::from(i128::MAX)
400 )
401 );
402
403 let value = ScalarValue::Decimal128(Some(0), 38, 0);
404 assert_eq!(
405 scalar_value_to_literal_value(value).unwrap(),
406 LiteralValue::Decimal75(
407 Precision::new(38).unwrap(),
408 0,
409 proof_of_sql::base::math::i256::I256::from(0i128)
410 )
411 );
412
413 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::from_i128(-456)), 75, 120);
415 assert_eq!(
416 scalar_value_to_literal_value(value).unwrap(),
417 LiteralValue::Decimal75(
418 Precision::new(75).unwrap(),
419 120,
420 proof_of_sql::base::math::i256::I256::from(-456i128)
421 )
422 );
423
424 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MIN), 75, 127);
426 assert_eq!(
427 scalar_value_to_literal_value(value).unwrap(),
428 LiteralValue::Decimal75(
429 Precision::new(75).unwrap(),
430 127,
431 proof_of_sql::base::math::i256::I256::new([0, 0, 0, i64::MIN as u64])
432 )
433 );
434 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::MAX), 75, -128);
435 assert_eq!(
436 scalar_value_to_literal_value(value).unwrap(),
437 LiteralValue::Decimal75(
438 Precision::new(75).unwrap(),
439 -128,
440 proof_of_sql::base::math::i256::I256::new([
441 u64::MAX,
442 u64::MAX,
443 u64::MAX,
444 i64::MAX as u64
445 ])
446 )
447 );
448 let value = ScalarValue::Decimal256(Some(arrow::datatypes::i256::ZERO), 75, 0);
449 assert_eq!(
450 scalar_value_to_literal_value(value).unwrap(),
451 LiteralValue::Decimal75(
452 Precision::new(75).unwrap(),
453 0,
454 proof_of_sql::base::math::i256::I256::from(0i128)
455 )
456 );
457 }
458
459 #[test]
460 fn we_cannot_convert_scalar_value_to_literal_value_if_unsupported() {
461 let value = ScalarValue::Float32(Some(1.0));
463 assert!(matches!(
464 scalar_value_to_literal_value(value),
465 Err(PlannerError::UnsupportedDataType { .. })
466 ));
467 }
468
469 #[test]
471 fn we_can_convert_column_to_column_ref() {
472 let column = Column::new(Some("namespace.table"), "a");
473 let schema = vec![("a".into(), ColumnType::Int)];
474 assert_eq!(
475 column_to_column_ref(&column, &schema).unwrap(),
476 ColumnRef::new(
477 TableRef::from_names(Some("namespace"), "table"),
478 "a".into(),
479 ColumnType::Int
480 )
481 );
482 }
483
484 #[test]
485 fn we_cannot_convert_column_to_column_ref_without_relation() {
486 let column = Column::new(None::<&str>, "a");
487 let schema = vec![("a".into(), ColumnType::Int)];
488 assert!(matches!(
489 column_to_column_ref(&column, &schema),
490 Err(PlannerError::UnresolvedLogicalPlan)
491 ));
492 }
493
494 #[test]
495 fn we_cannot_convert_column_to_column_ref_with_invalid_column_name() {
496 let column = Column::new(Some("namespace.table"), "b");
497 let schema = vec![("a".into(), ColumnType::Int)];
498 assert!(matches!(
499 column_to_column_ref(&column, &schema),
500 Err(PlannerError::ColumnNotFound)
501 ));
502 }
503
504 #[test]
506 fn we_can_convert_column_fields_to_schema() {
507 let column_fields = vec![];
509 let schema = column_fields_to_schema(column_fields);
510 assert_eq!(schema.all_fields(), Vec::<&Field>::new());
511
512 let column_fields = vec![
514 ColumnField::new("a".into(), ColumnType::SmallInt),
515 ColumnField::new("b".into(), ColumnType::VarChar),
516 ];
517 let schema = column_fields_to_schema(column_fields);
518 assert_eq!(
519 schema.all_fields(),
520 vec![
521 &Field::new("a", DataType::Int16, false),
522 &Field::new("b", DataType::Utf8, false),
523 ]
524 );
525 }
526
527 #[test]
529 fn we_can_convert_df_schema_to_column_fields() {
530 let column_fields = schema_to_column_fields(Vec::new());
532 assert_eq!(column_fields, Vec::<ColumnField>::new());
533
534 let schema = vec![
536 ("a".into(), ColumnType::SmallInt),
537 ("b".into(), ColumnType::VarChar),
538 ];
539 let column_fields = schema_to_column_fields(schema);
540 assert_eq!(
541 column_fields,
542 vec![
543 ColumnField::new("a".into(), ColumnType::SmallInt),
544 ColumnField::new("b".into(), ColumnType::VarChar),
545 ]
546 );
547 }
548}