dwbase_security/
lib.rs

1//! Security primitives for DWBase: capabilities, trust, and gatekeeping.
2//!
3//! The gatekeeper enforces per-worker capabilities and simple rate limits using a token bucket.
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9use dwbase_core::{AtomKind, Importance, WorldKey};
10use dwbase_engine::{DwbaseError, Gatekeeper, NewAtom, Question, Result, WorldAction};
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14#[derive(Clone, Debug, Default, Serialize, Deserialize)]
15pub struct Capabilities {
16    pub read_worlds: Vec<WorldKey>,
17    pub write_worlds: Vec<WorldKey>,
18    pub labels_write: Vec<String>,
19    pub kinds_write: Vec<AtomKind>,
20    pub importance_cap: Option<f32>,
21    pub rate_limits: RateLimits,
22    pub offline_policy: OfflinePolicy,
23}
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
26pub struct RateLimits {
27    /// Max remembers per second.
28    pub remember_per_sec: Option<f64>,
29    /// Max asks per minute.
30    pub ask_per_min: Option<f64>,
31}
32
33#[derive(Clone, Debug, Default, Serialize, Deserialize)]
34pub enum OfflinePolicy {
35    #[default]
36    Allow,
37    Deny,
38}
39
40#[derive(Clone, Debug, Default, Serialize, Deserialize)]
41pub struct TrustStore {
42    pub worker_scores: HashMap<String, f32>,
43    pub node_scores: HashMap<String, f32>,
44}
45
46impl TrustStore {
47    pub fn get_worker(&self, worker: &str) -> Option<f32> {
48        self.worker_scores.get(worker).copied()
49    }
50}
51
52#[derive(Debug, Error)]
53pub enum SecurityError {
54    #[error("capability denied: {0}")]
55    Capability(String),
56    #[error("rate limited: {0}")]
57    RateLimited(String),
58}
59
60impl From<SecurityError> for DwbaseError {
61    fn from(e: SecurityError) -> Self {
62        DwbaseError::CapabilityDenied(e.to_string())
63    }
64}
65
66trait Clock: Send + Sync {
67    fn now(&self) -> Instant;
68}
69
70#[derive(Default)]
71struct SystemClock;
72impl Clock for SystemClock {
73    fn now(&self) -> Instant {
74        Instant::now()
75    }
76}
77
78#[derive(Debug)]
79struct TokenBucket {
80    capacity: f64,
81    tokens: f64,
82    last_refill: Instant,
83    refill_per_sec: f64,
84}
85
86impl TokenBucket {
87    fn new(rate_per_sec: f64, now: Instant) -> Self {
88        let capacity = rate_per_sec.max(1.0);
89        Self {
90            capacity,
91            tokens: capacity,
92            last_refill: now,
93            refill_per_sec: rate_per_sec,
94        }
95    }
96
97    fn allow(&mut self, now: Instant) -> bool {
98        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
99        self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
100        self.last_refill = now;
101        if self.tokens >= 1.0 {
102            self.tokens -= 1.0;
103            true
104        } else {
105            false
106        }
107    }
108}
109
110pub struct LocalGatekeeper {
111    caps: Capabilities,
112    #[allow(dead_code)]
113    trust: TrustStore,
114    clock: Arc<dyn Clock>,
115    remember_buckets: Mutex<HashMap<String, TokenBucket>>,
116    ask_buckets: Mutex<HashMap<String, TokenBucket>>,
117}
118
119impl LocalGatekeeper {
120    pub fn new(caps: Capabilities, trust: TrustStore) -> Self {
121        Self::with_clock(caps, trust, Arc::new(SystemClock))
122    }
123
124    fn with_clock(caps: Capabilities, trust: TrustStore, clock: Arc<dyn Clock>) -> Self {
125        Self {
126            caps,
127            trust,
128            clock,
129            remember_buckets: Mutex::new(HashMap::new()),
130            ask_buckets: Mutex::new(HashMap::new()),
131        }
132    }
133
134    fn ensure_world_can_write(&self, world: &WorldKey) -> Result<()> {
135        if !self.caps.write_worlds.is_empty() && !self.caps.write_worlds.contains(world) {
136            return Err(SecurityError::Capability(format!(
137                "write not allowed for world {}",
138                world.0
139            ))
140            .into());
141        }
142        Ok(())
143    }
144
145    fn ensure_world_can_read(&self, world: &WorldKey) -> Result<()> {
146        if !self.caps.read_worlds.is_empty() && !self.caps.read_worlds.contains(world) {
147            return Err(SecurityError::Capability(format!(
148                "read not allowed for world {}",
149                world.0
150            ))
151            .into());
152        }
153        Ok(())
154    }
155
156    fn check_importance(&self, importance: Importance) -> Result<()> {
157        if let Some(cap) = self.caps.importance_cap {
158            if importance.get() > cap {
159                return Err(SecurityError::Capability(format!(
160                    "importance {} exceeds cap {}",
161                    importance.get(),
162                    cap
163                ))
164                .into());
165            }
166        }
167        Ok(())
168    }
169
170    fn check_labels(&self, labels: &[String]) -> Result<()> {
171        if self.caps.labels_write.is_empty() {
172            return Ok(());
173        }
174        if labels.iter().all(|l| self.caps.labels_write.contains(l)) {
175            Ok(())
176        } else {
177            Err(SecurityError::Capability("label not permitted".into()).into())
178        }
179    }
180
181    fn check_kind(&self, kind: &AtomKind) -> Result<()> {
182        if self.caps.kinds_write.is_empty() || self.caps.kinds_write.contains(kind) {
183            Ok(())
184        } else {
185            Err(SecurityError::Capability(format!("kind {kind:?} not permitted")).into())
186        }
187    }
188
189    fn rate_limit(
190        bucket_map: &Mutex<HashMap<String, TokenBucket>>,
191        key: &str,
192        rate_per_sec: Option<f64>,
193        clock: &Arc<dyn Clock>,
194    ) -> Result<()> {
195        if let Some(rate) = rate_per_sec {
196            let now = clock.now();
197            let mut buckets = bucket_map.lock().expect("bucket lock poisoned");
198            let bucket = buckets
199                .entry(key.to_string())
200                .or_insert_with(|| TokenBucket::new(rate, now));
201            if !bucket.allow(now) {
202                return Err(
203                    SecurityError::RateLimited(format!("rate limited for key {key}")).into(),
204                );
205            }
206        }
207        Ok(())
208    }
209}
210
211impl Gatekeeper for LocalGatekeeper {
212    fn check_remember(&self, new_atom: &NewAtom) -> Result<()> {
213        self.ensure_world_can_write(&new_atom.world)?;
214        self.check_kind(&new_atom.kind)?;
215        self.check_labels(&new_atom.labels)?;
216        self.check_importance(new_atom.importance)?;
217        Self::rate_limit(
218            &self.remember_buckets,
219            &new_atom.worker.0,
220            self.caps.rate_limits.remember_per_sec,
221            &self.clock,
222        )?;
223        Ok(())
224    }
225
226    fn check_ask(&self, question: &Question) -> Result<()> {
227        self.ensure_world_can_read(&question.world)?;
228        Self::rate_limit(
229            &self.ask_buckets,
230            "ask-global",
231            self.caps.rate_limits.ask_per_min.map(|v| v / 60.0),
232            &self.clock,
233        )?;
234        Ok(())
235    }
236
237    fn check_world_action(&self, action: &WorldAction) -> Result<()> {
238        match action {
239            WorldAction::Create(meta) => self.ensure_world_can_write(&meta.world),
240            WorldAction::Archive(world) | WorldAction::Resume(world) => {
241                self.ensure_world_can_write(world)
242            }
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use dwbase_core::{Timestamp, WorkerKey};
251    use std::time::Duration;
252
253    struct MockClock {
254        now: Mutex<Instant>,
255    }
256
257    impl MockClock {
258        fn advance(&self, dur: Duration) {
259            let mut guard = self.now.lock().unwrap();
260            *guard += dur;
261        }
262    }
263
264    impl Clock for MockClock {
265        fn now(&self) -> Instant {
266            *self.now.lock().unwrap()
267        }
268    }
269
270    fn gatekeeper_with_clock(clock: Arc<dyn Clock>, caps: Capabilities) -> LocalGatekeeper {
271        LocalGatekeeper::with_clock(caps, TrustStore::default(), clock)
272    }
273
274    fn sample_atom(world: &str, importance: f32, labels: &[&str], kind: AtomKind) -> NewAtom {
275        let mut labels_vec = Vec::new();
276        for l in labels {
277            labels_vec.push((*l).to_string());
278        }
279        NewAtom {
280            world: WorldKey::new(world),
281            worker: WorkerKey::new("worker-1"),
282            kind,
283            timestamp: Timestamp::new("2024-01-01T00:00:00Z"),
284            importance: Importance::new(importance).unwrap(),
285            payload_json: "{}".into(),
286            vector: None,
287            flags: Vec::new(),
288            labels: labels_vec,
289            links: Vec::new(),
290        }
291    }
292
293    #[test]
294    fn capability_denied_on_world_and_kind_and_importance() {
295        let caps = Capabilities {
296            write_worlds: vec![WorldKey::new("allowed")],
297            kinds_write: vec![AtomKind::Observation],
298            importance_cap: Some(0.5),
299            ..Default::default()
300        };
301        let gk = LocalGatekeeper::new(caps, TrustStore::default());
302        let bad_world = sample_atom("blocked", 0.4, &[], AtomKind::Observation);
303        assert!(gk.check_remember(&bad_world).is_err());
304        let bad_kind = sample_atom("allowed", 0.4, &[], AtomKind::Action);
305        assert!(gk.check_remember(&bad_kind).is_err());
306        let bad_importance = sample_atom("allowed", 0.9, &[], AtomKind::Observation);
307        assert!(gk.check_remember(&bad_importance).is_err());
308    }
309
310    #[test]
311    fn rate_limiting_enforced_deterministically() {
312        let clock = Arc::new(MockClock {
313            now: Mutex::new(Instant::now()),
314        });
315        let caps = Capabilities {
316            write_worlds: vec![WorldKey::new("w")],
317            rate_limits: RateLimits {
318                remember_per_sec: Some(1.0),
319                ask_per_min: Some(60.0),
320            },
321            ..Default::default()
322        };
323        let gk = gatekeeper_with_clock(clock.clone(), caps);
324        let atom = sample_atom("w", 0.1, &[], AtomKind::Observation);
325        // First allowed
326        assert!(gk.check_remember(&atom).is_ok());
327        // Second within same second denied
328        assert!(gk.check_remember(&atom).is_err());
329        // Advance 1s, allowed again
330        clock.advance(Duration::from_secs(1));
331        assert!(gk.check_remember(&atom).is_ok());
332
333        let question = Question {
334            world: WorldKey::new("w"),
335            text: "q".into(),
336            filter: None,
337        };
338        assert!(gk.check_ask(&question).is_ok());
339        // ask_per_min=60 => 1/sec tokens; next immediate call should fail
340        assert!(gk.check_ask(&question).is_err());
341        clock.advance(Duration::from_secs(1));
342        assert!(gk.check_ask(&question).is_ok());
343    }
344}