use sqlparser::ast::{ColumnOption, CreateTable, DataType, ObjectName, ObjectNamePart, Statement};
use crate::error::{Result, SQLRiteError};
fn is_vector_type(name: &ObjectName) -> bool {
name.0.len() == 1
&& match &name.0[0] {
ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
_ => false,
}
}
fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
match args {
[] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
[single] => {
let trimmed = single.trim();
match trimmed.parse::<usize>() {
Ok(d) if d > 0 => Ok(d),
Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
Err(_) => Err(format!(
"VECTOR dimension must be a positive integer (got `{trimmed}`)"
)),
}
}
many => Err(format!(
"VECTOR takes exactly one dimension argument (got {})",
many.len()
)),
}
}
#[derive(PartialEq, Debug)]
pub struct ParsedColumn {
pub name: String,
pub datatype: String,
pub is_pk: bool,
pub not_null: bool,
pub is_unique: bool,
}
#[derive(Debug)]
pub struct CreateQuery {
pub table_name: String,
pub columns: Vec<ParsedColumn>,
}
impl CreateQuery {
pub fn new(statement: &Statement) -> Result<CreateQuery> {
match statement {
Statement::CreateTable(CreateTable {
name,
columns,
constraints,
..
}) => {
let table_name = name;
let mut parsed_columns: Vec<ParsedColumn> = vec![];
for col in columns {
let name = col.name.to_string();
if parsed_columns.iter().any(|col| col.name == name) {
return Err(SQLRiteError::Internal(format!(
"Duplicate column name: {}",
&name
)));
}
let datatype: String = match &col.data_type {
DataType::TinyInt(_)
| DataType::SmallInt(_)
| DataType::Int2(_)
| DataType::Int(_)
| DataType::Int4(_)
| DataType::Int8(_)
| DataType::Integer(_)
| DataType::BigInt(_) => "Integer".to_string(),
DataType::Boolean => "Bool".to_string(),
DataType::Text => "Text".to_string(),
DataType::Varchar(_bytes) => "Text".to_string(),
DataType::Real => "Real".to_string(),
DataType::Float(_precision) => "Real".to_string(),
DataType::Double(_) => "Real".to_string(),
DataType::Decimal(_) => "Real".to_string(),
DataType::Custom(name, args) if is_vector_type(name) => {
match parse_vector_dim(args) {
Ok(dim) => format!("vector({dim})"),
Err(e) => {
return Err(SQLRiteError::General(format!(
"Invalid VECTOR column '{}': {e}",
col.name
)));
}
}
}
other => {
eprintln!("not matched on custom type: {other:?}");
"Invalid".to_string()
}
};
let mut is_pk: bool = false;
let mut is_unique: bool = false;
let mut not_null: bool = false;
for column_option in &col.options {
match &column_option.option {
ColumnOption::PrimaryKey(_) => {
if datatype != "Real" && datatype != "Bool" {
if parsed_columns.iter().any(|col| col.is_pk) {
return Err(SQLRiteError::Internal(format!(
"Table '{}' has more than one primary key",
&table_name
)));
}
is_pk = true;
is_unique = true;
not_null = true;
}
}
ColumnOption::Unique(_) => {
if datatype != "Real" && datatype != "Bool" {
is_unique = true;
}
}
ColumnOption::NotNull => {
not_null = true;
}
_ => (),
};
}
parsed_columns.push(ParsedColumn {
name,
datatype: datatype.to_string(),
is_pk,
not_null,
is_unique,
});
}
for constraint in constraints {
println!("{constraint:?}");
}
Ok(CreateQuery {
table_name: table_name.to_string(),
columns: parsed_columns,
})
}
_ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::*;
#[test]
fn create_table_validate_tablename_test() {
let sql_input = String::from(
"CREATE TABLE contacts (
id INTEGER PRIMARY KEY,
first_name TEXT NOT NULL,
last_name TEXT NOT NULl,
email TEXT NOT NULL UNIQUE
);",
);
let expected_table_name = String::from("contacts");
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
assert!(ast.len() == 1, "ast has more then one Statement");
let query = ast.pop().unwrap();
if let Statement::CreateTable(_) = query {
let result = CreateQuery::new(&query);
match result {
Ok(payload) => {
assert_eq!(payload.table_name, expected_table_name);
}
Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
}
}
}
}