1use super::types::{AuthorizationDecision, OperationEntity, ServerConfigEntity};
8use super::PolicyEvaluationError;
9use cedar_policy::{
10 Authorizer, Context, Entities, Entity, EntityId, EntityTypeName, EntityUid, PolicySet, Request,
11 Schema,
12};
13use std::collections::{HashMap, HashSet};
14use std::str::FromStr;
15
16#[derive(Debug, Clone)]
18pub struct CedarPolicyConfig {
19 pub schema_json: serde_json::Value,
21 pub policies: Vec<(String, String, String)>,
23}
24
25pub struct CedarPolicyEvaluator {
29 authorizer: Authorizer,
30 policy_set: PolicySet,
31 schema: Schema,
32}
33
34impl CedarPolicyEvaluator {
35 pub fn new(config: CedarPolicyConfig) -> Result<Self, PolicyEvaluationError> {
37 let schema_json = serde_json::to_string(&config.schema_json).map_err(|e| {
38 PolicyEvaluationError::ConfigError(format!("Invalid schema JSON: {}", e))
39 })?;
40
41 let schema = Schema::from_json_str(&schema_json).map_err(|e| {
42 PolicyEvaluationError::ConfigError(format!("Invalid Cedar schema: {}", e))
43 })?;
44
45 let mut policy_set = PolicySet::new();
46 for (id, _description, policy_text) in &config.policies {
47 let policy = cedar_policy::Policy::parse(
48 Some(cedar_policy::PolicyId::from_str(id).unwrap()),
49 policy_text,
50 )
51 .map_err(|e| {
52 PolicyEvaluationError::ConfigError(format!("Invalid policy '{}': {}", id, e))
53 })?;
54 policy_set.add(policy).map_err(|e| {
55 PolicyEvaluationError::ConfigError(format!("Duplicate policy '{}': {}", id, e))
56 })?;
57 }
58
59 Ok(Self {
60 authorizer: Authorizer::new(),
61 policy_set,
62 schema,
63 })
64 }
65
66 pub fn graphql_default() -> Result<Self, PolicyEvaluationError> {
68 let schema_json = super::types::get_code_mode_schema_json();
69 let baseline = super::types::get_baseline_policies();
70
71 let policies = baseline
72 .into_iter()
73 .map(|(id, desc, text)| (id.to_string(), desc.to_string(), text.to_string()))
74 .collect();
75
76 Self::new(CedarPolicyConfig {
77 schema_json,
78 policies,
79 })
80 }
81
82 fn build_operation_entity(&self, operation: &OperationEntity) -> Entity {
83 let uid = EntityUid::from_type_name_and_id(
84 EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
85 EntityId::from_str(&operation.id).expect("valid entity id"),
86 );
87
88 let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
89 HashMap::with_capacity(11);
90
91 attrs.insert(
92 "operationType".to_string(),
93 cedar_policy::RestrictedExpression::new_string(operation.operation_type.clone()),
94 );
95 attrs.insert(
96 "operationName".to_string(),
97 cedar_policy::RestrictedExpression::new_string(operation.operation_name.clone()),
98 );
99 attrs.insert(
100 "depth".to_string(),
101 cedar_policy::RestrictedExpression::new_long(operation.depth as i64),
102 );
103 attrs.insert(
104 "fieldCount".to_string(),
105 cedar_policy::RestrictedExpression::new_long(operation.field_count as i64),
106 );
107 attrs.insert(
108 "estimatedCost".to_string(),
109 cedar_policy::RestrictedExpression::new_long(operation.estimated_cost as i64),
110 );
111 attrs.insert(
112 "hasIntrospection".to_string(),
113 cedar_policy::RestrictedExpression::new_bool(operation.has_introspection),
114 );
115 attrs.insert(
116 "accessesSensitiveData".to_string(),
117 cedar_policy::RestrictedExpression::new_bool(operation.accesses_sensitive_data),
118 );
119
120 attrs.insert(
122 "rootFields".to_string(),
123 Self::string_set_expr(&operation.root_fields),
124 );
125 attrs.insert(
126 "accessedTypes".to_string(),
127 Self::string_set_expr(&operation.accessed_types),
128 );
129 attrs.insert(
130 "accessedFields".to_string(),
131 Self::string_set_expr(&operation.accessed_fields),
132 );
133 attrs.insert(
134 "sensitiveCategories".to_string(),
135 Self::string_set_expr(&operation.sensitive_categories),
136 );
137
138 Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
139 }
140
141 fn build_server_entity(&self, config: &ServerConfigEntity) -> Entity {
142 let uid = EntityUid::from_type_name_and_id(
143 EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
144 EntityId::from_str(&config.server_id).expect("valid entity id"),
145 );
146
147 let mut attrs: HashMap<String, cedar_policy::RestrictedExpression> =
148 HashMap::with_capacity(10);
149
150 attrs.insert(
151 "serverId".to_string(),
152 cedar_policy::RestrictedExpression::new_string(config.server_id.clone()),
153 );
154 attrs.insert(
155 "serverType".to_string(),
156 cedar_policy::RestrictedExpression::new_string(config.server_type.clone()),
157 );
158 attrs.insert(
159 "allowWrite".to_string(),
160 cedar_policy::RestrictedExpression::new_bool(config.allow_write),
161 );
162 attrs.insert(
163 "allowDelete".to_string(),
164 cedar_policy::RestrictedExpression::new_bool(config.allow_delete),
165 );
166 attrs.insert(
167 "allowAdmin".to_string(),
168 cedar_policy::RestrictedExpression::new_bool(config.allow_admin),
169 );
170 attrs.insert(
171 "maxDepth".to_string(),
172 cedar_policy::RestrictedExpression::new_long(config.max_depth as i64),
173 );
174 attrs.insert(
175 "maxCost".to_string(),
176 cedar_policy::RestrictedExpression::new_long(config.max_cost as i64),
177 );
178 attrs.insert(
179 "maxApiCalls".to_string(),
180 cedar_policy::RestrictedExpression::new_long(config.max_api_calls as i64),
181 );
182
183 attrs.insert(
184 "allowedOperations".to_string(),
185 Self::string_set_expr(&config.allowed_operations),
186 );
187 attrs.insert(
188 "blockedOperations".to_string(),
189 Self::string_set_expr(&config.blocked_operations),
190 );
191 attrs.insert(
192 "blockedFields".to_string(),
193 Self::string_set_expr(&config.blocked_fields),
194 );
195
196 Entity::new(uid, attrs, HashSet::new()).expect("valid entity")
197 }
198
199 fn string_set_expr(set: &HashSet<String>) -> cedar_policy::RestrictedExpression {
200 cedar_policy::RestrictedExpression::new_set(
201 set.iter()
202 .map(|s| cedar_policy::RestrictedExpression::new_string(s.clone())),
203 )
204 }
205}
206
207#[async_trait::async_trait]
208impl super::PolicyEvaluator for CedarPolicyEvaluator {
209 async fn evaluate_operation(
210 &self,
211 operation: &OperationEntity,
212 server_config: &ServerConfigEntity,
213 ) -> Result<AuthorizationDecision, PolicyEvaluationError> {
214 let principal = EntityUid::from_type_name_and_id(
215 EntityTypeName::from_str("CodeMode::Operation").expect("valid type name"),
216 EntityId::from_str(&operation.id).expect("valid entity id"),
217 );
218
219 let action_id = if operation.has_introspection {
221 "Admin"
222 } else {
223 match operation.operation_type.as_str() {
224 "mutation" => {
225 let op_name = operation.operation_name.to_lowercase();
226 if op_name.starts_with("delete")
227 || op_name.starts_with("remove")
228 || op_name.starts_with("purge")
229 {
230 "Delete"
231 } else {
232 "Write"
233 }
234 },
235 "subscription" => "Read",
236 _ => "Read",
237 }
238 };
239
240 let action = EntityUid::from_type_name_and_id(
241 EntityTypeName::from_str("CodeMode::Action").expect("valid type name"),
242 EntityId::from_str(action_id).expect("valid entity id"),
243 );
244
245 let resource = EntityUid::from_type_name_and_id(
246 EntityTypeName::from_str("CodeMode::Server").expect("valid type name"),
247 EntityId::from_str(&server_config.server_id).expect("valid entity id"),
248 );
249
250 let op_entity = self.build_operation_entity(operation);
251 let server_entity = self.build_server_entity(server_config);
252
253 let entities = Entities::from_entities([op_entity, server_entity], Some(&self.schema))
254 .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Entity error: {}", e)))?;
255
256 let context = Context::from_pairs([
257 (
258 "serverId".to_string(),
259 cedar_policy::RestrictedExpression::new_string(server_config.server_id.clone()),
260 ),
261 (
262 "serverType".to_string(),
263 cedar_policy::RestrictedExpression::new_string(server_config.server_type.clone()),
264 ),
265 ])
266 .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Context error: {}", e)))?;
267
268 let request = Request::new(principal, action, resource, context, Some(&self.schema))
269 .map_err(|e| PolicyEvaluationError::EvaluationError(format!("Request error: {}", e)))?;
270
271 let response = self
272 .authorizer
273 .is_authorized(&request, &self.policy_set, &entities);
274
275 let allowed = matches!(response.decision(), cedar_policy::Decision::Allow);
276
277 let determining_policies: Vec<String> = response
278 .diagnostics()
279 .reason()
280 .map(|p| p.to_string())
281 .collect();
282
283 let errors: Vec<String> = response
284 .diagnostics()
285 .errors()
286 .map(|e| e.to_string())
287 .collect();
288
289 Ok(AuthorizationDecision {
290 allowed,
291 determining_policies,
292 errors,
293 })
294 }
295
296 fn name(&self) -> &str {
297 "cedar-local"
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::super::PolicyEvaluator;
304 use super::*;
305
306 #[tokio::test]
307 async fn test_cedar_evaluator_permits_reads() {
308 let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
309
310 let operation = OperationEntity {
311 id: "test-query".to_string(),
312 operation_type: "query".to_string(),
313 operation_name: "GetUsers".to_string(),
314 root_fields: ["users".to_string()].into_iter().collect(),
315 accessed_types: ["User".to_string()].into_iter().collect(),
316 accessed_fields: ["User.id".to_string(), "User.name".to_string()]
317 .into_iter()
318 .collect(),
319 depth: 2,
320 field_count: 2,
321 estimated_cost: 2,
322 has_introspection: false,
323 accesses_sensitive_data: false,
324 sensitive_categories: HashSet::new(),
325 };
326
327 let server_config = ServerConfigEntity::default();
328
329 let decision = evaluator
330 .evaluate_operation(&operation, &server_config)
331 .await
332 .unwrap();
333 assert!(
334 decision.allowed,
335 "Read queries should be permitted by default"
336 );
337 }
338
339 #[tokio::test]
340 async fn test_cedar_evaluator_denies_mutations_when_disabled() {
341 let evaluator = CedarPolicyEvaluator::graphql_default().unwrap();
342
343 let operation = OperationEntity {
344 id: "test-mutation".to_string(),
345 operation_type: "mutation".to_string(),
346 operation_name: "CreateUser".to_string(),
347 root_fields: ["createUser".to_string()].into_iter().collect(),
348 accessed_types: ["User".to_string()].into_iter().collect(),
349 accessed_fields: ["User.id".to_string()].into_iter().collect(),
350 depth: 1,
351 field_count: 1,
352 estimated_cost: 1,
353 has_introspection: false,
354 accesses_sensitive_data: false,
355 sensitive_categories: HashSet::new(),
356 };
357
358 let server_config = ServerConfigEntity {
359 allow_write: false,
360 ..Default::default()
361 };
362
363 let decision = evaluator
364 .evaluate_operation(&operation, &server_config)
365 .await
366 .unwrap();
367 assert!(
368 !decision.allowed,
369 "Mutations should be denied when allow_write is false"
370 );
371 }
372}