use axum::{
extract::{Query, State},
Json,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tracing::{info, error};
use crate::api::models::ApiError;
use crate::api::server::AppState;
use crate::storage::dump::DatabaseInterface;
use crate::ai::nl_query::{
NlQueryEngine, NlQueryRequest, NlQueryConfig,
SchemaContext, TableSchema, ColumnSchema, ConversationContext,
QueryHistoryEntry, SqlDialect,
};
#[derive(Debug, Deserialize)]
pub struct NlQueryApiRequest {
pub question: String,
pub database: Option<String>,
pub schema: Option<String>,
pub tables: Option<Vec<String>>,
pub context: Option<Vec<ConversationEntry>>,
pub session_id: Option<String>,
pub config: Option<NlQueryConfigOverride>,
}
#[derive(Debug, Deserialize)]
pub struct ConversationEntry {
pub question: String,
pub sql: String,
pub success: bool,
}
#[derive(Debug, Deserialize)]
pub struct NlQueryConfigOverride {
pub dialect: Option<String>,
pub max_results: Option<usize>,
pub validate_sql: Option<bool>,
pub explain_results: Option<bool>,
pub temperature: Option<f32>,
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct NlQueryApiResponse {
pub sql: String,
pub explanation: Option<String>,
pub confidence: f32,
pub intent: String,
pub tables: Vec<String>,
pub valid: bool,
pub validation_errors: Vec<String>,
pub warnings: Vec<String>,
pub suggestions: Vec<SuggestionResponse>,
pub processing_time_ms: u64,
pub cached: bool,
}
#[derive(Debug, Serialize)]
pub struct SuggestionResponse {
pub text: String,
pub sql: Option<String>,
pub reason: String,
}
#[derive(Debug, Deserialize)]
pub struct NlExecuteRequest {
pub question: String,
pub branch: Option<String>,
pub database: Option<String>,
pub schema: Option<String>,
pub tables: Option<Vec<String>>,
pub limit: Option<usize>,
pub context: Option<Vec<ConversationEntry>>,
pub session_id: Option<String>,
pub config: Option<NlQueryConfigOverride>,
}
#[derive(Debug, Serialize)]
pub struct NlExecuteResponse {
pub sql: String,
pub explanation: Option<String>,
pub confidence: f32,
pub columns: Vec<String>,
pub column_types: Vec<String>,
pub rows: Vec<HashMap<String, serde_json::Value>>,
pub row_count: usize,
pub nl_processing_time_ms: u64,
pub sql_execution_time_ms: u64,
pub total_time_ms: u64,
pub warnings: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct SchemaContextResponse {
pub tables: Vec<TableSchemaResponse>,
pub database: Option<String>,
pub schema: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct TableSchemaResponse {
pub name: String,
pub description: Option<String>,
pub columns: Vec<ColumnSchemaResponse>,
pub primary_key: Option<Vec<String>>,
pub foreign_keys: Option<Vec<ForeignKeyResponse>>,
pub row_count: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct ColumnSchemaResponse {
pub name: String,
pub data_type: String,
pub nullable: bool,
pub description: Option<String>,
pub is_primary_key: bool,
}
#[derive(Debug, Serialize)]
pub struct ForeignKeyResponse {
pub columns: Vec<String>,
pub ref_table: String,
pub ref_columns: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct NlExplainRequest {
pub sql: String,
pub question: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct NlExplainResponse {
pub explanation: String,
pub breakdown: QueryBreakdown,
pub suggestions: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct QueryBreakdown {
pub operation: String,
pub tables: Vec<String>,
pub columns: Vec<String>,
pub conditions: Vec<String>,
pub joins: Vec<String>,
pub aggregations: Vec<String>,
pub order_by: Option<String>,
pub limit: Option<usize>,
}
#[derive(Debug, Deserialize)]
pub struct NlSuggestRequest {
pub partial: String,
pub database: Option<String>,
pub schema: Option<String>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct NlSuggestResponse {
pub suggestions: Vec<QuerySuggestionResponse>,
}
#[derive(Debug, Serialize)]
pub struct QuerySuggestionResponse {
pub question: String,
pub category: String,
pub complexity: String,
}
#[derive(Debug, Deserialize)]
pub struct SchemaQueryParams {
pub database: Option<String>,
pub schema: Option<String>,
pub tables: Option<String>,
pub include_samples: Option<bool>,
}
pub async fn nl_to_sql(
State(state): State<AppState>,
Json(request): Json<NlQueryApiRequest>,
) -> Result<Json<NlQueryApiResponse>, ApiError> {
let start = Instant::now();
info!("NL Query request: {}", request.question);
let schema_context = build_schema_context(&state, &request).await?;
let conversation_context = request.context.map(|entries| {
ConversationContext {
history: entries.into_iter().map(|e| QueryHistoryEntry {
question: e.question,
sql: e.sql,
success: e.success,
timestamp: None,
}).collect(),
entities: None,
session_id: request.session_id.clone(),
}
});
let config = build_config(request.config);
let nl_request = NlQueryRequest {
question: request.question.clone(),
schema: Some(schema_context),
context: conversation_context,
config: Some(config),
user_id: None, tenant_id: None, metadata: None,
};
let engine = get_nl_engine(&state)?;
let response = engine.translate(nl_request).await
.map_err(|e| ApiError::internal(format!("NL translation failed: {}", e)))?;
let api_response = NlQueryApiResponse {
sql: response.sql,
explanation: response.explanation,
confidence: response.confidence,
intent: format!("{:?}", response.analysis.intent).to_lowercase(),
tables: response.analysis.tables,
valid: response.validation.as_ref().map(|v| v.allowed).unwrap_or(true),
validation_errors: response.validation
.as_ref()
.map(|v| v.errors.iter().map(|e| e.message.clone()).collect())
.unwrap_or_default(),
warnings: response.warnings,
suggestions: response.suggestions
.unwrap_or_default()
.into_iter()
.map(|s| SuggestionResponse {
text: s.text,
sql: s.sql,
reason: s.reason,
})
.collect(),
processing_time_ms: start.elapsed().as_millis() as u64,
cached: response.cached,
};
info!(
"NL Query completed in {}ms, confidence: {:.2}",
api_response.processing_time_ms,
api_response.confidence
);
Ok(Json(api_response))
}
pub async fn nl_execute(
State(state): State<AppState>,
Json(request): Json<NlExecuteRequest>,
) -> Result<Json<NlExecuteResponse>, ApiError> {
let total_start = Instant::now();
info!("NL Execute request: {}", request.question);
let nl_start = Instant::now();
let schema_context = build_schema_context_from_execute(&state, &request).await?;
let conversation_context = request.context.map(|entries| {
ConversationContext {
history: entries.into_iter().map(|e| QueryHistoryEntry {
question: e.question,
sql: e.sql,
success: e.success,
timestamp: None,
}).collect(),
entities: None,
session_id: request.session_id.clone(),
}
});
let config = build_config(request.config);
let nl_request = NlQueryRequest {
question: request.question.clone(),
schema: Some(schema_context),
context: conversation_context,
config: Some(config.clone()),
user_id: None,
tenant_id: None,
metadata: None,
};
let engine = get_nl_engine(&state)?;
let nl_response = engine.translate(nl_request).await
.map_err(|e| ApiError::internal(format!("NL translation failed: {}", e)))?;
let nl_processing_time = nl_start.elapsed().as_millis() as u64;
if let Some(ref v) = nl_response.validation {
if !v.allowed {
return Err(ApiError::bad_request(format!(
"Generated SQL is not valid: {}",
v.errors.iter().map(|e| e.message.as_str()).collect::<Vec<_>>().join(", ")
)));
}
}
let sql_start = Instant::now();
let branch = request.branch.as_deref().unwrap_or("main");
let sql = if request.limit.is_some() && !nl_response.sql.to_uppercase().contains("LIMIT") {
format!("{} LIMIT {}", nl_response.sql, request.limit.unwrap_or(1000))
} else {
nl_response.sql.clone()
};
let tuples = state.db.query(&sql, &[])
.map_err(|e| {
error!("SQL execution failed: {}", e);
ApiError::from(e)
})?;
let sql_execution_time = sql_start.elapsed().as_millis() as u64;
let (columns, column_types, rows) = if tuples.is_empty() {
(vec![], vec![], vec![])
} else if let Some(first) = tuples.first() {
let cols: Vec<String> = (0..first.values.len())
.map(|i| format!("column_{}", i))
.collect();
let types: Vec<String> = first.values.iter()
.map(|v| format!("{:?}", v).split('(').next().unwrap_or("unknown").to_lowercase())
.collect();
let rows: Vec<HashMap<String, serde_json::Value>> = tuples.iter().map(|t| {
let mut row = HashMap::new();
for (i, v) in t.values.iter().enumerate() {
let json_val: serde_json::Value = v.into();
row.insert(cols.get(i).cloned().unwrap_or_default(), json_val);
}
row
}).collect();
(cols, types, rows)
} else {
(vec![], vec![], vec![])
};
let total_time = total_start.elapsed().as_millis() as u64;
let response = NlExecuteResponse {
sql: nl_response.sql,
explanation: nl_response.explanation,
confidence: nl_response.confidence,
columns,
column_types,
row_count: rows.len(),
rows,
nl_processing_time_ms: nl_processing_time,
sql_execution_time_ms: sql_execution_time,
total_time_ms: total_time,
warnings: nl_response.warnings,
};
info!(
"NL Execute completed: {} rows in {}ms (NL: {}ms, SQL: {}ms)",
response.row_count,
total_time,
nl_processing_time,
sql_execution_time
);
Ok(Json(response))
}
pub async fn nl_explain(
State(state): State<AppState>,
Json(request): Json<NlExplainRequest>,
) -> Result<Json<NlExplainResponse>, ApiError> {
info!("NL Explain request for SQL: {}", request.sql);
let sql_upper = request.sql.to_uppercase();
let operation = if sql_upper.starts_with("SELECT") {
"SELECT"
} else if sql_upper.starts_with("INSERT") {
"INSERT"
} else if sql_upper.starts_with("UPDATE") {
"UPDATE"
} else if sql_upper.starts_with("DELETE") {
"DELETE"
} else {
"UNKNOWN"
};
let tables = extract_tables_from_sql(&request.sql);
let conditions = extract_conditions_from_sql(&request.sql);
let joins = extract_joins_from_sql(&request.sql);
let mut explanation_parts = Vec::new();
match operation {
"SELECT" => {
explanation_parts.push(format!(
"This query retrieves data from {}.",
if tables.is_empty() {
"the database".to_string()
} else {
format!("the {} table(s)", tables.join(", "))
}
));
}
"INSERT" => {
explanation_parts.push("This query inserts new data.".to_string());
}
"UPDATE" => {
explanation_parts.push("This query updates existing data.".to_string());
}
"DELETE" => {
explanation_parts.push("This query deletes data.".to_string());
}
_ => {
explanation_parts.push("This is a database operation.".to_string());
}
}
if !conditions.is_empty() {
explanation_parts.push(format!(
"It filters results where {}.",
conditions.join(" and ")
));
}
if !joins.is_empty() {
explanation_parts.push(format!(
"It combines data using {} join(s).",
joins.len()
));
}
let limit = if let Some(pos) = sql_upper.find("LIMIT") {
let after = &request.sql[pos + 5..];
after.trim().split_whitespace().next()
.and_then(|s| s.parse::<usize>().ok())
} else {
None
};
if let Some(lim) = limit {
explanation_parts.push(format!("Results are limited to {} rows.", lim));
}
let mut suggestions = Vec::new();
if limit.is_none() && operation == "SELECT" {
suggestions.push("Consider adding a LIMIT clause for large tables.".to_string());
}
if sql_upper.contains("SELECT *") {
suggestions.push("Consider selecting specific columns instead of *.".to_string());
}
let response = NlExplainResponse {
explanation: explanation_parts.join(" "),
breakdown: QueryBreakdown {
operation: operation.to_string(),
tables,
columns: extract_columns_from_sql(&request.sql),
conditions,
joins,
aggregations: extract_aggregations_from_sql(&request.sql),
order_by: extract_order_by_from_sql(&request.sql),
limit,
},
suggestions,
};
Ok(Json(response))
}
pub async fn get_schema_context(
State(state): State<AppState>,
Query(params): Query<SchemaQueryParams>,
) -> Result<Json<SchemaContextResponse>, ApiError> {
info!("Getting schema context for NL queries");
let tables_result = state.db.list_tables()
.map_err(|e| ApiError::internal(format!("Failed to list tables: {}", e)))?;
let filter_tables: Option<Vec<String>> = params.tables
.map(|t| t.split(',').map(|s| s.trim().to_string()).collect());
let mut tables = Vec::new();
for table_name in tables_result {
if let Some(ref filter) = filter_tables {
if !filter.iter().any(|f| f.eq_ignore_ascii_case(&table_name)) {
continue;
}
}
if let Ok(schema) = state.db.get_table_schema(&table_name) {
let columns: Vec<ColumnSchemaResponse> = schema.columns.iter().map(|c| {
ColumnSchemaResponse {
name: c.name.clone(),
data_type: format!("{:?}", c.data_type),
nullable: c.nullable,
description: None,
is_primary_key: c.primary_key,
}
}).collect();
let primary_key: Vec<String> = schema.columns.iter()
.filter(|c| c.primary_key)
.map(|c| c.name.clone())
.collect();
tables.push(TableSchemaResponse {
name: table_name,
description: None,
columns,
primary_key: if primary_key.is_empty() { None } else { Some(primary_key) },
foreign_keys: None, row_count: None, });
}
}
let response = SchemaContextResponse {
tables,
database: params.database,
schema: params.schema,
};
Ok(Json(response))
}
pub async fn nl_suggest(
State(state): State<AppState>,
Json(request): Json<NlSuggestRequest>,
) -> Result<Json<NlSuggestResponse>, ApiError> {
info!("NL Suggest request: {}", request.partial);
let limit = request.limit.unwrap_or(5);
let partial_lower = request.partial.to_lowercase();
let tables = state.db.list_tables().unwrap_or_default();
let mut suggestions = Vec::new();
if partial_lower.contains("how many") || partial_lower.contains("count") {
for table in tables.iter().take(3) {
suggestions.push(QuerySuggestionResponse {
question: format!("How many records are in {}?", table),
category: "count".to_string(),
complexity: "simple".to_string(),
});
}
}
if partial_lower.contains("show") || partial_lower.contains("list") || partial_lower.contains("get") {
for table in tables.iter().take(3) {
suggestions.push(QuerySuggestionResponse {
question: format!("Show all records from {}", table),
category: "select".to_string(),
complexity: "simple".to_string(),
});
suggestions.push(QuerySuggestionResponse {
question: format!("Show the top 10 records from {}", table),
category: "select".to_string(),
complexity: "simple".to_string(),
});
}
}
if partial_lower.contains("average") || partial_lower.contains("avg") || partial_lower.contains("total") || partial_lower.contains("sum") {
suggestions.push(QuerySuggestionResponse {
question: "What is the average value?".to_string(),
category: "aggregate".to_string(),
complexity: "medium".to_string(),
});
suggestions.push(QuerySuggestionResponse {
question: "What is the total sum?".to_string(),
category: "aggregate".to_string(),
complexity: "medium".to_string(),
});
}
if partial_lower.contains("group") || partial_lower.contains("by") {
suggestions.push(QuerySuggestionResponse {
question: "Group records by category".to_string(),
category: "group".to_string(),
complexity: "medium".to_string(),
});
}
if suggestions.len() < limit {
for table in tables.iter() {
if suggestions.len() >= limit {
break;
}
suggestions.push(QuerySuggestionResponse {
question: format!("Find records in {} where...", table),
category: "search".to_string(),
complexity: "simple".to_string(),
});
}
}
suggestions.truncate(limit);
Ok(Json(NlSuggestResponse { suggestions }))
}
async fn build_schema_context(
state: &AppState,
request: &NlQueryApiRequest,
) -> Result<SchemaContext, ApiError> {
let tables_list = if let Some(ref tables) = request.tables {
tables.clone()
} else {
state.db.list_tables()
.map_err(|e| ApiError::internal(format!("Failed to list tables: {}", e)))?
};
let mut tables = Vec::new();
for table_name in tables_list {
if let Ok(schema) = state.db.get_table_schema(&table_name) {
let columns: Vec<ColumnSchema> = schema.columns.iter().map(|c| {
ColumnSchema {
name: c.name.clone(),
data_type: format!("{:?}", c.data_type),
nullable: c.nullable,
description: None,
default_value: None,
is_primary_key: c.primary_key,
is_unique: false,
enum_values: None,
}
}).collect();
let primary_key: Vec<String> = schema.columns.iter()
.filter(|c| c.primary_key)
.map(|c| c.name.clone())
.collect();
tables.push(TableSchema {
name: table_name,
description: None,
columns,
primary_key: if primary_key.is_empty() { None } else { Some(primary_key) },
foreign_keys: None,
indexes: None,
sample_values: None,
row_count: None,
});
}
}
Ok(SchemaContext {
tables,
database: request.database.clone(),
schema: request.schema.clone(),
hints: None,
})
}
async fn build_schema_context_from_execute(
state: &AppState,
request: &NlExecuteRequest,
) -> Result<SchemaContext, ApiError> {
let tables_list = if let Some(ref tables) = request.tables {
tables.clone()
} else {
state.db.list_tables()
.map_err(|e| ApiError::internal(format!("Failed to list tables: {}", e)))?
};
let mut tables = Vec::new();
for table_name in tables_list {
if let Ok(schema) = state.db.get_table_schema(&table_name) {
let columns: Vec<ColumnSchema> = schema.columns.iter().map(|c| {
ColumnSchema {
name: c.name.clone(),
data_type: format!("{:?}", c.data_type),
nullable: c.nullable,
description: None,
default_value: None,
is_primary_key: c.primary_key,
is_unique: false,
enum_values: None,
}
}).collect();
let primary_key: Vec<String> = schema.columns.iter()
.filter(|c| c.primary_key)
.map(|c| c.name.clone())
.collect();
tables.push(TableSchema {
name: table_name,
description: None,
columns,
primary_key: if primary_key.is_empty() { None } else { Some(primary_key) },
foreign_keys: None,
indexes: None,
sample_values: None,
row_count: None,
});
}
}
Ok(SchemaContext {
tables,
database: request.database.clone(),
schema: request.schema.clone(),
hints: None,
})
}
fn build_config(overrides: Option<NlQueryConfigOverride>) -> NlQueryConfig {
let mut config = NlQueryConfig::default();
if let Some(o) = overrides {
if let Some(dialect) = o.dialect {
config.dialect = match dialect.to_lowercase().as_str() {
"postgresql" | "postgres" => SqlDialect::PostgreSQL,
"mysql" => SqlDialect::MySQL,
"sqlite" => SqlDialect::SQLite,
"mssql" | "sqlserver" => SqlDialect::MSSQL,
"oracle" => SqlDialect::Oracle,
"heliosdb" => SqlDialect::HeliosDB,
_ => SqlDialect::PostgreSQL,
};
}
if let Some(max) = o.max_results {
config.max_results = max;
}
if let Some(validate) = o.validate_sql {
config.validate_sql = validate;
}
if let Some(explain) = o.explain_results {
config.explain_results = explain;
}
if let Some(temp) = o.temperature {
config.temperature = temp;
}
if let Some(model) = o.model {
config.model = Some(model);
}
}
config
}
fn get_nl_engine(state: &AppState) -> Result<Arc<NlQueryEngine>, ApiError> {
use crate::ai::providers::{LlmProviderConfig, ProviderRegistry};
let provider_config = LlmProviderConfig {
provider: "ollama".to_string(), api_key: None,
endpoint: Some("http://localhost:11434".to_string()),
model: Some("llama3.2".to_string()),
organization: None,
deployment: None,
api_version: None,
timeout_ms: Some(30000),
max_retries: Some(3),
headers: None,
};
let provider = ProviderRegistry::from_config(&provider_config)
.map_err(|e| ApiError::internal(format!("Failed to create LLM provider: {}", e)))?;
Ok(Arc::new(NlQueryEngine::new(provider)))
}
fn extract_tables_from_sql(sql: &str) -> Vec<String> {
let mut tables = Vec::new();
let re = regex::Regex::new(r"(?i)\b(?:FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z_][a-zA-Z0-9_]*)").ok();
if let Some(re) = re {
for cap in re.captures_iter(sql) {
if let Some(m) = cap.get(1) {
let table = m.as_str().to_string();
if !tables.contains(&table) {
tables.push(table);
}
}
}
}
tables
}
fn extract_conditions_from_sql(sql: &str) -> Vec<String> {
let mut conditions = Vec::new();
if let Some(where_pos) = sql.to_uppercase().find("WHERE") {
let after = &sql[where_pos + 5..];
let end = after.to_uppercase()
.find("ORDER BY")
.or_else(|| after.to_uppercase().find("GROUP BY"))
.or_else(|| after.to_uppercase().find("LIMIT"))
.unwrap_or(after.len());
let where_clause = after[..end].trim();
for part in where_clause.split(['(', ')']) {
let trimmed = part.trim();
if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("AND") && !trimmed.eq_ignore_ascii_case("OR") {
conditions.push(trimmed.to_string());
}
}
}
conditions
}
fn extract_joins_from_sql(sql: &str) -> Vec<String> {
let mut joins = Vec::new();
let re = regex::Regex::new(r"(?i)((?:LEFT|RIGHT|INNER|OUTER|CROSS|FULL)?\s*JOIN\s+[^\s]+\s+(?:ON\s+[^,]+)?)").ok();
if let Some(re) = re {
for cap in re.captures_iter(sql) {
if let Some(m) = cap.get(1) {
joins.push(m.as_str().trim().to_string());
}
}
}
joins
}
fn extract_columns_from_sql(sql: &str) -> Vec<String> {
let upper = sql.to_uppercase();
if let Some(select_pos) = upper.find("SELECT") {
if let Some(from_pos) = upper.find("FROM") {
let columns_part = &sql[select_pos + 6..from_pos];
return columns_part.split(',')
.map(|c| c.trim().to_string())
.filter(|c| !c.is_empty())
.collect();
}
}
Vec::new()
}
fn extract_aggregations_from_sql(sql: &str) -> Vec<String> {
let mut aggs = Vec::new();
let re = regex::Regex::new(r"(?i)(COUNT|SUM|AVG|MIN|MAX)\s*\([^)]+\)").ok();
if let Some(re) = re {
for cap in re.captures_iter(sql) {
if let Some(m) = cap.get(0) {
aggs.push(m.as_str().to_string());
}
}
}
aggs
}
fn extract_order_by_from_sql(sql: &str) -> Option<String> {
let upper = sql.to_uppercase();
if let Some(pos) = upper.find("ORDER BY") {
let after = &sql[pos + 8..];
let end = after.to_uppercase()
.find("LIMIT")
.unwrap_or(after.len());
return Some(after[..end].trim().to_string());
}
None
}