1use async_trait::async_trait;
4use cedar_policy::PolicySet;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
9pub enum PolicyStoreError {
10 #[error("Database error: {0}")]
11 Database(String),
12 #[error("Policy not found: {0}")]
13 NotFound(String),
14 #[error("Policy parse error: {0}")]
15 Parse(String),
16 #[error("Internal error: {0}")]
17 Internal(String),
18}
19
20#[derive(Debug, Error)]
22pub enum CacheError {
23 #[error("Connection error: {0}")]
24 Connection(String),
25 #[error("Publish error: {0}")]
26 Publish(String),
27 #[error("Subscribe error: {0}")]
28 Subscribe(String),
29}
30
31#[async_trait]
33pub trait PolicyStore: Send + Sync {
34 async fn create_policy(&self, content: String) -> Result<String, PolicyStoreError>;
36
37 async fn get_policy(&self, id: &str) -> Result<Option<String>, PolicyStoreError>;
39
40 async fn list_policies(&self) -> Result<Vec<(String, String)>, PolicyStoreError>;
42
43 async fn update_policy(&self, id: &str, content: String) -> Result<(), PolicyStoreError>;
45
46 async fn delete_policy(&self, id: &str) -> Result<(), PolicyStoreError>;
48
49 async fn load_all_policies(&self) -> Result<PolicySet, PolicyStoreError>;
51}
52
53#[async_trait]
55pub trait CacheInvalidation: Send + Sync {
56 async fn invalidate_policies(&self) -> Result<(), CacheError>;
58
59 async fn subscribe_to_invalidations<F>(&self, callback: F) -> Result<(), CacheError>
61 where
62 F: Fn() + Send + Sync + 'static;
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use std::collections::HashMap;
69 use std::sync::{Arc, Mutex};
70
71 struct MockPolicyStore {
73 policies: Arc<Mutex<HashMap<String, String>>>,
74 }
75
76 impl MockPolicyStore {
77 fn new() -> Self {
78 Self {
79 policies: Arc::new(Mutex::new(HashMap::new())),
80 }
81 }
82 }
83
84 #[async_trait]
85 impl PolicyStore for MockPolicyStore {
86 async fn create_policy(&self, content: String) -> Result<String, PolicyStoreError> {
87 let id = uuid::Uuid::new_v4().to_string();
88 self.policies.lock().unwrap().insert(id.clone(), content);
89 Ok(id)
90 }
91
92 async fn get_policy(&self, id: &str) -> Result<Option<String>, PolicyStoreError> {
93 Ok(self.policies.lock().unwrap().get(id).cloned())
94 }
95
96 async fn list_policies(&self) -> Result<Vec<(String, String)>, PolicyStoreError> {
97 Ok(self.policies.lock().unwrap()
98 .iter()
99 .map(|(k, v)| (k.clone(), v.clone()))
100 .collect())
101 }
102
103 async fn update_policy(&self, id: &str, content: String) -> Result<(), PolicyStoreError> {
104 let mut policies = self.policies.lock().unwrap();
105 if policies.contains_key(id) {
106 policies.insert(id.to_string(), content);
107 Ok(())
108 } else {
109 Err(PolicyStoreError::NotFound(id.to_string()))
110 }
111 }
112
113 async fn delete_policy(&self, id: &str) -> Result<(), PolicyStoreError> {
114 let mut policies = self.policies.lock().unwrap();
115 if policies.remove(id).is_some() {
116 Ok(())
117 } else {
118 Err(PolicyStoreError::NotFound(id.to_string()))
119 }
120 }
121
122 async fn load_all_policies(&self) -> Result<PolicySet, PolicyStoreError> {
123 Ok(PolicySet::new())
124 }
125 }
126
127 #[tokio::test]
128 async fn test_mock_policy_store_create() {
129 let store = MockPolicyStore::new();
130 let id = store.create_policy("test policy".to_string()).await.unwrap();
131 assert!(!id.is_empty());
132 }
133
134 #[tokio::test]
135 async fn test_mock_policy_store_get() {
136 let store = MockPolicyStore::new();
137 let id = store.create_policy("test policy".to_string()).await.unwrap();
138 let policy = store.get_policy(&id).await.unwrap();
139 assert_eq!(policy, Some("test policy".to_string()));
140 }
141
142 #[tokio::test]
143 async fn test_mock_policy_store_update() {
144 let store = MockPolicyStore::new();
145 let id = store.create_policy("original".to_string()).await.unwrap();
146 store.update_policy(&id, "updated".to_string()).await.unwrap();
147 let policy = store.get_policy(&id).await.unwrap();
148 assert_eq!(policy, Some("updated".to_string()));
149 }
150
151 #[tokio::test]
152 async fn test_mock_policy_store_delete() {
153 let store = MockPolicyStore::new();
154 let id = store.create_policy("test".to_string()).await.unwrap();
155 store.delete_policy(&id).await.unwrap();
156 let policy = store.get_policy(&id).await.unwrap();
157 assert_eq!(policy, None);
158 }
159
160 #[tokio::test]
161 async fn test_mock_policy_store_list() {
162 let store = MockPolicyStore::new();
163 let id1 = store.create_policy("policy1".to_string()).await.unwrap();
164 let id2 = store.create_policy("policy2".to_string()).await.unwrap();
165 let policies = store.list_policies().await.unwrap();
166 assert_eq!(policies.len(), 2);
167 assert!(policies.iter().any(|(id, _)| id == &id1));
168 assert!(policies.iter().any(|(id, _)| id == &id2));
169 }
170}