1use crate::error::AuthzError;
4use crate::traits::Tuple;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU32, Ordering};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum CheckResult {
13 Allowed,
14 Denied,
15 ConditionRequired(Vec<String>),
16}
17
18#[derive(Debug, Clone, PartialEq, Eq, Default)]
20pub enum RecursionStrategy {
21 #[default]
22 DepthFirst, BreadthFirst, }
25
26#[derive(Debug, Clone)]
28pub struct RecursionConfig {
29 pub strategy: RecursionStrategy,
30 pub max_depth: u32,
31 pub enable_cycle_detection: bool,
32}
33
34impl Default for RecursionConfig {
35 fn default() -> Self {
36 Self {
37 strategy: RecursionStrategy::DepthFirst,
38 max_depth: 25,
39 enable_cycle_detection: true,
40 }
41 }
42}
43
44impl RecursionConfig {
45 pub fn depth_first() -> Self {
47 Self::default()
48 }
49
50 pub fn breadth_first() -> Self {
52 Self {
53 strategy: RecursionStrategy::BreadthFirst,
54 max_depth: 50,
55 enable_cycle_detection: true,
56 }
57 }
58
59 pub fn max_depth(mut self, depth: u32) -> Self {
61 self.max_depth = depth;
62 self
63 }
64
65 pub fn cycle_detection(mut self, enabled: bool) -> Self {
67 self.enable_cycle_detection = enabled;
68 self
69 }
70
71 pub fn strategy(mut self, strategy: RecursionStrategy) -> Self {
73 self.strategy = strategy;
74 self
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Default)]
80pub enum Consistency {
81 #[default]
82 FullyConsistent, AtLeastAsFresh(String), MinimizeLatency, }
86
87#[derive(Debug, Clone)]
92pub struct ResolverMetadata {
93 pub dispatch_count: Arc<AtomicU32>,
94 pub datastore_queries: Arc<AtomicU32>,
95 pub cache_hits: Arc<AtomicU32>,
96 pub max_depth_reached: Arc<AtomicU32>,
97}
98
99impl Default for ResolverMetadata {
100 fn default() -> Self {
101 Self {
102 dispatch_count: Arc::new(AtomicU32::new(0)),
103 datastore_queries: Arc::new(AtomicU32::new(0)),
104 cache_hits: Arc::new(AtomicU32::new(0)),
105 max_depth_reached: Arc::new(AtomicU32::new(0)),
106 }
107 }
108}
109
110impl ResolverMetadata {
111 pub fn get_dispatch_count(&self) -> u32 {
113 self.dispatch_count.load(Ordering::Relaxed)
114 }
115
116 pub fn get_datastore_queries(&self) -> u32 {
118 self.datastore_queries.load(Ordering::Relaxed)
119 }
120
121 pub fn get_cache_hits(&self) -> u32 {
123 self.cache_hits.load(Ordering::Relaxed)
124 }
125
126 pub fn get_max_depth_reached(&self) -> u32 {
128 self.max_depth_reached.load(Ordering::Relaxed)
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct ResolveCheckRequest {
135 pub object_type: String,
136 pub object_id: String,
137 pub relation: String,
138 pub subject_type: String,
139 pub subject_id: String,
140 pub contextual_tuples: Vec<Tuple>,
141 pub depth_remaining: u32,
142 pub consistency: Consistency,
143 pub metadata: ResolverMetadata,
144 pub recursion_config: RecursionConfig,
145 pub visited: Vec<(String, String, String)>,
147 pub context: HashMap<String, serde_json::Value>,
149 pub at_revision: String,
154}
155
156impl ResolveCheckRequest {
157 pub fn new(
159 object_type: String,
160 object_id: String,
161 relation: String,
162 subject_type: String,
163 subject_id: String,
164 ) -> Self {
165 Self {
166 object_type,
167 object_id,
168 relation,
169 subject_type,
170 subject_id,
171 contextual_tuples: Vec::new(),
172 depth_remaining: 25, consistency: Consistency::default(),
174 metadata: ResolverMetadata::default(),
175 recursion_config: RecursionConfig::default(),
176 visited: Vec::new(),
177 context: HashMap::new(),
178 at_revision: String::new(),
179 }
180 }
181
182 pub fn with_config(
184 object_type: String,
185 object_id: String,
186 relation: String,
187 subject_type: String,
188 subject_id: String,
189 recursion_config: RecursionConfig,
190 ) -> Self {
191 Self {
192 object_type,
193 object_id,
194 relation,
195 subject_type,
196 subject_id,
197 contextual_tuples: Vec::new(),
198 depth_remaining: recursion_config.max_depth,
199 consistency: Consistency::default(),
200 metadata: ResolverMetadata::default(),
201 recursion_config,
202 visited: Vec::new(),
203 context: HashMap::new(),
204 at_revision: String::new(),
205 }
206 }
207
208 pub fn child_request(
213 &self,
214 object_type: String,
215 object_id: String,
216 relation: String,
217 subject_type: String,
218 subject_id: String,
219 ) -> Self {
220 Self {
221 object_type,
222 object_id,
223 relation,
224 subject_type,
225 subject_id,
226 contextual_tuples: self.contextual_tuples.clone(),
227 depth_remaining: self.depth_remaining.saturating_sub(1),
228 consistency: self.consistency.clone(),
229 metadata: self.metadata.clone(), recursion_config: self.recursion_config.clone(),
231 visited: self.visited.clone(),
232 context: self.context.clone(),
233 at_revision: self.at_revision.clone(),
234 }
235 }
236}
237
238#[derive(Debug, Clone, PartialEq, Eq)]
240pub struct ExpandNode {
241 pub object_type: String,
242 pub object_id: String,
243 pub relation: String,
244 pub result: CheckResult,
245 pub children: Vec<ExpandNode>,
246}
247
248#[async_trait]
250pub trait CheckResolver: Send + Sync {
251 async fn resolve_check(&self, request: ResolveCheckRequest) -> Result<CheckResult, AuthzError>;
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn check_result_variants() {
260 let _ = CheckResult::Allowed;
261 let _ = CheckResult::Denied;
262 let _ = CheckResult::ConditionRequired(vec!["param".into()]);
263 }
264}