1use super::types::{AuthorizationDecision, OperationEntity, ServerConfigEntity};
8use super::PolicyEvaluationError;
9use cedar_policy::{
10 Authorizer, Context, Entities, Entity, EntityId, EntityTypeName, EntityUid, PolicySet, Request,
11 Schema,
12};
13use std::collections::{HashMap, HashSet};
14use std::str::FromStr;
15
16#[derive(Debug, Clone)]
18pub struct CedarPolicyConfig {
19 pub schema_json: serde_json::Value,
21 pub policies: Vec<(String, String, String)>,
23}
24
25pub struct CedarPolicyEvaluator {
29 authorizer: Authorizer,
30 policy_set: PolicySet,
31 schema: Schema,
32}
33
34impl CedarPolicyEvaluator {
35 pub fn new(config: CedarPolicyConfig) -> Result<Self, PolicyEvaluationError> {
37 let schema_json = serde_json::to_string(&config.schema_json).map_err(|e| {
38 PolicyEvaluationError::ConfigError(format!("Invalid schema JSON: {}", e))
39 })?;
40
41 let schema = Schema::from_json_str(&schema_json).map_err(|e| {
42 PolicyEvaluationError::ConfigError(format!("Invalid Cedar schema: {}", e))
43 })?;
44
45 let mut policy_set = PolicySet::new();
46 for (id, _description, policy_text) in &config.policies {
47 let policy = cedar_policy::Policy::parse(
48 Some(cedar_policy::PolicyId::from_str(id).unwrap()),
49 policy_text,
50 )
51 .map_err(|e| {
52 PolicyEvaluationError::ConfigError(format!("Invalid policy '{}': {}", id, e))
53 })?;
54 policy_set.add(policy).map_err(|e| {
55 PolicyEvaluationError::ConfigError(format!("Duplicate policy '{}': {}", id, e))
56 })?;
57 }
58
59 Ok(Self {
60 authorizer: Authorizer::new(),
61 policy_set,
62 schema,
63 })
64 }
65
66 pub fn graphql_default() -> Result<Self, PolicyEvaluationError> {
68 let schema_json = super::types::get_code_mode_schema_json();
69 let baseline = super::types::get_baseline_policies();
70
71 let policies = baseline
72 .into_iter()
73 .map(|(id, desc, text)| (id.to_string(), desc.to_string(), text.to_string()))
74 .collect();
75
76 Self::new(CedarPolicyConfig {
77 schema_json,
78 policies,
79 })
80 }
81
82 fn build_operation_entity(&self, operation: &OperationEntity) -> Entity {
83 let uid = EntityUid::from_type_name_and_id(
84 EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
85 EntityId::from_str(&operation.id).expect("valid entity id"),
86 );
87
88 let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
89 HashMap::with_capacity(11);
90
91 attrs.insert(
92 "operationType".to_string(),
93 cedar_policy::RestrictedExpression::new_string(operation.operation_type.clone()),
94 );
95 attrs.insert(
96 "operationName".to_string(),
97 cedar_policy::RestrictedExpression::new_string(operation.operation_name.clone()),
98 );
99 attrs.insert(
100 "depth".to_string(),
101 cedar_policy::RestrictedExpression::new_long(operation.depth as i64),
102 );
103 attrs.insert(
104 "fieldCount".to_string(),
105 cedar_policy::RestrictedExpression::new_long(operation.field_count as i64),
106 );
107 attrs.insert(
108 "estimatedCost".to_string(),
109 cedar_policy::RestrictedExpression::new_long(operation.estimated_cost as i64),
110 );
111 attrs.insert(
112 "hasIntrospection".to_string(),
113 cedar_policy::RestrictedExpression::new_bool(operation.has_introspection),
114 );
115 attrs.insert(
116 "accessesSensitiveData".to_string(),
117 cedar_policy::RestrictedExpression::new_bool(operation.accesses_sensitive_data),
118 );
119
120 attrs.insert(
122 "rootFields".to_string(),
123 Self::string_set_expr(&operation.root_fields),
124 );
125 attrs.insert(
126 "accessedTypes".to_string(),
127 Self::string_set_expr(&operation.accessed_types),
128 );
129 attrs.insert(
130 "accessedFields".to_string(),
131 Self::string_set_expr(&operation.accessed_fields),
132 );
133 attrs.insert(
134 "sensitiveCategories".to_string(),
135 Self::string_set_expr(&operation.sensitive_categories),
136 );
137
138 Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
139 }
140
141 fn build_server_entity(&self, config: &ServerConfigEntity) -> Entity {
142 let uid = EntityUid::from_type_name_and_id(
143 EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
144 EntityId::from_str(&config.server_id).expect("valid entity id"),
145 );
146
147 let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
148 HashMap::with_capacity(12);
149
150 attrs.insert(
151 "serverId".to_string(),
152 cedar_policy::RestrictedExpression::new_string(config.server_id.clone()),
153 );
154 attrs.insert(
155 "serverType".to_string(),
156 cedar_policy::RestrictedExpression::new_string(config.server_type.clone()),
157 );
158 attrs.insert(
159 "allowWrite".to_string(),
160 cedar_policy::RestrictedExpression::new_bool(config.allow_write),
161 );
162 attrs.insert(
163 "allowDelete".to_string(),
164 cedar_policy::RestrictedExpression::new_bool(config.allow_delete),
165 );
166 attrs.insert(
167 "allowAdmin".to_string(),
168 cedar_policy::RestrictedExpression::new_bool(config.allow_admin),
169 );
170 attrs.insert(
171 "maxDepth".to_string(),
172 cedar_policy::RestrictedExpression::new_long(config.max_depth as i64),
173 );
174 attrs.insert(
175 "maxFieldCount".to_string(),
176 cedar_policy::RestrictedExpression::new_long(config.max_field_count as i64),
177 );
178 attrs.insert(
179 "maxCost".to_string(),
180 cedar_policy::RestrictedExpression::new_long(config.max_cost as i64),
181 );
182 attrs.insert(
183 "maxApiCalls".to_string(),
184 cedar_policy::RestrictedExpression::new_long(config.max_api_calls as i64),
185 );
186
187 attrs.insert(
188 "allowedOperations".to_string(),
189 Self::string_set_expr(&config.allowed_operations),
190 );
191 attrs.insert(
192 "blockedOperations".to_string(),
193 Self::string_set_expr(&config.blocked_operations),
194 );
195 attrs.insert(
196 "blockedFields".to_string(),
197 Self::string_set_expr(&config.blocked_fields),
198 );
199
200 Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
201 }
202
203 fn string_set_expr(set: &HashSet<String>) -> cedar_policy::RestrictedExpression {
204 cedar_policy::RestrictedExpression::new_set(
205 set.iter()
206 .map(|s| cedar_policy::RestrictedExpression::new_string(s.clone())),
207 )
208 }
209}
210
211#[async_trait::async_trait]
212impl super::PolicyEvaluator for CedarPolicyEvaluator {
213 async fn evaluate_operation(
214 &self,
215 operation: &OperationEntity,
216 server_config: &ServerConfigEntity,
217 ) -> Result<AuthorizationDecision, PolicyEvaluationError> {
218 let principal = EntityUid::from_type_name_and_id(
219 EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
220 EntityId::from_str(&operation.id).expect("valid entity id"),
221 );
222
223 let action_id = if operation.has_introspection {
225 "Admin"
226 } else {
227 match operation.operation_type.as_str() {
228 "mutation" => {
229 let op_name = operation.operation_name.to_lowercase();
230 if op_name.starts_with("delete")
231 || op_name.starts_with("remove")
232 || op_name.starts_with("purge")
233 {
234 "Delete"
235 } else {
236 "Write"
237 }
238 },
239 "subscription" => "Read",
240 _ => "Read",
241 }
242 };
243
244 let action = EntityUid::from_type_name_and_id(
245 EntityTypeName::from_str("CodeMode::Action").expect("valid type name"),
246 EntityId::from_str(action_id).expect("valid entity id"),
247 );
248
249 let resource = EntityUid::from_type_name_and_id(
250 EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
251 EntityId::from_str(&server_config.server_id).expect("valid entity id"),
252 );
253
254 let op_entity = self.build_operation_entity(operation);
255 let server_entity = self.build_server_entity(server_config);
256
257 let entities = Entities::from_entities([op_entity, server_entity], Some(&self.schema))
258 .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Entity error: {}", e)))?;
259
260 let context = Context::from_pairs([
261 (
262 "serverId".to_string(),
263 cedar_policy::RestrictedExpression::new_string(server_config.server_id.clone()),
264 ),
265 (
266 "serverType".to_string(),
267 cedar_policy::RestrictedExpression::new_string(server_config.server_type.clone()),
268 ),
269 ])
270 .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Context error: {}", e)))?;
271
272 let request = Request::new(principal, action, resource, context, Some(&self.schema))
273 .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Request error: {}", e)))?;
274
275 let response = self
276 .authorizer
277 .is_authorized(&request, &self.policy_set, &entities);
278
279 let allowed = matches!(response.decision(), cedar_policy::Decision::Allow);
280
281 let determining_policies: Vec<String> = response
282 .diagnostics()
283 .reason()
284 .map(|p| p.to_string())
285 .collect();
286
287 let errors: Vec<String> = response
288 .diagnostics()
289 .errors()
290 .map(|e| e.to_string())
291 .collect();
292
293 Ok(AuthorizationDecision {
294 allowed,
295 determining_policies,
296 errors,
297 })
298 }
299
300 fn name(&self) -> &str {
301 "cedar-local"
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::super::PolicyEvaluator;
308 use super::*;
309
310 #[tokio::test]
311 async fn test_cedar_evaluator_permits_reads() {
312 let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
313
314 let operation = OperationEntity {
315 id: "test-query".to_string(),
316 operation_type: "query".to_string(),
317 operation_name: "GetUsers".to_string(),
318 root_fields: ["users".to_string()].into_iter().collect(),
319 accessed_types: ["User".to_string()].into_iter().collect(),
320 accessed_fields: ["User.id".to_string(), "User.name".to_string()]
321 .into_iter()
322 .collect(),
323 depth: 2,
324 field_count: 2,
325 estimated_cost: 2,
326 has_introspection: false,
327 accesses_sensitive_data: false,
328 sensitive_categories: HashSet::new(),
329 };
330
331 let server_config = ServerConfigEntity::default();
332
333 let decision = evaluator
334 .evaluate_operation(&operation, &server_config)
335 .await
336 .unwrap();
337 assert!(
338 decision.allowed,
339 "Read queries should be permitted by default"
340 );
341 }
342
343 #[tokio::test]
344 async fn test_cedar_evaluator_denies_mutations_when_disabled() {
345 let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
346
347 let operation = OperationEntity {
348 id: "test-mutation".to_string(),
349 operation_type: "mutation".to_string(),
350 operation_name: "CreateUser".to_string(),
351 root_fields: ["createUser".to_string()].into_iter().collect(),
352 accessed_types: ["User".to_string()].into_iter().collect(),
353 accessed_fields: ["User.id".to_string()].into_iter().collect(),
354 depth: 1,
355 field_count: 1,
356 estimated_cost: 1,
357 has_introspection: false,
358 accesses_sensitive_data: false,
359 sensitive_categories: HashSet::new(),
360 };
361
362 let server_config = ServerConfigEntity {
363 allow_write: false,
364 ..Default::default()
365 };
366
367 let decision = evaluator
368 .evaluate_operation(&operation, &server_config)
369 .await
370 .unwrap();
371 assert!(
372 !decision.allowed,
373 "Mutations should be denied when allow_write is false"
374 );
375 }
376}