use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RiskLevel {
Low,
Medium,
High,
Critical,
}
impl RiskLevel {
pub fn requires_approval(&self, auto_approve_levels: &[RiskLevel]) -> bool {
!auto_approve_levels.contains(self)
}
}
impl std::fmt::Display for RiskLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RiskLevel::Low => write!(f, "LOW"),
RiskLevel::Medium => write!(f, "MEDIUM"),
RiskLevel::High => write!(f, "HIGH"),
RiskLevel::Critical => write!(f, "CRITICAL"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CodeType {
GraphQLQuery,
GraphQLMutation,
SqlQuery,
SqlMutation,
RestGet,
RestMutation,
Workflow,
}
impl CodeType {
pub fn is_read_only(&self) -> bool {
matches!(
self,
CodeType::GraphQLQuery | CodeType::SqlQuery | CodeType::RestGet
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum UnifiedAction {
Read,
Write,
Delete,
Admin,
}
impl UnifiedAction {
pub fn from_graphql(operation: &str, mutation_name: Option<&str>) -> Self {
match operation.to_lowercase().as_str() {
"query" => Self::Read,
"mutation" => {
if let Some(name) = mutation_name {
let lower = name.to_lowercase();
if lower.starts_with("delete")
|| lower.starts_with("remove")
|| lower.starts_with("purge")
{
return Self::Delete;
}
}
Self::Write
},
_ => Self::Read,
}
}
pub fn from_http_method(method: &str) -> Self {
match method.to_uppercase().as_str() {
"GET" | "HEAD" | "OPTIONS" => Self::Read,
"POST" | "PUT" | "PATCH" => Self::Write,
"DELETE" => Self::Delete,
_ => Self::Read,
}
}
pub fn from_sql(statement_type: &str) -> Self {
match statement_type.to_uppercase().as_str() {
"SELECT" => Self::Read,
"INSERT" | "UPDATE" | "MERGE" => Self::Write,
"DELETE" | "TRUNCATE" => Self::Delete,
"CREATE" | "ALTER" | "DROP" | "GRANT" | "REVOKE" => Self::Admin,
_ => Self::Read,
}
}
pub fn resolve(
inferred: Self,
action_tags: &HashMap<String, String>,
operation_name: &str,
) -> Self {
if let Some(tag) = action_tags.get(operation_name) {
match tag.to_lowercase().as_str() {
"read" => Self::Read,
"write" => Self::Write,
"delete" => Self::Delete,
"admin" => Self::Admin,
_ => inferred,
}
} else {
inferred
}
}
}
impl std::fmt::Display for UnifiedAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Read => write!(f, "Read"),
Self::Write => write!(f, "Write"),
Self::Delete => write!(f, "Delete"),
Self::Admin => write!(f, "Admin"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResult {
pub is_valid: bool,
pub explanation: String,
pub risk_level: RiskLevel,
pub approval_token: Option<String>,
pub metadata: ValidationMetadata,
pub violations: Vec<PolicyViolation>,
pub warnings: Vec<String>,
}
impl ValidationResult {
pub fn success(
explanation: String,
risk_level: RiskLevel,
approval_token: String,
metadata: ValidationMetadata,
) -> Self {
Self {
is_valid: true,
explanation,
risk_level,
approval_token: Some(approval_token),
metadata,
violations: vec![],
warnings: vec![],
}
}
pub fn failure(violations: Vec<PolicyViolation>, metadata: ValidationMetadata) -> Self {
Self {
is_valid: false,
explanation: String::new(),
risk_level: RiskLevel::Critical,
approval_token: None,
metadata,
violations,
warnings: vec![],
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ValidationMetadata {
pub is_read_only: bool,
pub estimated_rows: Option<u64>,
pub accessed_types: Vec<String>,
pub accessed_fields: Vec<String>,
pub has_aggregation: bool,
pub code_type: Option<CodeType>,
pub action: Option<UnifiedAction>,
pub validation_time_ms: u64,
}
#[derive(Debug, Clone, Default)]
pub struct SecurityAnalysis {
pub is_read_only: bool,
pub tables_accessed: HashSet<String>,
pub fields_accessed: HashSet<String>,
pub has_aggregation: bool,
pub has_subqueries: bool,
pub estimated_complexity: Complexity,
pub potential_issues: Vec<SecurityIssue>,
pub estimated_rows: Option<u64>,
}
impl SecurityAnalysis {
pub fn assess_risk(&self) -> RiskLevel {
if self.potential_issues.iter().any(|i| i.is_critical()) {
return RiskLevel::Critical;
}
if !self.is_read_only {
if let Some(rows) = self.estimated_rows {
if rows > 100 {
return RiskLevel::High;
}
}
if matches!(self.estimated_complexity, Complexity::High) {
return RiskLevel::High;
}
return RiskLevel::Medium;
}
if self.potential_issues.iter().any(|i| i.is_sensitive()) {
return RiskLevel::Medium;
}
if matches!(self.estimated_complexity, Complexity::High) {
return RiskLevel::Medium;
}
RiskLevel::Low
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Complexity {
#[default]
Low,
Medium,
High,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityIssue {
pub issue_type: SecurityIssueType,
pub message: String,
pub location: Option<CodeLocation>,
}
impl SecurityIssue {
pub fn new(issue_type: SecurityIssueType, message: impl Into<String>) -> Self {
Self {
issue_type,
message: message.into(),
location: None,
}
}
pub fn with_location(mut self, location: CodeLocation) -> Self {
self.location = Some(location);
self
}
pub fn is_critical(&self) -> bool {
matches!(self.issue_type, SecurityIssueType::PotentialInjection)
}
pub fn is_sensitive(&self) -> bool {
matches!(self.issue_type, SecurityIssueType::SensitiveFields)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SecurityIssueType {
UnboundedQuery,
SensitiveFields,
CrossTypeJoin,
DynamicTableName,
PotentialInjection,
DeepNesting,
HighComplexity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeLocation {
pub line: u32,
pub column: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyViolation {
pub policy_name: String,
pub rule: String,
pub location: Option<CodeLocation>,
pub message: String,
pub suggestion: Option<String>,
}
impl PolicyViolation {
pub fn new(
policy_name: impl Into<String>,
rule: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self {
policy_name: policy_name.into(),
rule: rule.into(),
location: None,
message: message.into(),
suggestion: None,
}
}
pub fn with_location(mut self, location: CodeLocation) -> Self {
self.location = Some(location);
self
}
pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
self.suggestion = Some(suggestion.into());
self
}
}
#[derive(Debug, thiserror::Error)]
pub enum ValidationError {
#[error("Parse error at line {line}, column {column}: {message}")]
ParseError {
message: String,
line: u32,
column: u32,
},
#[error("Schema error for field '{field}': {message}")]
SchemaError { message: String, field: String },
#[error("Permission denied: {message} (requires: {required_permission})")]
PermissionError {
message: String,
required_permission: String,
},
#[error("Security error: {message}")]
SecurityError {
message: String,
issue: SecurityIssueType,
},
#[error("Policy violation: {0}")]
PolicyViolation(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Internal error: {0}")]
InternalError(String),
}
#[derive(Debug, thiserror::Error)]
pub enum ExecutionError {
#[error("Token has expired — request a new approval token via validate_code")]
TokenExpired,
#[error("Token signature is invalid: {0}")]
TokenInvalid(String),
#[error("Code hash mismatch — the code sent to execute_code does not match the code that was validated (expected {expected_hash}, got {actual_hash}). Ensure the code string is identical to what was sent to validate_code")]
CodeMismatch {
expected_hash: String,
actual_hash: String,
},
#[error("Context has changed since validation (schema or permissions updated)")]
ContextChanged,
#[error("User mismatch: token was issued for a different user")]
UserMismatch,
#[error("Backend error: {0}")]
BackendError(String),
#[error("Execution timed out after {0} seconds")]
Timeout(u32),
#[error("Validation required before execution")]
ValidationRequired,
#[error("Runtime error: {message}")]
RuntimeError { message: String },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CodeLanguage {
GraphQL,
JavaScript,
Sql,
Mcp,
}
impl CodeLanguage {
pub fn from_attr(s: &str) -> Option<Self> {
match s {
"graphql" => Some(Self::GraphQL),
"javascript" | "js" => Some(Self::JavaScript),
"sql" => Some(Self::Sql),
"mcp" => Some(Self::Mcp),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::GraphQL => "graphql",
Self::JavaScript => "javascript",
Self::Sql => "sql",
Self::Mcp => "mcp",
}
}
pub fn required_feature(&self) -> Option<&'static str> {
match self {
Self::GraphQL => None,
Self::JavaScript => Some("openapi-code-mode"),
Self::Sql => Some("sql-code-mode"),
Self::Mcp => Some("mcp-code-mode"),
}
}
}
impl std::fmt::Display for CodeLanguage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, thiserror::Error)]
pub enum TokenError {
#[error("HMAC token secret must be at least {minimum} bytes, got {actual}")]
SecretTooShort {
minimum: usize,
actual: usize,
},
}