use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use vibesql_executor::{
CursorExecutor, CursorStore, FetchResult as CursorFetchResult, PreparedStatement,
PreparedStatementCache, PreparedStatementCacheStats,
};
use vibesql_types::SqlValue;
use crate::registry::SharedDatabase;
use crate::transaction::SessionTransactionManager;
pub struct Session {
#[allow(dead_code)]
pub database: String,
#[allow(dead_code)]
pub user: String,
db: SharedDatabase,
stmt_cache: Arc<PreparedStatementCache>,
named_statements: HashMap<String, Arc<PreparedStatement>>,
cursors: CursorStore,
txn_manager: SessionTransactionManager,
}
#[derive(Debug)]
pub enum ExecutionResult {
Select {
rows: Vec<Row>,
columns: Vec<Column>,
},
Insert {
rows_affected: usize,
},
Update {
rows_affected: usize,
},
Delete {
rows_affected: usize,
},
CreateTable,
CreateIndex,
CreateView,
DropTable,
DropIndex,
DropView,
Analyze {
tables_analyzed: usize,
},
Prepare {
statement_name: String,
},
Deallocate {
statement_name: String,
},
DeclareCursor {
cursor_name: String,
},
OpenCursor {
cursor_name: String,
},
Fetch {
rows: Vec<Row>,
columns: Vec<Column>,
},
CloseCursor {
cursor_name: String,
},
Begin,
Commit,
Rollback,
Other {
message: String,
},
}
impl ExecutionResult {
pub fn statement_type(&self) -> &str {
match self {
ExecutionResult::Select { .. } => "SELECT",
ExecutionResult::Insert { .. } => "INSERT",
ExecutionResult::Update { .. } => "UPDATE",
ExecutionResult::Delete { .. } => "DELETE",
ExecutionResult::CreateTable => "CREATE_TABLE",
ExecutionResult::CreateIndex => "CREATE_INDEX",
ExecutionResult::CreateView => "CREATE_VIEW",
ExecutionResult::DropTable => "DROP_TABLE",
ExecutionResult::DropIndex => "DROP_INDEX",
ExecutionResult::DropView => "DROP_VIEW",
ExecutionResult::Analyze { .. } => "ANALYZE",
ExecutionResult::Prepare { .. } => "PREPARE",
ExecutionResult::Deallocate { .. } => "DEALLOCATE",
ExecutionResult::DeclareCursor { .. } => "DECLARE_CURSOR",
ExecutionResult::OpenCursor { .. } => "OPEN_CURSOR",
ExecutionResult::Fetch { .. } => "FETCH",
ExecutionResult::CloseCursor { .. } => "CLOSE_CURSOR",
ExecutionResult::Begin => "BEGIN",
ExecutionResult::Commit => "COMMIT",
ExecutionResult::Rollback => "ROLLBACK",
ExecutionResult::Other { .. } => "OTHER",
}
}
pub fn rows_affected(&self) -> u64 {
match self {
ExecutionResult::Select { rows, .. } => rows.len() as u64,
ExecutionResult::Insert { rows_affected } => *rows_affected as u64,
ExecutionResult::Update { rows_affected } => *rows_affected as u64,
ExecutionResult::Delete { rows_affected } => *rows_affected as u64,
ExecutionResult::Fetch { rows, .. } => rows.len() as u64,
_ => 0,
}
}
}
#[derive(Debug, Clone)]
pub struct Column {
pub name: String,
}
#[derive(Debug, Clone)]
pub struct Row {
pub values: Vec<vibesql_types::SqlValue>,
}
impl Session {
pub fn new(database: String, user: String, db: SharedDatabase) -> Self {
Self {
database,
user,
db,
stmt_cache: Arc::new(PreparedStatementCache::default_cache()),
named_statements: HashMap::new(),
cursors: CursorStore::new(),
txn_manager: SessionTransactionManager::new(),
}
}
pub fn new_standalone(database: String, user: String) -> Self {
let db = Arc::new(tokio::sync::RwLock::new(vibesql_storage::Database::new()));
Self::new(database, user, db)
}
pub fn in_transaction(&self) -> bool {
self.txn_manager.in_transaction()
}
pub fn shared_database(&self) -> &SharedDatabase {
&self.db
}
#[allow(dead_code)]
pub fn with_cache(
database: String,
user: String,
db: SharedDatabase,
cache: Arc<PreparedStatementCache>,
) -> Self {
Self {
database,
user,
db,
stmt_cache: cache,
named_statements: HashMap::new(),
cursors: CursorStore::new(),
txn_manager: SessionTransactionManager::new(),
}
}
#[allow(dead_code)]
pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>> {
self.stmt_cache.get_or_prepare(sql).map_err(|e| anyhow::anyhow!("{}", e))
}
#[allow(dead_code)]
pub async fn execute_prepared(
&mut self,
stmt: &PreparedStatement,
params: &[SqlValue],
) -> Result<ExecutionResult> {
let bound_stmt = stmt.bind(params).map_err(|e| anyhow::anyhow!("{}", e))?;
self.execute_statement(&bound_stmt).await
}
pub async fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
let prepared = self.stmt_cache.get_or_prepare(sql).map_err(|e| anyhow::anyhow!("{}", e))?;
self.execute_statement(prepared.statement()).await
}
#[allow(dead_code)]
pub async fn execute_with_params(
&mut self,
sql: &str,
params: &[SqlValue],
) -> Result<ExecutionResult> {
let prepared = self.prepare(sql)?;
self.execute_prepared(&prepared, params).await
}
async fn execute_statement(
&mut self,
statement: &vibesql_ast::Statement,
) -> Result<ExecutionResult> {
let mut db = self.db.write().await;
match statement {
vibesql_ast::Statement::Select(select_stmt) => {
let executor = vibesql_executor::SelectExecutor::new(&db);
let rows = executor.execute(select_stmt)?;
let result_rows: Vec<Row> =
rows.iter().map(|r| Row { values: r.values.to_vec() }).collect();
let columns = if !rows.is_empty() {
(0..rows[0].values.len())
.map(|i| Column { name: format!("col{}", i) })
.collect()
} else {
Vec::new()
};
Ok(ExecutionResult::Select { rows: result_rows, columns })
}
vibesql_ast::Statement::Insert(insert_stmt) => {
let affected =
vibesql_executor::InsertExecutor::execute(&mut db, insert_stmt)?;
self.stmt_cache.invalidate_table(&insert_stmt.table_name);
Ok(ExecutionResult::Insert { rows_affected: affected })
}
vibesql_ast::Statement::Update(update_stmt) => {
let affected =
vibesql_executor::UpdateExecutor::execute(update_stmt, &mut db)?;
self.stmt_cache.invalidate_table(&update_stmt.table_name);
Ok(ExecutionResult::Update { rows_affected: affected })
}
vibesql_ast::Statement::Delete(delete_stmt) => {
let affected =
vibesql_executor::DeleteExecutor::execute(delete_stmt, &mut db)?;
self.stmt_cache.invalidate_table(&delete_stmt.table_name);
Ok(ExecutionResult::Delete { rows_affected: affected })
}
vibesql_ast::Statement::CreateTable(create_stmt) => {
vibesql_executor::CreateTableExecutor::execute(create_stmt, &mut db)?;
Ok(ExecutionResult::CreateTable)
}
vibesql_ast::Statement::CreateIndex(index_stmt) => {
vibesql_executor::CreateIndexExecutor::execute(index_stmt, &mut db)?;
Ok(ExecutionResult::CreateIndex)
}
vibesql_ast::Statement::CreateView(view_stmt) => {
vibesql_executor::advanced_objects::execute_create_view(view_stmt, &mut db)?;
Ok(ExecutionResult::CreateView)
}
vibesql_ast::Statement::DropTable(drop_stmt) => {
vibesql_executor::DropTableExecutor::execute(drop_stmt, &mut db)?;
self.stmt_cache.invalidate_table(&drop_stmt.table_name);
Ok(ExecutionResult::DropTable)
}
vibesql_ast::Statement::DropIndex(drop_stmt) => {
vibesql_executor::DropIndexExecutor::execute(drop_stmt, &mut db)?;
Ok(ExecutionResult::DropIndex)
}
vibesql_ast::Statement::DropView(drop_stmt) => {
vibesql_executor::advanced_objects::execute_drop_view(drop_stmt, &mut db)?;
Ok(ExecutionResult::DropView)
}
vibesql_ast::Statement::Analyze(analyze_stmt) => {
let message =
vibesql_executor::AnalyzeExecutor::execute(analyze_stmt, &mut db)?;
let tables_analyzed =
if analyze_stmt.table_name.is_some() { 1 } else { db.list_tables().len() };
let _ = message; Ok(ExecutionResult::Analyze { tables_analyzed })
}
vibesql_ast::Statement::Prepare(prepare_stmt) => {
drop(db);
self.execute_prepare(prepare_stmt)
}
vibesql_ast::Statement::Execute(execute_stmt) => {
drop(db);
self.execute_execute(execute_stmt).await
}
vibesql_ast::Statement::Deallocate(deallocate_stmt) => {
drop(db);
self.execute_deallocate(deallocate_stmt)
}
vibesql_ast::Statement::DeclareCursor(declare_stmt) => {
drop(db);
self.execute_declare_cursor(declare_stmt)
}
vibesql_ast::Statement::OpenCursor(open_stmt) => {
CursorExecutor::open(&mut self.cursors, open_stmt, &db)
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ExecutionResult::OpenCursor { cursor_name: open_stmt.cursor_name.clone() })
}
vibesql_ast::Statement::Fetch(fetch_stmt) => {
drop(db);
self.execute_fetch(fetch_stmt)
}
vibesql_ast::Statement::CloseCursor(close_stmt) => {
drop(db);
self.execute_close_cursor(close_stmt)
}
vibesql_ast::Statement::BeginTransaction(_) => {
drop(db);
self.begin_transaction().await
}
vibesql_ast::Statement::Commit(_) => {
drop(db);
self.commit().await
}
vibesql_ast::Statement::Rollback(_) => {
drop(db);
self.rollback().await
}
vibesql_ast::Statement::RollbackToSavepoint(_savepoint_stmt) => {
Ok(ExecutionResult::Other { message: "ROLLBACK TO SAVEPOINT".to_string() })
}
vibesql_ast::Statement::Savepoint(_savepoint_stmt) => {
Ok(ExecutionResult::Other { message: "SAVEPOINT".to_string() })
}
vibesql_ast::Statement::ReleaseSavepoint(_release_stmt) => {
Ok(ExecutionResult::Other { message: "RELEASE SAVEPOINT".to_string() })
}
_ => {
Ok(ExecutionResult::Other { message: "Command completed successfully".to_string() })
}
}
}
fn execute_prepare(
&mut self,
prepare_stmt: &vibesql_ast::PrepareStmt,
) -> Result<ExecutionResult> {
use vibesql_ast::PreparedStatementBody;
let name = prepare_stmt.name.clone();
let sql = match &prepare_stmt.statement {
PreparedStatementBody::SqlString(s) => s.clone(),
PreparedStatementBody::ParsedStatement(_stmt) => {
return Err(anyhow::anyhow!(
"PREPARE ... AS syntax not yet supported. Use PREPARE ... FROM 'sql_string' instead"
));
}
};
let prepared = self
.stmt_cache
.get_or_prepare(&sql)
.map_err(|e| anyhow::anyhow!("Failed to prepare statement: {}", e))?;
self.named_statements.insert(name.clone(), prepared);
Ok(ExecutionResult::Prepare { statement_name: name })
}
fn execute_execute(
&mut self,
execute_stmt: &vibesql_ast::ExecuteStmt,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ExecutionResult>> + Send + '_>>
{
let name = execute_stmt.name.clone();
let param_exprs = execute_stmt.params.clone();
Box::pin(async move {
let prepared = self
.named_statements
.get(&name)
.ok_or_else(|| anyhow::anyhow!("Prepared statement '{}' not found", name))?
.clone();
let params: Vec<SqlValue> =
param_exprs.iter().map(evaluate_expression).collect::<Result<Vec<_>>>()?;
self.execute_prepared(&prepared, ¶ms).await
})
}
fn execute_deallocate(
&mut self,
deallocate_stmt: &vibesql_ast::DeallocateStmt,
) -> Result<ExecutionResult> {
use vibesql_ast::DeallocateTarget;
match &deallocate_stmt.target {
DeallocateTarget::Name(name) => {
if self.named_statements.remove(name).is_none() {
return Err(anyhow::anyhow!("Prepared statement '{}' not found", name));
}
Ok(ExecutionResult::Deallocate { statement_name: name.clone() })
}
DeallocateTarget::All => {
let count = self.named_statements.len();
self.named_statements.clear();
Ok(ExecutionResult::Other {
message: format!("Deallocated {} prepared statement(s)", count),
})
}
}
}
#[allow(dead_code)]
pub fn cache_stats(&self) -> PreparedStatementCacheStats {
self.stmt_cache.stats()
}
#[allow(dead_code)]
pub fn clear_cache(&self) {
self.stmt_cache.clear();
}
pub async fn begin_transaction(&mut self) -> Result<ExecutionResult> {
self.txn_manager.begin().map_err(|e| anyhow::anyhow!("{}", e))?;
let mut db = self.db.write().await;
db.begin_transaction()
.map_err(|e| anyhow::anyhow!("Failed to begin transaction: {}", e))?;
Ok(ExecutionResult::Begin)
}
pub async fn commit(&mut self) -> Result<ExecutionResult> {
let _changes = self.txn_manager.commit().map_err(|e| anyhow::anyhow!("{}", e))?;
let mut db = self.db.write().await;
db.commit_transaction()
.map_err(|e| anyhow::anyhow!("Failed to commit transaction: {}", e))?;
Ok(ExecutionResult::Commit)
}
pub async fn rollback(&mut self) -> Result<ExecutionResult> {
self.txn_manager.rollback().map_err(|e| anyhow::anyhow!("{}", e))?;
let mut db = self.db.write().await;
db.rollback_transaction()
.map_err(|e| anyhow::anyhow!("Failed to rollback transaction: {}", e))?;
Ok(ExecutionResult::Rollback)
}
fn execute_declare_cursor(
&mut self,
stmt: &vibesql_ast::DeclareCursorStmt,
) -> Result<ExecutionResult> {
CursorExecutor::declare(&mut self.cursors, stmt).map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ExecutionResult::DeclareCursor { cursor_name: stmt.cursor_name.clone() })
}
fn execute_fetch(&mut self, stmt: &vibesql_ast::FetchStmt) -> Result<ExecutionResult> {
let fetch_result: CursorFetchResult =
CursorExecutor::fetch(&mut self.cursors, stmt).map_err(|e| anyhow::anyhow!("{}", e))?;
let rows: Vec<Row> =
fetch_result.rows.iter().map(|r| Row { values: r.values.to_vec() }).collect();
let columns: Vec<Column> =
fetch_result.columns.iter().map(|name| Column { name: name.clone() }).collect();
Ok(ExecutionResult::Fetch { rows, columns })
}
fn execute_close_cursor(
&mut self,
stmt: &vibesql_ast::CloseCursorStmt,
) -> Result<ExecutionResult> {
CursorExecutor::close(&mut self.cursors, stmt).map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ExecutionResult::CloseCursor { cursor_name: stmt.cursor_name.clone() })
}
}
fn evaluate_expression(expr: &vibesql_ast::Expression) -> Result<SqlValue> {
use vibesql_ast::Expression;
match expr {
Expression::Literal(val) => Ok(val.clone()),
Expression::UnaryOp { op, expr: operand } => {
if let vibesql_ast::UnaryOperator::Minus = op {
let val = evaluate_expression(operand)?;
match val {
SqlValue::Integer(n) => Ok(SqlValue::Integer(-n)),
SqlValue::Bigint(n) => Ok(SqlValue::Bigint(-n)),
SqlValue::Float(n) => Ok(SqlValue::Float(-n)),
SqlValue::Double(n) => Ok(SqlValue::Double(-n)),
SqlValue::Numeric(n) => Ok(SqlValue::Numeric(-n)),
_ => Err(anyhow::anyhow!("Cannot negate non-numeric value")),
}
} else {
Err(anyhow::anyhow!("Unsupported unary operator in EXECUTE parameter"))
}
}
_ => Err(anyhow::anyhow!(
"Unsupported expression type in EXECUTE parameters. Only literals are currently supported."
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::RwLock;
use vibesql_storage::Database;
fn create_shared_db() -> SharedDatabase {
Arc::new(RwLock::new(Database::new()))
}
#[test]
fn test_session_creation() {
let db = create_shared_db();
let session = Session::new("testdb".to_string(), "testuser".to_string(), db);
assert_eq!(session.database, "testdb");
assert_eq!(session.user, "testuser");
assert!(!session.in_transaction());
}
#[tokio::test]
async fn test_transaction_state() {
let db = create_shared_db();
let mut session = Session::new("testdb".to_string(), "testuser".to_string(), db);
assert!(!session.in_transaction());
assert!(session.begin_transaction().await.is_ok());
assert!(session.in_transaction());
assert!(session.begin_transaction().await.is_err());
assert!(session.commit().await.is_ok());
assert!(!session.in_transaction());
assert!(session.commit().await.is_err());
}
#[tokio::test]
async fn test_prepare_and_execute() {
let db = create_shared_db();
let mut session = Session::new("testdb".to_string(), "testuser".to_string(), db);
session.execute("CREATE TABLE users (id INT, name VARCHAR(100))").await.unwrap();
let stmt = session.prepare("SELECT * FROM users WHERE id = 1").unwrap();
assert_eq!(stmt.param_count(), 0);
let result = session.execute_prepared(&stmt, &[]).await;
assert!(result.is_ok());
match result.unwrap() {
ExecutionResult::Select { .. } => (),
_ => panic!("Expected Select result"),
}
}
#[test]
fn test_cache_hit() {
let db = create_shared_db();
let session = Session::new("testdb".to_string(), "testuser".to_string(), db);
let _stmt1 = session.prepare("SELECT 1").unwrap();
let stats = session.cache_stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
let _stmt2 = session.prepare("SELECT 1").unwrap();
let stats = session.cache_stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
}
#[tokio::test]
async fn test_auto_caching_in_execute() {
let db = create_shared_db();
let mut session = Session::new("testdb".to_string(), "testuser".to_string(), db);
session.execute("SELECT 1").await.unwrap();
let stats = session.cache_stats();
assert_eq!(stats.misses, 1);
session.execute("SELECT 1").await.unwrap();
let stats = session.cache_stats();
assert_eq!(stats.hits, 1);
}
#[tokio::test]
async fn test_analyze_single_table() {
let db = create_shared_db();
let mut session = Session::new("testdb".to_string(), "testuser".to_string(), db);
session.execute("CREATE TABLE users (id INT, name VARCHAR(100))").await.unwrap();
session.execute("INSERT INTO users VALUES (1, 'Alice')").await.unwrap();
session.execute("INSERT INTO users VALUES (2, 'Bob')").await.unwrap();
let result = session.execute("ANALYZE users").await.unwrap();
match result {
ExecutionResult::Analyze { tables_analyzed } => {
assert_eq!(tables_analyzed, 1);
}
other => panic!("Expected Analyze result, got {:?}", other),
}
}
#[tokio::test]
async fn test_analyze_all_tables() {
let db = create_shared_db();
let mut session = Session::new("testdb".to_string(), "testuser".to_string(), db);
session.execute("CREATE TABLE users (id INT, name VARCHAR(100))").await.unwrap();
session.execute("CREATE TABLE products (id INT, price INT)").await.unwrap();
session.execute("INSERT INTO users VALUES (1, 'Alice')").await.unwrap();
session.execute("INSERT INTO products VALUES (1, 100)").await.unwrap();
let result = session.execute("ANALYZE").await.unwrap();
match result {
ExecutionResult::Analyze { tables_analyzed } => {
assert_eq!(tables_analyzed, 2);
}
other => panic!("Expected Analyze result, got {:?}", other),
}
}
#[tokio::test]
async fn test_analyze_with_columns() {
let db = create_shared_db();
let mut session = Session::new("testdb".to_string(), "testuser".to_string(), db);
session.execute("CREATE TABLE users (id INT, name VARCHAR(100), age INT)").await.unwrap();
session.execute("INSERT INTO users VALUES (1, 'Alice', 30)").await.unwrap();
let result = session.execute("ANALYZE users (id, name)").await.unwrap();
match result {
ExecutionResult::Analyze { tables_analyzed } => {
assert_eq!(tables_analyzed, 1);
}
other => panic!("Expected Analyze result, got {:?}", other),
}
}
#[test]
fn test_analyze_statement_type() {
let result = ExecutionResult::Analyze { tables_analyzed: 1 };
assert_eq!(result.statement_type(), "ANALYZE");
}
#[tokio::test]
async fn test_shared_database_across_sessions() {
let db = create_shared_db();
let mut session1 = Session::new("testdb".to_string(), "user1".to_string(), Arc::clone(&db));
let mut session2 = Session::new("testdb".to_string(), "user2".to_string(), Arc::clone(&db));
session1
.execute("CREATE TABLE shared_test (id INT, value VARCHAR(100))")
.await
.unwrap();
session1
.execute("INSERT INTO shared_test VALUES (1, 'from session 1')")
.await
.unwrap();
let result = session2.execute("SELECT * FROM shared_test").await.unwrap();
match result {
ExecutionResult::Select { rows, .. } => {
assert_eq!(rows.len(), 1);
}
_ => panic!("Expected Select result"),
}
session2
.execute("INSERT INTO shared_test VALUES (2, 'from session 2')")
.await
.unwrap();
let result = session1.execute("SELECT * FROM shared_test").await.unwrap();
match result {
ExecutionResult::Select { rows, .. } => {
assert_eq!(rows.len(), 2);
}
_ => panic!("Expected Select result"),
}
}
}