use sqlparser::ast::{
ColumnDef, ColumnOption, CreateTable, DataType, Expr, ObjectName, ObjectNamePart, Statement,
UnaryOperator, Value as AstValue,
};
use crate::error::{Result, SQLRiteError};
use crate::sql::db::table::Value;
fn is_vector_type(name: &ObjectName) -> bool {
name.0.len() == 1
&& match &name.0[0] {
ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
_ => false,
}
}
fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
match args {
[] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
[single] => {
let trimmed = single.trim();
match trimmed.parse::<usize>() {
Ok(d) if d > 0 => Ok(d),
Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
Err(_) => Err(format!(
"VECTOR dimension must be a positive integer (got `{trimmed}`)"
)),
}
}
many => Err(format!(
"VECTOR takes exactly one dimension argument (got {})",
many.len()
)),
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct ParsedColumn {
pub name: String,
pub datatype: String,
pub is_pk: bool,
pub not_null: bool,
pub is_unique: bool,
pub default: Option<Value>,
}
#[derive(Debug)]
pub struct CreateQuery {
pub table_name: String,
pub columns: Vec<ParsedColumn>,
}
pub fn parse_one_column(col: &ColumnDef) -> Result<ParsedColumn> {
let name = col.name.to_string();
let datatype: String = match &col.data_type {
DataType::TinyInt(_)
| DataType::SmallInt(_)
| DataType::Int2(_)
| DataType::Int(_)
| DataType::Int4(_)
| DataType::Int8(_)
| DataType::Integer(_)
| DataType::BigInt(_) => "Integer".to_string(),
DataType::Boolean => "Bool".to_string(),
DataType::Text => "Text".to_string(),
DataType::Varchar(_bytes) => "Text".to_string(),
DataType::Real => "Real".to_string(),
DataType::Float(_precision) => "Real".to_string(),
DataType::Double(_) => "Real".to_string(),
DataType::Decimal(_) => "Real".to_string(),
DataType::JSON | DataType::JSONB => "Json".to_string(),
DataType::Custom(name, args) if is_vector_type(name) => match parse_vector_dim(args) {
Ok(dim) => format!("vector({dim})"),
Err(e) => {
return Err(SQLRiteError::General(format!(
"Invalid VECTOR column '{}': {e}",
col.name
)));
}
},
other => {
eprintln!("not matched on custom type: {other:?}");
"Invalid".to_string()
}
};
let mut is_pk: bool = false;
let mut is_unique: bool = false;
let mut not_null: bool = false;
let mut default: Option<Value> = None;
for column_option in &col.options {
match &column_option.option {
ColumnOption::PrimaryKey(_) => {
if datatype != "Real" && datatype != "Bool" {
is_pk = true;
is_unique = true;
not_null = true;
}
}
ColumnOption::Unique(_) => {
if datatype != "Real" && datatype != "Bool" {
is_unique = true;
}
}
ColumnOption::NotNull => {
not_null = true;
}
ColumnOption::Default(expr) => {
default = Some(eval_literal_default(expr, &datatype, &name)?);
}
_ => (),
};
}
Ok(ParsedColumn {
name,
datatype,
is_pk,
not_null,
is_unique,
default,
})
}
fn eval_literal_default(expr: &Expr, datatype: &str, col_name: &str) -> Result<Value> {
let value = match expr {
Expr::Value(v) => &v.value,
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: inner,
} => {
return match inner.as_ref() {
Expr::Value(v) => match &v.value {
AstValue::Number(n, _) => {
let neg = format!("-{n}");
coerce_number_default(&neg, datatype, col_name)
}
_ => Err(SQLRiteError::General(format!(
"DEFAULT for column '{col_name}' must be a literal value"
))),
},
_ => Err(SQLRiteError::General(format!(
"DEFAULT for column '{col_name}' must be a literal value"
))),
};
}
Expr::UnaryOp {
op: UnaryOperator::Plus,
expr: inner,
} => {
return eval_literal_default(inner, datatype, col_name);
}
_ => {
return Err(SQLRiteError::General(format!(
"DEFAULT for column '{col_name}' must be a literal value"
)));
}
};
match value {
AstValue::Null => Ok(Value::Null),
AstValue::Boolean(b) => {
if datatype == "Bool" {
Ok(Value::Bool(*b))
} else {
Err(SQLRiteError::General(format!(
"DEFAULT type mismatch for column '{col_name}': boolean is not a {datatype}"
)))
}
}
AstValue::SingleQuotedString(s) => {
if datatype == "Text" {
Ok(Value::Text(s.clone()))
} else if datatype == "Json" {
serde_json::from_str::<serde_json::Value>(s).map_err(|e| {
SQLRiteError::General(format!(
"DEFAULT type mismatch for column '{col_name}': '{s}' is not valid JSON: {e}"
))
})?;
Ok(Value::Text(s.clone()))
} else {
Err(SQLRiteError::General(format!(
"DEFAULT type mismatch for column '{col_name}': text is not a {datatype}"
)))
}
}
AstValue::Number(n, _) => coerce_number_default(n, datatype, col_name),
_ => Err(SQLRiteError::General(format!(
"DEFAULT for column '{col_name}' must be a literal value"
))),
}
}
fn coerce_number_default(n: &str, datatype: &str, col_name: &str) -> Result<Value> {
match datatype {
"Integer" => n.parse::<i64>().map(Value::Integer).map_err(|_| {
SQLRiteError::General(format!(
"DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid INTEGER"
))
}),
"Real" => n.parse::<f64>().map(Value::Real).map_err(|_| {
SQLRiteError::General(format!(
"DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid REAL"
))
}),
other => Err(SQLRiteError::General(format!(
"DEFAULT type mismatch for column '{col_name}': numeric literal is not a {other}"
))),
}
}
impl CreateQuery {
pub fn new(statement: &Statement) -> Result<CreateQuery> {
match statement {
Statement::CreateTable(CreateTable {
name,
columns,
constraints,
..
}) => {
let table_name = name;
let mut parsed_columns: Vec<ParsedColumn> = vec![];
for col in columns {
let name = col.name.to_string();
if parsed_columns.iter().any(|c| c.name == name) {
return Err(SQLRiteError::Internal(format!(
"Duplicate column name: {}",
&name
)));
}
let parsed = parse_one_column(col)?;
if parsed.is_pk && parsed_columns.iter().any(|c| c.is_pk) {
return Err(SQLRiteError::Internal(format!(
"Table '{}' has more than one primary key",
&table_name
)));
}
parsed_columns.push(parsed);
}
let _ = constraints;
Ok(CreateQuery {
table_name: table_name.to_string(),
columns: parsed_columns,
})
}
_ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::*;
#[test]
fn create_table_validate_tablename_test() {
let sql_input = String::from(
"CREATE TABLE contacts (
id INTEGER PRIMARY KEY,
first_name TEXT NOT NULL,
last_name TEXT NOT NULl,
email TEXT NOT NULL UNIQUE
);",
);
let expected_table_name = String::from("contacts");
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
assert!(ast.len() == 1, "ast has more then one Statement");
let query = ast.pop().unwrap();
if let Statement::CreateTable(_) = query {
let result = CreateQuery::new(&query);
match result {
Ok(payload) => {
assert_eq!(payload.table_name, expected_table_name);
}
Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
}
}
}
}