use anyhow::{anyhow, Context};
use futures_util::stream::Stream;
use futures_util::StreamExt;
use serde_json::Value;
use std::borrow::Cow;
use std::path::Path;
use std::pin::Pin;
use super::csv_import::run_csv_import;
use super::error_highlighting::{display_stmt_db_error, display_stmt_error};
use super::sql::{
DelayedFunctionCall, ParsedSqlFile, ParsedStatement, SimpleSelectValue, StmtWithParams,
};
use crate::dynamic_component::parse_dynamic_rows;
use crate::utils::add_value_to_map;
use crate::webserver::database::sql_to_json::row_to_string;
use crate::webserver::http_request_info::ExecutionContext;
use crate::webserver::request_variables::SetVariablesMap;
use crate::webserver::single_or_vec::SingleOrVec;
use crate::webserver::ErrorWithStatus;
use super::syntax_tree::{extract_req_param, StmtParam};
use super::{error_highlighting::display_db_error, Database, DbItem};
use sqlx::any::{AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo};
use sqlx::pool::PoolConnection;
use sqlx::{
Any, AnyConnection, Arguments, Column, Either, Executor, Row as _, Statement, ValueRef,
};
pub type DbConn = Option<PoolConnection<sqlx::Any>>;
impl Database {
pub(crate) async fn prepare_with(
&self,
query: &str,
param_types: &[AnyTypeInfo],
) -> anyhow::Result<AnyStatement<'static>> {
self.connection
.prepare_with(query, param_types)
.await
.map(|s| s.to_owned())
.map_err(|e| display_db_error(Path::new("autogenerated sqlpage query"), query, e))
}
}
pub fn stream_query_results_with_conn<'a>(
sql_file: &'a ParsedSqlFile,
request: &'a ExecutionContext,
db_connection: &'a mut DbConn,
) -> impl Stream<Item = DbItem> + 'a {
let source_file = &sql_file.source_path;
async_stream::try_stream! {
for res in &sql_file.statements {
match res {
ParsedStatement::CsvImport(csv_import) => {
let connection = take_connection(&request.app_state.db, db_connection, request).await?;
log::debug!("Executing CSV import: {csv_import:?}");
run_csv_import(connection, csv_import, request).await.with_context(|| format!("Failed to import the CSV file {:?} into the table {:?}", csv_import.uploaded_file, csv_import.table_name))?;
},
ParsedStatement::StmtWithParams(stmt) => {
let query = bind_parameters(stmt, request, db_connection)
.await
.map_err(|e| with_stmt_position(source_file, stmt.query_position, e))?;
request.server_timing.record("bind_params");
let connection = take_connection(&request.app_state.db, db_connection, request).await?;
log::trace!("Executing query {:?}", query.sql);
let mut stream = connection.fetch_many(query);
let mut error = None;
while let Some(elem) = stream.next().await {
let mut query_result = parse_single_sql_result(source_file, stmt, elem);
if let DbItem::Error(e) = query_result {
error = Some(e);
break;
}
apply_json_columns(&mut query_result, &stmt.json_columns);
apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result).await?;
for db_item in parse_dynamic_rows(query_result) {
yield db_item;
}
}
drop(stream);
if let Some(error) = error {
try_rollback_transaction(connection).await;
yield DbItem::Error(error);
}
},
ParsedStatement::SetVariable { variable, value} => {
execute_set_variable_query(db_connection, request, variable, value, source_file).await
.with_context(||
format!("Failed to set the {variable} variable to {value:?}")
)?;
},
ParsedStatement::StaticSimpleSet { variable, value} => {
execute_set_simple_static(db_connection, request, variable, value, source_file).await
.with_context(||
format!("Failed to set the {variable} variable to {value:?}")
)?;
},
ParsedStatement::StaticSimpleSelect { values, query_position } => {
let row = exec_static_simple_select(values, request, db_connection)
.await
.map_err(|e| with_stmt_position(source_file, *query_position, e))?;
for i in parse_dynamic_rows(DbItem::Row(row)) {
yield i;
}
}
ParsedStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(source_file, e)),
}
}
}
.map(|res| res.unwrap_or_else(DbItem::Error))
}
fn with_stmt_position(
source_file: &Path,
query_position: super::sql::SourceSpan,
error: anyhow::Error,
) -> anyhow::Error {
if error.downcast_ref::<ErrorWithStatus>().is_some() {
error
} else {
display_stmt_error(source_file, query_position, error)
}
}
pub fn stop_at_first_error(
results_stream: impl Stream<Item = DbItem>,
) -> impl Stream<Item = DbItem> {
let (error_tx, error_rx) = tokio::sync::oneshot::channel();
let mut error_tx = Some(error_tx);
results_stream
.inspect(move |item| {
if let DbItem::Error(err) = item {
log::error!("{err:?}");
if let Some(tx) = error_tx.take() {
let _ = tx.send(());
}
}
})
.take_until(error_rx)
}
async fn exec_static_simple_select(
columns: &[(String, SimpleSelectValue)],
req: &ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<serde_json::Value> {
let mut map = serde_json::Map::with_capacity(columns.len());
for (name, value) in columns {
let value = match value {
SimpleSelectValue::Static(s) => s.clone(),
SimpleSelectValue::Dynamic(p) => {
extract_req_param_as_json(p, req, db_connection).await?
}
};
map = add_value_to_map(map, (name.clone(), value));
}
Ok(serde_json::Value::Object(map))
}
async fn try_rollback_transaction(db_connection: &mut AnyConnection) {
log::debug!("Attempting to rollback transaction");
match db_connection.execute("ROLLBACK").await {
Ok(_) => log::debug!("Rolled back transaction"),
Err(e) => {
log::debug!("There was probably no transaction in progress when this happened: {e:?}");
}
}
}
async fn extract_req_param_as_json(
param: &StmtParam,
request: &ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<serde_json::Value> {
if let Some(val) = extract_req_param(param, request, db_connection).await? {
Ok(serde_json::Value::String(val.into_owned()))
} else {
Ok(serde_json::Value::Null)
}
}
pub fn stream_query_results_boxed<'a>(
sql_file: &'a ParsedSqlFile,
request: &'a ExecutionContext,
db_connection: &'a mut DbConn,
) -> Pin<Box<dyn Stream<Item = DbItem> + 'a>> {
Box::pin(stream_query_results_with_conn(
sql_file,
request,
db_connection,
))
}
async fn execute_set_variable_query<'a>(
db_connection: &'a mut DbConn,
request: &'a ExecutionContext,
variable: &StmtParam,
statement: &StmtWithParams,
source_file: &Path,
) -> anyhow::Result<()> {
let query = bind_parameters(statement, request, db_connection).await?;
let connection = take_connection(&request.app_state.db, db_connection, request).await?;
log::debug!(
"Executing query to set the {variable:?} variable: {:?}",
query.sql
);
let value = match connection.fetch_optional(query).await {
Ok(Some(row)) => row_to_string(&row),
Ok(None) => None,
Err(e) => {
try_rollback_transaction(connection).await;
let err = display_stmt_db_error(source_file, statement, e);
return Err(err);
}
};
let (mut vars, name) = vars_and_name(request, variable)?;
log::debug!("Setting variable {name} to {value:?}");
vars.insert(name.to_owned(), value.map(SingleOrVec::Single));
Ok(())
}
async fn execute_set_simple_static<'a>(
db_connection: &'a mut DbConn,
request: &'a ExecutionContext,
variable: &StmtParam,
value: &SimpleSelectValue,
_source_file: &Path,
) -> anyhow::Result<()> {
let value_str = match value {
SimpleSelectValue::Static(json_value) => match json_value {
serde_json::Value::Null => None,
serde_json::Value::String(s) => Some(s.clone()),
other => Some(other.to_string()),
},
SimpleSelectValue::Dynamic(stmt_param) => {
extract_req_param(stmt_param, request, db_connection)
.await?
.map(std::borrow::Cow::into_owned)
}
};
let (mut vars, name) = vars_and_name(request, variable)?;
log::debug!("Setting variable {name} to static value {value_str:?}");
vars.insert(name.to_owned(), value_str.map(SingleOrVec::Single));
Ok(())
}
fn vars_and_name<'a, 'b>(
request: &'a ExecutionContext,
variable: &'b StmtParam,
) -> anyhow::Result<(std::cell::RefMut<'a, SetVariablesMap>, &'b str)> {
match variable {
StmtParam::PostOrGet(name) | StmtParam::Get(name) => {
if request.post_variables.contains_key(name) {
log::warn!("Deprecation warning! Setting the value of ${name}, but there is already a form field named :{name}. This will stop working soon. Please rename the variable, or use :{name} directly if you intended to overwrite the posted form field value.");
}
Ok((request.set_variables.borrow_mut(), name))
}
StmtParam::Post(name) => Ok((request.set_variables.borrow_mut(), name)),
_ => Err(anyhow!(
"Only GET and POST variables can be set, not {variable:?}"
)),
}
}
async fn take_connection<'a>(
db: &'a Database,
conn: &'a mut DbConn,
request: &ExecutionContext,
) -> anyhow::Result<&'a mut PoolConnection<sqlx::Any>> {
if let Some(c) = conn {
return Ok(c);
}
match db.connection.acquire().await {
Ok(c) => {
log::debug!("Acquired a database connection");
request.server_timing.record("db_conn");
*conn = Some(c);
Ok(conn.as_mut().unwrap())
}
Err(e) => {
let db_name = db.connection.any_kind();
let active_count = db.connection.size();
let err_msg = format!("Unable to connect to {db_name:?}. The connection pool currently has {active_count} active connections.");
Err(anyhow::Error::new(e).context(err_msg))
}
}
}
#[inline]
fn parse_single_sql_result(
source_file: &Path,
stmt: &StmtWithParams,
res: sqlx::Result<Either<AnyQueryResult, AnyRow>>,
) -> DbItem {
match res {
Ok(Either::Right(r)) => {
if log::log_enabled!(log::Level::Trace) {
debug_row(&r);
}
DbItem::Row(super::sql_to_json::row_to_json(&r))
}
Ok(Either::Left(res)) => {
log::debug!("Finished query with result: {res:?}");
DbItem::FinishedQuery
}
Err(err) => {
let nice_err = display_stmt_db_error(source_file, stmt, err);
DbItem::Error(nice_err)
}
}
}
fn debug_row(r: &AnyRow) {
use std::fmt::Write;
let columns = r.columns();
let mut row_str = String::new();
for (i, col) in columns.iter().enumerate() {
if let Ok(value) = r.try_get_raw(i) {
write!(
&mut row_str,
"[{:?} ({}): {:?}: {:?}]",
col.name(),
if value.is_null() { "NULL" } else { "NOT NULL" },
col,
value.type_info()
)
.unwrap();
}
}
log::trace!("Received db row: {row_str}");
}
fn clone_anyhow_err(source_file: &Path, err: &anyhow::Error) -> anyhow::Error {
if let Some(func_err) = err.downcast_ref::<super::sql::SqlPageFunctionError>() {
let line = func_err.line;
let loc = if line > 0 {
format!(":{line}")
} else {
String::new()
};
return anyhow::anyhow!("{}{loc} {}", source_file.display(), func_err);
}
let mut e = anyhow!("{} contains a syntax error preventing SQLPage from parsing and preparing its SQL statements.", source_file.display());
for c in err.chain().rev() {
e = e.context(c.to_string());
}
e
}
async fn bind_parameters<'a>(
stmt: &'a StmtWithParams,
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<StatementWithParams<'a>> {
let sql = stmt.query.as_str();
log::debug!("Preparing statement: {sql}");
let mut arguments = AnyArguments::default();
for (param_idx, param) in stmt.params.iter().enumerate() {
log::trace!("\tevaluating parameter {}: {}", param_idx + 1, param);
let argument = extract_req_param(param, request, db_connection).await?;
log::debug!(
"\tparameter {}: {}",
param_idx + 1,
argument.as_ref().unwrap_or(&Cow::Borrowed("NULL"))
);
match argument {
None => arguments.add(None::<String>),
Some(Cow::Owned(s)) => arguments.add(s),
Some(Cow::Borrowed(v)) => arguments.add(v),
}
}
let has_arguments = !stmt.params.is_empty();
Ok(StatementWithParams {
sql,
arguments,
has_arguments,
})
}
async fn apply_delayed_functions(
request: &ExecutionContext,
delayed_functions: &[DelayedFunctionCall],
item: &mut DbItem,
) -> anyhow::Result<()> {
let mut db_conn = None;
if let DbItem::Row(serde_json::Value::Object(ref mut results)) = item {
for f in delayed_functions {
log::trace!("Applying delayed function {} to {:?}", f.function, results);
apply_single_delayed_function(request, &mut db_conn, f, results).await?;
log::trace!(
"Delayed function applied {}. Result: {:?}",
f.function,
results
);
}
}
Ok(())
}
async fn apply_single_delayed_function(
request: &ExecutionContext,
db_connection: &mut DbConn,
f: &DelayedFunctionCall,
row: &mut serde_json::Map<String, serde_json::Value>,
) -> anyhow::Result<()> {
let mut params = Vec::new();
for arg in &f.argument_col_names {
let Some(arg_value) = row.remove(arg) else {
anyhow::bail!("The column {arg} is missing in the result set, but it is required by the {} function.", f.function);
};
params.push(json_to_fn_param(arg_value));
}
let result_str = f.function.evaluate(request, db_connection, params).await?;
let result_json = result_str
.map(Cow::into_owned)
.map_or(serde_json::Value::Null, serde_json::Value::String);
row.insert(f.target_col_name.clone(), result_json);
Ok(())
}
fn json_to_fn_param(json: serde_json::Value) -> Option<Cow<'static, str>> {
match json {
serde_json::Value::String(s) => Some(Cow::Owned(s)),
serde_json::Value::Null => None,
_ => Some(Cow::Owned(json.to_string())),
}
}
fn apply_json_columns(item: &mut DbItem, json_columns: &[String]) {
if let DbItem::Row(Value::Object(ref mut row)) = item {
for column in json_columns {
if let Some(value) = row.get_mut(column) {
if let Value::String(json_str) = value {
if let Ok(parsed_json) = serde_json::from_str(json_str) {
log::trace!("Parsed JSON column {column}: {parsed_json}");
*value = parsed_json;
} else {
log::warn!("The column {column} contains invalid JSON: {json_str}");
}
} else if let Value::Array(array) = value {
for item in array {
if let Value::String(json_str) = item {
if let Ok(parsed_json) = serde_json::from_str(json_str) {
log::trace!("Parsed JSON array item: {parsed_json}");
*item = parsed_json;
}
}
}
}
} else {
log::warn!("The column {column} is missing from the result set, so it cannot be converted to JSON.");
}
}
}
}
pub struct StatementWithParams<'a> {
sql: &'a str,
arguments: AnyArguments<'a>,
has_arguments: bool,
}
impl<'q> sqlx::Execute<'q, Any> for StatementWithParams<'q> {
fn sql(&self) -> &'q str {
self.sql
}
fn statement(&self) -> Option<&<Any as sqlx::database::HasStatement<'q>>::Statement> {
None
}
fn take_arguments(&mut self) -> Option<<Any as sqlx::database::HasArguments<'q>>::Arguments> {
if self.has_arguments {
Some(std::mem::take(&mut self.arguments))
} else {
None
}
}
fn persistent(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{json, Value};
fn create_row_item(value: Value) -> DbItem {
DbItem::Row(value)
}
fn assert_json_value(item: &DbItem, key: &str, expected: Value) {
let DbItem::Row(Value::Object(row)) = item else {
panic!("Expected DbItem::Row");
};
assert_eq!(row[key], expected);
drop(expected);
}
#[test]
fn test_basic_json_string_conversion() {
let mut item = create_row_item(json!({
"json_col": "{\"key\": \"value\"}",
"normal_col": "text"
}));
apply_json_columns(&mut item, &["json_col".to_string()]);
assert_json_value(&item, "json_col", json!({"key": "value"}));
assert_json_value(&item, "normal_col", json!("text"));
}
#[test]
fn test_json_array_conversion() {
let mut item = create_row_item(json!({
"array_col": ["{\"a\": 1}", "{\"b\": 2}"],
"normal_array": ["text"]
}));
apply_json_columns(&mut item, &["array_col".to_string()]);
assert_json_value(&item, "array_col", json!([{"a": 1}, {"b": 2}]));
assert_json_value(&item, "normal_array", json!(["text"]));
}
#[test]
fn test_invalid_json_handling() {
let mut item = create_row_item(json!({
"invalid_json": "{not valid json}",
"normal_col": "text"
}));
apply_json_columns(&mut item, &["invalid_json".to_string()]);
assert_json_value(&item, "invalid_json", json!("{not valid json}"));
assert_json_value(&item, "normal_col", json!("text"));
}
#[test]
fn test_missing_column_handling() {
let mut item = create_row_item(json!({
"existing_col": "text"
}));
apply_json_columns(&mut item, &["missing_col".to_string()]);
assert_json_value(&item, "existing_col", json!("text"));
}
#[test]
fn test_non_row_dbitem_handling() {
let mut item = DbItem::FinishedQuery;
apply_json_columns(&mut item, &["json_col".to_string()]);
assert!(matches!(item, DbItem::FinishedQuery));
}
#[test]
fn test_duplicate_json_column_names() {
let mut item = create_row_item(json!({
"json_col": "{\"key\": \"value\"}",
"normal_col": "text"
}));
apply_json_columns(&mut item, &["json_col".to_string(), "json_col".to_string()]);
assert_json_value(&item, "json_col", json!({"key": "value"}));
assert_json_value(&item, "normal_col", json!("text"));
}
}