use crate::error::{Result, SQLRiteError};
use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
use crate::sql::parser::create::CreateQuery;
use std::collections::{BTreeMap, HashMap};
use std::fmt;
use std::sync::{Arc, Mutex};
use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
#[derive(PartialEq, Debug, Clone)]
pub enum DataType {
Integer,
Text,
Real,
Bool,
Vector(usize),
None,
Invalid,
}
impl DataType {
pub fn new(cmd: String) -> DataType {
let lower = cmd.to_lowercase();
match lower.as_str() {
"integer" => DataType::Integer,
"text" => DataType::Text,
"real" => DataType::Real,
"bool" => DataType::Bool,
"none" => DataType::None,
other if other.starts_with("vector(") && other.ends_with(')') => {
let inside = &other["vector(".len()..other.len() - 1];
match inside.trim().parse::<usize>() {
Ok(dim) if dim > 0 => DataType::Vector(dim),
_ => {
eprintln!("Invalid VECTOR dimension in {cmd}");
DataType::Invalid
}
}
}
_ => {
eprintln!("Invalid data type given {}", cmd);
DataType::Invalid
}
}
}
pub fn to_wire_string(&self) -> String {
match self {
DataType::Integer => "Integer".to_string(),
DataType::Text => "Text".to_string(),
DataType::Real => "Real".to_string(),
DataType::Bool => "Bool".to_string(),
DataType::Vector(dim) => format!("vector({dim})"),
DataType::None => "None".to_string(),
DataType::Invalid => "Invalid".to_string(),
}
}
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DataType::Integer => f.write_str("Integer"),
DataType::Text => f.write_str("Text"),
DataType::Real => f.write_str("Real"),
DataType::Bool => f.write_str("Boolean"),
DataType::Vector(dim) => write!(f, "Vector({dim})"),
DataType::None => f.write_str("None"),
DataType::Invalid => f.write_str("Invalid"),
}
}
}
#[derive(Debug)]
pub struct Table {
pub tb_name: String,
pub columns: Vec<Column>,
pub rows: Arc<Mutex<HashMap<String, Row>>>,
pub secondary_indexes: Vec<SecondaryIndex>,
pub last_rowid: i64,
pub primary_key: String,
}
impl Table {
pub fn new(create_query: CreateQuery) -> Self {
let table_name = create_query.table_name;
let mut primary_key: String = String::from("-1");
let columns = create_query.columns;
let mut table_cols: Vec<Column> = vec![];
let table_rows: Arc<Mutex<HashMap<String, Row>>> = Arc::new(Mutex::new(HashMap::new()));
let mut secondary_indexes: Vec<SecondaryIndex> = Vec::new();
for col in &columns {
let col_name = &col.name;
if col.is_pk {
primary_key = col_name.to_string();
}
table_cols.push(Column::new(
col_name.to_string(),
col.datatype.to_string(),
col.is_pk,
col.not_null,
col.is_unique,
));
let dt = DataType::new(col.datatype.to_string());
let row_storage = match &dt {
DataType::Integer => Row::Integer(BTreeMap::new()),
DataType::Real => Row::Real(BTreeMap::new()),
DataType::Text => Row::Text(BTreeMap::new()),
DataType::Bool => Row::Bool(BTreeMap::new()),
DataType::Vector(_dim) => Row::Vector(BTreeMap::new()),
DataType::Invalid | DataType::None => Row::None,
};
table_rows
.lock()
.expect("Table row storage mutex poisoned")
.insert(col.name.to_string(), row_storage);
if (col.is_pk || col.is_unique) && matches!(dt, DataType::Integer | DataType::Text) {
let name = SecondaryIndex::auto_name(&table_name, &col.name);
match SecondaryIndex::new(
name,
table_name.clone(),
col.name.clone(),
&dt,
true,
IndexOrigin::Auto,
) {
Ok(idx) => secondary_indexes.push(idx),
Err(_) => {
}
}
}
}
Table {
tb_name: table_name,
columns: table_cols,
rows: table_rows,
secondary_indexes,
last_rowid: 0,
primary_key,
}
}
pub fn deep_clone(&self) -> Self {
let cloned_rows: HashMap<String, Row> = {
let guard = self.rows.lock().expect("row mutex poisoned");
guard.clone()
};
Table {
tb_name: self.tb_name.clone(),
columns: self.columns.clone(),
rows: Arc::new(Mutex::new(cloned_rows)),
secondary_indexes: self.secondary_indexes.clone(),
last_rowid: self.last_rowid,
primary_key: self.primary_key.clone(),
}
}
pub fn index_for_column(&self, column: &str) -> Option<&SecondaryIndex> {
self.secondary_indexes
.iter()
.find(|i| i.column_name == column)
}
fn index_for_column_mut(&mut self, column: &str) -> Option<&mut SecondaryIndex> {
self.secondary_indexes
.iter_mut()
.find(|i| i.column_name == column)
}
#[allow(dead_code)]
pub fn index_by_name(&self, name: &str) -> Option<&SecondaryIndex> {
self.secondary_indexes.iter().find(|i| i.name == name)
}
pub fn contains_column(&self, column: String) -> bool {
self.columns.iter().any(|col| col.column_name == column)
}
pub fn column_names(&self) -> Vec<String> {
self.columns.iter().map(|c| c.column_name.clone()).collect()
}
pub fn rowids(&self) -> Vec<i64> {
let Some(first) = self.columns.first() else {
return vec![];
};
let rows = self.rows.lock().expect("rows mutex poisoned");
rows.get(&first.column_name)
.map(|r| r.rowids())
.unwrap_or_default()
}
pub fn get_value(&self, column: &str, rowid: i64) -> Option<Value> {
let rows = self.rows.lock().expect("rows mutex poisoned");
rows.get(column).and_then(|r| r.get(rowid))
}
pub fn delete_row(&mut self, rowid: i64) {
let per_column_values: Vec<(String, Option<Value>)> = self
.columns
.iter()
.map(|c| (c.column_name.clone(), self.get_value(&c.column_name, rowid)))
.collect();
{
let rows_clone = Arc::clone(&self.rows);
let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
for col in &self.columns {
if let Some(r) = row_data.get_mut(&col.column_name) {
match r {
Row::Integer(m) => {
m.remove(&rowid);
}
Row::Text(m) => {
m.remove(&rowid);
}
Row::Real(m) => {
m.remove(&rowid);
}
Row::Bool(m) => {
m.remove(&rowid);
}
Row::Vector(m) => {
m.remove(&rowid);
}
Row::None => {}
}
}
}
}
for (col_name, value) in per_column_values {
if let Some(idx) = self.index_for_column_mut(&col_name) {
if let Some(v) = value {
idx.remove(&v, rowid);
}
}
}
}
pub fn restore_row(&mut self, rowid: i64, values: Vec<Option<Value>>) -> Result<()> {
if values.len() != self.columns.len() {
return Err(SQLRiteError::Internal(format!(
"cell has {} values but table '{}' has {} columns",
values.len(),
self.tb_name,
self.columns.len()
)));
}
let column_names: Vec<String> =
self.columns.iter().map(|c| c.column_name.clone()).collect();
for (i, value) in values.into_iter().enumerate() {
let col_name = &column_names[i];
{
let rows_clone = Arc::clone(&self.rows);
let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
let cell = row_data.get_mut(col_name).ok_or_else(|| {
SQLRiteError::Internal(format!("Row storage missing for column '{col_name}'"))
})?;
match (cell, &value) {
(Row::Integer(map), Some(Value::Integer(v))) => {
map.insert(rowid, *v as i32);
}
(Row::Integer(_), None) => {
return Err(SQLRiteError::Internal(format!(
"Integer column '{col_name}' cannot store NULL — corrupt cell?"
)));
}
(Row::Text(map), Some(Value::Text(s))) => {
map.insert(rowid, s.clone());
}
(Row::Text(map), None) => {
map.insert(rowid, "Null".to_string());
}
(Row::Real(map), Some(Value::Real(v))) => {
map.insert(rowid, *v as f32);
}
(Row::Real(_), None) => {
return Err(SQLRiteError::Internal(format!(
"Real column '{col_name}' cannot store NULL — corrupt cell?"
)));
}
(Row::Bool(map), Some(Value::Bool(v))) => {
map.insert(rowid, *v);
}
(Row::Bool(_), None) => {
return Err(SQLRiteError::Internal(format!(
"Bool column '{col_name}' cannot store NULL — corrupt cell?"
)));
}
(Row::Vector(map), Some(Value::Vector(v))) => {
map.insert(rowid, v.clone());
}
(Row::Vector(_), None) => {
return Err(SQLRiteError::Internal(format!(
"Vector column '{col_name}' cannot store NULL — corrupt cell?"
)));
}
(row, v) => {
return Err(SQLRiteError::Internal(format!(
"Type mismatch restoring column '{col_name}': storage {row:?} vs value {v:?}"
)));
}
}
}
if let Some(v) = &value {
if let Some(idx) = self.index_for_column_mut(col_name) {
idx.insert(v, rowid)?;
}
}
}
if rowid > self.last_rowid {
self.last_rowid = rowid;
}
Ok(())
}
pub fn extract_row(&self, rowid: i64) -> Vec<Option<Value>> {
self.columns
.iter()
.map(|c| match self.get_value(&c.column_name, rowid) {
Some(Value::Null) => None,
Some(v) => Some(v),
None => None,
})
.collect()
}
pub fn set_value(&mut self, column: &str, rowid: i64, new_val: Value) -> Result<()> {
let col_index = self
.columns
.iter()
.position(|c| c.column_name == column)
.ok_or_else(|| SQLRiteError::General(format!("Column '{column}' not found")))?;
let current = self.get_value(column, rowid);
if current.as_ref() == Some(&new_val) {
return Ok(());
}
if self.columns[col_index].is_unique && !matches!(new_val, Value::Null) {
if let Some(idx) = self.index_for_column(column) {
for other in idx.lookup(&new_val) {
if other != rowid {
return Err(SQLRiteError::General(format!(
"UNIQUE constraint violated for column '{column}'"
)));
}
}
} else {
for other in self.rowids() {
if other == rowid {
continue;
}
if self.get_value(column, other).as_ref() == Some(&new_val) {
return Err(SQLRiteError::General(format!(
"UNIQUE constraint violated for column '{column}'"
)));
}
}
}
}
if let Some(old) = current {
if let Some(idx) = self.index_for_column_mut(column) {
idx.remove(&old, rowid);
}
}
let declared = &self.columns[col_index].datatype;
{
let rows_clone = Arc::clone(&self.rows);
let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
let cell = row_data.get_mut(column).ok_or_else(|| {
SQLRiteError::Internal(format!("Row storage missing for column '{column}'"))
})?;
match (cell, &new_val, declared) {
(Row::Integer(m), Value::Integer(v), _) => {
m.insert(rowid, *v as i32);
}
(Row::Real(m), Value::Real(v), _) => {
m.insert(rowid, *v as f32);
}
(Row::Real(m), Value::Integer(v), _) => {
m.insert(rowid, *v as f32);
}
(Row::Text(m), Value::Text(v), _) => {
m.insert(rowid, v.clone());
}
(Row::Bool(m), Value::Bool(v), _) => {
m.insert(rowid, *v);
}
(Row::Vector(m), Value::Vector(v), DataType::Vector(declared_dim)) => {
if v.len() != *declared_dim {
return Err(SQLRiteError::General(format!(
"Vector dimension mismatch for column '{column}': declared {declared_dim}, got {}",
v.len()
)));
}
m.insert(rowid, v.clone());
}
(Row::Text(m), Value::Null, _) => {
m.insert(rowid, "Null".to_string());
}
(_, new, dt) => {
return Err(SQLRiteError::General(format!(
"Type mismatch: cannot assign {} to column '{column}' of type {dt}",
new.to_display_string()
)));
}
}
}
if !matches!(new_val, Value::Null) {
if let Some(idx) = self.index_for_column_mut(column) {
idx.insert(&new_val, rowid)?;
}
}
Ok(())
}
#[allow(dead_code)]
pub fn get_column(&mut self, column_name: String) -> Result<&Column> {
if let Some(column) = self
.columns
.iter()
.filter(|c| c.column_name == column_name)
.collect::<Vec<&Column>>()
.first()
{
Ok(column)
} else {
Err(SQLRiteError::General(String::from("Column not found.")))
}
}
pub fn validate_unique_constraint(
&mut self,
cols: &Vec<String>,
values: &Vec<String>,
) -> Result<()> {
for (idx, name) in cols.iter().enumerate() {
let column = self
.columns
.iter()
.find(|c| &c.column_name == name)
.ok_or_else(|| SQLRiteError::General(format!("Column '{name}' not found")))?;
if !column.is_unique {
continue;
}
let datatype = &column.datatype;
let val = &values[idx];
let parsed = match datatype {
DataType::Integer => val.parse::<i64>().map(Value::Integer).map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: expected INTEGER for column '{name}', got '{val}'"
))
})?,
DataType::Text => Value::Text(val.clone()),
DataType::Real => val.parse::<f64>().map(Value::Real).map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: expected REAL for column '{name}', got '{val}'"
))
})?,
DataType::Bool => val.parse::<bool>().map(Value::Bool).map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: expected BOOL for column '{name}', got '{val}'"
))
})?,
DataType::Vector(declared_dim) => {
let parsed_vec = parse_vector_literal(val).map_err(|e| {
SQLRiteError::General(format!(
"Type mismatch: expected VECTOR({declared_dim}) for column '{name}', {e}"
))
})?;
if parsed_vec.len() != *declared_dim {
return Err(SQLRiteError::General(format!(
"Vector dimension mismatch for column '{name}': declared {declared_dim}, got {}",
parsed_vec.len()
)));
}
Value::Vector(parsed_vec)
}
DataType::None | DataType::Invalid => {
return Err(SQLRiteError::Internal(format!(
"column '{name}' has an unsupported datatype"
)));
}
};
if let Some(secondary) = self.index_for_column(name) {
if secondary.would_violate_unique(&parsed) {
return Err(SQLRiteError::General(format!(
"UNIQUE constraint violated for column '{name}': value '{val}' already exists"
)));
}
} else {
for other in self.rowids() {
if self.get_value(name, other).as_ref() == Some(&parsed) {
return Err(SQLRiteError::General(format!(
"UNIQUE constraint violated for column '{name}': value '{val}' already exists"
)));
}
}
}
}
Ok(())
}
pub fn insert_row(&mut self, cols: &Vec<String>, values: &Vec<String>) -> Result<()> {
let mut next_rowid = self.last_rowid + 1;
if self.primary_key != "-1" {
if !cols.iter().any(|col| col == &self.primary_key) {
let val = next_rowid as i32;
let wrote_integer = {
let rows_clone = Arc::clone(&self.rows);
let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
let table_col_data = row_data.get_mut(&self.primary_key).ok_or_else(|| {
SQLRiteError::Internal(format!(
"Row storage missing for primary key column '{}'",
self.primary_key
))
})?;
match table_col_data {
Row::Integer(tree) => {
tree.insert(next_rowid, val);
true
}
_ => false, }
};
if wrote_integer {
let pk = self.primary_key.clone();
if let Some(idx) = self.index_for_column_mut(&pk) {
idx.insert(&Value::Integer(val as i64), next_rowid)?;
}
}
} else {
for i in 0..cols.len() {
if cols[i] == self.primary_key {
let val = &values[i];
next_rowid = val.parse::<i64>().map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: PRIMARY KEY column '{}' expects INTEGER, got '{val}'",
self.primary_key
))
})?;
}
}
}
}
let column_names = self
.columns
.iter()
.map(|col| col.column_name.to_string())
.collect::<Vec<String>>();
let mut j: usize = 0;
for i in 0..column_names.len() {
let mut val = String::from("Null");
let key = &column_names[i];
if let Some(supplied_key) = cols.get(j) {
if supplied_key == &column_names[i] {
val = values[j].to_string();
j += 1;
} else if self.primary_key == column_names[i] {
continue;
}
} else if self.primary_key == column_names[i] {
continue;
}
let typed_value: Option<Value> = {
let rows_clone = Arc::clone(&self.rows);
let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
let table_col_data = row_data.get_mut(key).ok_or_else(|| {
SQLRiteError::Internal(format!("Row storage missing for column '{key}'"))
})?;
match table_col_data {
Row::Integer(tree) => {
let parsed = val.parse::<i32>().map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: expected INTEGER for column '{key}', got '{val}'"
))
})?;
tree.insert(next_rowid, parsed);
Some(Value::Integer(parsed as i64))
}
Row::Text(tree) => {
tree.insert(next_rowid, val.to_string());
if val != "Null" {
Some(Value::Text(val.to_string()))
} else {
None
}
}
Row::Real(tree) => {
let parsed = val.parse::<f32>().map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: expected REAL for column '{key}', got '{val}'"
))
})?;
tree.insert(next_rowid, parsed);
Some(Value::Real(parsed as f64))
}
Row::Bool(tree) => {
let parsed = val.parse::<bool>().map_err(|_| {
SQLRiteError::General(format!(
"Type mismatch: expected BOOL for column '{key}', got '{val}'"
))
})?;
tree.insert(next_rowid, parsed);
Some(Value::Bool(parsed))
}
Row::Vector(tree) => {
let parsed = parse_vector_literal(&val).map_err(|e| {
SQLRiteError::General(format!(
"Type mismatch: expected VECTOR for column '{key}', {e}"
))
})?;
let declared_dim = match &self.columns[i].datatype {
DataType::Vector(d) => *d,
other => {
return Err(SQLRiteError::Internal(format!(
"Row::Vector storage on non-Vector column '{key}' (declared as {other})"
)));
}
};
if parsed.len() != declared_dim {
return Err(SQLRiteError::General(format!(
"Vector dimension mismatch for column '{key}': declared {declared_dim}, got {}",
parsed.len()
)));
}
tree.insert(next_rowid, parsed.clone());
Some(Value::Vector(parsed))
}
Row::None => {
return Err(SQLRiteError::Internal(format!(
"Column '{key}' has no row storage"
)));
}
}
};
if let Some(v) = typed_value {
if let Some(idx) = self.index_for_column_mut(key) {
idx.insert(&v, next_rowid)?;
}
}
}
self.last_rowid = next_rowid;
Ok(())
}
pub fn print_table_schema(&self) -> Result<usize> {
let mut table = PrintTable::new();
table.add_row(row![
"Column Name",
"Data Type",
"PRIMARY KEY",
"UNIQUE",
"NOT NULL"
]);
for col in &self.columns {
table.add_row(row![
col.column_name,
col.datatype,
col.is_pk,
col.is_unique,
col.not_null
]);
}
table.printstd();
Ok(table.len() * 2 + 1)
}
pub fn print_table_data(&self) {
let mut print_table = PrintTable::new();
let column_names = self
.columns
.iter()
.map(|col| col.column_name.to_string())
.collect::<Vec<String>>();
let header_row = PrintRow::new(
column_names
.iter()
.map(|col| PrintCell::new(col))
.collect::<Vec<PrintCell>>(),
);
let rows_clone = Arc::clone(&self.rows);
let row_data = rows_clone.lock().expect("rows mutex poisoned");
let first_col_data = row_data
.get(&self.columns.first().unwrap().column_name)
.unwrap();
let num_rows = first_col_data.count();
let mut print_table_rows: Vec<PrintRow> = vec![PrintRow::new(vec![]); num_rows];
for col_name in &column_names {
let col_val = row_data
.get(col_name)
.expect("Can't find any rows with the given column");
let columns: Vec<String> = col_val.get_serialized_col_data();
for i in 0..num_rows {
if let Some(cell) = &columns.get(i) {
print_table_rows[i].add_cell(PrintCell::new(cell));
} else {
print_table_rows[i].add_cell(PrintCell::new(""));
}
}
}
print_table.add_row(header_row);
for row in print_table_rows {
print_table.add_row(row);
}
print_table.printstd();
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct Column {
pub column_name: String,
pub datatype: DataType,
pub is_pk: bool,
pub not_null: bool,
pub is_unique: bool,
}
impl Column {
pub fn new(
name: String,
datatype: String,
is_pk: bool,
not_null: bool,
is_unique: bool,
) -> Self {
let dt = DataType::new(datatype);
Column {
column_name: name,
datatype: dt,
is_pk,
not_null,
is_unique,
}
}
}
#[derive(PartialEq, Debug, Clone)]
pub enum Row {
Integer(BTreeMap<i64, i32>),
Text(BTreeMap<i64, String>),
Real(BTreeMap<i64, f32>),
Bool(BTreeMap<i64, bool>),
Vector(BTreeMap<i64, Vec<f32>>),
None,
}
impl Row {
fn get_serialized_col_data(&self) -> Vec<String> {
match self {
Row::Integer(cd) => cd.values().map(|v| v.to_string()).collect(),
Row::Real(cd) => cd.values().map(|v| v.to_string()).collect(),
Row::Text(cd) => cd.values().map(|v| v.to_string()).collect(),
Row::Bool(cd) => cd.values().map(|v| v.to_string()).collect(),
Row::Vector(cd) => cd.values().map(format_vector_for_display).collect(),
Row::None => panic!("Found None in columns"),
}
}
fn count(&self) -> usize {
match self {
Row::Integer(cd) => cd.len(),
Row::Real(cd) => cd.len(),
Row::Text(cd) => cd.len(),
Row::Bool(cd) => cd.len(),
Row::Vector(cd) => cd.len(),
Row::None => panic!("Found None in columns"),
}
}
pub fn rowids(&self) -> Vec<i64> {
match self {
Row::Integer(m) => m.keys().copied().collect(),
Row::Text(m) => m.keys().copied().collect(),
Row::Real(m) => m.keys().copied().collect(),
Row::Bool(m) => m.keys().copied().collect(),
Row::Vector(m) => m.keys().copied().collect(),
Row::None => vec![],
}
}
pub fn get(&self, rowid: i64) -> Option<Value> {
match self {
Row::Integer(m) => m.get(&rowid).map(|v| Value::Integer(i64::from(*v))),
Row::Text(m) => m.get(&rowid).map(|v| {
if v == "Null" {
Value::Null
} else {
Value::Text(v.clone())
}
}),
Row::Real(m) => m.get(&rowid).map(|v| Value::Real(f64::from(*v))),
Row::Bool(m) => m.get(&rowid).map(|v| Value::Bool(*v)),
Row::Vector(m) => m.get(&rowid).map(|v| Value::Vector(v.clone())),
Row::None => None,
}
}
}
fn format_vector_for_display(v: &Vec<f32>) -> String {
let mut s = String::with_capacity(v.len() * 6 + 2);
s.push('[');
for (i, x) in v.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
s.push_str(&x.to_string());
}
s.push(']');
s
}
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
Integer(i64),
Text(String),
Real(f64),
Bool(bool),
Vector(Vec<f32>),
Null,
}
impl Value {
pub fn to_display_string(&self) -> String {
match self {
Value::Integer(v) => v.to_string(),
Value::Text(s) => s.clone(),
Value::Real(f) => f.to_string(),
Value::Bool(b) => b.to_string(),
Value::Vector(v) => format_vector_for_display(v),
Value::Null => String::from("NULL"),
}
}
}
pub fn parse_vector_literal(s: &str) -> Result<Vec<f32>> {
let trimmed = s.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return Err(SQLRiteError::General(format!(
"expected bracket-array literal `[...]`, got `{s}`"
)));
}
let inner = &trimmed[1..trimmed.len() - 1].trim();
if inner.is_empty() {
return Ok(Vec::new());
}
let mut out = Vec::new();
for (i, part) in inner.split(',').enumerate() {
let element = part.trim();
let parsed: f32 = element.parse().map_err(|_| {
SQLRiteError::General(format!("vector element {i} (`{element}`) is not a number"))
})?;
out.push(parsed);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
#[test]
fn datatype_display_trait_test() {
let integer = DataType::Integer;
let text = DataType::Text;
let real = DataType::Real;
let boolean = DataType::Bool;
let vector = DataType::Vector(384);
let none = DataType::None;
let invalid = DataType::Invalid;
assert_eq!(format!("{}", integer), "Integer");
assert_eq!(format!("{}", text), "Text");
assert_eq!(format!("{}", real), "Real");
assert_eq!(format!("{}", boolean), "Boolean");
assert_eq!(format!("{}", vector), "Vector(384)");
assert_eq!(format!("{}", none), "None");
assert_eq!(format!("{}", invalid), "Invalid");
}
#[test]
fn datatype_new_parses_vector_dim() {
assert_eq!(DataType::new("vector(1)".to_string()), DataType::Vector(1));
assert_eq!(
DataType::new("vector(384)".to_string()),
DataType::Vector(384)
);
assert_eq!(
DataType::new("vector(1536)".to_string()),
DataType::Vector(1536)
);
assert_eq!(
DataType::new("VECTOR(384)".to_string()),
DataType::Vector(384)
);
assert_eq!(
DataType::new("vector( 64 )".to_string()),
DataType::Vector(64)
);
}
#[test]
fn datatype_new_rejects_bad_vector_strings() {
assert_eq!(DataType::new("vector(0)".to_string()), DataType::Invalid);
assert_eq!(DataType::new("vector(abc)".to_string()), DataType::Invalid);
assert_eq!(DataType::new("vector()".to_string()), DataType::Invalid);
assert_eq!(DataType::new("vector(-3)".to_string()), DataType::Invalid);
}
#[test]
fn datatype_to_wire_string_round_trips_vector() {
let dt = DataType::Vector(384);
let wire = dt.to_wire_string();
assert_eq!(wire, "vector(384)");
assert_eq!(DataType::new(wire), DataType::Vector(384));
}
#[test]
fn parse_vector_literal_accepts_floats() {
let v = parse_vector_literal("[0.1, 0.2, 0.3]").expect("parse");
assert_eq!(v, vec![0.1f32, 0.2, 0.3]);
}
#[test]
fn parse_vector_literal_accepts_ints_widening_to_f32() {
let v = parse_vector_literal("[1, 2, 3]").expect("parse");
assert_eq!(v, vec![1.0f32, 2.0, 3.0]);
}
#[test]
fn parse_vector_literal_handles_negatives_and_whitespace() {
let v = parse_vector_literal("[ -1.5 , 2.0, -3.5 ]").expect("parse");
assert_eq!(v, vec![-1.5f32, 2.0, -3.5]);
}
#[test]
fn parse_vector_literal_empty_brackets_is_empty_vec() {
let v = parse_vector_literal("[]").expect("parse");
assert!(v.is_empty());
}
#[test]
fn parse_vector_literal_rejects_non_bracketed() {
assert!(parse_vector_literal("0.1, 0.2").is_err());
assert!(parse_vector_literal("(0.1, 0.2)").is_err());
assert!(parse_vector_literal("[0.1, 0.2").is_err()); assert!(parse_vector_literal("0.1, 0.2]").is_err()); }
#[test]
fn parse_vector_literal_rejects_non_numeric_elements() {
let err = parse_vector_literal("[1.0, 'foo', 3.0]").unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("vector element 1") && msg.contains("'foo'"),
"error message should pinpoint the bad element: got `{msg}`"
);
}
#[test]
fn value_vector_display_format() {
let v = Value::Vector(vec![0.1, 0.2, 0.3]);
assert_eq!(v.to_display_string(), "[0.1, 0.2, 0.3]");
let empty = Value::Vector(vec![]);
assert_eq!(empty.to_display_string(), "[]");
}
#[test]
fn create_new_table_test() {
let query_statement = "CREATE TABLE contacts (
id INTEGER PRIMARY KEY,
first_name TEXT NOT NULL,
last_name TEXT NOT NULl,
email TEXT NOT NULL UNIQUE,
active BOOL,
score REAL
);";
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
if ast.len() > 1 {
panic!("Expected a single query statement, but there are more then 1.")
}
let query = ast.pop().unwrap();
let create_query = CreateQuery::new(&query).unwrap();
let table = Table::new(create_query);
assert_eq!(table.columns.len(), 6);
assert_eq!(table.last_rowid, 0);
let id_column = "id".to_string();
if let Some(column) = table
.columns
.iter()
.filter(|c| c.column_name == id_column)
.collect::<Vec<&Column>>()
.first()
{
assert!(column.is_pk);
assert_eq!(column.datatype, DataType::Integer);
} else {
panic!("column not found");
}
}
#[test]
fn print_table_schema_test() {
let query_statement = "CREATE TABLE contacts (
id INTEGER PRIMARY KEY,
first_name TEXT NOT NULL,
last_name TEXT NOT NULl
);";
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
if ast.len() > 1 {
panic!("Expected a single query statement, but there are more then 1.")
}
let query = ast.pop().unwrap();
let create_query = CreateQuery::new(&query).unwrap();
let table = Table::new(create_query);
let lines_printed = table.print_table_schema();
assert_eq!(lines_printed, Ok(9));
}
}