1use crate::abi::{CapabilityMask, InstanceId, Principal};
14use core::marker::PhantomData;
15
16use super::op::Op;
17
18pub(crate) type InvariantLifetime<'i> = PhantomData<fn(&'i ()) -> &'i ()>;
21
22mod seal {
23 pub trait Sealed {}
24}
25
26#[derive(Debug)]
28pub enum Unverified {}
29
30#[derive(Debug)]
32pub enum Authorized {}
33
34impl seal::Sealed for Unverified {}
35impl seal::Sealed for Authorized {}
36
37pub trait AuthState: seal::Sealed {}
39impl AuthState for Unverified {}
40impl AuthState for Authorized {}
41
42#[derive(Debug)]
48pub struct Effect<'i, S: AuthState> {
49 pub(crate) instance_id: InstanceId,
50 pub(crate) principal: Principal,
51 pub(crate) op: Op,
52 _state: PhantomData<S>,
53 _brand: InvariantLifetime<'i>,
54}
55
56impl<'i> Effect<'i, Unverified> {
57 pub(crate) fn new(instance_id: InstanceId, principal: Principal, op: Op) -> Self {
60 Self {
61 instance_id,
62 principal,
63 op,
64 _state: PhantomData,
65 _brand: PhantomData,
66 }
67 }
68}
69
70impl<'i, S: AuthState> Effect<'i, S> {
71 pub fn instance_id(&self) -> InstanceId {
74 self.instance_id
75 }
76}
77
78#[non_exhaustive]
80#[derive(Clone, Debug, PartialEq, Eq)]
81pub enum DenyReason {
82 CapabilityDenied,
84 InstanceMismatch,
86 OperationRestricted,
89 NotImplemented,
91}
92
93impl core::fmt::Display for DenyReason {
94 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
95 match self {
96 Self::CapabilityDenied => write!(f, "capability denied"),
97 Self::InstanceMismatch => write!(f, "instance mismatch"),
98 Self::OperationRestricted => write!(f, "operation restricted"),
99 Self::NotImplemented => write!(f, "authorize not implemented"),
100 }
101 }
102}
103
104impl std::error::Error for DenyReason {}
105
106pub(crate) fn authorize<'i>(
114 caps: CapabilityMask,
115 effect: Effect<'i, Unverified>,
116) -> Result<Effect<'i, Authorized>, DenyReason> {
117 match &effect.principal {
118 Principal::System => { }
119 Principal::Unauthenticated => return Err(DenyReason::CapabilityDenied),
120 Principal::External(_) => {
121 if !caps.contains(CapabilityMask::SYSTEM) && !match_op_cap(&effect.op, caps) {
122 return Err(DenyReason::CapabilityDenied);
123 }
124 }
125 }
126 Ok(Effect {
127 instance_id: effect.instance_id,
128 principal: effect.principal,
129 op: effect.op,
130 _state: PhantomData,
131 _brand: PhantomData,
132 })
133}
134
135fn match_op_cap(op: &Op, caps: CapabilityMask) -> bool {
138 match op {
139 Op::SpawnEntity { .. }
141 | Op::DespawnEntity { .. }
142 | Op::SetComponent { .. }
143 | Op::RemoveComponent { .. }
144 | Op::EmitEvent { .. } => true,
145 Op::ScheduleAction { .. } | Op::SendSignal { .. } => caps.contains(CapabilityMask::SYSTEM),
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::abi::{EntityId, ExternalId, RouteId, Tick, TypeCode};
154 use bytes::Bytes;
155
156 fn inst() -> InstanceId {
157 InstanceId::new(1).unwrap()
158 }
159 fn ent() -> EntityId {
160 EntityId::new(1).unwrap()
161 }
162
163 #[test]
164 fn unverified_is_uninhabited() {
165 fn _proof(x: Unverified) -> ! {
166 match x {}
167 }
168 }
169
170 #[test]
171 fn authorized_is_uninhabited() {
172 fn _proof(x: Authorized) -> ! {
173 match x {}
174 }
175 }
176
177 #[test]
178 fn effect_carries_instance_id_and_op() {
179 let e: Effect<'_, Unverified> = Effect::new(
180 inst(),
181 Principal::System,
182 Op::SpawnEntity {
183 id: ent(),
184 owner: Principal::System,
185 },
186 );
187 assert_eq!(e.instance_id().get(), 1);
188 }
189
190 #[test]
191 fn auth_state_seal_blocks_external_impl() {
192 fn assert_authstate<T: AuthState>() {}
193 assert_authstate::<Unverified>();
194 assert_authstate::<Authorized>();
195 }
196
197 #[test]
198 fn deny_reason_display_and_error() {
199 assert_eq!(
200 format!("{}", DenyReason::CapabilityDenied),
201 "capability denied"
202 );
203 assert_eq!(
204 format!("{}", DenyReason::OperationRestricted),
205 "operation restricted"
206 );
207 fn assert_err<E: std::error::Error>() {}
208 assert_err::<DenyReason>();
209 }
210
211 fn spawn_op() -> Op {
214 Op::SpawnEntity {
215 id: ent(),
216 owner: Principal::System,
217 }
218 }
219 fn schedule_op() -> Op {
220 Op::ScheduleAction {
221 at: Tick(0),
222 actor: None,
223 action_type_code: TypeCode(0),
224 action_bytes: Bytes::new(),
225 action_principal: Principal::System,
226 }
227 }
228 fn signal_op() -> Op {
229 Op::SendSignal {
230 target: inst(),
231 route: RouteId(1),
232 payload: Bytes::new(),
233 }
234 }
235
236 #[test]
237 fn system_principal_authorized_for_all_ops() {
238 for op in [spawn_op(), schedule_op(), signal_op()] {
239 let e = Effect::new(inst(), Principal::System, op);
240 let result = authorize(CapabilityMask::default(), e);
241 assert!(result.is_ok(), "System principal must always pass");
242 }
243 }
244
245 #[test]
246 fn unauthenticated_principal_always_denied() {
247 let e = Effect::new(inst(), Principal::Unauthenticated, spawn_op());
248 let result = authorize(CapabilityMask::SYSTEM, e);
249 assert_eq!(result.unwrap_err(), DenyReason::CapabilityDenied);
250 }
251
252 #[test]
253 fn external_with_system_cap_authorized() {
254 let e = Effect::new(inst(), Principal::External(ExternalId(7)), schedule_op());
255 let result = authorize(CapabilityMask::SYSTEM, e);
256 assert!(result.is_ok());
257 }
258
259 #[test]
260 fn external_without_cap_denied_for_schedule() {
261 let e = Effect::new(inst(), Principal::External(ExternalId(7)), schedule_op());
262 let result = authorize(CapabilityMask::default(), e);
263 assert_eq!(result.unwrap_err(), DenyReason::CapabilityDenied);
264 }
265
266 #[test]
267 fn external_without_cap_denied_for_send_signal() {
268 let e = Effect::new(inst(), Principal::External(ExternalId(7)), signal_op());
269 let result = authorize(CapabilityMask::default(), e);
270 assert_eq!(result.unwrap_err(), DenyReason::CapabilityDenied);
271 }
272
273 #[test]
274 fn external_with_basic_cap_authorized_for_state_op() {
275 let e = Effect::new(inst(), Principal::External(ExternalId(7)), spawn_op());
276 let result = authorize(CapabilityMask::default(), e);
277 assert!(
278 result.is_ok(),
279 "External with basic state op (no SYSTEM) is allowed in MVP"
280 );
281 }
282}