use fraiseql_error::{FraiseQLError, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SqlClassification {
ReadOnly,
Rejected(RejectionReason),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RejectionReason {
WriteStatement(String),
DdlStatement(String),
WritableCte,
PrivilegeEscalation,
ProceduralBlock,
ProcedureCall,
CopyStatement,
ExplainAnalyze,
Unknown(String),
}
impl std::fmt::Display for RejectionReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WriteStatement(stmt) => write!(f, "write statement not allowed: {}", stmt),
Self::DdlStatement(stmt) => write!(f, "DDL statement not allowed: {}", stmt),
Self::WritableCte => write!(f, "CTE with writable statement not allowed"),
Self::PrivilegeEscalation => write!(f, "privilege escalation not allowed"),
Self::ProceduralBlock => write!(f, "procedural block not allowed"),
Self::ProcedureCall => write!(f, "procedure call not allowed"),
Self::CopyStatement => write!(f, "COPY statement not allowed"),
Self::ExplainAnalyze => {
write!(f, "EXPLAIN ANALYZE not allowed (executes the statement)")
},
Self::Unknown(stmt) => write!(f, "unknown or disallowed statement: {}", stmt),
}
}
}
pub fn classify_sql(sql: &str) -> Result<SqlClassification> {
use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
let dialect = PostgreSqlDialect {};
let statements = Parser::parse_sql(&dialect, sql).map_err(|e| FraiseQLError::Validation {
message: format!("invalid SQL: {}", e),
path: None,
})?;
for stmt in statements {
let classification = classify_statement(&stmt)?;
match classification {
SqlClassification::ReadOnly => {},
SqlClassification::Rejected(reason) => return Ok(SqlClassification::Rejected(reason)),
}
}
Ok(SqlClassification::ReadOnly)
}
fn classify_statement(stmt: &sqlparser::ast::Statement) -> Result<SqlClassification> {
use sqlparser::ast::Statement;
match stmt {
Statement::Query(_) => Ok(SqlClassification::ReadOnly),
Statement::Explain { analyze, .. } => {
if *analyze {
Ok(SqlClassification::Rejected(RejectionReason::ExplainAnalyze))
} else {
Ok(SqlClassification::ReadOnly)
}
},
Statement::Insert { .. } => Ok(SqlClassification::Rejected(
RejectionReason::WriteStatement("INSERT".to_string()),
)),
Statement::Update { .. } => Ok(SqlClassification::Rejected(
RejectionReason::WriteStatement("UPDATE".to_string()),
)),
Statement::Delete { .. } => Ok(SqlClassification::Rejected(
RejectionReason::WriteStatement("DELETE".to_string()),
)),
Statement::CreateTable { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE TABLE".to_string()),
)),
Statement::CreateView { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE VIEW".to_string()),
)),
Statement::CreateIndex { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE INDEX".to_string()),
)),
Statement::CreateSchema { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE SCHEMA".to_string()),
)),
Statement::CreateRole { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE ROLE".to_string()),
)),
Statement::CreateExtension { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE EXTENSION".to_string()),
)),
Statement::CreateSecret { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE SECRET".to_string()),
)),
Statement::CreateVirtualTable { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("CREATE VIRTUAL TABLE".to_string()),
)),
Statement::Drop { .. } => {
Ok(SqlClassification::Rejected(RejectionReason::DdlStatement("DROP".to_string())))
},
Statement::DropFunction { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("DROP FUNCTION".to_string()),
)),
Statement::DropSecret { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("DROP SECRET".to_string()),
)),
Statement::AlterTable { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("ALTER TABLE".to_string()),
)),
Statement::AlterIndex { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("ALTER INDEX".to_string()),
)),
Statement::AlterView { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("ALTER VIEW".to_string()),
)),
Statement::AlterRole { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("ALTER ROLE".to_string()),
)),
Statement::Truncate { .. } => Ok(SqlClassification::Rejected(
RejectionReason::DdlStatement("TRUNCATE".to_string()),
)),
Statement::Set(_) => Ok(SqlClassification::Rejected(RejectionReason::PrivilegeEscalation)),
Statement::Call(_) => Ok(SqlClassification::Rejected(RejectionReason::ProcedureCall)),
Statement::Copy { .. } | Statement::CopyIntoSnowflake { .. } => {
Ok(SqlClassification::Rejected(RejectionReason::CopyStatement))
},
Statement::Analyze { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"ANALYZE statement not allowed".to_string(),
))),
Statement::Install { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"INSTALL not allowed".to_string(),
))),
Statement::Load { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"LOAD not allowed".to_string(),
))),
Statement::Directory { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"DIRECTORY not allowed".to_string(),
))),
Statement::AttachDatabase { .. } => Ok(SqlClassification::Rejected(
RejectionReason::Unknown("ATTACH DATABASE not allowed".to_string()),
)),
Statement::AttachDuckDBDatabase { .. } => Ok(SqlClassification::Rejected(
RejectionReason::Unknown("ATTACH DUCKDB DATABASE not allowed".to_string()),
)),
Statement::DetachDuckDBDatabase { .. } => Ok(SqlClassification::Rejected(
RejectionReason::Unknown("DETACH DUCKDB DATABASE not allowed".to_string()),
)),
Statement::Declare { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"DECLARE not allowed".to_string(),
))),
Statement::Close { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"CLOSE not allowed".to_string(),
))),
Statement::Fetch { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"FETCH not allowed".to_string(),
))),
Statement::Flush { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"FLUSH not allowed".to_string(),
))),
Statement::Discard { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"DISCARD not allowed".to_string(),
))),
Statement::StartTransaction { .. } => Ok(SqlClassification::Rejected(
RejectionReason::Unknown("START TRANSACTION not allowed".to_string()),
)),
Statement::Msck { .. } => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
"MSCK not allowed".to_string(),
))),
_ => Ok(SqlClassification::Rejected(RejectionReason::Unknown(
format!("{:?}", stmt).chars().take(50).collect(),
))),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] mod tests;