use super::*;
pub(crate) fn is_mutating_action(action: &str) -> bool {
matches!(
action,
"CreateTable"
| "DeleteTable"
| "UpdateTable"
| "PutItem"
| "DeleteItem"
| "UpdateItem"
| "BatchWriteItem"
| "TagResource"
| "UntagResource"
| "TransactWriteItems"
| "UpdateTimeToLive"
| "PutResourcePolicy"
| "DeleteResourcePolicy"
| "CreateBackup"
| "DeleteBackup"
| "RestoreTableFromBackup"
| "RestoreTableToPointInTime"
| "UpdateContinuousBackups"
| "CreateGlobalTable"
| "UpdateGlobalTable"
| "UpdateGlobalTableSettings"
| "UpdateTableReplicaAutoScaling"
| "EnableKinesisStreamingDestination"
| "DisableKinesisStreamingDestination"
| "UpdateKinesisStreamingDestination"
| "UpdateContributorInsights"
| "ExportTableToPointInTime"
| "ImportTable"
)
}
fn partiql_is_write(stmt: &str) -> bool {
let trimmed = stmt.trim_start();
let trimmed = strip_sql_comments(trimmed);
let kw = trimmed
.split(|c: char| c.is_whitespace())
.next()
.unwrap_or("")
.to_ascii_uppercase();
matches!(
kw.as_str(),
"INSERT" | "UPDATE" | "DELETE" | "UPSERT" | "MERGE"
)
}
fn strip_sql_comments(s: &str) -> &str {
let s = s.trim_start();
if let Some(rest) = s.strip_prefix("--") {
if let Some(nl) = rest.find('\n') {
return strip_sql_comments(&rest[nl + 1..]);
}
return "";
}
if let Some(rest) = s.strip_prefix("/*") {
if let Some(end) = rest.find("*/") {
return strip_sql_comments(&rest[end + 2..]);
}
return "";
}
s
}
pub(crate) fn is_mutating_request(action: &str, body: &Value) -> bool {
if is_mutating_action(action) {
return true;
}
match action {
"ExecuteStatement" => body
.get("Statement")
.and_then(Value::as_str)
.map(partiql_is_write)
.unwrap_or(true),
"ExecuteTransaction" => body
.get("TransactStatements")
.and_then(Value::as_array)
.map(|arr| {
arr.iter().any(|s| {
s.get("Statement")
.and_then(Value::as_str)
.map(partiql_is_write)
.unwrap_or(true)
})
})
.unwrap_or(true),
"BatchExecuteStatement" => body
.get("Statements")
.and_then(Value::as_array)
.map(|arr| {
arr.iter().any(|s| {
s.get("Statement")
.and_then(Value::as_str)
.map(partiql_is_write)
.unwrap_or(true)
})
})
.unwrap_or(true),
_ => false,
}
}
pub(crate) fn require_str<'a>(body: &'a Value, field: &str) -> Result<&'a str, AwsServiceError> {
require_str_with_code(body, field, "ValidationException")
}
pub(crate) fn require_str_with_code<'a>(
body: &'a Value,
field: &str,
code: &str,
) -> Result<&'a str, AwsServiceError> {
body[field].as_str().ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
code,
format!("{field} is required"),
)
})
}
pub(crate) fn require_object(
body: &Value,
field: &str,
) -> Result<HashMap<String, AttributeValue>, AwsServiceError> {
let obj = body[field].as_object().ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationException",
format!("{field} is required"),
)
})?;
Ok(obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
}