skardi 0.4.0

High performance query engine for both offline compute and online serving
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
use anyhow::{Result, anyhow};
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
use datafusion::sql::sqlparser::ast::{Expr, SetExpr, Statement};
use datafusion::sql::sqlparser::dialect::GenericDialect;
use datafusion::sql::sqlparser::parser::Parser;
use std::collections::HashMap;
use std::sync::Arc;

use super::types::{InferredFieldType, RequestSchema, ResponseSchema};

/// SQL schema inference engine for extracting parameter and response schemas from SQL queries
/// TODO: Fix the issue that extra parameters in the request is allowed
pub struct SqlSchemaInferrer {
    datafusion_ctx: Arc<SessionContext>,
}

impl std::fmt::Debug for SqlSchemaInferrer {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("SqlSchemaInferrer")
            .field("datafusion_ctx", &"<SessionContext>")
            .finish()
    }
}

/// Named parameter structure with enhanced metadata
#[derive(Debug, Clone)]
pub struct NamedParameter {
    pub name: String,        // Parameter name from {name} syntax
    pub column_name: String, // Inferred column that this parameter binds to
    pub field_type: DataType,
    pub nullable: bool,
    pub source_location: String,
    pub occurrences: usize, // Number of times this parameter appears
}

/// Validation report for SQL queries
#[derive(Debug, Clone)]
pub struct ValidationReport {
    pub is_valid: bool,
    pub errors: Vec<String>,
    pub warnings: Vec<String>,
    pub parameter_count: usize,
    pub response_field_count: usize,
}

impl SqlSchemaInferrer {
    /// Create a new SQL schema inferrer with the provided `Arc<SessionContext>`
    pub fn new(datafusion_ctx: Arc<SessionContext>) -> Result<Self> {
        Ok(Self { datafusion_ctx })
    }

    /// Extract named parameters from SQL query using Rust-native {name} syntax
    /// and infer types from table schemas in the SessionContext
    pub async fn extract_parameters(
        &self,
        sql: &str,
    ) -> Result<HashMap<String, InferredFieldType>> {
        let named_params = self.extract_named_parameters(sql).await?;
        let mut parameters = HashMap::new();

        for (param_name, named_param) in named_params {
            parameters.insert(
                param_name,
                InferredFieldType {
                    field_type: named_param.field_type,
                    nullable: named_param.nullable,
                    source_location: named_param.source_location,
                },
            );
        }

        Ok(parameters)
    }

    /// Extract request schema from SQL query using named parameter analysis
    pub async fn extract_request_schema(&self, sql: &str) -> Result<RequestSchema> {
        let parameters = self.extract_parameters(sql).await?;
        Ok(RequestSchema { fields: parameters })
    }

    /// Extract response schema from SQL query using SELECT clause analysis
    pub async fn extract_response_schema(&self, sql: &str) -> Result<ResponseSchema> {
        let fields = self.extract_response_fields(sql).await?;
        Ok(ResponseSchema { fields })
    }

    /// Extract named parameters with full metadata using Rust-native {name} syntax
    /// Uses AST analysis for robust column-to-parameter mapping
    pub async fn extract_named_parameters(
        &self,
        sql: &str,
    ) -> Result<HashMap<String, NamedParameter>> {
        // Step 1: Convert {parameter_name} to ? placeholders and track parameter mapping
        let (sql_with_placeholders, parameter_order) = self.convert_named_to_placeholders(sql)?;

        // Step 2: Parse SQL with placeholders to get AST
        let statements = self.parse_sql_syntax(&sql_with_placeholders)?;

        // Step 3: Use AST to find column-parameter relationships
        let placeholder_to_column = self.extract_placeholder_columns(&statements)?;

        // Step 4: Build named parameter metadata using AST analysis + table schemas
        let mut named_parameters = HashMap::new();

        for (placeholder_index, param_name) in parameter_order.iter().enumerate() {
            // Count occurrences of this parameter in original SQL
            let param_pattern = format!(r"\{{{}\}}", regex::escape(param_name));
            let occurrences = regex::Regex::new(&param_pattern)
                .map_err(|e| {
                    anyhow!(
                        "Failed to compile parameter regex for '{}': {}",
                        param_name,
                        e
                    )
                })?
                .find_iter(sql)
                .count();

            // Get column name from AST analysis (more reliable than pattern matching)
            let column_name = placeholder_to_column
                .get(&placeholder_index)
                .cloned()
                .unwrap_or_else(|| self.infer_column_from_parameter_name(param_name));

            // Get column type from registered table schema or special context
            let column_type = self
                .get_column_type_with_special_cases(&column_name, sql)
                .await?;

            named_parameters.insert(
                param_name.clone(),
                NamedParameter {
                    name: param_name.clone(),
                    column_name: column_name.clone(),
                    field_type: column_type,
                    nullable: true, // Named parameters are typically nullable
                    source_location: format!(
                        "AST analysis: parameter '{}' → column '{}' (placeholder index {})",
                        param_name, column_name, placeholder_index
                    ),
                    occurrences,
                },
            );
        }

        Ok(named_parameters)
    }

    /// Convert {parameter_name} syntax to ? placeholders while tracking parameter order
    /// Returns (sql_with_placeholders, parameter_order)
    fn convert_named_to_placeholders(&self, sql: &str) -> Result<(String, Vec<String>)> {
        let mut parameter_order = Vec::new();
        let mut sql_with_placeholders = sql.to_string();

        // Extract all {parameter_name} patterns and replace with ? placeholders
        let parameter_pattern = regex::Regex::new(r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}")
            .map_err(|e| anyhow!("Failed to compile parameter regex: {}", e))?;

        // Find all matches and collect unique parameter names in order
        let mut seen_params = std::collections::HashSet::new();
        for cap in parameter_pattern.captures_iter(sql) {
            let param_name = cap[1].to_string();
            if !seen_params.contains(&param_name) {
                seen_params.insert(param_name.clone());
                parameter_order.push(param_name);
            }
        }

        // Pre-pass: `VALUES {name}` is the multi-row tuple-list shape — the
        // server-side renderer will expand `{name}` into `(c1, c2), (c1, c2)`
        // at request time. The bare-`?` substitution would produce
        // `VALUES ?`, which sqlparser rejects (`Expected: (, found: ?`),
        // breaking pipeline load. Substitute with `VALUES (?)` so the SQL
        // parses as a single-row tuple stub; runtime types still come from
        // the renderer, not from this stub.
        let values_pattern = regex::Regex::new(r"(?i)\bVALUES\s*\{([a-zA-Z_][a-zA-Z0-9_]*)\}")
            .map_err(|e| anyhow!("Failed to compile VALUES placeholder regex: {}", e))?;
        sql_with_placeholders = values_pattern
            .replace_all(&sql_with_placeholders, "VALUES (?)")
            .to_string();

        // Replace each unique {parameter_name} with ? placeholders
        for param_name in &parameter_order {
            let pattern = format!(r"\{{{}\}}", regex::escape(param_name));
            let regex = regex::Regex::new(&pattern).map_err(|e| {
                anyhow!(
                    "Failed to compile parameter regex for '{}': {}",
                    param_name,
                    e
                )
            })?;
            sql_with_placeholders = regex.replace_all(&sql_with_placeholders, "?").to_string();
        }

        Ok((sql_with_placeholders, parameter_order))
    }

    /// Extract column names that are used with ? placeholders from AST
    /// Returns mapping from placeholder index to column name
    fn extract_placeholder_columns(
        &self,
        statements: &[Statement],
    ) -> Result<HashMap<usize, String>> {
        let mut placeholder_to_column = HashMap::new();
        let mut placeholder_index = 0;

        for statement in statements {
            // Extract the query to analyze (either a top-level SELECT or the inner SELECT of an INSERT/UPDATE/DELETE)
            let query = match statement {
                Statement::Query(query) => Some(query.as_ref()),
                Statement::Insert(insert) => insert.source.as_ref().map(|q| q.as_ref()),
                _ => None,
            };

            if let Some(query) = query {
                match &*query.body {
                    SetExpr::Select(select) => {
                        // Look for parameters in WHERE clauses
                        if let Some(selection) = &select.selection {
                            self.collect_parameter_columns_ast(
                                selection,
                                &mut placeholder_to_column,
                                &mut placeholder_index,
                            );
                        }
                    }
                    _ => continue,
                }

                // Look for parameters in LIMIT/OFFSET at the Query level
                // Note: In datafusion 50, LimitClause is now an enum with variants
                if let Some(limit_clause) = &query.limit_clause {
                    use datafusion::sql::sqlparser::ast::LimitClause;
                    match limit_clause {
                        LimitClause::LimitOffset { limit, offset, .. } => {
                            // Check for placeholder in LIMIT expression
                            if let Some(limit_expr) = limit {
                                if self.is_placeholder_ast(limit_expr) {
                                    placeholder_to_column
                                        .insert(placeholder_index, "limit".to_string());
                                    placeholder_index += 1;
                                }
                            }
                            // Check for placeholder in OFFSET expression
                            if let Some(offset_info) = offset {
                                if self.is_placeholder_ast(&offset_info.value) {
                                    placeholder_to_column
                                        .insert(placeholder_index, "offset".to_string());
                                    placeholder_index += 1;
                                }
                            }
                        }
                        LimitClause::OffsetCommaLimit {
                            offset: offset_expr,
                            limit: limit_expr,
                        } => {
                            // MySQL-style: LIMIT offset, limit
                            if self.is_placeholder_ast(offset_expr) {
                                placeholder_to_column
                                    .insert(placeholder_index, "offset".to_string());
                                placeholder_index += 1;
                            }
                            if self.is_placeholder_ast(limit_expr) {
                                placeholder_to_column
                                    .insert(placeholder_index, "limit".to_string());
                                placeholder_index += 1;
                            }
                        }
                    }
                }
            }

            // Handle UPDATE SET ... WHERE and DELETE ... WHERE
            match statement {
                Statement::Update {
                    selection: Some(selection),
                    ..
                } => {
                    self.collect_parameter_columns_ast(
                        selection,
                        &mut placeholder_to_column,
                        &mut placeholder_index,
                    );
                }
                Statement::Delete(delete) => {
                    if let Some(selection) = &delete.selection {
                        self.collect_parameter_columns_ast(
                            selection,
                            &mut placeholder_to_column,
                            &mut placeholder_index,
                        );
                    }
                }
                _ => {}
            }
        }

        Ok(placeholder_to_column)
    }

    /// Recursively collect column names that are used with ? placeholders in AST expressions
    fn collect_parameter_columns_ast(
        &self,
        expr: &Expr,
        placeholder_to_column: &mut HashMap<usize, String>,
        placeholder_index: &mut usize,
    ) {
        match expr {
            Expr::BinaryOp { left, op: _, right } => {
                // Look for patterns like: column_name = ? or ? = column_name
                if let (Some(column_name), true) = (
                    self.extract_column_name_ast(left),
                    self.is_placeholder_ast(right),
                ) {
                    placeholder_to_column.insert(*placeholder_index, column_name);
                    *placeholder_index += 1;
                }
                if let (Some(column_name), true) = (
                    self.extract_column_name_ast(right),
                    self.is_placeholder_ast(left),
                ) {
                    placeholder_to_column.insert(*placeholder_index, column_name);
                    *placeholder_index += 1;
                }

                // Recursively search in sub-expressions
                self.collect_parameter_columns_ast(left, placeholder_to_column, placeholder_index);
                self.collect_parameter_columns_ast(right, placeholder_to_column, placeholder_index);
            }
            Expr::Nested(inner) => {
                self.collect_parameter_columns_ast(inner, placeholder_to_column, placeholder_index);
            }
            Expr::InList { expr, list, .. } => {
                // Handle IN clauses: column_name IN (?, ?, ?)
                if let Some(column_name) = self.extract_column_name_ast(expr) {
                    for item in list {
                        if self.is_placeholder_ast(item) {
                            placeholder_to_column.insert(*placeholder_index, column_name.clone());
                            *placeholder_index += 1;
                        }
                    }
                }
            }
            _ => {}
        }
    }

    /// Check if an AST expression is a placeholder (?)
    fn is_placeholder_ast(&self, expr: &Expr) -> bool {
        matches!(
            expr,
            Expr::Value(datafusion::sql::sqlparser::ast::ValueWithSpan {
                value: datafusion::sql::sqlparser::ast::Value::Placeholder(_),
                ..
            })
        )
    }

    /// Extract column name from an AST expression if it's a simple identifier
    fn extract_column_name_ast(&self, expr: &Expr) -> Option<String> {
        match expr {
            Expr::Identifier(ident) => Some(ident.value.clone()),
            Expr::CompoundIdentifier(parts) => {
                // Handle quoted column names like "Brand" or table.column
                if parts.len() == 1 {
                    Some(parts[0].value.clone())
                } else if parts.len() == 2 {
                    Some(parts[1].value.clone()) // table.column -> column
                } else {
                    None
                }
            }
            _ => None,
        }
    }

    /// Infer likely column name from parameter name using common patterns
    fn infer_column_from_parameter_name(&self, param_name: &str) -> String {
        // Handle common parameter naming patterns
        if param_name.starts_with("min_") {
            param_name
                .strip_prefix("min_")
                .unwrap_or(param_name)
                .to_string()
        } else if param_name.starts_with("max_") {
            param_name
                .strip_prefix("max_")
                .unwrap_or(param_name)
                .to_string()
        } else if param_name.ends_with("_min") {
            param_name
                .strip_suffix("_min")
                .unwrap_or(param_name)
                .to_string()
        } else if param_name.ends_with("_max") {
            param_name
                .strip_suffix("_max")
                .unwrap_or(param_name)
                .to_string()
        } else {
            // Default: assume parameter name matches column name
            param_name.to_string()
        }
    }

    /// Get column type from registered tables using the SessionContext
    async fn get_column_type_from_registered_tables(
        &self,
        column_name: &str,
    ) -> Result<Option<DataType>> {
        // Get table names from the default catalog
        // TODO: Allow custom catalog and schema
        let catalog = self
            .datafusion_ctx
            .catalog("datafusion")
            .ok_or_else(|| anyhow!("Default catalog not found"))?;
        let schema = catalog
            .schema("public")
            .ok_or_else(|| anyhow!("Public schema not found"))?;

        for table_name in schema.table_names() {
            if let Ok(Some(table)) = schema.table(&table_name).await {
                let table_schema = table.schema();

                // Look for the column in this table's schema
                for field in table_schema.fields() {
                    if field.name() == column_name
                        || field.name() == &format!("\"{}\"", column_name)
                    {
                        return Ok(Some(field.data_type().clone()));
                    }
                }
            }
        }

        Ok(None)
    }

    /// Get column type from context (registered tables or inferred)
    async fn get_column_type_from_context(&self, column_name: &str, sql: &str) -> Result<DataType> {
        // First try to get type from registered table schema
        if let Some(schema_type) = self
            .get_column_type_from_registered_tables(column_name)
            .await?
        {
            Ok(schema_type)
        } else {
            // Fallback to pattern-based inference from SQL context
            Ok(self.infer_type_from_sql_context(sql, column_name))
        }
    }

    /// Get column type with special case handling for LIMIT/OFFSET parameters
    async fn get_column_type_with_special_cases(
        &self,
        column_name: &str,
        sql: &str,
    ) -> Result<DataType> {
        // Handle special cases for SQL keywords that don't correspond to table columns
        match column_name.to_lowercase().as_str() {
            "limit" | "offset" => {
                // LIMIT and OFFSET parameters are always integers
                Ok(DataType::Int64)
            }
            _ => {
                // Use normal column type resolution for actual table columns
                self.get_column_type_from_context(column_name, sql).await
            }
        }
    }

    /// Extract response fields from SQL SELECT clause using DataFusion parsing
    /// against registered tables in the SessionContext.
    /// For DML statements (INSERT/UPDATE/DELETE), returns a `count` field since
    /// these operations return the number of affected rows, not data rows.
    pub async fn extract_response_fields(
        &self,
        sql: &str,
    ) -> Result<HashMap<String, InferredFieldType>> {
        // Check if this is a DML statement (INSERT/UPDATE/DELETE)
        // DML statements return a count of affected rows, not data
        if self.is_dml_statement(sql) {
            let mut fields = HashMap::new();
            fields.insert(
                "count".to_string(),
                InferredFieldType {
                    field_type: DataType::UInt64,
                    nullable: false,
                    source_location: "DML statement result (rows affected)".to_string(),
                },
            );
            return Ok(fields);
        }

        // Parse the SQL to get the logical plan using the SessionContext
        let logical_plan = self.parse_sql_to_logical_plan(sql).await?;

        // Extract schema from the logical plan
        let schema = logical_plan.schema();
        let mut fields = HashMap::new();

        for field in schema.fields() {
            let field_name = field.name().clone();
            fields.insert(
                field_name.clone(),
                InferredFieldType {
                    field_type: field.data_type().clone(),
                    nullable: field.is_nullable(),
                    source_location: format!("SELECT field '{}'", field_name),
                },
            );
        }

        Ok(fields)
    }

    /// Validate SQL syntax and structure against registered tables in the SessionContext
    pub async fn validate_sql(&self, sql: &str) -> Result<ValidationReport> {
        let mut errors = Vec::new();
        let mut warnings = Vec::new();

        // Replace parameters for parsing and syntax validation
        let sql_for_parsing = self.replace_parameters_for_parsing(sql)?;

        // Check SQL syntax with parameters replaced
        let parse_result = self.parse_sql_syntax(&sql_for_parsing);
        if let Err(e) = parse_result {
            errors.push(format!("SQL syntax error: {}", e));
            return Ok(ValidationReport {
                is_valid: false,
                errors,
                warnings,
                parameter_count: 0,
                response_field_count: 0,
            });
        }

        // Try to create logical plan
        let plan_result = self.parse_sql_to_logical_plan(sql).await;
        let (parameter_count, response_field_count) = match plan_result {
            Ok(_) => {
                // Count parameters and response fields
                let params = self.extract_parameters(sql).await.unwrap_or_default();
                let fields = self.extract_response_fields(sql).await.unwrap_or_default();
                (params.len(), fields.len())
            }
            Err(e) => {
                warnings.push(format!("Logical plan creation warning: {}", e));
                // Still try to count parameters from the original SQL
                let params = self.extract_parameters(sql).await.unwrap_or_default();
                (params.len(), 0)
            }
        };

        // Check for common issues
        if parameter_count == 0 {
            warnings.push("No parameters found in SQL query".to_string());
        }

        if response_field_count == 0 {
            warnings.push("No response fields found in SQL query".to_string());
        }

        // Validate named parameters follow Rust identifier rules
        let named_params = self.extract_named_parameters(sql).await.unwrap_or_default();
        for (param_name, _) in &named_params {
            if !self.is_valid_rust_identifier(param_name) {
                warnings.push(format!(
                    "Parameter '{}' does not follow Rust identifier naming rules",
                    param_name
                ));
            }
        }

        Ok(ValidationReport {
            is_valid: errors.is_empty(),
            errors,
            warnings,
            parameter_count,
            response_field_count,
        })
    }

    /// Check if the SQL statement is a DML statement (INSERT/UPDATE/DELETE)
    fn is_dml_statement(&self, sql: &str) -> bool {
        let trimmed = sql.trim().to_uppercase();
        trimmed.starts_with("INSERT")
            || trimmed.starts_with("UPDATE")
            || trimmed.starts_with("DELETE")
    }

    /// Check if parameter name follows Rust identifier rules
    fn is_valid_rust_identifier(&self, name: &str) -> bool {
        let mut chars = name.chars();
        let Some(first_char) = chars.next() else {
            return false;
        };
        if !first_char.is_alphabetic() && first_char != '_' {
            return false;
        }

        chars.all(|c| c.is_alphanumeric() || c == '_')
    }

    /// Parse SQL syntax using sqlparser
    fn parse_sql_syntax(&self, sql: &str) -> Result<Vec<Statement>> {
        let dialect = GenericDialect {};
        Parser::parse_sql(&dialect, sql).map_err(|e| anyhow!("Failed to parse SQL syntax: {}", e))
    }

    /// Parse SQL to DataFusion logical plan using the SessionContext
    async fn parse_sql_to_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
        // Replace parameters with placeholder values for parsing
        let sql_for_parsing = self.replace_parameters_for_parsing(sql)?;

        // Create logical plan using the SessionContext
        let plan = self
            .datafusion_ctx
            .sql(&sql_for_parsing)
            .await?
            .into_optimized_plan()
            .map_err(|e| anyhow!("Failed to create logical plan: {}", e))?;

        Ok(plan)
    }

    /// Replace {parameter_name} placeholders with NULL for DataFusion parsing.
    ///
    /// NULL works for all SQL contexts:
    /// - `LIMIT NULL` is valid in DataFusion (treated as no limit)
    /// - UDFs like `candle()` have hardcoded return types independent of argument types,
    ///   so `candle('model', NULL)` still resolves to `List<Float32>` as long as the UDF
    ///   is registered in the SessionContext before pipeline loading.
    /// - `VALUES {rows}` (the multi-row tuple-list shape) gets a special-case
    ///   pre-pass that emits `VALUES (NULL)` rather than the malformed
    ///   `VALUES NULL`, since DataFusion rejects the latter at parse time.
    fn replace_parameters_for_parsing(&self, sql: &str) -> Result<String> {
        // Pre-pass: `VALUES {name}` → `VALUES (NULL)` so the multi-row
        // tuple-list shape parses (`VALUES NULL` is not valid SQL).
        let values_pattern = regex::Regex::new(r"(?i)\bVALUES\s*\{[a-zA-Z_][a-zA-Z0-9_]*\}")
            .map_err(|e| anyhow!("Failed to compile VALUES placeholder regex: {}", e))?;
        let sql = values_pattern.replace_all(sql, "VALUES (NULL)").to_string();

        let parameter_pattern = regex::Regex::new(r"\{[a-zA-Z_][a-zA-Z0-9_]*\}")
            .map_err(|e| anyhow!("Failed to compile parameter regex: {}", e))?;
        Ok(parameter_pattern.replace_all(&sql, "NULL").to_string())
    }

    /// Infer data type from SQL context when table schema is not available
    fn infer_type_from_sql_context(&self, sql: &str, column_name: &str) -> DataType {
        if self.is_integer_context(sql, column_name) {
            DataType::Int64
        } else if self.is_numeric_context(sql, column_name) {
            DataType::Float64
        } else {
            // Default to string type
            DataType::Utf8
        }
    }

    /// Check if a column appears in numeric context (e.g., comparison with numbers)
    /// TODO: Handle compare with timestamps
    fn is_numeric_context(&self, sql: &str, column_name: &str) -> bool {
        let column_pattern = format!(r"\b{}\b", regex::escape(column_name));
        let numeric_patterns = [
            format!(r"{}\s*[<>=!]+\s*\d+\.?\d*", column_pattern),
            format!(r"\d+\.?\d*\s*[<>=!]+\s*{}", column_pattern),
            format!(r"{}\s*[<>=!]+\s*\d+", column_pattern),
            format!(r"\d+\s*[<>=!]+\s*{}", column_pattern),
        ];

        numeric_patterns
            .iter()
            .any(|pattern| regex::Regex::new(pattern).map_or(false, |regex| regex.is_match(sql)))
    }

    /// Check if a column appears in integer context (e.g., LIMIT clause)
    fn is_integer_context(&self, sql: &str, column_name: &str) -> bool {
        let column_pattern = regex::escape(column_name);
        let integer_patterns = [
            format!(r"LIMIT\s+{}", column_pattern),
            format!(r"OFFSET\s+{}", column_pattern),
            format!(r"TOP\s+{}", column_pattern),
        ];

        let sql_upper = sql.to_uppercase();
        let column_upper = column_name.to_uppercase();

        integer_patterns.iter().any(|pattern| {
            let pattern_upper = pattern.replace(column_name, &column_upper);
            regex::Regex::new(&pattern_upper).map_or(false, |regex| regex.is_match(&sql_upper))
        })
    }
}

impl Default for SqlSchemaInferrer {
    fn default() -> Self {
        Self::new(Arc::new(SessionContext::new())).expect("Failed to create SqlSchemaInferrer")
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use datafusion::arrow::array::{Float64Array, Int64Array, StringArray};
    use datafusion::arrow::datatypes::{DataType, Field, Schema};
    use datafusion::arrow::record_batch::RecordBatch;
    use std::sync::Arc;

    async fn create_test_context() -> SessionContext {
        let ctx = SessionContext::new();

        // Create a test products table schema
        let schema = Arc::new(Schema::new(vec![
            Field::new("product_id", DataType::Int64, false),
            Field::new("name", DataType::Utf8, false),
            Field::new("brand", DataType::Utf8, true),
            Field::new("price", DataType::Float64, false),
            Field::new("category", DataType::Utf8, true),
            Field::new("stock", DataType::Int64, false),
        ]));

        // Create sample data
        let product_ids = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
        let names = Arc::new(StringArray::from(vec![
            "Laptop Pro",
            "Gaming Mouse",
            "Wireless Headphones",
            "Tablet",
            "Smartphone",
        ]));
        let brands = Arc::new(StringArray::from(vec![
            Some("Apple"),
            Some("Logitech"),
            Some("Sony"),
            Some("Samsung"),
            Some("Google"),
        ]));
        let prices = Arc::new(Float64Array::from(vec![
            1299.99, 79.99, 199.99, 599.99, 899.99,
        ]));
        let categories = Arc::new(StringArray::from(vec![
            Some("Electronics"),
            Some("Accessories"),
            Some("Audio"),
            Some("Electronics"),
            Some("Electronics"),
        ]));
        let stock = Arc::new(Int64Array::from(vec![50, 120, 75, 30, 85]));

        let batch = RecordBatch::try_new(
            schema.clone(),
            vec![product_ids, names, brands, prices, categories, stock],
        )
        .unwrap();

        // Register the table
        ctx.register_batch("products", batch).unwrap();

        ctx
    }

    #[tokio::test]
    async fn test_extract_parameters() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql = r#"
            SELECT product_id, name, price
            FROM products
            WHERE brand = {brand}
            AND price < {price}
        "#;

        let params = inferrer.extract_parameters(sql).await.unwrap();

        assert_eq!(params.len(), 2);
        assert!(params.contains_key("brand"));
        assert!(params.contains_key("price"));

        // All parameters should be nullable
        assert!(params["brand"].nullable);
        assert!(params["price"].nullable);

        // Check type inference based on registered table schema
        assert_eq!(params["brand"].field_type, DataType::Utf8);
        assert_eq!(params["price"].field_type, DataType::Float64);
    }

    #[tokio::test]
    async fn test_extract_response_fields() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql = r#"
            SELECT
                product_id,
                name as product_name,
                price,
                brand
            FROM products
            WHERE price > {price}
        "#;

        let fields = inferrer.extract_response_fields(sql).await.unwrap();

        assert_eq!(fields.len(), 4);
        assert!(fields.contains_key("product_id"));
        assert!(fields.contains_key("product_name"));
        assert!(fields.contains_key("price"));
        assert!(fields.contains_key("brand"));

        // Check field types match table schema
        assert_eq!(fields["product_id"].field_type, DataType::Int64);
        assert_eq!(fields["product_name"].field_type, DataType::Utf8);
        assert_eq!(fields["price"].field_type, DataType::Float64);
        assert_eq!(fields["brand"].field_type, DataType::Utf8);

        // Check nullability
        assert!(!fields["product_id"].nullable);
        assert!(!fields["product_name"].nullable);
        assert!(!fields["price"].nullable);
        assert!(fields["brand"].nullable);
    }

    #[tokio::test]
    async fn test_validate_sql_with_registered_table() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // Valid SQL with registered table
        let valid_sql = "SELECT product_id, name FROM products WHERE price > {price}";
        let report = inferrer.validate_sql(valid_sql).await.unwrap();
        assert!(report.is_valid);
        assert_eq!(report.parameter_count, 1);
        assert_eq!(report.response_field_count, 2);

        // Invalid SQL - clear syntax error
        let invalid_sql = "SELECT * FROM table WHERE column =";
        let report = inferrer.validate_sql(invalid_sql).await.unwrap();
        assert!(!report.is_valid);
        assert!(!report.errors.is_empty());

        // Valid syntax but references non-existent table
        let nonexistent_table_sql = "SELECT id FROM missing_table WHERE value = {value}";
        let report = inferrer.validate_sql(nonexistent_table_sql).await.unwrap();
        // This should generate warnings but not necessarily be invalid
        assert_eq!(report.parameter_count, 1); // Fallback inference should still detect the parameter
    }

    #[test]
    fn test_parameter_type_inference() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // Test numeric context
        assert!(inferrer.is_numeric_context("WHERE price > 100", "price"));
        assert!(inferrer.is_numeric_context("WHERE score <= 100.5", "score"));
        assert!(inferrer.is_numeric_context("WHERE value = 42", "value"));

        // Test integer context - note that these patterns need the column in a LIMIT/OFFSET context
        assert!(inferrer.is_integer_context("SELECT * FROM table LIMIT count", "count"));
        assert!(inferrer.is_integer_context("SELECT * FROM table OFFSET skip", "skip"));

        // Test string context (default)
        assert!(!inferrer.is_numeric_context("WHERE name = 'brand'", "name"));
        assert!(!inferrer.is_integer_context("WHERE category = 'type'", "category"));
    }

    #[tokio::test]
    async fn test_fallback_type_inference() {
        // Test fallback type inference when no tables are registered
        let ctx = SessionContext::new(); // No tables registered
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // Test the pattern-based inference logic directly
        assert_eq!(
            inferrer.infer_type_from_sql_context("WHERE score > 100", "score"),
            DataType::Float64
        );
        assert_eq!(
            inferrer.infer_type_from_sql_context("LIMIT count", "count"),
            DataType::Int64
        );
        assert_eq!(
            inferrer.infer_type_from_sql_context("WHERE name = 'test'", "name"),
            DataType::Utf8
        );

        // Test full parameter extraction with fallback
        let sql = r#"
            SELECT * FROM unknown_table
            WHERE name = {name}
            AND description = {description}
        "#;

        let params = inferrer.extract_parameters(sql).await.unwrap();

        assert_eq!(params.len(), 2);

        // Since there are no numeric patterns in this SQL, both should default to string
        assert_eq!(params["name"].field_type, DataType::Utf8);
        assert_eq!(params["description"].field_type, DataType::Utf8);

        // Test that all parameters are nullable
        assert!(params["name"].nullable);
        assert!(params["description"].nullable);
    }

    #[tokio::test]
    async fn test_complex_parameter_patterns() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let complex_sql = r#"
            SELECT product_id, name, price, brand
            FROM products
            WHERE brand = {brand}
            AND price > {price}
            AND stock > {stock}
            ORDER BY price DESC
        "#;

        let params = inferrer.extract_parameters(complex_sql).await.unwrap();

        // Should find parameters for the named parameters
        assert_eq!(params.len(), 3); // brand, price, stock
        assert!(params.contains_key("brand"));
        assert!(params.contains_key("price"));
        assert!(params.contains_key("stock"));

        // Check type inference from table schema
        assert_eq!(params["brand"].field_type, DataType::Utf8);
        assert_eq!(params["price"].field_type, DataType::Float64);
        assert_eq!(params["stock"].field_type, DataType::Int64);
    }

    #[tokio::test]
    async fn test_named_parameter_extraction() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql = r#"
            SELECT product_id, name, price, brand
            FROM products
            WHERE brand = {brand}
            AND price >= {min_price}
            AND price <= {max_price}
            AND category = {category}
            LIMIT {limit}
        "#;

        let named_params = inferrer.extract_named_parameters(sql).await.unwrap();

        // Should find all 5 named parameters
        assert_eq!(named_params.len(), 5);
        assert!(named_params.contains_key("brand"));
        assert!(named_params.contains_key("min_price"));
        assert!(named_params.contains_key("max_price"));
        assert!(named_params.contains_key("category"));
        assert!(named_params.contains_key("limit"));

        // Check parameter metadata
        let brand_param = &named_params["brand"];
        assert_eq!(brand_param.name, "brand");
        assert_eq!(brand_param.column_name, "brand");
        assert_eq!(brand_param.field_type, DataType::Utf8);
        assert!(brand_param.nullable);
        assert_eq!(brand_param.occurrences, 1);

        // Check min_price and max_price both map to price column
        let min_price_param = &named_params["min_price"];
        assert_eq!(min_price_param.name, "min_price");
        assert_eq!(min_price_param.column_name, "price");
        assert_eq!(min_price_param.field_type, DataType::Float64);

        let max_price_param = &named_params["max_price"];
        assert_eq!(max_price_param.name, "max_price");
        assert_eq!(max_price_param.column_name, "price");
        assert_eq!(max_price_param.field_type, DataType::Float64);

        // Check limit parameter
        let limit_param = &named_params["limit"];
        assert_eq!(limit_param.name, "limit");
        assert_eq!(limit_param.field_type, DataType::Int64);
    }

    #[tokio::test]
    async fn test_parameter_replacement_for_parsing() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql =
            "SELECT * FROM products WHERE brand = {brand} AND price > {min_price} LIMIT {limit}";
        let replaced = inferrer.replace_parameters_for_parsing(sql).unwrap();

        // All parameters replaced with NULL
        let expected = "SELECT * FROM products WHERE brand = NULL AND price > NULL LIMIT NULL";
        assert_eq!(replaced, expected);
    }

    /// `VALUES {rows}` must produce `VALUES (NULL)` rather than the
    /// malformed `VALUES NULL` so the logical-plan path also accepts the
    /// multi-row tuple-list shape.
    #[tokio::test]
    async fn test_parameter_replacement_for_parsing_values_tuple_list() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let replaced = inferrer
            .replace_parameters_for_parsing("INSERT INTO t (a, b) VALUES {rows}")
            .unwrap();
        assert_eq!(replaced, "INSERT INTO t (a, b) VALUES (NULL)");

        // Multi-line YAML form survives the regex.
        let replaced = inferrer
            .replace_parameters_for_parsing(
                "INSERT INTO products (product_id, name)\nVALUES {rows}",
            )
            .unwrap();
        assert_eq!(
            replaced,
            "INSERT INTO products (product_id, name)\nVALUES (NULL)"
        );
    }

    #[tokio::test]
    async fn test_rust_identifier_validation() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // Valid Rust identifiers
        assert!(inferrer.is_valid_rust_identifier("valid_name"));
        assert!(inferrer.is_valid_rust_identifier("_underscore"));
        assert!(inferrer.is_valid_rust_identifier("name123"));
        assert!(inferrer.is_valid_rust_identifier("CamelCase"));

        // Invalid Rust identifiers
        assert!(!inferrer.is_valid_rust_identifier("123invalid"));
        assert!(!inferrer.is_valid_rust_identifier("invalid-name"));
        assert!(!inferrer.is_valid_rust_identifier("invalid.name"));
        assert!(!inferrer.is_valid_rust_identifier(""));
    }

    #[tokio::test]
    async fn test_parameter_name_to_column_inference() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // Test common naming patterns
        assert_eq!(
            inferrer.infer_column_from_parameter_name("min_price"),
            "price"
        );
        assert_eq!(
            inferrer.infer_column_from_parameter_name("max_price"),
            "price"
        );
        assert_eq!(
            inferrer.infer_column_from_parameter_name("price_min"),
            "price"
        );
        assert_eq!(
            inferrer.infer_column_from_parameter_name("price_max"),
            "price"
        );
        assert_eq!(inferrer.infer_column_from_parameter_name("brand"), "brand");
        assert_eq!(
            inferrer.infer_column_from_parameter_name("category"),
            "category"
        );
    }

    #[tokio::test]
    async fn test_ast_based_parameter_extraction() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // Test complex SQL with multiple parameter types - demonstrates AST analysis superiority
        let sql = r#"
            SELECT product_id, name, price, brand, category
            FROM products
            WHERE (brand = {brand_filter} OR brand = {backup_brand})
            AND price >= {min_price} AND price <= {max_price}
            AND category IN ({primary_category}, {secondary_category})
            AND stock > {min_stock}
            ORDER BY price ASC
            LIMIT {limit}
        "#;

        let named_params = inferrer.extract_named_parameters(sql).await.unwrap();

        // Should find all unique named parameters
        assert_eq!(named_params.len(), 8); // brand_filter, backup_brand, min_price, max_price, primary_category, secondary_category, min_stock, limit

        // Test AST-based column mapping accuracy
        let brand_filter = &named_params["brand_filter"];
        assert_eq!(brand_filter.column_name, "brand");
        assert_eq!(brand_filter.field_type, DataType::Utf8);

        let backup_brand = &named_params["backup_brand"];
        assert_eq!(backup_brand.column_name, "brand");
        assert_eq!(backup_brand.field_type, DataType::Utf8);

        // Test BETWEEN clause mapping (AST can handle complex operators)
        let min_price = &named_params["min_price"];
        assert_eq!(min_price.column_name, "price");
        assert_eq!(min_price.field_type, DataType::Float64);

        let max_price = &named_params["max_price"];
        assert_eq!(max_price.column_name, "price");
        assert_eq!(max_price.field_type, DataType::Float64);

        // Test IN clause mapping (multiple parameters for same column)
        let primary_category = &named_params["primary_category"];
        assert_eq!(primary_category.column_name, "category");
        assert_eq!(primary_category.field_type, DataType::Utf8);

        let secondary_category = &named_params["secondary_category"];
        assert_eq!(secondary_category.column_name, "category");
        assert_eq!(secondary_category.field_type, DataType::Utf8);

        // Test stock comparison
        let min_stock = &named_params["min_stock"];
        assert_eq!(min_stock.column_name, "stock");
        assert_eq!(min_stock.field_type, DataType::Int64);

        // Test LIMIT special case handling
        let limit_param = &named_params["limit"];
        assert_eq!(limit_param.column_name, "limit");
        assert_eq!(limit_param.field_type, DataType::Int64); // Special case: always integer

        // Verify source location indicates AST analysis
        assert!(brand_filter.source_location.contains("AST analysis"));
        assert!(limit_param.source_location.contains("placeholder index"));
    }

    #[tokio::test]
    async fn test_convert_named_to_placeholders() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql = "SELECT * FROM products WHERE brand = {brand} AND price > {min_price} AND price < {min_price} LIMIT {limit}";
        let (sql_with_placeholders, parameter_order) =
            inferrer.convert_named_to_placeholders(sql).unwrap();

        // Should replace all {param} with ? and preserve unique parameter order
        assert_eq!(
            sql_with_placeholders,
            "SELECT * FROM products WHERE brand = ? AND price > ? AND price < ? LIMIT ?"
        );
        assert_eq!(parameter_order, vec!["brand", "min_price", "limit"]); // Unique parameters only

        // Test parameter order preservation with repeated parameters
        assert_eq!(parameter_order.len(), 3); // Only unique parameters
        assert!(parameter_order.contains(&"brand".to_string()));
        assert!(parameter_order.contains(&"min_price".to_string()));
        assert!(parameter_order.contains(&"limit".to_string()));
    }

    /// `INSERT … VALUES {rows}` is the multi-row tuple-list shape the
    /// server-side renderer expands at request time. The parser stub must
    /// emit `VALUES (?)`, not the malformed `VALUES ?`, otherwise pipeline
    /// load fails with `Expected: (, found: ?` and the pipeline never
    /// becomes available — the same class of bug the SQL validator fix
    /// addressed at config-load time.
    #[tokio::test]
    async fn test_convert_named_to_placeholders_values_tuple_list() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let (sql_with_placeholders, parameter_order) = inferrer
            .convert_named_to_placeholders("INSERT INTO users (name, email) VALUES {rows}")
            .unwrap();
        assert_eq!(
            sql_with_placeholders,
            "INSERT INTO users (name, email) VALUES (?)"
        );
        assert_eq!(parameter_order, vec!["rows"]);

        // Multi-line YAML form (`query: |` produces a newline between
        // `(cols)` and `VALUES`) — the regex must still match.
        let (sql_with_placeholders, _) = inferrer
            .convert_named_to_placeholders("INSERT INTO products (product_id, name)\nVALUES {rows}")
            .unwrap();
        assert_eq!(
            sql_with_placeholders,
            "INSERT INTO products (product_id, name)\nVALUES (?)"
        );

        // The rendered stub must be a parseable SQL statement — i.e. the
        // pipeline-loader's `parse_sql_syntax` accepts it.
        Parser::parse_sql(&GenericDialect {}, &sql_with_placeholders)
            .expect("VALUES (?) stub must parse");
    }

    #[tokio::test]
    async fn test_dml_response_schema_insert() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // INSERT statement should return a count field, not try to create a logical plan
        let sql = "INSERT INTO products (name, price) VALUES ({name}, {price})";
        let fields = inferrer.extract_response_fields(sql).await.unwrap();

        assert_eq!(fields.len(), 1);
        assert!(fields.contains_key("count"));
        assert_eq!(fields["count"].field_type, DataType::UInt64);
        assert!(!fields["count"].nullable);
    }

    #[tokio::test]
    async fn test_dml_response_schema_update() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql = "UPDATE products SET price = {price} WHERE name = {name}";
        let fields = inferrer.extract_response_fields(sql).await.unwrap();

        assert_eq!(fields.len(), 1);
        assert!(fields.contains_key("count"));
        assert_eq!(fields["count"].field_type, DataType::UInt64);
    }

    #[tokio::test]
    async fn test_dml_response_schema_delete() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        let sql = "DELETE FROM products WHERE name = {name}";
        let fields = inferrer.extract_response_fields(sql).await.unwrap();

        assert_eq!(fields.len(), 1);
        assert!(fields.contains_key("count"));
        assert_eq!(fields["count"].field_type, DataType::UInt64);
    }

    #[tokio::test]
    async fn test_insert_select_parameter_extraction() {
        let ctx = create_test_context().await;
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        // INSERT INTO ... SELECT with WHERE clause parameter
        let sql = r#"
            INSERT INTO products (name, price)
            SELECT name, price
            FROM products
            WHERE brand = {brand}
        "#;

        let named_params = inferrer.extract_named_parameters(sql).await.unwrap();

        assert_eq!(named_params.len(), 1);
        assert!(named_params.contains_key("brand"));
        assert_eq!(named_params["brand"].column_name, "brand");
        assert_eq!(named_params["brand"].field_type, DataType::Utf8);
    }

    #[tokio::test]
    async fn test_dml_is_detected() {
        let ctx = SessionContext::new();
        let inferrer = SqlSchemaInferrer::new(Arc::new(ctx)).unwrap();

        assert!(inferrer.is_dml_statement("INSERT INTO t (a) VALUES (1)"));
        assert!(inferrer.is_dml_statement("  UPDATE t SET a = 1"));
        assert!(inferrer.is_dml_statement("DELETE FROM t WHERE id = 1"));
        assert!(inferrer.is_dml_statement("  insert into t (a) values (1)"));
        assert!(!inferrer.is_dml_statement("SELECT * FROM t"));
        assert!(!inferrer.is_dml_statement("  select * from t"));
    }
}