1use crate::{
7 access::{self, AccessError, metrics::AccessMetrics},
8 cdk::types::Principal,
9 ids::{AccessMetricKind, EndpointCall},
10 log,
11 log::Topic,
12};
13use async_trait::async_trait;
14use std::{future::Future, pin::Pin, sync::Arc};
15
16#[derive(Clone, Debug)]
21pub struct AccessContext {
22 pub caller: Principal,
23 pub call: EndpointCall,
24}
25
26pub enum DefaultAppGuard {
32 AllowsUpdates,
33 IsQueryable,
34}
35
36#[derive(Clone)]
41pub enum AccessExpr {
42 All(Vec<Self>),
43 Any(Vec<Self>),
44 Not(Box<Self>),
45 Pred(AccessPredicate),
46}
47
48#[derive(Clone)]
53pub enum AccessPredicate {
54 Builtin(BuiltinPredicate),
55 Custom(Arc<dyn AsyncAccessPredicate>),
56}
57
58#[derive(Clone, Copy, Debug)]
63pub enum BuiltinPredicate {
64 AppAllowsUpdates,
65 AppIsQueryable,
66 SelfIsPrimeSubnet,
67 SelfIsPrimeRoot,
68 CallerIsController,
69 CallerIsParent,
70 CallerIsChild,
71 CallerIsRoot,
72 CallerIsSameCanister,
73 CallerIsRegisteredToSubnet,
74 CallerIsWhitelisted,
75 DelegatedTokenValid,
76 BuildIcOnly,
77 BuildLocalOnly,
78}
79
80#[async_trait]
81pub trait AsyncAccessPredicate: Send + Sync {
82 async fn eval(&self, ctx: &AccessContext) -> Result<(), AccessError>;
89 fn name(&self) -> &'static str;
90}
91
92pub fn all<I>(exprs: I) -> AccessExpr
93where
94 I: IntoIterator<Item = AccessExpr>,
95{
96 AccessExpr::All(exprs.into_iter().collect())
97}
98
99pub fn any<I>(exprs: I) -> AccessExpr
100where
101 I: IntoIterator<Item = AccessExpr>,
102{
103 AccessExpr::Any(exprs.into_iter().collect())
104}
105
106#[must_use]
107pub fn not(expr: AccessExpr) -> AccessExpr {
108 AccessExpr::Not(Box::new(expr))
109}
110
111pub fn requires<I>(exprs: I) -> AccessExpr
112where
113 I: IntoIterator<Item = AccessExpr>,
114{
115 all(exprs)
116}
117
118pub fn custom<P>(pred: P) -> AccessExpr
119where
120 P: AsyncAccessPredicate + 'static,
121{
122 AccessExpr::Pred(AccessPredicate::Custom(Arc::new(pred)))
123}
124
125pub mod app {
126 use super::{AccessExpr, BuiltinPredicate, builtin};
127
128 #[must_use]
129 pub const fn allows_updates() -> AccessExpr {
130 builtin(BuiltinPredicate::AppAllowsUpdates)
131 }
132
133 #[must_use]
134 pub const fn is_queryable() -> AccessExpr {
135 builtin(BuiltinPredicate::AppIsQueryable)
136 }
137}
138
139pub mod caller {
140 use super::{AccessExpr, BuiltinPredicate, builtin};
141
142 #[must_use]
143 pub const fn is_controller() -> AccessExpr {
144 builtin(BuiltinPredicate::CallerIsController)
145 }
146
147 #[must_use]
148 pub const fn is_parent() -> AccessExpr {
149 builtin(BuiltinPredicate::CallerIsParent)
150 }
151
152 #[must_use]
153 pub const fn is_child() -> AccessExpr {
154 builtin(BuiltinPredicate::CallerIsChild)
155 }
156
157 #[must_use]
158 pub const fn is_root() -> AccessExpr {
159 builtin(BuiltinPredicate::CallerIsRoot)
160 }
161
162 #[must_use]
163 pub const fn is_same_canister() -> AccessExpr {
164 builtin(BuiltinPredicate::CallerIsSameCanister)
165 }
166
167 #[must_use]
168 pub const fn is_registered_to_subnet() -> AccessExpr {
169 builtin(BuiltinPredicate::CallerIsRegisteredToSubnet)
170 }
171
172 #[must_use]
173 pub const fn is_whitelisted() -> AccessExpr {
174 builtin(BuiltinPredicate::CallerIsWhitelisted)
175 }
176}
177
178pub mod env {
179 use super::{AccessExpr, BuiltinPredicate, builtin};
180
181 #[must_use]
182 pub const fn is_prime_subnet() -> AccessExpr {
183 builtin(BuiltinPredicate::SelfIsPrimeSubnet)
184 }
185
186 #[must_use]
187 pub const fn is_prime_root() -> AccessExpr {
188 builtin(BuiltinPredicate::SelfIsPrimeRoot)
189 }
190}
191
192pub mod auth {
193 use super::{AccessExpr, BuiltinPredicate, builtin};
194
195 #[must_use]
196 pub const fn delegated_token_valid() -> AccessExpr {
197 builtin(BuiltinPredicate::DelegatedTokenValid)
198 }
199}
200
201pub mod rule {
202 use super::{AccessExpr, BuiltinPredicate, builtin};
203
204 #[must_use]
205 pub const fn build_ic_only() -> AccessExpr {
206 builtin(BuiltinPredicate::BuildIcOnly)
207 }
208
209 #[must_use]
210 pub const fn build_local_only() -> AccessExpr {
211 builtin(BuiltinPredicate::BuildLocalOnly)
212 }
213}
214
215#[allow(clippy::future_not_send)]
216pub async fn eval_access(expr: &AccessExpr, ctx: &AccessContext) -> Result<(), AccessError> {
217 match eval_access_inner(expr, ctx).await {
218 Ok(()) => Ok(()),
219 Err(failure) => Err(record_access_failure(ctx, failure)),
220 }
221}
222
223type AccessEvalFuture<'a> = Pin<Box<dyn Future<Output = Result<(), AccessFailure>> + Send + 'a>>;
224
225pub fn eval_default_app_guard(
226 guard: DefaultAppGuard,
227 ctx: &AccessContext,
228) -> Result<(), AccessError> {
229 let result = match guard {
230 DefaultAppGuard::AllowsUpdates => access::guard::guard_app_update(),
231 DefaultAppGuard::IsQueryable => access::guard::guard_app_query(),
232 };
233
234 match result {
235 Ok(()) => Ok(()),
236 Err(err) => {
237 let predicate = match guard {
238 DefaultAppGuard::AllowsUpdates => BuiltinPredicate::AppAllowsUpdates,
239 DefaultAppGuard::IsQueryable => BuiltinPredicate::AppIsQueryable,
240 };
241 Err(record_access_failure(
242 ctx,
243 AccessFailure::from_builtin(predicate, err),
244 ))
245 }
246 }
247}
248
249fn eval_access_inner<'a>(expr: &'a AccessExpr, ctx: &'a AccessContext) -> AccessEvalFuture<'a> {
250 Box::pin(async move {
251 match expr {
252 AccessExpr::All(exprs) => {
253 if exprs.is_empty() {
254 return Err(AccessFailure::no_predicates("all"));
255 }
256 for expr in exprs {
257 if let Err(failure) = eval_access_inner(expr, ctx).await {
258 return Err(failure.with_context("all"));
259 }
260 }
261 Ok(())
262 }
263 AccessExpr::Any(exprs) => {
264 if exprs.is_empty() {
265 return Err(AccessFailure::no_predicates("any"));
266 }
267 let mut last = None;
268 for expr in exprs {
269 match eval_access_inner(expr, ctx).await {
270 Ok(()) => return Ok(()),
271 Err(failure) => last = Some(failure.with_context("any")),
272 }
273 }
274 Err(last.unwrap_or_else(|| AccessFailure::no_predicates("any")))
275 }
276 AccessExpr::Not(expr) => match eval_access_inner(expr, ctx).await {
277 Ok(()) => Err(AccessFailure::negated()),
278 Err(_) => Ok(()),
279 },
280 AccessExpr::Pred(pred) => match pred {
281 AccessPredicate::Builtin(builtin) => eval_builtin(builtin, ctx)
282 .await
283 .map_err(|err| AccessFailure::from_builtin(*builtin, err)),
284 AccessPredicate::Custom(custom) => custom
285 .eval(ctx)
286 .await
287 .map_err(|err| AccessFailure::from_custom(custom.name(), err)),
288 },
289 }
290 })
291}
292
293async fn eval_builtin(pred: &BuiltinPredicate, ctx: &AccessContext) -> Result<(), AccessError> {
294 match pred {
295 BuiltinPredicate::AppAllowsUpdates => access::guard::guard_app_update(),
296 BuiltinPredicate::AppIsQueryable => access::guard::guard_app_query(),
297 BuiltinPredicate::SelfIsPrimeSubnet => access::env::is_prime_subnet().await,
298 BuiltinPredicate::SelfIsPrimeRoot => access::env::is_prime_root().await,
299 BuiltinPredicate::CallerIsController => access::auth::is_controller(ctx.caller).await,
300 BuiltinPredicate::CallerIsParent => access::auth::is_parent(ctx.caller).await,
301 BuiltinPredicate::CallerIsChild => access::auth::is_child(ctx.caller).await,
302 BuiltinPredicate::CallerIsRoot => access::auth::caller_is_root(ctx.caller).await,
303 BuiltinPredicate::CallerIsSameCanister => access::auth::is_same_canister(ctx.caller).await,
304 BuiltinPredicate::CallerIsRegisteredToSubnet => {
305 access::auth::is_registered_to_subnet(ctx.caller).await
306 }
307 BuiltinPredicate::CallerIsWhitelisted => access::auth::is_whitelisted(ctx.caller).await,
308 BuiltinPredicate::DelegatedTokenValid => access::auth::verify_delegated_token().await,
309 BuiltinPredicate::BuildIcOnly => access::rule::build_network_ic().await,
310 BuiltinPredicate::BuildLocalOnly => access::rule::build_network_local().await,
311 }
312}
313
314const fn builtin(pred: BuiltinPredicate) -> AccessExpr {
315 AccessExpr::Pred(AccessPredicate::Builtin(pred))
316}
317
318#[derive(Debug)]
319struct AccessFailure {
320 error: AccessError,
321 metric_kind: AccessMetricKind,
322 predicate: &'static str,
323 context: Option<&'static str>,
324}
325
326impl AccessFailure {
327 const fn from_builtin(pred: BuiltinPredicate, error: AccessError) -> Self {
328 Self {
329 error,
330 metric_kind: pred.metric_kind(),
331 predicate: pred.name(),
332 context: None,
333 }
334 }
335
336 const fn from_custom(name: &'static str, error: AccessError) -> Self {
337 Self {
338 error,
339 metric_kind: AccessMetricKind::Custom,
340 predicate: name,
341 context: None,
342 }
343 }
344
345 fn no_predicates(context: &'static str) -> Self {
346 Self {
347 error: AccessError::Denied("one or more rules must be defined".to_string()),
348 metric_kind: AccessMetricKind::Auth,
349 predicate: "no_rules",
350 context: Some(context),
351 }
352 }
353
354 fn negated() -> Self {
355 Self {
356 error: AccessError::Denied("negated predicate matched".to_string()),
357 metric_kind: AccessMetricKind::Auth,
358 predicate: "not",
359 context: Some("not"),
360 }
361 }
362
363 fn with_context(mut self, context: &'static str) -> Self {
364 self.context.get_or_insert(context);
365 self
366 }
367}
368
369impl BuiltinPredicate {
370 const fn name(self) -> &'static str {
371 match self {
372 Self::AppAllowsUpdates => "app_allows_updates",
373 Self::AppIsQueryable => "app_is_queryable",
374 Self::SelfIsPrimeSubnet => "self_is_prime_subnet",
375 Self::SelfIsPrimeRoot => "self_is_prime_root",
376 Self::CallerIsController => "caller_is_controller",
377 Self::CallerIsParent => "caller_is_parent",
378 Self::CallerIsChild => "caller_is_child",
379 Self::CallerIsRoot => "caller_is_root",
380 Self::CallerIsSameCanister => "caller_is_same_canister",
381 Self::CallerIsRegisteredToSubnet => "caller_is_registered_to_subnet",
382 Self::CallerIsWhitelisted => "caller_is_whitelisted",
383 Self::DelegatedTokenValid => "delegated_token_valid",
384 Self::BuildIcOnly => "build_ic_only",
385 Self::BuildLocalOnly => "build_local_only",
386 }
387 }
388
389 const fn metric_kind(self) -> AccessMetricKind {
390 match self {
391 Self::AppAllowsUpdates | Self::AppIsQueryable => AccessMetricKind::Guard,
392 Self::SelfIsPrimeSubnet | Self::SelfIsPrimeRoot => AccessMetricKind::Env,
393 Self::BuildIcOnly | Self::BuildLocalOnly => AccessMetricKind::Rule,
394 Self::CallerIsController
395 | Self::CallerIsParent
396 | Self::CallerIsChild
397 | Self::CallerIsRoot
398 | Self::CallerIsSameCanister
399 | Self::CallerIsRegisteredToSubnet
400 | Self::CallerIsWhitelisted
401 | Self::DelegatedTokenValid => AccessMetricKind::Auth,
402 }
403 }
404}
405
406fn record_access_failure(ctx: &AccessContext, failure: AccessFailure) -> AccessError {
407 AccessMetrics::increment(ctx.call, failure.metric_kind, failure.predicate);
408 log!(
409 Topic::Auth,
410 Warn,
411 "access denied kind={} predicate={} context={:?}: {}",
412 failure.metric_kind.as_str(),
413 failure.predicate,
414 failure.context,
415 failure.error,
416 );
417 failure.error
418}