1use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9use dwbase_core::{AtomKind, Importance, WorldKey};
10use dwbase_engine::{DwbaseError, Gatekeeper, NewAtom, Question, Result, WorldAction};
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14#[derive(Clone, Debug, Default, Serialize, Deserialize)]
15pub struct Capabilities {
16 pub read_worlds: Vec<WorldKey>,
17 pub write_worlds: Vec<WorldKey>,
18 pub labels_write: Vec<String>,
19 pub kinds_write: Vec<AtomKind>,
20 pub importance_cap: Option<f32>,
21 pub rate_limits: RateLimits,
22 pub offline_policy: OfflinePolicy,
23}
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
26pub struct RateLimits {
27 pub remember_per_sec: Option<f64>,
29 pub ask_per_min: Option<f64>,
31}
32
33#[derive(Clone, Debug, Default, Serialize, Deserialize)]
34pub enum OfflinePolicy {
35 #[default]
36 Allow,
37 Deny,
38}
39
40#[derive(Clone, Debug, Default, Serialize, Deserialize)]
41pub struct TrustStore {
42 pub worker_scores: HashMap<String, f32>,
43 pub node_scores: HashMap<String, f32>,
44}
45
46impl TrustStore {
47 pub fn get_worker(&self, worker: &str) -> Option<f32> {
48 self.worker_scores.get(worker).copied()
49 }
50}
51
52#[derive(Debug, Error)]
53pub enum SecurityError {
54 #[error("capability denied: {0}")]
55 Capability(String),
56 #[error("rate limited: {0}")]
57 RateLimited(String),
58}
59
60impl From<SecurityError> for DwbaseError {
61 fn from(e: SecurityError) -> Self {
62 DwbaseError::CapabilityDenied(e.to_string())
63 }
64}
65
66trait Clock: Send + Sync {
67 fn now(&self) -> Instant;
68}
69
70#[derive(Default)]
71struct SystemClock;
72impl Clock for SystemClock {
73 fn now(&self) -> Instant {
74 Instant::now()
75 }
76}
77
78#[derive(Debug)]
79struct TokenBucket {
80 capacity: f64,
81 tokens: f64,
82 last_refill: Instant,
83 refill_per_sec: f64,
84}
85
86impl TokenBucket {
87 fn new(rate_per_sec: f64, now: Instant) -> Self {
88 let capacity = rate_per_sec.max(1.0);
89 Self {
90 capacity,
91 tokens: capacity,
92 last_refill: now,
93 refill_per_sec: rate_per_sec,
94 }
95 }
96
97 fn allow(&mut self, now: Instant) -> bool {
98 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
99 self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
100 self.last_refill = now;
101 if self.tokens >= 1.0 {
102 self.tokens -= 1.0;
103 true
104 } else {
105 false
106 }
107 }
108}
109
110pub struct LocalGatekeeper {
111 caps: Capabilities,
112 #[allow(dead_code)]
113 trust: TrustStore,
114 clock: Arc<dyn Clock>,
115 remember_buckets: Mutex<HashMap<String, TokenBucket>>,
116 ask_buckets: Mutex<HashMap<String, TokenBucket>>,
117}
118
119impl LocalGatekeeper {
120 pub fn new(caps: Capabilities, trust: TrustStore) -> Self {
121 Self::with_clock(caps, trust, Arc::new(SystemClock))
122 }
123
124 fn with_clock(caps: Capabilities, trust: TrustStore, clock: Arc<dyn Clock>) -> Self {
125 Self {
126 caps,
127 trust,
128 clock,
129 remember_buckets: Mutex::new(HashMap::new()),
130 ask_buckets: Mutex::new(HashMap::new()),
131 }
132 }
133
134 fn ensure_world_can_write(&self, world: &WorldKey) -> Result<()> {
135 if !self.caps.write_worlds.is_empty() && !self.caps.write_worlds.contains(world) {
136 return Err(SecurityError::Capability(format!(
137 "write not allowed for world {}",
138 world.0
139 ))
140 .into());
141 }
142 Ok(())
143 }
144
145 fn ensure_world_can_read(&self, world: &WorldKey) -> Result<()> {
146 if !self.caps.read_worlds.is_empty() && !self.caps.read_worlds.contains(world) {
147 return Err(SecurityError::Capability(format!(
148 "read not allowed for world {}",
149 world.0
150 ))
151 .into());
152 }
153 Ok(())
154 }
155
156 fn check_importance(&self, importance: Importance) -> Result<()> {
157 if let Some(cap) = self.caps.importance_cap {
158 if importance.get() > cap {
159 return Err(SecurityError::Capability(format!(
160 "importance {} exceeds cap {}",
161 importance.get(),
162 cap
163 ))
164 .into());
165 }
166 }
167 Ok(())
168 }
169
170 fn check_labels(&self, labels: &[String]) -> Result<()> {
171 if self.caps.labels_write.is_empty() {
172 return Ok(());
173 }
174 if labels.iter().all(|l| self.caps.labels_write.contains(l)) {
175 Ok(())
176 } else {
177 Err(SecurityError::Capability("label not permitted".into()).into())
178 }
179 }
180
181 fn check_kind(&self, kind: &AtomKind) -> Result<()> {
182 if self.caps.kinds_write.is_empty() || self.caps.kinds_write.contains(kind) {
183 Ok(())
184 } else {
185 Err(SecurityError::Capability(format!("kind {kind:?} not permitted")).into())
186 }
187 }
188
189 fn rate_limit(
190 bucket_map: &Mutex<HashMap<String, TokenBucket>>,
191 key: &str,
192 rate_per_sec: Option<f64>,
193 clock: &Arc<dyn Clock>,
194 ) -> Result<()> {
195 if let Some(rate) = rate_per_sec {
196 let now = clock.now();
197 let mut buckets = bucket_map.lock().expect("bucket lock poisoned");
198 let bucket = buckets
199 .entry(key.to_string())
200 .or_insert_with(|| TokenBucket::new(rate, now));
201 if !bucket.allow(now) {
202 return Err(
203 SecurityError::RateLimited(format!("rate limited for key {key}")).into(),
204 );
205 }
206 }
207 Ok(())
208 }
209}
210
211impl Gatekeeper for LocalGatekeeper {
212 fn check_remember(&self, new_atom: &NewAtom) -> Result<()> {
213 self.ensure_world_can_write(&new_atom.world)?;
214 self.check_kind(&new_atom.kind)?;
215 self.check_labels(&new_atom.labels)?;
216 self.check_importance(new_atom.importance)?;
217 Self::rate_limit(
218 &self.remember_buckets,
219 &new_atom.worker.0,
220 self.caps.rate_limits.remember_per_sec,
221 &self.clock,
222 )?;
223 Ok(())
224 }
225
226 fn check_ask(&self, question: &Question) -> Result<()> {
227 self.ensure_world_can_read(&question.world)?;
228 Self::rate_limit(
229 &self.ask_buckets,
230 "ask-global",
231 self.caps.rate_limits.ask_per_min.map(|v| v / 60.0),
232 &self.clock,
233 )?;
234 Ok(())
235 }
236
237 fn check_world_action(&self, action: &WorldAction) -> Result<()> {
238 match action {
239 WorldAction::Create(meta) => self.ensure_world_can_write(&meta.world),
240 WorldAction::Archive(world) | WorldAction::Resume(world) => {
241 self.ensure_world_can_write(world)
242 }
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use dwbase_core::{Timestamp, WorkerKey};
251 use std::time::Duration;
252
253 struct MockClock {
254 now: Mutex<Instant>,
255 }
256
257 impl MockClock {
258 fn advance(&self, dur: Duration) {
259 let mut guard = self.now.lock().unwrap();
260 *guard += dur;
261 }
262 }
263
264 impl Clock for MockClock {
265 fn now(&self) -> Instant {
266 *self.now.lock().unwrap()
267 }
268 }
269
270 fn gatekeeper_with_clock(clock: Arc<dyn Clock>, caps: Capabilities) -> LocalGatekeeper {
271 LocalGatekeeper::with_clock(caps, TrustStore::default(), clock)
272 }
273
274 fn sample_atom(world: &str, importance: f32, labels: &[&str], kind: AtomKind) -> NewAtom {
275 let mut labels_vec = Vec::new();
276 for l in labels {
277 labels_vec.push((*l).to_string());
278 }
279 NewAtom {
280 world: WorldKey::new(world),
281 worker: WorkerKey::new("worker-1"),
282 kind,
283 timestamp: Timestamp::new("2024-01-01T00:00:00Z"),
284 importance: Importance::new(importance).unwrap(),
285 payload_json: "{}".into(),
286 vector: None,
287 flags: Vec::new(),
288 labels: labels_vec,
289 links: Vec::new(),
290 }
291 }
292
293 #[test]
294 fn capability_denied_on_world_and_kind_and_importance() {
295 let caps = Capabilities {
296 write_worlds: vec![WorldKey::new("allowed")],
297 kinds_write: vec![AtomKind::Observation],
298 importance_cap: Some(0.5),
299 ..Default::default()
300 };
301 let gk = LocalGatekeeper::new(caps, TrustStore::default());
302 let bad_world = sample_atom("blocked", 0.4, &[], AtomKind::Observation);
303 assert!(gk.check_remember(&bad_world).is_err());
304 let bad_kind = sample_atom("allowed", 0.4, &[], AtomKind::Action);
305 assert!(gk.check_remember(&bad_kind).is_err());
306 let bad_importance = sample_atom("allowed", 0.9, &[], AtomKind::Observation);
307 assert!(gk.check_remember(&bad_importance).is_err());
308 }
309
310 #[test]
311 fn rate_limiting_enforced_deterministically() {
312 let clock = Arc::new(MockClock {
313 now: Mutex::new(Instant::now()),
314 });
315 let caps = Capabilities {
316 write_worlds: vec![WorldKey::new("w")],
317 rate_limits: RateLimits {
318 remember_per_sec: Some(1.0),
319 ask_per_min: Some(60.0),
320 },
321 ..Default::default()
322 };
323 let gk = gatekeeper_with_clock(clock.clone(), caps);
324 let atom = sample_atom("w", 0.1, &[], AtomKind::Observation);
325 assert!(gk.check_remember(&atom).is_ok());
327 assert!(gk.check_remember(&atom).is_err());
329 clock.advance(Duration::from_secs(1));
331 assert!(gk.check_remember(&atom).is_ok());
332
333 let question = Question {
334 world: WorldKey::new("w"),
335 text: "q".into(),
336 filter: None,
337 };
338 assert!(gk.check_ask(&question).is_ok());
339 assert!(gk.check_ask(&question).is_err());
341 clock.advance(Duration::from_secs(1));
342 assert!(gk.check_ask(&question).is_ok());
343 }
344}