Skip to main content

sqlrite/sql/parser/
create.rs

1use sqlparser::ast::{
2    ColumnDef, ColumnOption, CreateTable, DataType, Expr, ObjectName, ObjectNamePart, Statement,
3    UnaryOperator, Value as AstValue,
4};
5
6use crate::error::{Result, SQLRiteError};
7use crate::sql::db::table::Value;
8
9/// True when an `ObjectName` resolves to a single identifier `VECTOR`
10/// (case-insensitive). Phase 7a adds the `VECTOR(N)` column type as a
11/// sqlparser `DataType::Custom` — the engine recognizes it via this
12/// helper so the regular DataType match arm above stays uncluttered.
13fn is_vector_type(name: &ObjectName) -> bool {
14    name.0.len() == 1
15        && match &name.0[0] {
16            ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
17            // Function-form ObjectNamePart shouldn't appear in a CREATE TABLE
18            // column type position. If it ever does, treat it as not-a-vector
19            // and the outer match falls through to the "Invalid" arm.
20            _ => false,
21        }
22}
23
24/// Parses the dimension out of the `Custom` args for `VECTOR(N)`.
25/// `args` is the `Vec<String>` sqlparser hands back for parenthesized
26/// type arguments — for `VECTOR(384)` that's `["384"]`. Validates that
27/// exactly one positive-integer argument was supplied.
28fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
29    match args {
30        [] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
31        [single] => {
32            let trimmed = single.trim();
33            match trimmed.parse::<usize>() {
34                Ok(d) if d > 0 => Ok(d),
35                Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
36                Err(_) => Err(format!(
37                    "VECTOR dimension must be a positive integer (got `{trimmed}`)"
38                )),
39            }
40        }
41        many => Err(format!(
42            "VECTOR takes exactly one dimension argument (got {})",
43            many.len()
44        )),
45    }
46}
47
48/// The schema for each SQL column in every table is represented by
49/// the following structure after parsed and tokenized
50#[derive(PartialEq, Debug, Clone)]
51pub struct ParsedColumn {
52    /// Name of the column
53    pub name: String,
54    /// Datatype of the column in String format
55    pub datatype: String,
56    /// Value representing if column is PRIMARY KEY
57    pub is_pk: bool,
58    /// Value representing if column was declared with the NOT NULL Constraint
59    pub not_null: bool,
60    /// Value representing if column was declared with the UNIQUE Constraint
61    pub is_unique: bool,
62    /// Literal value to use when this column is omitted from an INSERT.
63    /// Restricted to literal expressions (integer, real, text, bool, NULL);
64    /// non-literal `DEFAULT` expressions are rejected at CREATE TABLE time.
65    pub default: Option<Value>,
66}
67
68/// The following structure represents a CREATE TABLE query already parsed
69/// and broken down into name and a Vector of `ParsedColumn` metadata
70///
71#[derive(Debug)]
72pub struct CreateQuery {
73    /// name of table after parking and tokenizing of query
74    pub table_name: String,
75    /// Vector of `ParsedColumn` type with column metadata information
76    pub columns: Vec<ParsedColumn>,
77    /// `true` when the statement was `CREATE TABLE IF NOT EXISTS …`.
78    /// When set, re-creating an existing table is a no-op rather than
79    /// an error — matching `CREATE INDEX IF NOT EXISTS` and SQLite.
80    pub if_not_exists: bool,
81}
82
83/// Parses a single sqlparser `ColumnDef` into our internal `ParsedColumn`
84/// representation. Extracted from `CreateQuery::new` so `ALTER TABLE ADD
85/// COLUMN` can reuse the same column-shape parsing without re-implementing
86/// the type / constraint / default plumbing.
87///
88/// Caller-side responsibilities not handled here:
89/// - duplicate column name detection (a multi-column invariant)
90/// - "more than one PRIMARY KEY" detection (a multi-column invariant)
91pub fn parse_one_column(col: &ColumnDef) -> Result<ParsedColumn> {
92    let name = col.name.to_string();
93
94    // Parsing each column for it data type
95    // For now only accepting basic data types
96    let datatype: String = match &col.data_type {
97        DataType::TinyInt(_)
98        | DataType::SmallInt(_)
99        | DataType::Int2(_)
100        | DataType::Int(_)
101        | DataType::Int4(_)
102        | DataType::Int8(_)
103        | DataType::Integer(_)
104        | DataType::BigInt(_) => "Integer".to_string(),
105        DataType::Boolean => "Bool".to_string(),
106        DataType::Text => "Text".to_string(),
107        DataType::Varchar(_bytes) => "Text".to_string(),
108        DataType::Real => "Real".to_string(),
109        DataType::Float(_precision) => "Real".to_string(),
110        DataType::Double(_) => "Real".to_string(),
111        DataType::Decimal(_) => "Real".to_string(),
112        // Phase 7e — `JSON` parses as a unit variant in
113        // sqlparser's DataType enum. JSONB is treated as
114        // an alias (matches PostgreSQL's permissive
115        // behaviour); both store as text under the hood.
116        DataType::JSON | DataType::JSONB => "Json".to_string(),
117        // Phase 7a — `VECTOR(N)` parses as Custom("VECTOR", ["N"]).
118        // sqlparser's SQLite dialect doesn't have a built-in
119        // Vector variant; Custom is what unrecognized type
120        // names + their parenthesized args fall through to.
121        DataType::Custom(name, args) if is_vector_type(name) => match parse_vector_dim(args) {
122            Ok(dim) => format!("vector({dim})"),
123            Err(e) => {
124                return Err(SQLRiteError::General(format!(
125                    "Invalid VECTOR column '{}': {e}",
126                    col.name
127                )));
128            }
129        },
130        other => {
131            eprintln!("not matched on custom type: {other:?}");
132            "Invalid".to_string()
133        }
134    };
135
136    let mut is_pk: bool = false;
137    let mut is_unique: bool = false;
138    let mut not_null: bool = false;
139    let mut default: Option<Value> = None;
140    for column_option in &col.options {
141        match &column_option.option {
142            ColumnOption::PrimaryKey(_) => {
143                // For now, only Integer and Text types can be PRIMARY KEY and Unique
144                // Therefore Indexed.
145                if datatype != "Real" && datatype != "Bool" {
146                    is_pk = true;
147                    is_unique = true;
148                    not_null = true;
149                }
150            }
151            ColumnOption::Unique(_) => {
152                // For now, only Integer and Text types can be UNIQUE
153                // Therefore Indexed.
154                if datatype != "Real" && datatype != "Bool" {
155                    is_unique = true;
156                }
157            }
158            ColumnOption::NotNull => {
159                not_null = true;
160            }
161            ColumnOption::Default(expr) => {
162                default = Some(eval_literal_default(expr, &datatype, &name)?);
163            }
164            _ => (),
165        };
166    }
167
168    Ok(ParsedColumn {
169        name,
170        datatype,
171        is_pk,
172        not_null,
173        is_unique,
174        default,
175    })
176}
177
178/// Evaluates a `DEFAULT <expr>` clause to a runtime `Value`. Restricted to
179/// literal expressions — anything else (function calls, column references,
180/// arithmetic on non-literals, `CURRENT_TIMESTAMP`, …) is rejected with a
181/// typed error so users see the limit at `CREATE TABLE` time rather than
182/// silently accepting a `DEFAULT` we can't honour at INSERT time.
183///
184/// Negative numeric literals come through sqlparser as `UnaryOp { Minus, Value(N) }`;
185/// we unwrap one level of leading `+`/`-` to support `DEFAULT -1` / `DEFAULT +3.14`.
186///
187/// Type-checks the literal against the column's declared datatype and
188/// rejects mismatches (e.g. `INTEGER ... DEFAULT 'foo'`).
189fn eval_literal_default(expr: &Expr, datatype: &str, col_name: &str) -> Result<Value> {
190    let value = match expr {
191        Expr::Value(v) => &v.value,
192        Expr::UnaryOp {
193            op: UnaryOperator::Minus,
194            expr: inner,
195        } => {
196            return match inner.as_ref() {
197                Expr::Value(v) => match &v.value {
198                    AstValue::Number(n, _) => {
199                        let neg = format!("-{n}");
200                        coerce_number_default(&neg, datatype, col_name)
201                    }
202                    _ => Err(SQLRiteError::General(format!(
203                        "DEFAULT for column '{col_name}' must be a literal value"
204                    ))),
205                },
206                _ => Err(SQLRiteError::General(format!(
207                    "DEFAULT for column '{col_name}' must be a literal value"
208                ))),
209            };
210        }
211        Expr::UnaryOp {
212            op: UnaryOperator::Plus,
213            expr: inner,
214        } => {
215            return eval_literal_default(inner, datatype, col_name);
216        }
217        _ => {
218            return Err(SQLRiteError::General(format!(
219                "DEFAULT for column '{col_name}' must be a literal value"
220            )));
221        }
222    };
223
224    match value {
225        AstValue::Null => Ok(Value::Null),
226        AstValue::Boolean(b) => {
227            if datatype == "Bool" {
228                Ok(Value::Bool(*b))
229            } else {
230                Err(SQLRiteError::General(format!(
231                    "DEFAULT type mismatch for column '{col_name}': boolean is not a {datatype}"
232                )))
233            }
234        }
235        AstValue::SingleQuotedString(s) => {
236            if datatype == "Text" {
237                Ok(Value::Text(s.clone()))
238            } else if datatype == "Json" {
239                // JSON columns accept text literals only if they parse as
240                // JSON — otherwise an ALTER TABLE ADD COLUMN ... JSON
241                // DEFAULT '<garbage>' would silently backfill every row
242                // with invalid JSON (insert_row's per-row JSON validation
243                // is bypassed during the backfill path).
244                serde_json::from_str::<serde_json::Value>(s).map_err(|e| {
245                    SQLRiteError::General(format!(
246                        "DEFAULT type mismatch for column '{col_name}': '{s}' is not valid JSON: {e}"
247                    ))
248                })?;
249                Ok(Value::Text(s.clone()))
250            } else {
251                Err(SQLRiteError::General(format!(
252                    "DEFAULT type mismatch for column '{col_name}': text is not a {datatype}"
253                )))
254            }
255        }
256        AstValue::Number(n, _) => coerce_number_default(n, datatype, col_name),
257        _ => Err(SQLRiteError::General(format!(
258            "DEFAULT for column '{col_name}' must be a literal value"
259        ))),
260    }
261}
262
263fn coerce_number_default(n: &str, datatype: &str, col_name: &str) -> Result<Value> {
264    match datatype {
265        "Integer" => n.parse::<i64>().map(Value::Integer).map_err(|_| {
266            SQLRiteError::General(format!(
267                "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid INTEGER"
268            ))
269        }),
270        "Real" => n.parse::<f64>().map(Value::Real).map_err(|_| {
271            SQLRiteError::General(format!(
272                "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid REAL"
273            ))
274        }),
275        other => Err(SQLRiteError::General(format!(
276            "DEFAULT type mismatch for column '{col_name}': numeric literal is not a {other}"
277        ))),
278    }
279}
280
281impl CreateQuery {
282    pub fn new(statement: &Statement) -> Result<CreateQuery> {
283        match statement {
284            // Confirming the Statement is sqlparser::ast:Statement::CreateTable
285            Statement::CreateTable(CreateTable {
286                name,
287                columns,
288                constraints,
289                if_not_exists,
290                ..
291            }) => {
292                let table_name = name;
293                let mut parsed_columns: Vec<ParsedColumn> = vec![];
294
295                // Iterating over the columns returned form the Parser::parse:sql
296                // in the mod sql
297                for col in columns {
298                    // Checks if columm already added to parsed_columns, if so, returns an error
299                    let name = col.name.to_string();
300                    if parsed_columns.iter().any(|c| c.name == name) {
301                        return Err(SQLRiteError::Internal(format!(
302                            "Duplicate column name: {}",
303                            &name
304                        )));
305                    }
306
307                    let parsed = parse_one_column(col)?;
308
309                    // Multi-column invariant: only one PRIMARY KEY per table.
310                    if parsed.is_pk && parsed_columns.iter().any(|c| c.is_pk) {
311                        return Err(SQLRiteError::Internal(format!(
312                            "Table '{}' has more than one primary key",
313                            &table_name
314                        )));
315                    }
316
317                    parsed_columns.push(parsed);
318                }
319                // TODO: handle constraints + check constraints + ON DELETE /
320                // ON UPDATE referential actions properly. They're currently
321                // parsed by `sqlparser` and dropped on the floor here.
322                // (Previously we `println!`-ed them to stdout as a debug
323                // aid — removed in the engine-stdout-pollution cleanup;
324                // flip to a `tracing` span if we ever want them visible in
325                // dev builds.)
326                let _ = constraints;
327                Ok(CreateQuery {
328                    table_name: table_name.to_string(),
329                    columns: parsed_columns,
330                    if_not_exists: *if_not_exists,
331                })
332            }
333
334            _ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::sql::*;
343
344    #[test]
345    fn create_table_validate_tablename_test() {
346        let sql_input = String::from(
347            "CREATE TABLE contacts (
348            id INTEGER PRIMARY KEY,
349            first_name TEXT NOT NULL,
350            last_name TEXT NOT NULl,
351            email TEXT NOT NULL UNIQUE
352        );",
353        );
354        let expected_table_name = String::from("contacts");
355
356        let dialect = SqlriteDialect::new();
357        let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
358
359        assert!(ast.len() == 1, "ast has more then one Statement");
360
361        let query = ast.pop().unwrap();
362
363        // Initialy only implementing some basic SQL Statements
364        if let Statement::CreateTable(_) = query {
365            let result = CreateQuery::new(&query);
366            match result {
367                Ok(payload) => {
368                    assert_eq!(payload.table_name, expected_table_name);
369                }
370                Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
371            }
372        }
373    }
374
375    /// SQLR-10 — the `IF NOT EXISTS` clause must surface on `CreateQuery`
376    /// so the executor can treat a re-create as a no-op.
377    #[test]
378    fn create_query_captures_if_not_exists_flag() {
379        let dialect = SqlriteDialect::new();
380
381        // Without IF NOT EXISTS → flag is false.
382        let mut ast =
383            Parser::parse_sql(&dialect, "CREATE TABLE t (id INTEGER PRIMARY KEY);").unwrap();
384        let q = ast.pop().unwrap();
385        assert!(!CreateQuery::new(&q).unwrap().if_not_exists);
386
387        // With IF NOT EXISTS → flag is true.
388        let mut ast = Parser::parse_sql(
389            &dialect,
390            "CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY);",
391        )
392        .unwrap();
393        let q = ast.pop().unwrap();
394        assert!(CreateQuery::new(&q).unwrap().if_not_exists);
395    }
396}