1use crate::{
7 access::{self, AccessError, metrics::AccessMetrics},
8 cdk::types::Principal,
9 ids::{AccessMetricKind, EndpointCall},
10 log,
11 log::Topic,
12};
13use async_trait::async_trait;
14use std::{future::Future, pin::Pin, sync::Arc};
15
16#[derive(Clone, Debug)]
21pub struct AccessContext {
22 pub caller: Principal,
23 pub call: EndpointCall,
24}
25
26#[derive(Clone)]
31pub enum AccessExpr {
32 All(Vec<Self>),
33 Any(Vec<Self>),
34 Not(Box<Self>),
35 Pred(AccessPredicate),
36}
37
38#[derive(Clone)]
43pub enum AccessPredicate {
44 Builtin(BuiltinPredicate),
45 Custom(Arc<dyn AsyncAccessPredicate>),
46}
47
48#[derive(Clone, Copy, Debug)]
53pub enum BuiltinPredicate {
54 AppAllowsUpdates,
55 AppIsQueryable,
56 SelfIsPrimeSubnet,
57 SelfIsPrimeRoot,
58 CallerIsController,
59 CallerIsParent,
60 CallerIsChild,
61 CallerIsRoot,
62 CallerIsSameCanister,
63 CallerIsRegisteredToSubnet,
64 CallerIsWhitelisted,
65 DelegatedTokenValid,
66 BuildIcOnly,
67 BuildLocalOnly,
68}
69
70#[async_trait]
71pub trait AsyncAccessPredicate: Send + Sync {
72 async fn eval(&self, ctx: &AccessContext) -> Result<(), AccessError>;
79 fn name(&self) -> &'static str;
80}
81
82pub fn all<I>(exprs: I) -> AccessExpr
83where
84 I: IntoIterator<Item = AccessExpr>,
85{
86 AccessExpr::All(exprs.into_iter().collect())
87}
88
89pub fn any<I>(exprs: I) -> AccessExpr
90where
91 I: IntoIterator<Item = AccessExpr>,
92{
93 AccessExpr::Any(exprs.into_iter().collect())
94}
95
96#[must_use]
97pub fn not(expr: AccessExpr) -> AccessExpr {
98 AccessExpr::Not(Box::new(expr))
99}
100
101pub fn requires<I>(exprs: I) -> AccessExpr
102where
103 I: IntoIterator<Item = AccessExpr>,
104{
105 all(exprs)
106}
107
108pub fn custom<P>(pred: P) -> AccessExpr
109where
110 P: AsyncAccessPredicate + 'static,
111{
112 AccessExpr::Pred(AccessPredicate::Custom(Arc::new(pred)))
113}
114
115pub mod app {
116 use super::{AccessExpr, BuiltinPredicate, builtin};
117
118 #[must_use]
119 pub const fn allows_updates() -> AccessExpr {
120 builtin(BuiltinPredicate::AppAllowsUpdates)
121 }
122
123 #[must_use]
124 pub const fn is_queryable() -> AccessExpr {
125 builtin(BuiltinPredicate::AppIsQueryable)
126 }
127}
128
129pub mod caller {
130 use super::{AccessExpr, BuiltinPredicate, builtin};
131
132 #[must_use]
133 pub const fn is_controller() -> AccessExpr {
134 builtin(BuiltinPredicate::CallerIsController)
135 }
136
137 #[must_use]
138 pub const fn is_parent() -> AccessExpr {
139 builtin(BuiltinPredicate::CallerIsParent)
140 }
141
142 #[must_use]
143 pub const fn is_child() -> AccessExpr {
144 builtin(BuiltinPredicate::CallerIsChild)
145 }
146
147 #[must_use]
148 pub const fn is_root() -> AccessExpr {
149 builtin(BuiltinPredicate::CallerIsRoot)
150 }
151
152 #[must_use]
153 pub const fn is_same_canister() -> AccessExpr {
154 builtin(BuiltinPredicate::CallerIsSameCanister)
155 }
156
157 #[must_use]
158 pub const fn is_registered_to_subnet() -> AccessExpr {
159 builtin(BuiltinPredicate::CallerIsRegisteredToSubnet)
160 }
161
162 #[must_use]
163 pub const fn is_whitelisted() -> AccessExpr {
164 builtin(BuiltinPredicate::CallerIsWhitelisted)
165 }
166}
167
168pub mod env {
169 use super::{AccessExpr, BuiltinPredicate, builtin};
170
171 #[must_use]
172 pub const fn is_prime_subnet() -> AccessExpr {
173 builtin(BuiltinPredicate::SelfIsPrimeSubnet)
174 }
175
176 #[must_use]
177 pub const fn is_prime_root() -> AccessExpr {
178 builtin(BuiltinPredicate::SelfIsPrimeRoot)
179 }
180}
181
182pub mod auth {
183 use super::{AccessExpr, BuiltinPredicate, builtin};
184
185 #[must_use]
186 pub const fn delegated_token_valid() -> AccessExpr {
187 builtin(BuiltinPredicate::DelegatedTokenValid)
188 }
189}
190
191pub mod rule {
192 use super::{AccessExpr, BuiltinPredicate, builtin};
193
194 #[must_use]
195 pub const fn build_ic_only() -> AccessExpr {
196 builtin(BuiltinPredicate::BuildIcOnly)
197 }
198
199 #[must_use]
200 pub const fn build_local_only() -> AccessExpr {
201 builtin(BuiltinPredicate::BuildLocalOnly)
202 }
203}
204
205#[allow(clippy::future_not_send)]
206pub async fn eval_access(expr: &AccessExpr, ctx: &AccessContext) -> Result<(), AccessError> {
207 match eval_access_inner(expr, ctx).await {
208 Ok(()) => Ok(()),
209 Err(failure) => {
210 AccessMetrics::increment(ctx.call, failure.metric_kind, failure.predicate);
211 log!(
212 Topic::Auth,
213 Warn,
214 "access denied kind={} predicate={} context={:?}: {}",
215 failure.metric_kind.as_str(),
216 failure.predicate,
217 failure.context,
218 failure.error,
219 );
220 Err(failure.error)
221 }
222 }
223}
224
225type AccessEvalFuture<'a> = Pin<Box<dyn Future<Output = Result<(), AccessFailure>> + Send + 'a>>;
226
227fn eval_access_inner<'a>(expr: &'a AccessExpr, ctx: &'a AccessContext) -> AccessEvalFuture<'a> {
228 Box::pin(async move {
229 match expr {
230 AccessExpr::All(exprs) => {
231 if exprs.is_empty() {
232 return Err(AccessFailure::no_predicates("all"));
233 }
234 for expr in exprs {
235 if let Err(failure) = eval_access_inner(expr, ctx).await {
236 return Err(failure.with_context("all"));
237 }
238 }
239 Ok(())
240 }
241 AccessExpr::Any(exprs) => {
242 if exprs.is_empty() {
243 return Err(AccessFailure::no_predicates("any"));
244 }
245 let mut last = None;
246 for expr in exprs {
247 match eval_access_inner(expr, ctx).await {
248 Ok(()) => return Ok(()),
249 Err(failure) => last = Some(failure.with_context("any")),
250 }
251 }
252 Err(last.unwrap_or_else(|| AccessFailure::no_predicates("any")))
253 }
254 AccessExpr::Not(expr) => match eval_access_inner(expr, ctx).await {
255 Ok(()) => Err(AccessFailure::negated()),
256 Err(_) => Ok(()),
257 },
258 AccessExpr::Pred(pred) => match pred {
259 AccessPredicate::Builtin(builtin) => eval_builtin(builtin, ctx)
260 .await
261 .map_err(|err| AccessFailure::from_builtin(*builtin, err)),
262 AccessPredicate::Custom(custom) => custom
263 .eval(ctx)
264 .await
265 .map_err(|err| AccessFailure::from_custom(custom.name(), err)),
266 },
267 }
268 })
269}
270
271async fn eval_builtin(pred: &BuiltinPredicate, ctx: &AccessContext) -> Result<(), AccessError> {
272 match pred {
273 BuiltinPredicate::AppAllowsUpdates => access::guard::guard_app_update(),
274 BuiltinPredicate::AppIsQueryable => access::guard::guard_app_query(),
275 BuiltinPredicate::SelfIsPrimeSubnet => access::env::is_prime_subnet().await,
276 BuiltinPredicate::SelfIsPrimeRoot => access::env::is_prime_root().await,
277 BuiltinPredicate::CallerIsController => access::auth::is_controller(ctx.caller).await,
278 BuiltinPredicate::CallerIsParent => access::auth::is_parent(ctx.caller).await,
279 BuiltinPredicate::CallerIsChild => access::auth::is_child(ctx.caller).await,
280 BuiltinPredicate::CallerIsRoot => access::auth::caller_is_root(ctx.caller).await,
281 BuiltinPredicate::CallerIsSameCanister => access::auth::is_same_canister(ctx.caller).await,
282 BuiltinPredicate::CallerIsRegisteredToSubnet => {
283 access::auth::is_registered_to_subnet(ctx.caller).await
284 }
285 BuiltinPredicate::CallerIsWhitelisted => access::auth::is_whitelisted(ctx.caller).await,
286 BuiltinPredicate::DelegatedTokenValid => access::auth::verify_delegated_token().await,
287 BuiltinPredicate::BuildIcOnly => access::rule::build_network_ic().await,
288 BuiltinPredicate::BuildLocalOnly => access::rule::build_network_local().await,
289 }
290}
291
292const fn builtin(pred: BuiltinPredicate) -> AccessExpr {
293 AccessExpr::Pred(AccessPredicate::Builtin(pred))
294}
295
296#[derive(Debug)]
297struct AccessFailure {
298 error: AccessError,
299 metric_kind: AccessMetricKind,
300 predicate: &'static str,
301 context: Option<&'static str>,
302}
303
304impl AccessFailure {
305 const fn from_builtin(pred: BuiltinPredicate, error: AccessError) -> Self {
306 Self {
307 error,
308 metric_kind: pred.metric_kind(),
309 predicate: pred.name(),
310 context: None,
311 }
312 }
313
314 const fn from_custom(name: &'static str, error: AccessError) -> Self {
315 Self {
316 error,
317 metric_kind: AccessMetricKind::Custom,
318 predicate: name,
319 context: None,
320 }
321 }
322
323 fn no_predicates(context: &'static str) -> Self {
324 Self {
325 error: AccessError::Denied("one or more rules must be defined".to_string()),
326 metric_kind: AccessMetricKind::Auth,
327 predicate: "no_rules",
328 context: Some(context),
329 }
330 }
331
332 fn negated() -> Self {
333 Self {
334 error: AccessError::Denied("negated predicate matched".to_string()),
335 metric_kind: AccessMetricKind::Auth,
336 predicate: "not",
337 context: Some("not"),
338 }
339 }
340
341 fn with_context(mut self, context: &'static str) -> Self {
342 self.context.get_or_insert(context);
343 self
344 }
345}
346
347impl BuiltinPredicate {
348 const fn name(self) -> &'static str {
349 match self {
350 Self::AppAllowsUpdates => "app_allows_updates",
351 Self::AppIsQueryable => "app_is_queryable",
352 Self::SelfIsPrimeSubnet => "self_is_prime_subnet",
353 Self::SelfIsPrimeRoot => "self_is_prime_root",
354 Self::CallerIsController => "caller_is_controller",
355 Self::CallerIsParent => "caller_is_parent",
356 Self::CallerIsChild => "caller_is_child",
357 Self::CallerIsRoot => "caller_is_root",
358 Self::CallerIsSameCanister => "caller_is_same_canister",
359 Self::CallerIsRegisteredToSubnet => "caller_is_registered_to_subnet",
360 Self::CallerIsWhitelisted => "caller_is_whitelisted",
361 Self::DelegatedTokenValid => "delegated_token_valid",
362 Self::BuildIcOnly => "build_ic_only",
363 Self::BuildLocalOnly => "build_local_only",
364 }
365 }
366
367 const fn metric_kind(self) -> AccessMetricKind {
368 match self {
369 Self::AppAllowsUpdates | Self::AppIsQueryable => AccessMetricKind::Guard,
370 Self::SelfIsPrimeSubnet | Self::SelfIsPrimeRoot => AccessMetricKind::Env,
371 Self::BuildIcOnly | Self::BuildLocalOnly => AccessMetricKind::Rule,
372 Self::CallerIsController
373 | Self::CallerIsParent
374 | Self::CallerIsChild
375 | Self::CallerIsRoot
376 | Self::CallerIsSameCanister
377 | Self::CallerIsRegisteredToSubnet
378 | Self::CallerIsWhitelisted
379 | Self::DelegatedTokenValid => AccessMetricKind::Auth,
380 }
381 }
382}