1use std::collections::HashMap;
37use std::str::FromStr;
38use std::sync::Arc;
39use std::time::Duration;
40
41use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
42use sqlparser::ast::{ColumnDef, DataType as SqlDataType};
43
44use crate::parser::ParseError;
45use crate::parser::{CreateSinkStatement, CreateSourceStatement, SinkFrom, WatermarkDef};
46
47pub const MIN_BUFFER_SIZE: usize = 4;
49
50pub const MAX_BUFFER_SIZE: usize = 1 << 20; pub const DEFAULT_BUFFER_SIZE: usize = 2048;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
58pub enum BackpressureStrategy {
59 #[default]
61 Block,
62 DropOldest,
64 Reject,
66}
67
68impl std::str::FromStr for BackpressureStrategy {
69 type Err = ParseError;
70
71 fn from_str(s: &str) -> Result<Self, Self::Err> {
72 match s.to_lowercase().as_str() {
73 "block" | "blocking" => Ok(Self::Block),
74 "drop" | "drop_oldest" | "dropoldest" => Ok(Self::DropOldest),
75 "reject" | "error" => Ok(Self::Reject),
76 _ => Err(ParseError::ValidationError(format!(
77 "invalid backpressure strategy: '{}'. Valid values: block, drop_oldest, reject",
78 s
79 ))),
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
86pub enum WaitStrategy {
87 Spin,
89 #[default]
91 SpinYield,
92 Park,
94}
95
96impl std::str::FromStr for WaitStrategy {
97 type Err = ParseError;
98
99 fn from_str(s: &str) -> Result<Self, Self::Err> {
100 match s.to_lowercase().as_str() {
101 "spin" => Ok(Self::Spin),
102 "spin_yield" | "spinyield" | "yield" => Ok(Self::SpinYield),
103 "park" | "parking" => Ok(Self::Park),
104 _ => Err(ParseError::ValidationError(format!(
105 "invalid wait strategy: '{}'. Valid values: spin, spin_yield, park",
106 s
107 ))),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct WatermarkSpec {
115 pub column: String,
117 pub max_out_of_orderness: Duration,
119}
120
121#[derive(Debug, Clone)]
123pub struct SourceConfigOptions {
124 pub buffer_size: usize,
126 pub backpressure: BackpressureStrategy,
128 pub wait_strategy: WaitStrategy,
130 pub track_stats: bool,
132}
133
134impl Default for SourceConfigOptions {
135 fn default() -> Self {
136 Self {
137 buffer_size: DEFAULT_BUFFER_SIZE,
138 backpressure: BackpressureStrategy::Block,
139 wait_strategy: WaitStrategy::SpinYield,
140 track_stats: false,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct ColumnDefinition {
148 pub name: String,
150 pub data_type: DataType,
152 pub nullable: bool,
154}
155
156#[derive(Debug, Clone)]
161pub struct SourceDefinition {
162 pub name: String,
164 pub columns: Vec<ColumnDefinition>,
166 pub schema: SchemaRef,
168 pub watermark: Option<WatermarkSpec>,
170 pub config: SourceConfigOptions,
172}
173
174impl TryFrom<CreateSourceStatement> for SourceDefinition {
175 type Error = ParseError;
176
177 fn try_from(stmt: CreateSourceStatement) -> Result<Self, Self::Error> {
178 translate_create_source(stmt)
179 }
180}
181
182#[derive(Debug, Clone)]
184pub struct SinkDefinition {
185 pub name: String,
187 pub input: String,
189 pub config: SourceConfigOptions,
191}
192
193impl TryFrom<CreateSinkStatement> for SinkDefinition {
194 type Error = ParseError;
195
196 fn try_from(stmt: CreateSinkStatement) -> Result<Self, Self::Error> {
197 translate_create_sink(stmt)
198 }
199}
200
201pub fn translate_create_source(
210 stmt: CreateSourceStatement,
211) -> Result<SourceDefinition, ParseError> {
212 validate_source_options(&stmt.with_options)?;
214
215 let config = parse_source_options(&stmt.with_options)?;
217
218 let columns = convert_columns(&stmt.columns)?;
220
221 let fields: Vec<Field> = columns
223 .iter()
224 .map(|col| Field::new(&col.name, col.data_type.clone(), col.nullable))
225 .collect();
226 let schema = Arc::new(Schema::new(fields));
227
228 let watermark = if let Some(wm) = stmt.watermark {
230 Some(parse_watermark(&wm, &columns)?)
231 } else {
232 None
233 };
234
235 Ok(SourceDefinition {
236 name: stmt.name.to_string(),
237 columns,
238 schema,
239 watermark,
240 config,
241 })
242}
243
244pub fn translate_create_sink(stmt: CreateSinkStatement) -> Result<SinkDefinition, ParseError> {
252 validate_source_options(&stmt.with_options)?;
254
255 let config = parse_source_options(&stmt.with_options)?;
257
258 let input = match stmt.from {
260 SinkFrom::Table(name) => name.to_string(),
261 SinkFrom::Query(_) => {
262 return Err(ParseError::ValidationError(
264 "inline queries not yet supported in CREATE SINK - use a view".to_string(),
265 ));
266 }
267 };
268
269 Ok(SinkDefinition {
270 name: stmt.name.to_string(),
271 input,
272 config,
273 })
274}
275
276fn validate_source_options(options: &HashMap<String, String>) -> Result<(), ParseError> {
278 if options.contains_key("channel") {
280 return Err(ParseError::ValidationError(
281 "the 'channel' option is not user-configurable - channel type is automatically derived from usage patterns".to_string(),
282 ));
283 }
284
285 if options.contains_key("type") {
287 return Err(ParseError::ValidationError(
288 "the 'type' option is not user-configurable for in-memory streaming sources"
289 .to_string(),
290 ));
291 }
292
293 Ok(())
294}
295
296fn parse_source_options(
298 options: &HashMap<String, String>,
299) -> Result<SourceConfigOptions, ParseError> {
300 let mut config = SourceConfigOptions::default();
301
302 for (key, value) in options {
303 match key.to_lowercase().as_str() {
304 "buffer_size" | "buffersize" => {
305 config.buffer_size = parse_buffer_size(value)?;
306 }
307 "backpressure" => {
308 config.backpressure = BackpressureStrategy::from_str(value)?;
309 }
310 "wait_strategy" | "waitstrategy" => {
311 config.wait_strategy = WaitStrategy::from_str(value)?;
312 }
313 "track_stats" | "trackstats" | "stats" => {
314 config.track_stats = parse_bool(value)?;
315 }
316 _ => {}
320 }
321 }
322
323 Ok(config)
324}
325
326fn parse_buffer_size(value: &str) -> Result<usize, ParseError> {
328 let size: usize = value.parse().map_err(|_| {
329 ParseError::ValidationError(format!(
330 "invalid buffer_size: '{}' - must be a number",
331 value
332 ))
333 })?;
334
335 if size < MIN_BUFFER_SIZE {
336 return Err(ParseError::ValidationError(format!(
337 "buffer_size {} is too small - minimum is {}",
338 size, MIN_BUFFER_SIZE
339 )));
340 }
341
342 if size > MAX_BUFFER_SIZE {
343 return Err(ParseError::ValidationError(format!(
344 "buffer_size {} is too large - maximum is {}",
345 size, MAX_BUFFER_SIZE
346 )));
347 }
348
349 Ok(size)
350}
351
352fn parse_bool(value: &str) -> Result<bool, ParseError> {
354 match value.to_lowercase().as_str() {
355 "true" | "yes" | "on" | "1" => Ok(true),
356 "false" | "no" | "off" | "0" => Ok(false),
357 _ => Err(ParseError::ValidationError(format!(
358 "invalid boolean value: '{}' - expected true/false",
359 value
360 ))),
361 }
362}
363
364fn convert_columns(columns: &[ColumnDef]) -> Result<Vec<ColumnDefinition>, ParseError> {
366 columns.iter().map(convert_column).collect()
367}
368
369fn convert_column(col: &ColumnDef) -> Result<ColumnDefinition, ParseError> {
371 let data_type = sql_type_to_arrow(&col.data_type)?;
372
373 let nullable = !col
375 .options
376 .iter()
377 .any(|opt| matches!(opt.option, sqlparser::ast::ColumnOption::NotNull));
378
379 Ok(ColumnDefinition {
380 name: col.name.to_string(),
381 data_type,
382 nullable,
383 })
384}
385
386pub fn sql_type_to_arrow(sql_type: &SqlDataType) -> Result<DataType, ParseError> {
392 match sql_type {
393 SqlDataType::TinyInt(_) => Ok(DataType::Int8),
395 SqlDataType::SmallInt(_) => Ok(DataType::Int16),
396 SqlDataType::Int(_) | SqlDataType::Integer(_) => Ok(DataType::Int32),
397 SqlDataType::BigInt(_) => Ok(DataType::Int64),
398
399 SqlDataType::Float(_) | SqlDataType::Real => Ok(DataType::Float32),
404 SqlDataType::Double(_) | SqlDataType::DoublePrecision => Ok(DataType::Float64),
405
406 SqlDataType::Decimal(info) | SqlDataType::Numeric(info) => {
408 #[allow(clippy::cast_possible_truncation)] let (precision, scale) = match info {
410 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => (*p as u8, *s as i8),
411 sqlparser::ast::ExactNumberInfo::Precision(p) => (*p as u8, 0),
412 sqlparser::ast::ExactNumberInfo::None => (38, 9), };
414 Ok(DataType::Decimal128(precision, scale))
415 }
416
417 SqlDataType::Char(_)
419 | SqlDataType::Character(_)
420 | SqlDataType::Varchar(_)
421 | SqlDataType::CharacterVarying(_)
422 | SqlDataType::Text
423 | SqlDataType::String(_)
424 | SqlDataType::JSON
425 | SqlDataType::JSONB
426 | SqlDataType::Uuid => Ok(DataType::Utf8),
427
428 SqlDataType::Binary(_)
430 | SqlDataType::Varbinary(_)
431 | SqlDataType::Blob(_)
432 | SqlDataType::Bytea => Ok(DataType::Binary),
433
434 SqlDataType::Boolean | SqlDataType::Bool => Ok(DataType::Boolean),
436
437 SqlDataType::Date => Ok(DataType::Date32),
439 SqlDataType::Time(_, _) => Ok(DataType::Time64(TimeUnit::Microsecond)),
440 SqlDataType::Timestamp(_, _) => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
441
442 SqlDataType::Interval { .. } => Ok(DataType::Interval(
444 arrow::datatypes::IntervalUnit::MonthDayNano,
445 )),
446
447 _ => Err(ParseError::ValidationError(format!(
449 "unsupported data type: {:?}",
450 sql_type
451 ))),
452 }
453}
454
455fn parse_watermark(
457 wm: &WatermarkDef,
458 columns: &[ColumnDefinition],
459) -> Result<WatermarkSpec, ParseError> {
460 let column_name = wm.column.to_string();
461
462 let col = columns
464 .iter()
465 .find(|c| c.name == column_name)
466 .ok_or_else(|| {
467 ParseError::ValidationError(format!(
468 "watermark column '{}' not found in column list",
469 column_name
470 ))
471 })?;
472
473 if !matches!(
475 col.data_type,
476 DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64
477 ) {
478 return Err(ParseError::ValidationError(format!(
479 "watermark column '{}' must be a timestamp type, found {:?}",
480 column_name, col.data_type
481 )));
482 }
483
484 let max_out_of_orderness = parse_watermark_expression(&wm.expression);
487
488 Ok(WatermarkSpec {
489 column: column_name,
490 max_out_of_orderness,
491 })
492}
493
494fn parse_watermark_expression(expr: &sqlparser::ast::Expr) -> Duration {
496 use sqlparser::ast::Expr;
497
498 match expr {
499 Expr::BinaryOp { op, right, .. } => match op {
500 sqlparser::ast::BinaryOperator::Minus => parse_interval_expr(right),
501 _ => Duration::ZERO,
502 },
503 Expr::Identifier(_) => Duration::ZERO,
505 _ => Duration::from_secs(1),
507 }
508}
509
510fn parse_interval_expr(expr: &sqlparser::ast::Expr) -> Duration {
512 use sqlparser::ast::Expr;
513
514 let Expr::Interval(interval) = expr else {
515 return Duration::from_secs(1);
516 };
517
518 let value_str = match interval.value.as_ref() {
520 Expr::Value(v) => {
521 match &v.value {
523 sqlparser::ast::Value::SingleQuotedString(s) => s.clone(),
524 sqlparser::ast::Value::Number(n, _) => n.clone(),
525 _ => return Duration::from_secs(1),
526 }
527 }
528 _ => return Duration::from_secs(1),
529 };
530
531 let value: u64 = value_str.parse().unwrap_or(1);
532
533 let unit = interval
535 .leading_field
536 .as_ref()
537 .map_or("second", |u| match u {
538 sqlparser::ast::DateTimeField::Microsecond => "microsecond",
539 sqlparser::ast::DateTimeField::Millisecond => "millisecond",
540 sqlparser::ast::DateTimeField::Minute => "minute",
541 sqlparser::ast::DateTimeField::Hour => "hour",
542 sqlparser::ast::DateTimeField::Day => "day",
543 _ => "second",
544 });
545
546 match unit {
547 "microsecond" | "microseconds" => Duration::from_micros(value),
548 "millisecond" | "milliseconds" => Duration::from_millis(value),
549 "minute" | "minutes" => Duration::from_secs(value * 60),
550 "hour" | "hours" => Duration::from_secs(value * 3600),
551 "day" | "days" => Duration::from_secs(value * 86400),
552 _ => Duration::from_secs(value),
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::parser::{parse_streaming_sql, StreamingStatement};
560
561 fn parse_and_translate(sql: &str) -> Result<SourceDefinition, ParseError> {
562 let statements = parse_streaming_sql(sql)?;
563 let stmt = statements
564 .into_iter()
565 .next()
566 .ok_or_else(|| ParseError::StreamingError("No statement found".to_string()))?;
567 match stmt {
568 StreamingStatement::CreateSource(source) => translate_create_source(*source),
569 _ => Err(ParseError::StreamingError(
570 "Expected CREATE SOURCE".to_string(),
571 )),
572 }
573 }
574
575 #[test]
576 fn test_basic_source() {
577 let def =
578 parse_and_translate("CREATE SOURCE events (id BIGINT NOT NULL, name VARCHAR)").unwrap();
579
580 assert_eq!(def.name, "events");
581 assert_eq!(def.columns.len(), 2);
582 assert_eq!(def.columns[0].name, "id");
583 assert_eq!(def.columns[0].data_type, DataType::Int64);
584 assert!(!def.columns[0].nullable);
585 assert_eq!(def.columns[1].name, "name");
586 assert!(def.columns[1].nullable);
587 }
588
589 #[test]
590 fn test_source_with_options() {
591 let def = parse_and_translate(
592 "CREATE SOURCE events (id BIGINT) WITH (
593 'buffer_size' = '4096',
594 'backpressure' = 'reject'
595 )",
596 )
597 .unwrap();
598
599 assert_eq!(def.config.buffer_size, 4096);
600 assert_eq!(def.config.backpressure, BackpressureStrategy::Reject);
601 }
602
603 #[test]
604 fn test_source_with_watermark() {
605 let def = parse_and_translate(
606 "CREATE SOURCE events (
607 id BIGINT,
608 ts TIMESTAMP,
609 WATERMARK FOR ts AS ts - INTERVAL '5' SECOND
610 )",
611 )
612 .unwrap();
613
614 assert!(def.watermark.is_some());
615 let wm = def.watermark.unwrap();
616 assert_eq!(wm.column, "ts");
617 assert_eq!(wm.max_out_of_orderness, Duration::from_secs(5));
618 }
619
620 #[test]
621 fn test_reject_channel_option() {
622 let result =
623 parse_and_translate("CREATE SOURCE events (id BIGINT) WITH ('channel' = 'mpsc')");
624
625 assert!(result.is_err());
626 let err = result.unwrap_err();
627 assert!(err.to_string().contains("channel"));
628 }
629
630 #[test]
631 fn test_reject_type_option() {
632 let result = parse_and_translate("CREATE SOURCE events (id BIGINT) WITH ('type' = 'spsc')");
633
634 assert!(result.is_err());
635 }
636
637 #[test]
638 fn test_buffer_size_bounds() {
639 let result =
641 parse_and_translate("CREATE SOURCE events (id BIGINT) WITH ('buffer_size' = '1')");
642 assert!(result.is_err());
643
644 let result = parse_and_translate(
646 "CREATE SOURCE events (id BIGINT) WITH ('buffer_size' = '999999999')",
647 );
648 assert!(result.is_err());
649
650 let result =
652 parse_and_translate("CREATE SOURCE events (id BIGINT) WITH ('buffer_size' = '1024')");
653 assert!(result.is_ok());
654 }
655
656 #[test]
657 fn test_backpressure_strategies() {
658 assert_eq!(
659 BackpressureStrategy::from_str("block").unwrap(),
660 BackpressureStrategy::Block
661 );
662 assert_eq!(
663 BackpressureStrategy::from_str("drop_oldest").unwrap(),
664 BackpressureStrategy::DropOldest
665 );
666 assert_eq!(
667 BackpressureStrategy::from_str("reject").unwrap(),
668 BackpressureStrategy::Reject
669 );
670 assert!(BackpressureStrategy::from_str("invalid").is_err());
671 }
672
673 #[test]
674 fn test_wait_strategies() {
675 assert_eq!(WaitStrategy::from_str("spin").unwrap(), WaitStrategy::Spin);
676 assert_eq!(
677 WaitStrategy::from_str("spin_yield").unwrap(),
678 WaitStrategy::SpinYield
679 );
680 assert_eq!(WaitStrategy::from_str("park").unwrap(), WaitStrategy::Park);
681 assert!(WaitStrategy::from_str("invalid").is_err());
682 }
683
684 #[test]
685 fn test_sql_type_conversions() {
686 let def = parse_and_translate(
687 "CREATE SOURCE types (
688 a TINYINT,
689 b SMALLINT,
690 c INT,
691 d BIGINT,
692 e FLOAT,
693 f DOUBLE,
694 g DECIMAL(10,2),
695 h VARCHAR(255),
696 i TEXT,
697 j BOOLEAN,
698 k TIMESTAMP,
699 l DATE
700 )",
701 )
702 .unwrap();
703
704 assert_eq!(def.columns.len(), 12);
705 assert_eq!(def.columns[0].data_type, DataType::Int8);
706 assert_eq!(def.columns[1].data_type, DataType::Int16);
707 assert_eq!(def.columns[2].data_type, DataType::Int32);
708 assert_eq!(def.columns[3].data_type, DataType::Int64);
709 assert_eq!(def.columns[4].data_type, DataType::Float32);
710 assert_eq!(def.columns[5].data_type, DataType::Float64);
711 assert_eq!(def.columns[6].data_type, DataType::Decimal128(10, 2));
712 assert_eq!(def.columns[7].data_type, DataType::Utf8);
713 assert_eq!(def.columns[8].data_type, DataType::Utf8);
714 assert_eq!(def.columns[9].data_type, DataType::Boolean);
715 assert!(matches!(
716 def.columns[10].data_type,
717 DataType::Timestamp(_, _)
718 ));
719 assert_eq!(def.columns[11].data_type, DataType::Date32);
720 }
721
722 #[test]
723 fn test_schema_generation() {
724 let def = parse_and_translate(
725 "CREATE SOURCE events (id BIGINT NOT NULL, name VARCHAR NOT NULL, value DOUBLE)",
726 )
727 .unwrap();
728
729 let schema = def.schema;
730 assert_eq!(schema.fields().len(), 3);
731 assert_eq!(schema.field(0).name(), "id");
732 assert!(!schema.field(0).is_nullable());
733 assert_eq!(schema.field(1).name(), "name");
734 assert!(!schema.field(1).is_nullable());
735 assert_eq!(schema.field(2).name(), "value");
736 assert!(schema.field(2).is_nullable());
737 }
738
739 #[test]
740 fn test_watermark_column_not_found() {
741 let result = parse_and_translate(
742 "CREATE SOURCE events (
743 id BIGINT,
744 WATERMARK FOR nonexistent AS nonexistent - INTERVAL '1' SECOND
745 )",
746 );
747
748 assert!(result.is_err());
749 assert!(result.unwrap_err().to_string().contains("not found"));
750 }
751
752 #[test]
753 fn test_watermark_wrong_type() {
754 let result = parse_and_translate(
755 "CREATE SOURCE events (
756 id BIGINT,
757 WATERMARK FOR id AS id - INTERVAL '1' SECOND
758 )",
759 );
760
761 assert!(result.is_err());
762 assert!(result.unwrap_err().to_string().contains("timestamp type"));
763 }
764
765 #[test]
766 fn test_watermark_milliseconds() {
767 let def = parse_and_translate(
768 "CREATE SOURCE events (
769 ts TIMESTAMP,
770 WATERMARK FOR ts AS ts - INTERVAL '100' MILLISECOND
771 )",
772 )
773 .unwrap();
774
775 let wm = def.watermark.unwrap();
776 assert_eq!(wm.max_out_of_orderness, Duration::from_millis(100));
777 }
778
779 #[test]
780 fn test_watermark_minutes() {
781 let def = parse_and_translate(
782 "CREATE SOURCE events (
783 ts TIMESTAMP,
784 WATERMARK FOR ts AS ts - INTERVAL '5' MINUTE
785 )",
786 )
787 .unwrap();
788
789 let wm = def.watermark.unwrap();
790 assert_eq!(wm.max_out_of_orderness, Duration::from_secs(300));
791 }
792
793 #[test]
794 fn test_track_stats_option() {
795 let def =
796 parse_and_translate("CREATE SOURCE events (id BIGINT) WITH ('track_stats' = 'true')")
797 .unwrap();
798
799 assert!(def.config.track_stats);
800 }
801
802 #[test]
803 fn test_wait_strategy_option() {
804 let def =
805 parse_and_translate("CREATE SOURCE events (id BIGINT) WITH ('wait_strategy' = 'park')")
806 .unwrap();
807
808 assert_eq!(def.config.wait_strategy, WaitStrategy::Park);
809 }
810
811 #[test]
812 fn test_default_config() {
813 let def = parse_and_translate("CREATE SOURCE events (id BIGINT)").unwrap();
814
815 assert_eq!(def.config.buffer_size, DEFAULT_BUFFER_SIZE);
816 assert_eq!(def.config.backpressure, BackpressureStrategy::Block);
817 assert_eq!(def.config.wait_strategy, WaitStrategy::SpinYield);
818 assert!(!def.config.track_stats);
819 }
820
821 #[test]
822 fn test_external_connector_options_ignored() {
823 let def = parse_and_translate(
825 "CREATE SOURCE events (id BIGINT) WITH (
826 'connector' = 'kafka',
827 'topic' = 'events',
828 'bootstrap.servers' = 'localhost:9092',
829 'buffer_size' = '8192'
830 )",
831 )
832 .unwrap();
833
834 assert_eq!(def.config.buffer_size, 8192);
836 }
837}