Skip to main content

pmcp_code_mode/
types.rs

1//! Core types for Code Mode.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::collections::HashSet;
6
7/// Risk level assessed for a query or workflow.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum RiskLevel {
11    /// Read-only, small result set, no sensitive data
12    Low,
13    /// Read-only with sensitive data, or small mutations
14    Medium,
15    /// Large mutations, cross-table operations
16    High,
17    /// Schema changes, bulk deletes, admin operations
18    Critical,
19}
20
21impl RiskLevel {
22    /// Returns true if this risk level requires human approval.
23    pub fn requires_approval(&self, auto_approve_levels: &[RiskLevel]) -> bool {
24        !auto_approve_levels.contains(self)
25    }
26}
27
28impl std::fmt::Display for RiskLevel {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            RiskLevel::Low => write!(f, "LOW"),
32            RiskLevel::Medium => write!(f, "MEDIUM"),
33            RiskLevel::High => write!(f, "HIGH"),
34            RiskLevel::Critical => write!(f, "CRITICAL"),
35        }
36    }
37}
38
39/// Type of code being validated/executed.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41#[serde(rename_all = "lowercase")]
42pub enum CodeType {
43    /// GraphQL query (read-only)
44    GraphQLQuery,
45    /// GraphQL mutation (write)
46    GraphQLMutation,
47    /// SQL SELECT query
48    SqlQuery,
49    /// SQL INSERT/UPDATE/DELETE
50    SqlMutation,
51    /// REST GET request
52    RestGet,
53    /// REST POST/PUT/DELETE request
54    RestMutation,
55    /// Multi-tool workflow
56    Workflow,
57}
58
59impl CodeType {
60    /// Returns true if this code type is read-only.
61    pub fn is_read_only(&self) -> bool {
62        matches!(
63            self,
64            CodeType::GraphQLQuery | CodeType::SqlQuery | CodeType::RestGet
65        )
66    }
67}
68
69/// Unified action model that maps to business permissions.
70/// Works consistently across GraphQL, OpenAPI, and SQL servers.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72#[serde(rename_all = "lowercase")]
73pub enum UnifiedAction {
74    /// Retrieve data without modification (Query, GET, SELECT)
75    Read,
76    /// Create or modify data (Mutation create/update, POST/PUT/PATCH, INSERT/UPDATE)
77    Write,
78    /// Remove data (Mutation delete, DELETE, DELETE/TRUNCATE)
79    Delete,
80    /// Schema changes, permissions, admin operations (DDL: CREATE/ALTER/DROP)
81    Admin,
82}
83
84impl UnifiedAction {
85    /// Infer action from GraphQL operation type.
86    pub fn from_graphql(operation: &str, mutation_name: Option<&str>) -> Self {
87        match operation.to_lowercase().as_str() {
88            "query" => Self::Read,
89            "mutation" => {
90                if let Some(name) = mutation_name {
91                    let lower = name.to_lowercase();
92                    if lower.starts_with("delete")
93                        || lower.starts_with("remove")
94                        || lower.starts_with("purge")
95                    {
96                        return Self::Delete;
97                    }
98                }
99                Self::Write
100            }
101            _ => Self::Read,
102        }
103    }
104
105    /// Infer action from HTTP method.
106    pub fn from_http_method(method: &str) -> Self {
107        match method.to_uppercase().as_str() {
108            "GET" | "HEAD" | "OPTIONS" => Self::Read,
109            "POST" | "PUT" | "PATCH" => Self::Write,
110            "DELETE" => Self::Delete,
111            _ => Self::Read,
112        }
113    }
114
115    /// Infer action from SQL statement type.
116    pub fn from_sql(statement_type: &str) -> Self {
117        match statement_type.to_uppercase().as_str() {
118            "SELECT" => Self::Read,
119            "INSERT" | "UPDATE" | "MERGE" => Self::Write,
120            "DELETE" | "TRUNCATE" => Self::Delete,
121            "CREATE" | "ALTER" | "DROP" | "GRANT" | "REVOKE" => Self::Admin,
122            _ => Self::Read,
123        }
124    }
125
126    /// Resolve action with optional tag override.
127    pub fn resolve(
128        inferred: Self,
129        action_tags: &HashMap<String, String>,
130        operation_name: &str,
131    ) -> Self {
132        if let Some(tag) = action_tags.get(operation_name) {
133            match tag.to_lowercase().as_str() {
134                "read" => Self::Read,
135                "write" => Self::Write,
136                "delete" => Self::Delete,
137                "admin" => Self::Admin,
138                _ => inferred,
139            }
140        } else {
141            inferred
142        }
143    }
144}
145
146impl std::fmt::Display for UnifiedAction {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        match self {
149            Self::Read => write!(f, "Read"),
150            Self::Write => write!(f, "Write"),
151            Self::Delete => write!(f, "Delete"),
152            Self::Admin => write!(f, "Admin"),
153        }
154    }
155}
156
157/// Result of validating code through the pipeline.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ValidationResult {
160    /// Whether the code is valid and can be executed
161    pub is_valid: bool,
162
163    /// Human-readable explanation of what the code does
164    pub explanation: String,
165
166    /// Assessed risk level
167    pub risk_level: RiskLevel,
168
169    /// Signed approval token (if valid)
170    pub approval_token: Option<String>,
171
172    /// Detailed metadata about the validation
173    pub metadata: ValidationMetadata,
174
175    /// Any policy violations found
176    pub violations: Vec<PolicyViolation>,
177
178    /// Warnings (non-blocking)
179    pub warnings: Vec<String>,
180}
181
182impl ValidationResult {
183    /// Create a successful validation result.
184    pub fn success(
185        explanation: String,
186        risk_level: RiskLevel,
187        approval_token: String,
188        metadata: ValidationMetadata,
189    ) -> Self {
190        Self {
191            is_valid: true,
192            explanation,
193            risk_level,
194            approval_token: Some(approval_token),
195            metadata,
196            violations: vec![],
197            warnings: vec![],
198        }
199    }
200
201    /// Create a failed validation result.
202    pub fn failure(violations: Vec<PolicyViolation>, metadata: ValidationMetadata) -> Self {
203        Self {
204            is_valid: false,
205            explanation: String::new(),
206            risk_level: RiskLevel::Critical,
207            approval_token: None,
208            metadata,
209            violations,
210            warnings: vec![],
211        }
212    }
213}
214
215/// Detailed metadata about a validation.
216#[derive(Debug, Clone, Default, Serialize, Deserialize)]
217pub struct ValidationMetadata {
218    /// Whether the code is read-only
219    pub is_read_only: bool,
220
221    /// Estimated number of rows that will be returned/affected
222    pub estimated_rows: Option<u64>,
223
224    /// Tables/types accessed by the code
225    pub accessed_types: Vec<String>,
226
227    /// Fields accessed by the code
228    pub accessed_fields: Vec<String>,
229
230    /// Whether the query has aggregations
231    pub has_aggregation: bool,
232
233    /// Code type detected
234    pub code_type: Option<CodeType>,
235
236    /// Unified action determined for this operation
237    pub action: Option<UnifiedAction>,
238
239    /// Time taken to validate (milliseconds)
240    pub validation_time_ms: u64,
241}
242
243/// Security analysis of code.
244#[derive(Debug, Clone, Default)]
245pub struct SecurityAnalysis {
246    /// Whether the code is read-only
247    pub is_read_only: bool,
248
249    /// Tables/types accessed
250    pub tables_accessed: HashSet<String>,
251
252    /// Fields accessed
253    pub fields_accessed: HashSet<String>,
254
255    /// Whether the query has aggregations
256    pub has_aggregation: bool,
257
258    /// Whether the query has subqueries/nested operations
259    pub has_subqueries: bool,
260
261    /// Estimated complexity
262    pub estimated_complexity: Complexity,
263
264    /// Potential security issues found
265    pub potential_issues: Vec<SecurityIssue>,
266
267    /// Estimated number of rows
268    pub estimated_rows: Option<u64>,
269}
270
271impl SecurityAnalysis {
272    /// Assess the risk level based on the security analysis.
273    pub fn assess_risk(&self) -> RiskLevel {
274        // Critical: Has critical security issues
275        if self.potential_issues.iter().any(|i| i.is_critical()) {
276            return RiskLevel::Critical;
277        }
278
279        // High: Mutations with high complexity or affecting many rows
280        if !self.is_read_only {
281            if let Some(rows) = self.estimated_rows {
282                if rows > 100 {
283                    return RiskLevel::High;
284                }
285            }
286            if matches!(self.estimated_complexity, Complexity::High) {
287                return RiskLevel::High;
288            }
289            return RiskLevel::Medium;
290        }
291
292        // Medium: Read-only but has sensitive issues or high complexity
293        if self.potential_issues.iter().any(|i| i.is_sensitive()) {
294            return RiskLevel::Medium;
295        }
296        if matches!(self.estimated_complexity, Complexity::High) {
297            return RiskLevel::Medium;
298        }
299
300        // Low: Simple read-only queries
301        RiskLevel::Low
302    }
303}
304
305/// Estimated complexity of a query.
306#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
307pub enum Complexity {
308    #[default]
309    Low,
310    Medium,
311    High,
312}
313
314/// Potential security issues found during analysis.
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct SecurityIssue {
317    /// Issue type
318    pub issue_type: SecurityIssueType,
319    /// Human-readable message
320    pub message: String,
321    /// Location in code (if applicable)
322    pub location: Option<CodeLocation>,
323}
324
325impl SecurityIssue {
326    pub fn new(issue_type: SecurityIssueType, message: impl Into<String>) -> Self {
327        Self {
328            issue_type,
329            message: message.into(),
330            location: None,
331        }
332    }
333
334    pub fn with_location(mut self, location: CodeLocation) -> Self {
335        self.location = Some(location);
336        self
337    }
338
339    /// Returns true if this is a critical issue that should block execution.
340    /// Note: DynamicTableName is NOT critical for REST APIs - it's a common pattern
341    /// for discovery-then-use workflows (e.g., search for station ID, then use in path).
342    pub fn is_critical(&self) -> bool {
343        matches!(self.issue_type, SecurityIssueType::PotentialInjection)
344    }
345
346    /// Returns true if this issue involves sensitive data.
347    pub fn is_sensitive(&self) -> bool {
348        matches!(self.issue_type, SecurityIssueType::SensitiveFields)
349    }
350}
351
352/// Types of security issues.
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
354#[serde(rename_all = "snake_case")]
355pub enum SecurityIssueType {
356    /// Query without LIMIT/pagination
357    UnboundedQuery,
358    /// Accessing PII or sensitive columns
359    SensitiveFields,
360    /// Joining across security boundaries
361    CrossTypeJoin,
362    /// Dynamic table/type name (potential injection)
363    DynamicTableName,
364    /// Potential injection vulnerability
365    PotentialInjection,
366    /// Deeply nested query
367    DeepNesting,
368    /// High complexity query
369    HighComplexity,
370}
371
372/// Location in source code.
373#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct CodeLocation {
375    pub line: u32,
376    pub column: u32,
377}
378
379/// A policy violation found during validation.
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct PolicyViolation {
382    /// Name of the policy that was violated
383    pub policy_name: String,
384    /// Specific rule within the policy
385    pub rule: String,
386    /// Location in the code where the violation occurred
387    pub location: Option<CodeLocation>,
388    /// Human-readable message explaining the violation
389    pub message: String,
390    /// Suggestion for how to fix the violation
391    pub suggestion: Option<String>,
392}
393
394impl PolicyViolation {
395    pub fn new(
396        policy_name: impl Into<String>,
397        rule: impl Into<String>,
398        message: impl Into<String>,
399    ) -> Self {
400        Self {
401            policy_name: policy_name.into(),
402            rule: rule.into(),
403            location: None,
404            message: message.into(),
405            suggestion: None,
406        }
407    }
408
409    pub fn with_location(mut self, location: CodeLocation) -> Self {
410        self.location = Some(location);
411        self
412    }
413
414    pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
415        self.suggestion = Some(suggestion.into());
416        self
417    }
418}
419
420/// Errors that can occur during validation.
421#[derive(Debug, thiserror::Error)]
422pub enum ValidationError {
423    #[error("Parse error at line {line}, column {column}: {message}")]
424    ParseError {
425        message: String,
426        line: u32,
427        column: u32,
428    },
429
430    #[error("Schema error for field '{field}': {message}")]
431    SchemaError { message: String, field: String },
432
433    #[error("Permission denied: {message} (requires: {required_permission})")]
434    PermissionError {
435        message: String,
436        required_permission: String,
437    },
438
439    #[error("Security error: {message}")]
440    SecurityError {
441        message: String,
442        issue: SecurityIssueType,
443    },
444
445    #[error("Policy violation: {0}")]
446    PolicyViolation(String),
447
448    #[error("Configuration error: {0}")]
449    ConfigError(String),
450
451    #[error("Internal error: {0}")]
452    InternalError(String),
453}
454
455/// Errors that can occur during execution.
456#[derive(Debug, thiserror::Error)]
457pub enum ExecutionError {
458    #[error("Token has expired — request a new approval token via validate_code")]
459    TokenExpired,
460
461    #[error("Token signature is invalid: {0}")]
462    TokenInvalid(String),
463
464    #[error("Code hash mismatch — the code sent to execute_code does not match the code that was validated (expected {expected_hash}, got {actual_hash}). Ensure the code string is identical to what was sent to validate_code")]
465    CodeMismatch {
466        expected_hash: String,
467        actual_hash: String,
468    },
469
470    #[error("Context has changed since validation (schema or permissions updated)")]
471    ContextChanged,
472
473    #[error("User mismatch: token was issued for a different user")]
474    UserMismatch,
475
476    #[error("Backend error: {0}")]
477    BackendError(String),
478
479    #[error("Execution timed out after {0} seconds")]
480    Timeout(u32),
481
482    #[error("Validation required before execution")]
483    ValidationRequired,
484
485    #[error("Runtime error: {message}")]
486    RuntimeError { message: String },
487
488    /// Loop continue signal (not a real error, used for control flow)
489    #[error("Loop continue")]
490    LoopContinue,
491
492    /// Loop break signal (not a real error, used for control flow)
493    #[error("Loop break")]
494    LoopBreak,
495}