use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use sochdb_core::soch::SochValue;
use crate::ast_query::AstQueryExecutor;
use crate::connection::SochConnection;
use crate::crud::{DeleteResult, InsertResult, UpdateResult};
use crate::error::{ClientError, Result};
use crate::schema::{CreateIndexResult, CreateTableResult, DropTableResult};
pub use crate::ast_query::QueryResult;
static QUERY_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone)]
pub struct QueryStats {
pub query_id: u64,
pub parse_time_us: u64,
pub exec_time_us: u64,
pub total_time_us: u64,
pub rows_affected: usize,
pub dialect: String,
pub query_type: String,
}
pub fn execute(conn: &SochConnection, sql: &str) -> Result<QueryResult> {
execute_with_params(conn, sql, &[])
}
pub fn execute_with_params(
conn: &SochConnection,
sql: &str,
params: &[SochValue],
) -> Result<QueryResult> {
let start = Instant::now();
let query_id = QUERY_COUNTER.fetch_add(1, Ordering::Relaxed);
let executor = AstQueryExecutor::new(conn);
let result = executor.execute_with_params(sql, params);
let total_us = start.elapsed().as_micros() as u64;
if cfg!(debug_assertions) {
tracing::debug!(
query_id = query_id,
total_us = total_us,
sql = sql,
"SQL executed"
);
}
result
}
pub fn execute_with_stats(
conn: &SochConnection,
sql: &str,
) -> Result<(QueryResult, QueryStats)> {
execute_with_params_and_stats(conn, sql, &[])
}
pub fn execute_with_params_and_stats(
conn: &SochConnection,
sql: &str,
params: &[SochValue],
) -> Result<(QueryResult, QueryStats)> {
let start = Instant::now();
let query_id = QUERY_COUNTER.fetch_add(1, Ordering::Relaxed);
let dialect = detect_dialect(sql);
let parse_start = Instant::now();
let executor = AstQueryExecutor::new(conn);
let parse_time_us = parse_start.elapsed().as_micros() as u64;
let exec_start = Instant::now();
let result = executor.execute_with_params(sql, params)?;
let exec_time_us = exec_start.elapsed().as_micros() as u64;
let total_time_us = start.elapsed().as_micros() as u64;
let (rows_affected, query_type) = match &result {
QueryResult::Select(rows) => (rows.len(), "SELECT"),
QueryResult::Insert(r) => (r.rows_inserted as usize, "INSERT"),
QueryResult::Update(r) => (r.rows_updated as usize, "UPDATE"),
QueryResult::Delete(r) => (r.rows_deleted as usize, "DELETE"),
QueryResult::CreateTable(_) => (0, "CREATE TABLE"),
QueryResult::DropTable(_) => (0, "DROP TABLE"),
QueryResult::CreateIndex(_) => (0, "CREATE INDEX"),
QueryResult::Empty => (0, "OTHER"),
};
let stats = QueryStats {
query_id,
parse_time_us,
exec_time_us,
total_time_us,
rows_affected,
dialect: dialect.to_string(),
query_type: query_type.to_string(),
};
Ok((result, stats))
}
fn detect_dialect(sql: &str) -> &'static str {
let upper = sql.to_uppercase();
if upper.contains("ON CONFLICT") && upper.contains("DO UPDATE") {
"PostgreSQL"
} else if upper.contains("ON DUPLICATE KEY") {
"MySQL"
} else if upper.contains("INSERT OR REPLACE") || upper.contains("INSERT OR IGNORE") {
"SQLite"
} else if upper.contains("RETURNING") {
"PostgreSQL"
} else if upper.contains("LIMIT") && upper.contains("OFFSET") && upper.contains(",") {
"MySQL"
} else {
"SQL-92"
}
}
#[derive(Debug)]
pub struct BatchResult {
pub results: Vec<Result<QueryResult>>,
pub total: usize,
pub succeeded: usize,
pub failed: usize,
}
pub fn execute_batch(
conn: &SochConnection,
statements: &[&str],
stop_on_error: bool,
) -> BatchResult {
let mut results = Vec::with_capacity(statements.len());
let mut succeeded = 0;
let mut failed = 0;
for sql in statements {
let result = execute(conn, sql);
match &result {
Ok(_) => succeeded += 1,
Err(_) => {
failed += 1;
if stop_on_error {
results.push(result);
break;
}
}
}
results.push(result);
}
BatchResult {
results,
total: statements.len(),
succeeded,
failed,
}
}
pub struct PreparedStatement<'a> {
conn: &'a SochConnection,
sql: String,
param_count: usize,
}
impl<'a> PreparedStatement<'a> {
pub fn prepare(conn: &'a SochConnection, sql: impl Into<String>) -> Result<Self> {
let sql = sql.into();
let param_count = count_parameters(&sql);
let executor = AstQueryExecutor::new(conn);
let _ = executor.execute_with_params(&sql, &vec![SochValue::Null; param_count])?;
Ok(Self {
conn,
sql,
param_count,
})
}
pub fn execute(&self, params: &[SochValue]) -> Result<QueryResult> {
if params.len() != self.param_count {
return Err(ClientError::Parse(format!(
"Expected {} parameters, got {}",
self.param_count,
params.len()
)));
}
execute_with_params(self.conn, &self.sql, params)
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn param_count(&self) -> usize {
self.param_count
}
}
fn count_parameters(sql: &str) -> usize {
let mut count = 0;
let mut max_numbered = 0;
let chars: Vec<char> = sql.chars().collect();
let mut i = 0;
while i < chars.len() {
match chars[i] {
'?' => {
count += 1;
i += 1;
}
'$' => {
let mut num_str = String::new();
let mut j = i + 1;
while j < chars.len() && chars[j].is_ascii_digit() {
num_str.push(chars[j]);
j += 1;
}
if !num_str.is_empty() {
if let Ok(n) = num_str.parse::<usize>() {
max_numbered = max_numbered.max(n);
}
}
i = j;
}
'\'' | '"' => {
let quote = chars[i];
i += 1;
while i < chars.len() && chars[i] != quote {
if chars[i] == '\\' && i + 1 < chars.len() {
i += 1;
}
i += 1;
}
i += 1;
}
_ => i += 1,
}
}
count.max(max_numbered)
}
pub struct QueryBuilder<'a> {
conn: &'a SochConnection,
query_type: QueryBuilderType,
table: String,
columns: Vec<String>,
values: Vec<SochValue>,
conditions: Vec<(String, &'static str, SochValue)>,
order_by: Option<(String, bool)>,
limit: Option<usize>,
offset: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
enum QueryBuilderType {
Select,
Insert,
Update,
Delete,
}
impl<'a> QueryBuilder<'a> {
pub fn select(conn: &'a SochConnection, table: impl Into<String>) -> Self {
Self {
conn,
query_type: QueryBuilderType::Select,
table: table.into(),
columns: Vec::new(),
values: Vec::new(),
conditions: Vec::new(),
order_by: None,
limit: None,
offset: None,
}
}
pub fn insert(conn: &'a SochConnection, table: impl Into<String>) -> Self {
Self {
conn,
query_type: QueryBuilderType::Insert,
table: table.into(),
columns: Vec::new(),
values: Vec::new(),
conditions: Vec::new(),
order_by: None,
limit: None,
offset: None,
}
}
pub fn update(conn: &'a SochConnection, table: impl Into<String>) -> Self {
Self {
conn,
query_type: QueryBuilderType::Update,
table: table.into(),
columns: Vec::new(),
values: Vec::new(),
conditions: Vec::new(),
order_by: None,
limit: None,
offset: None,
}
}
pub fn delete(conn: &'a SochConnection, table: impl Into<String>) -> Self {
Self {
conn,
query_type: QueryBuilderType::Delete,
table: table.into(),
columns: Vec::new(),
values: Vec::new(),
conditions: Vec::new(),
order_by: None,
limit: None,
offset: None,
}
}
pub fn columns(mut self, cols: &[&str]) -> Self {
self.columns = cols.iter().map(|s| s.to_string()).collect();
self
}
pub fn set(mut self, column: impl Into<String>, value: SochValue) -> Self {
self.columns.push(column.into());
self.values.push(value);
self
}
pub fn where_eq(mut self, column: impl Into<String>, value: SochValue) -> Self {
self.conditions.push((column.into(), "=", value));
self
}
pub fn order_by(mut self, column: impl Into<String>, asc: bool) -> Self {
self.order_by = Some((column.into(), asc));
self
}
pub fn limit(mut self, n: usize) -> Self {
self.limit = Some(n);
self
}
pub fn offset(mut self, n: usize) -> Self {
self.offset = Some(n);
self
}
pub fn to_sql(&self) -> (String, Vec<SochValue>) {
let mut sql = String::new();
let mut params = Vec::new();
let mut param_idx = 1;
match self.query_type {
QueryBuilderType::Select => {
sql.push_str("SELECT ");
if self.columns.is_empty() {
sql.push('*');
} else {
sql.push_str(&self.columns.join(", "));
}
sql.push_str(" FROM ");
sql.push_str(&self.table);
}
QueryBuilderType::Insert => {
sql.push_str("INSERT INTO ");
sql.push_str(&self.table);
sql.push_str(" (");
sql.push_str(&self.columns.join(", "));
sql.push_str(") VALUES (");
let placeholders: Vec<String> = (0..self.values.len())
.map(|i| format!("${}", i + 1))
.collect();
sql.push_str(&placeholders.join(", "));
sql.push(')');
params = self.values.clone();
param_idx = self.values.len() + 1;
}
QueryBuilderType::Update => {
sql.push_str("UPDATE ");
sql.push_str(&self.table);
sql.push_str(" SET ");
let sets: Vec<String> = self.columns.iter().enumerate()
.map(|(i, col)| format!("{} = ${}", col, i + 1))
.collect();
sql.push_str(&sets.join(", "));
params = self.values.clone();
param_idx = self.values.len() + 1;
}
QueryBuilderType::Delete => {
sql.push_str("DELETE FROM ");
sql.push_str(&self.table);
}
}
if !self.conditions.is_empty() {
sql.push_str(" WHERE ");
let conds: Vec<String> = self.conditions.iter().enumerate()
.map(|(i, (col, op, val))| {
params.push(val.clone());
let idx = param_idx + i;
format!("{} {} ${}", col, op, idx)
})
.collect();
sql.push_str(&conds.join(" AND "));
}
if let Some((col, asc)) = &self.order_by {
sql.push_str(" ORDER BY ");
sql.push_str(col);
sql.push_str(if *asc { " ASC" } else { " DESC" });
}
if let Some(n) = self.limit {
sql.push_str(&format!(" LIMIT {}", n));
}
if let Some(n) = self.offset {
sql.push_str(&format!(" OFFSET {}", n));
}
(sql, params)
}
pub fn execute(self) -> Result<QueryResult> {
let (sql, params) = self.to_sql();
execute_with_params(self.conn, &sql, ¶ms)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_parameters() {
assert_eq!(count_parameters("SELECT * FROM t WHERE id = ?"), 1);
assert_eq!(count_parameters("SELECT * FROM t WHERE id = ? AND name = ?"), 2);
assert_eq!(count_parameters("SELECT * FROM t WHERE id = $1"), 1);
assert_eq!(count_parameters("SELECT * FROM t WHERE id = $1 AND name = $2"), 2);
assert_eq!(count_parameters("SELECT * FROM t WHERE name = 'test?'"), 0);
}
#[test]
fn test_detect_dialect() {
assert_eq!(detect_dialect("INSERT INTO t (id) VALUES (1) ON CONFLICT (id) DO UPDATE SET x = 1"), "PostgreSQL");
assert_eq!(detect_dialect("INSERT INTO t (id) VALUES (1) ON DUPLICATE KEY UPDATE x = 1"), "MySQL");
assert_eq!(detect_dialect("INSERT OR REPLACE INTO t (id) VALUES (1)"), "SQLite");
assert_eq!(detect_dialect("SELECT * FROM t"), "SQL-92");
}
#[test]
fn test_query_builder_sql() {
}
}