use std::borrow::Cow;
use std::str::FromStr;
use sqlparser::ast::FunctionArg;
use crate::webserver::http_request_info::ExecutionContext;
use crate::webserver::single_or_vec::SingleOrVec;
use super::{
execute_queries::DbConn, sql::function_args_to_stmt_params, sql::ParamExtractContext,
sqlpage_functions::functions::SqlPageFunctionName,
};
use anyhow::Context as _;
#[derive(Debug, PartialEq, Eq, Clone)]
pub(crate) enum StmtParam {
Get(String),
Post(String),
PostOrGet(String),
Error(String),
Literal(String),
Null,
Concat(Vec<StmtParam>),
Coalesce(Vec<StmtParam>),
JsonObject(Vec<StmtParam>),
JsonArray(Vec<StmtParam>),
FunctionCall(SqlPageFunctionCall),
}
impl std::fmt::Display for StmtParam {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StmtParam::Get(name) => write!(f, "?{name}"),
StmtParam::Post(name) => write!(f, ":{name}"),
StmtParam::PostOrGet(name) => write!(f, "${name}"),
StmtParam::Literal(x) => write!(f, "'{}'", x.replace('\'', "''")),
StmtParam::Null => write!(f, "NULL"),
StmtParam::Concat(items) => {
write!(f, "CONCAT(")?;
for item in items {
write!(f, "{item}, ")?;
}
write!(f, ")")
}
StmtParam::Coalesce(items) => {
write!(f, "COALESCE(")?;
for item in items {
write!(f, "{item}, ")?;
}
write!(f, ")")
}
StmtParam::JsonObject(items) => {
write!(f, "JSON_OBJECT(")?;
for item in items {
write!(f, "{item}, ")?;
}
write!(f, ")")
}
StmtParam::JsonArray(items) => {
write!(f, "JSON_ARRAY(")?;
for item in items {
write!(f, "{item}, ")?;
}
write!(f, ")")
}
StmtParam::FunctionCall(call) => write!(f, "{call}"),
StmtParam::Error(x) => {
if let Some((i, _)) = x.char_indices().nth(21) {
write!(f, "## {}... ##", &x[..i])
} else {
write!(f, "## {x} ##")
}
}
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SqlPageFunctionCall {
pub function: SqlPageFunctionName,
pub arguments: Vec<StmtParam>,
}
impl SqlPageFunctionCall {
pub fn from_func_call(
func_name: &str,
arguments: &mut [FunctionArg],
ctx: &ParamExtractContext,
) -> anyhow::Result<Self> {
let function = SqlPageFunctionName::from_str(func_name)?;
let arguments = function_args_to_stmt_params(arguments, ctx)?;
Ok(Self {
function,
arguments,
})
}
pub async fn evaluate<'a>(
&self,
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<Option<Cow<'a, str>>> {
let mut params = Vec::with_capacity(self.arguments.len());
for param in &self.arguments {
params.push(Box::pin(extract_req_param(param, request, db_connection)).await?);
}
log::trace!("Starting function call to {self}");
let result = self
.function
.evaluate(request, db_connection, params)
.await?;
log::trace!(
"Function call to {self} returned: {}",
result.as_deref().unwrap_or("NULL")
);
Ok(result)
}
}
impl std::fmt::Display for SqlPageFunctionCall {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}(", self.function)?;
let mut it = self.arguments.iter();
if let Some(x) = it.next() {
write!(f, "{x}")?;
}
for x in it {
write!(f, ", {x}")?;
}
write!(f, ")")
}
}
pub(super) async fn extract_req_param<'a>(
param: &StmtParam,
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<Option<Cow<'a, str>>> {
Ok(match param {
StmtParam::Get(x) => request.url_params.get(x).map(SingleOrVec::as_json_str),
StmtParam::Post(x) => {
if let Some(val) = request.set_variables.borrow().get(x) {
val.as_ref()
.map(|v| Cow::Owned(v.as_json_str().into_owned()))
} else {
request.post_variables.get(x).map(SingleOrVec::as_json_str)
}
}
StmtParam::PostOrGet(x) => {
if let Some(val) = request.set_variables.borrow().get(x) {
val.as_ref()
.map(|v| Cow::Owned(v.as_json_str().into_owned()))
} else {
let url_val = request.url_params.get(x);
if request.post_variables.contains_key(x) {
if url_val.is_some() {
log::warn!(
"Deprecation warning! There is both a URL parameter named '{x}' and a form field named '{x}'. \
SQLPage is using the URL parameter for ${x}. Please use :{x} to reference the form field explicitly."
);
} else {
log::warn!(
"Deprecation warning! ${x} was used to reference a form field value (a POST variable). \
This now uses only URL parameters. Please use :{x} instead."
);
}
}
url_val.map(SingleOrVec::as_json_str)
}
}
StmtParam::Error(x) => anyhow::bail!("{x}"),
StmtParam::Literal(x) => Some(Cow::Owned(x.clone())),
StmtParam::Null => None,
StmtParam::Concat(args) => concat_params(&args[..], request, db_connection).await?,
StmtParam::JsonObject(args) => {
json_object_params(&args[..], request, db_connection).await?
}
StmtParam::JsonArray(args) => json_array_params(&args[..], request, db_connection).await?,
StmtParam::Coalesce(args) => coalesce_params(&args[..], request, db_connection).await?,
StmtParam::FunctionCall(func) => {
func.evaluate(request, db_connection)
.await
.with_context(|| {
format!(
"Error in function call {func}.\nExpected {:#}",
func.function
)
})?
}
})
}
async fn concat_params<'a>(
args: &[StmtParam],
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<Option<Cow<'a, str>>> {
let mut result = String::new();
for arg in args {
let Some(arg) = Box::pin(extract_req_param(arg, request, db_connection)).await? else {
return Ok(None);
};
result.push_str(&arg);
}
Ok(Some(Cow::Owned(result)))
}
async fn coalesce_params<'a>(
args: &[StmtParam],
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<Option<Cow<'a, str>>> {
for arg in args {
if let Some(arg) = Box::pin(extract_req_param(arg, request, db_connection)).await? {
return Ok(Some(arg));
}
}
Ok(None)
}
async fn json_object_params<'a>(
args: &[StmtParam],
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<Option<Cow<'a, str>>> {
use serde::{ser::SerializeMap, Serializer};
let mut result = Vec::new();
let mut ser = serde_json::Serializer::new(&mut result);
let mut map_ser = ser.serialize_map(Some(args.len()))?;
let mut it = args.iter();
while let Some(key) = it.next() {
let key = Box::pin(extract_req_param(key, request, db_connection)).await?;
map_ser.serialize_key(&key)?;
let val = it
.next()
.ok_or_else(|| anyhow::anyhow!("Odd number of arguments in JSON_OBJECT"))?;
match val {
StmtParam::JsonObject(args) => {
let raw_json = Box::pin(json_object_params(args, request, db_connection)).await?;
let obj = cow_to_raw_json(raw_json.as_ref());
map_ser.serialize_value(&obj)?;
}
StmtParam::JsonArray(args) => {
let raw_json = Box::pin(json_array_params(args, request, db_connection)).await?;
let obj = cow_to_raw_json(raw_json.as_ref());
map_ser.serialize_value(&obj)?;
}
val => {
let evaluated = Box::pin(extract_req_param(val, request, db_connection)).await?;
map_ser.serialize_value(&evaluated)?;
}
}
}
map_ser.end()?;
Ok(Some(Cow::Owned(String::from_utf8(result)?)))
}
async fn json_array_params<'a>(
args: &[StmtParam],
request: &'a ExecutionContext,
db_connection: &mut DbConn,
) -> anyhow::Result<Option<Cow<'a, str>>> {
use serde::{ser::SerializeSeq, Serializer};
let mut result = Vec::new();
let mut ser = serde_json::Serializer::new(&mut result);
let mut seq_ser = ser.serialize_seq(Some(args.len()))?;
for element in args {
match element {
StmtParam::JsonObject(args) => {
let raw_json = json_object_params(args, request, db_connection).await?;
let obj = cow_to_raw_json(raw_json.as_ref());
seq_ser.serialize_element(&obj)?;
}
StmtParam::JsonArray(args) => {
let raw_json = Box::pin(json_array_params(args, request, db_connection)).await?;
let obj = cow_to_raw_json(raw_json.as_ref());
seq_ser.serialize_element(&obj)?;
}
element => {
let evaluated =
Box::pin(extract_req_param(element, request, db_connection)).await?;
seq_ser.serialize_element(&evaluated)?;
}
}
}
seq_ser.end()?;
Ok(Some(Cow::Owned(String::from_utf8(result)?)))
}
fn cow_to_raw_json<'a>(
raw_json: Option<&'a impl AsRef<str>>,
) -> Option<&'a serde_json::value::RawValue> {
raw_json
.map(AsRef::as_ref)
.map(serde_json::from_str::<&'a serde_json::value::RawValue>)
.map(Result::unwrap)
}