Skip to main content

pmcp_code_mode/policy/
cedar.rs

1//! Local Cedar policy evaluator.
2//!
3//! Provides in-process Cedar policy evaluation using the `cedar-policy` crate.
4//! This enables local agent projects to get real Cedar policy enforcement
5//! without an AWS account.
6
7use super::types::{AuthorizationDecision, OperationEntity, ServerConfigEntity};
8use super::PolicyEvaluationError;
9use cedar_policy::{
10    Authorizer, Context, Entities, Entity, EntityId, EntityTypeName, EntityUid, PolicySet, Request,
11    Schema,
12};
13use std::collections::{HashMap, HashSet};
14use std::str::FromStr;
15
16/// Configuration for the local Cedar policy evaluator.
17#[derive(Debug, Clone)]
18pub struct CedarPolicyConfig {
19    /// Cedar schema in JSON format
20    pub schema_json: serde_json::Value,
21    /// Cedar policies (id, description, policy_text)
22    pub policies: Vec<(String, String, String)>,
23}
24
25/// Local Cedar policy evaluator.
26///
27/// Uses the `cedar-policy` crate for in-process policy evaluation.
28pub struct CedarPolicyEvaluator {
29    authorizer: Authorizer,
30    policy_set: PolicySet,
31    schema: Schema,
32}
33
34impl CedarPolicyEvaluator {
35    /// Create a new evaluator from config.
36    pub fn new(config: CedarPolicyConfig) -> Result<Self, PolicyEvaluationError> {
37        let schema_json = serde_json::to_string(&config.schema_json).map_err(|e| {
38            PolicyEvaluationError::ConfigError(format!("Invalid schema JSON: {}", e))
39        })?;
40
41        let schema = Schema::from_json_str(&schema_json).map_err(|e| {
42            PolicyEvaluationError::ConfigError(format!("Invalid Cedar schema: {}", e))
43        })?;
44
45        let mut policy_set = PolicySet::new();
46        for (id, _description, policy_text) in &config.policies {
47            let policy = cedar_policy::Policy::parse(
48                Some(cedar_policy::PolicyId::from_str(id).unwrap()),
49                policy_text,
50            )
51            .map_err(|e| {
52                PolicyEvaluationError::ConfigError(format!("Invalid policy '{}': {}", id, e))
53            })?;
54            policy_set.add(policy).map_err(|e| {
55                PolicyEvaluationError::ConfigError(format!("Duplicate policy '{}': {}", id, e))
56            })?;
57        }
58
59        Ok(Self {
60            authorizer: Authorizer::new(),
61            policy_set,
62            schema,
63        })
64    }
65
66    /// Create a default evaluator for GraphQL Code Mode using built-in schemas and policies.
67    pub fn graphql_default() -> Result<Self, PolicyEvaluationError> {
68        let schema_json = super::types::get_code_mode_schema_json();
69        let baseline = super::types::get_baseline_policies();
70
71        let policies = baseline
72            .into_iter()
73            .map(|(id, desc, text)| (id.to_string(), desc.to_string(), text.to_string()))
74            .collect();
75
76        Self::new(CedarPolicyConfig {
77            schema_json,
78            policies,
79        })
80    }
81
82    fn build_operation_entity(&self, operation: &OperationEntity) -> Entity {
83        let uid = EntityUid::from_type_name_and_id(
84            EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
85            EntityId::from_str(&operation.id).expect("valid entity id"),
86        );
87
88        let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
89            HashMap::with_capacity(11);
90
91        attrs.insert(
92            "operationType".to_string(),
93            cedar_policy::RestrictedExpression::new_string(operation.operation_type.clone()),
94        );
95        attrs.insert(
96            "operationName".to_string(),
97            cedar_policy::RestrictedExpression::new_string(operation.operation_name.clone()),
98        );
99        attrs.insert(
100            "depth".to_string(),
101            cedar_policy::RestrictedExpression::new_long(operation.depth as i64),
102        );
103        attrs.insert(
104            "fieldCount".to_string(),
105            cedar_policy::RestrictedExpression::new_long(operation.field_count as i64),
106        );
107        attrs.insert(
108            "estimatedCost".to_string(),
109            cedar_policy::RestrictedExpression::new_long(operation.estimated_cost as i64),
110        );
111        attrs.insert(
112            "hasIntrospection".to_string(),
113            cedar_policy::RestrictedExpression::new_bool(operation.has_introspection),
114        );
115        attrs.insert(
116            "accessesSensitiveData".to_string(),
117            cedar_policy::RestrictedExpression::new_bool(operation.accesses_sensitive_data),
118        );
119
120        // Set attributes
121        attrs.insert(
122            "rootFields".to_string(),
123            Self::string_set_expr(&operation.root_fields),
124        );
125        attrs.insert(
126            "accessedTypes".to_string(),
127            Self::string_set_expr(&operation.accessed_types),
128        );
129        attrs.insert(
130            "accessedFields".to_string(),
131            Self::string_set_expr(&operation.accessed_fields),
132        );
133        attrs.insert(
134            "sensitiveCategories".to_string(),
135            Self::string_set_expr(&operation.sensitive_categories),
136        );
137
138        Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
139    }
140
141    fn build_server_entity(&self, config: &ServerConfigEntity) -> Entity {
142        let uid = EntityUid::from_type_name_and_id(
143            EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
144            EntityId::from_str(&config.server_id).expect("valid entity id"),
145        );
146
147        let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
148            HashMap::with_capacity(10);
149
150        attrs.insert(
151            "serverId".to_string(),
152            cedar_policy::RestrictedExpression::new_string(config.server_id.clone()),
153        );
154        attrs.insert(
155            "serverType".to_string(),
156            cedar_policy::RestrictedExpression::new_string(config.server_type.clone()),
157        );
158        attrs.insert(
159            "allowWrite".to_string(),
160            cedar_policy::RestrictedExpression::new_bool(config.allow_write),
161        );
162        attrs.insert(
163            "allowDelete".to_string(),
164            cedar_policy::RestrictedExpression::new_bool(config.allow_delete),
165        );
166        attrs.insert(
167            "allowAdmin".to_string(),
168            cedar_policy::RestrictedExpression::new_bool(config.allow_admin),
169        );
170        attrs.insert(
171            "maxDepth".to_string(),
172            cedar_policy::RestrictedExpression::new_long(config.max_depth as i64),
173        );
174        attrs.insert(
175            "maxCost".to_string(),
176            cedar_policy::RestrictedExpression::new_long(config.max_cost as i64),
177        );
178        attrs.insert(
179            "maxApiCalls".to_string(),
180            cedar_policy::RestrictedExpression::new_long(config.max_api_calls as i64),
181        );
182
183        attrs.insert(
184            "allowedOperations".to_string(),
185            Self::string_set_expr(&config.allowed_operations),
186        );
187        attrs.insert(
188            "blockedOperations".to_string(),
189            Self::string_set_expr(&config.blocked_operations),
190        );
191        attrs.insert(
192            "blockedFields".to_string(),
193            Self::string_set_expr(&config.blocked_fields),
194        );
195
196        Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
197    }
198
199    fn string_set_expr(set: &HashSet<String>) -> cedar_policy::RestrictedExpression {
200        cedar_policy::RestrictedExpression::new_set(
201            set.iter()
202                .map(|s| cedar_policy::RestrictedExpression::new_string(s.clone())),
203        )
204    }
205}
206
207#[async_trait::async_trait]
208impl super::PolicyEvaluator for CedarPolicyEvaluator {
209    async fn evaluate_operation(
210        &self,
211        operation: &OperationEntity,
212        server_config: &ServerConfigEntity,
213    ) -> Result<AuthorizationDecision, PolicyEvaluationError> {
214        let principal = EntityUid::from_type_name_and_id(
215            EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
216            EntityId::from_str(&operation.id).expect("valid entity id"),
217        );
218
219        // Determine action
220        let action_id = if operation.has_introspection {
221            "Admin"
222        } else {
223            match operation.operation_type.as_str() {
224                "mutation" => {
225                    let op_name = operation.operation_name.to_lowercase();
226                    if op_name.starts_with("delete")
227                        || op_name.starts_with("remove")
228                        || op_name.starts_with("purge")
229                    {
230                        "Delete"
231                    } else {
232                        "Write"
233                    }
234                },
235                "subscription" => "Read",
236                _ => "Read",
237            }
238        };
239
240        let action = EntityUid::from_type_name_and_id(
241            EntityTypeName::from_str("CodeMode::Action").expect("valid type name"),
242            EntityId::from_str(action_id).expect("valid entity id"),
243        );
244
245        let resource = EntityUid::from_type_name_and_id(
246            EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
247            EntityId::from_str(&server_config.server_id).expect("valid entity id"),
248        );
249
250        let op_entity = self.build_operation_entity(operation);
251        let server_entity = self.build_server_entity(server_config);
252
253        let entities = Entities::from_entities([op_entity, server_entity], Some(&self.schema))
254            .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Entity error: {}", e)))?;
255
256        let context = Context::from_pairs([
257            (
258                "serverId".to_string(),
259                cedar_policy::RestrictedExpression::new_string(server_config.server_id.clone()),
260            ),
261            (
262                "serverType".to_string(),
263                cedar_policy::RestrictedExpression::new_string(server_config.server_type.clone()),
264            ),
265        ])
266        .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Context error: {}", e)))?;
267
268        let request = Request::new(principal, action, resource, context, Some(&self.schema))
269            .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Request error: {}", e)))?;
270
271        let response = self
272            .authorizer
273            .is_authorized(&request, &self.policy_set, &entities);
274
275        let allowed = matches!(response.decision(), cedar_policy::Decision::Allow);
276
277        let determining_policies: Vec<String> = response
278            .diagnostics()
279            .reason()
280            .map(|p| p.to_string())
281            .collect();
282
283        let errors: Vec<String> = response
284            .diagnostics()
285            .errors()
286            .map(|e| e.to_string())
287            .collect();
288
289        Ok(AuthorizationDecision {
290            allowed,
291            determining_policies,
292            errors,
293        })
294    }
295
296    fn name(&self) -> &str {
297        "cedar-local"
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::super::PolicyEvaluator;
304    use super::*;
305
306    #[tokio::test]
307    async fn test_cedar_evaluator_permits_reads() {
308        let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
309
310        let operation = OperationEntity {
311            id: "test-query".to_string(),
312            operation_type: "query".to_string(),
313            operation_name: "GetUsers".to_string(),
314            root_fields: ["users".to_string()].into_iter().collect(),
315            accessed_types: ["User".to_string()].into_iter().collect(),
316            accessed_fields: ["User.id".to_string(), "User.name".to_string()]
317                .into_iter()
318                .collect(),
319            depth: 2,
320            field_count: 2,
321            estimated_cost: 2,
322            has_introspection: false,
323            accesses_sensitive_data: false,
324            sensitive_categories: HashSet::new(),
325        };
326
327        let server_config = ServerConfigEntity::default();
328
329        let decision = evaluator
330            .evaluate_operation(&operation, &server_config)
331            .await
332            .unwrap();
333        assert!(
334            decision.allowed,
335            "Read queries should be permitted by default"
336        );
337    }
338
339    #[tokio::test]
340    async fn test_cedar_evaluator_denies_mutations_when_disabled() {
341        let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
342
343        let operation = OperationEntity {
344            id: "test-mutation".to_string(),
345            operation_type: "mutation".to_string(),
346            operation_name: "CreateUser".to_string(),
347            root_fields: ["createUser".to_string()].into_iter().collect(),
348            accessed_types: ["User".to_string()].into_iter().collect(),
349            accessed_fields: ["User.id".to_string()].into_iter().collect(),
350            depth: 1,
351            field_count: 1,
352            estimated_cost: 1,
353            has_introspection: false,
354            accesses_sensitive_data: false,
355            sensitive_categories: HashSet::new(),
356        };
357
358        let server_config = ServerConfigEntity {
359            allow_write: false,
360            ..Default::default()
361        };
362
363        let decision = evaluator
364            .evaluate_operation(&operation, &server_config)
365            .await
366            .unwrap();
367        assert!(
368            !decision.allowed,
369            "Mutations should be denied when allow_write is false"
370        );
371    }
372}