Skip to main content

pmcp_code_mode/policy/
cedar.rs

1//! Local Cedar policy evaluator.
2//!
3//! Provides in-process Cedar policy evaluation using the `cedar-policy` crate.
4//! This enables local agent projects to get real Cedar policy enforcement
5//! without an AWS account.
6
7use super::types::{AuthorizationDecision, OperationEntity, ServerConfigEntity};
8use super::PolicyEvaluationError;
9use cedar_policy::{
10    Authorizer, Context, Entities, Entity, EntityId, EntityTypeName, EntityUid, PolicySet, Request,
11    Schema,
12};
13use std::collections::{HashMap, HashSet};
14use std::str::FromStr;
15
16/// Configuration for the local Cedar policy evaluator.
17#[derive(Debug, Clone)]
18pub struct CedarPolicyConfig {
19    /// Cedar schema in JSON format
20    pub schema_json: serde_json::Value,
21    /// Cedar policies (id, description, policy_text)
22    pub policies: Vec<(String, String, String)>,
23}
24
25/// Local Cedar policy evaluator.
26///
27/// Uses the `cedar-policy` crate for in-process policy evaluation.
28pub struct CedarPolicyEvaluator {
29    authorizer: Authorizer,
30    policy_set: PolicySet,
31    schema: Schema,
32}
33
34impl CedarPolicyEvaluator {
35    /// Create a new evaluator from config.
36    pub fn new(config: CedarPolicyConfig) -> Result<Self, PolicyEvaluationError> {
37        let schema_json = serde_json::to_string(&config.schema_json).map_err(|e| {
38            PolicyEvaluationError::ConfigError(format!("Invalid schema JSON: {}", e))
39        })?;
40
41        let schema = Schema::from_json_str(&schema_json).map_err(|e| {
42            PolicyEvaluationError::ConfigError(format!("Invalid Cedar schema: {}", e))
43        })?;
44
45        let mut policy_set = PolicySet::new();
46        for (id, _description, policy_text) in &config.policies {
47            let policy = cedar_policy::Policy::parse(
48                Some(cedar_policy::PolicyId::from_str(id).unwrap()),
49                policy_text,
50            )
51            .map_err(|e| {
52                PolicyEvaluationError::ConfigError(format!("Invalid policy '{}': {}", id, e))
53            })?;
54            policy_set.add(policy).map_err(|e| {
55                PolicyEvaluationError::ConfigError(format!("Duplicate policy '{}': {}", id, e))
56            })?;
57        }
58
59        Ok(Self {
60            authorizer: Authorizer::new(),
61            policy_set,
62            schema,
63        })
64    }
65
66    /// Create a default evaluator for GraphQL Code Mode using built-in schemas and policies.
67    pub fn graphql_default() -> Result<Self, PolicyEvaluationError> {
68        let schema_json = super::types::get_code_mode_schema_json();
69        let baseline = super::types::get_baseline_policies();
70
71        let policies = baseline
72            .into_iter()
73            .map(|(id, desc, text)| (id.to_string(), desc.to_string(), text.to_string()))
74            .collect();
75
76        Self::new(CedarPolicyConfig {
77            schema_json,
78            policies,
79        })
80    }
81
82    fn build_operation_entity(&self, operation: &OperationEntity) -> Entity {
83        let uid = EntityUid::from_type_name_and_id(
84            EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
85            EntityId::from_str(&operation.id).expect("valid entity id"),
86        );
87
88        let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
89            HashMap::with_capacity(11);
90
91        attrs.insert(
92            "operationType".to_string(),
93            cedar_policy::RestrictedExpression::new_string(operation.operation_type.clone()),
94        );
95        attrs.insert(
96            "operationName".to_string(),
97            cedar_policy::RestrictedExpression::new_string(operation.operation_name.clone()),
98        );
99        attrs.insert(
100            "depth".to_string(),
101            cedar_policy::RestrictedExpression::new_long(operation.depth as i64),
102        );
103        attrs.insert(
104            "fieldCount".to_string(),
105            cedar_policy::RestrictedExpression::new_long(operation.field_count as i64),
106        );
107        attrs.insert(
108            "estimatedCost".to_string(),
109            cedar_policy::RestrictedExpression::new_long(operation.estimated_cost as i64),
110        );
111        attrs.insert(
112            "hasIntrospection".to_string(),
113            cedar_policy::RestrictedExpression::new_bool(operation.has_introspection),
114        );
115        attrs.insert(
116            "accessesSensitiveData".to_string(),
117            cedar_policy::RestrictedExpression::new_bool(operation.accesses_sensitive_data),
118        );
119
120        // Set attributes
121        attrs.insert(
122            "rootFields".to_string(),
123            Self::string_set_expr(&operation.root_fields),
124        );
125        attrs.insert(
126            "accessedTypes".to_string(),
127            Self::string_set_expr(&operation.accessed_types),
128        );
129        attrs.insert(
130            "accessedFields".to_string(),
131            Self::string_set_expr(&operation.accessed_fields),
132        );
133        attrs.insert(
134            "sensitiveCategories".to_string(),
135            Self::string_set_expr(&operation.sensitive_categories),
136        );
137
138        Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
139    }
140
141    fn build_server_entity(&self, config: &ServerConfigEntity) -> Entity {
142        let uid = EntityUid::from_type_name_and_id(
143            EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
144            EntityId::from_str(&config.server_id).expect("valid entity id"),
145        );
146
147        let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
148            HashMap::with_capacity(12);
149
150        attrs.insert(
151            "serverId".to_string(),
152            cedar_policy::RestrictedExpression::new_string(config.server_id.clone()),
153        );
154        attrs.insert(
155            "serverType".to_string(),
156            cedar_policy::RestrictedExpression::new_string(config.server_type.clone()),
157        );
158        attrs.insert(
159            "allowWrite".to_string(),
160            cedar_policy::RestrictedExpression::new_bool(config.allow_write),
161        );
162        attrs.insert(
163            "allowDelete".to_string(),
164            cedar_policy::RestrictedExpression::new_bool(config.allow_delete),
165        );
166        attrs.insert(
167            "allowAdmin".to_string(),
168            cedar_policy::RestrictedExpression::new_bool(config.allow_admin),
169        );
170        attrs.insert(
171            "maxDepth".to_string(),
172            cedar_policy::RestrictedExpression::new_long(config.max_depth as i64),
173        );
174        attrs.insert(
175            "maxFieldCount".to_string(),
176            cedar_policy::RestrictedExpression::new_long(config.max_field_count as i64),
177        );
178        attrs.insert(
179            "maxCost".to_string(),
180            cedar_policy::RestrictedExpression::new_long(config.max_cost as i64),
181        );
182        attrs.insert(
183            "maxApiCalls".to_string(),
184            cedar_policy::RestrictedExpression::new_long(config.max_api_calls as i64),
185        );
186
187        attrs.insert(
188            "allowedOperations".to_string(),
189            Self::string_set_expr(&config.allowed_operations),
190        );
191        attrs.insert(
192            "blockedOperations".to_string(),
193            Self::string_set_expr(&config.blocked_operations),
194        );
195        attrs.insert(
196            "blockedFields".to_string(),
197            Self::string_set_expr(&config.blocked_fields),
198        );
199
200        Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
201    }
202
203    fn string_set_expr(set: &HashSet<String>) -> cedar_policy::RestrictedExpression {
204        cedar_policy::RestrictedExpression::new_set(
205            set.iter()
206                .map(|s| cedar_policy::RestrictedExpression::new_string(s.clone())),
207        )
208    }
209}
210
211#[async_trait::async_trait]
212impl super::PolicyEvaluator for CedarPolicyEvaluator {
213    async fn evaluate_operation(
214        &self,
215        operation: &OperationEntity,
216        server_config: &ServerConfigEntity,
217    ) -> Result<AuthorizationDecision, PolicyEvaluationError> {
218        let principal = EntityUid::from_type_name_and_id(
219            EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
220            EntityId::from_str(&operation.id).expect("valid entity id"),
221        );
222
223        // Determine action
224        let action_id = if operation.has_introspection {
225            "Admin"
226        } else {
227            match operation.operation_type.as_str() {
228                "mutation" => {
229                    let op_name = operation.operation_name.to_lowercase();
230                    if op_name.starts_with("delete")
231                        || op_name.starts_with("remove")
232                        || op_name.starts_with("purge")
233                    {
234                        "Delete"
235                    } else {
236                        "Write"
237                    }
238                },
239                "subscription" => "Read",
240                _ => "Read",
241            }
242        };
243
244        let action = EntityUid::from_type_name_and_id(
245            EntityTypeName::from_str("CodeMode::Action").expect("valid type name"),
246            EntityId::from_str(action_id).expect("valid entity id"),
247        );
248
249        let resource = EntityUid::from_type_name_and_id(
250            EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
251            EntityId::from_str(&server_config.server_id).expect("valid entity id"),
252        );
253
254        let op_entity = self.build_operation_entity(operation);
255        let server_entity = self.build_server_entity(server_config);
256
257        let entities = Entities::from_entities([op_entity, server_entity], Some(&self.schema))
258            .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Entity error: {}", e)))?;
259
260        let context = Context::from_pairs([
261            (
262                "serverId".to_string(),
263                cedar_policy::RestrictedExpression::new_string(server_config.server_id.clone()),
264            ),
265            (
266                "serverType".to_string(),
267                cedar_policy::RestrictedExpression::new_string(server_config.server_type.clone()),
268            ),
269        ])
270        .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Context error: {}", e)))?;
271
272        let request = Request::new(principal, action, resource, context, Some(&self.schema))
273            .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Request error: {}", e)))?;
274
275        let response = self
276            .authorizer
277            .is_authorized(&request, &self.policy_set, &entities);
278
279        let allowed = matches!(response.decision(), cedar_policy::Decision::Allow);
280
281        let determining_policies: Vec<String> = response
282            .diagnostics()
283            .reason()
284            .map(|p| p.to_string())
285            .collect();
286
287        let errors: Vec<String> = response
288            .diagnostics()
289            .errors()
290            .map(|e| e.to_string())
291            .collect();
292
293        Ok(AuthorizationDecision {
294            allowed,
295            determining_policies,
296            errors,
297        })
298    }
299
300    fn name(&self) -> &str {
301        "cedar-local"
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::super::PolicyEvaluator;
308    use super::*;
309
310    #[tokio::test]
311    async fn test_cedar_evaluator_permits_reads() {
312        let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
313
314        let operation = OperationEntity {
315            id: "test-query".to_string(),
316            operation_type: "query".to_string(),
317            operation_name: "GetUsers".to_string(),
318            root_fields: ["users".to_string()].into_iter().collect(),
319            accessed_types: ["User".to_string()].into_iter().collect(),
320            accessed_fields: ["User.id".to_string(), "User.name".to_string()]
321                .into_iter()
322                .collect(),
323            depth: 2,
324            field_count: 2,
325            estimated_cost: 2,
326            has_introspection: false,
327            accesses_sensitive_data: false,
328            sensitive_categories: HashSet::new(),
329        };
330
331        let server_config = ServerConfigEntity::default();
332
333        let decision = evaluator
334            .evaluate_operation(&operation, &server_config)
335            .await
336            .unwrap();
337        assert!(
338            decision.allowed,
339            "Read queries should be permitted by default"
340        );
341    }
342
343    #[tokio::test]
344    async fn test_cedar_evaluator_denies_mutations_when_disabled() {
345        let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
346
347        let operation = OperationEntity {
348            id: "test-mutation".to_string(),
349            operation_type: "mutation".to_string(),
350            operation_name: "CreateUser".to_string(),
351            root_fields: ["createUser".to_string()].into_iter().collect(),
352            accessed_types: ["User".to_string()].into_iter().collect(),
353            accessed_fields: ["User.id".to_string()].into_iter().collect(),
354            depth: 1,
355            field_count: 1,
356            estimated_cost: 1,
357            has_introspection: false,
358            accesses_sensitive_data: false,
359            sensitive_categories: HashSet::new(),
360        };
361
362        let server_config = ServerConfigEntity {
363            allow_write: false,
364            ..Default::default()
365        };
366
367        let decision = evaluator
368            .evaluate_operation(&operation, &server_config)
369            .await
370            .unwrap();
371        assert!(
372            !decision.allowed,
373            "Mutations should be denied when allow_write is false"
374        );
375    }
376}