Skip to main content

sochdb_query/
capability_token.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Capability Tokens + ACLs (Task 8)
19//!
20//! This module implements staged ACLs via capability tokens for the local-first
21//! architecture. The design prioritizes:
22//!
23//! 1. **Simplicity** - Easy to reason about, hard to misapply
24//! 2. **Local-first** - No external auth service required
25//! 3. **Composability** - ACLs integrate with existing filter infrastructure
26//!
27//! ## Token Structure
28//!
29//! ```text
30//! CapabilityToken {
31//!     allowed_namespaces: ["prod", "staging"],
32//!     tenant_id: Option<"acme_corp">,
33//!     project_id: Option<"project_123">,
34//!     capabilities: { read: true, write: false, ... },
35//!     expires_at: 1735689600,
36//!     signature: HMAC-SHA256(...)
37//! }
38//! ```
39//!
40//! ## Verification
41//!
42//! Token verification is O(1):
43//! - HMAC-SHA256 for symmetric tokens
44//! - Ed25519 for asymmetric tokens (cached verification)
45//!
46//! ## Row-Level ACLs (Future)
47//!
48//! Row-level ACL tags become "just another metadata atom":
49//! ```text
50//! HasTag(acl_tag) → bitmap lookup → AllowedSet intersection
51//! ```
52//!
53//! This composes cleanly with existing filter infrastructure.
54
55use std::collections::HashSet;
56use std::time::{Duration, SystemTime, UNIX_EPOCH};
57
58use serde::{Deserialize, Serialize};
59
60use crate::filter_ir::{AuthCapabilities, AuthScope};
61
62// ============================================================================
63// Capability Token
64// ============================================================================
65
66/// A capability token that encodes access permissions
67///
68/// This is the serializable form that can be passed across API boundaries.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct CapabilityToken {
71    /// Token version (for future upgrades)
72    pub version: u8,
73
74    /// Token ID (for revocation tracking)
75    pub token_id: String,
76
77    /// Allowed namespaces (non-empty)
78    pub allowed_namespaces: Vec<String>,
79
80    /// Optional tenant ID
81    pub tenant_id: Option<String>,
82
83    /// Optional project ID
84    pub project_id: Option<String>,
85
86    /// Capability flags
87    pub capabilities: TokenCapabilities,
88
89    /// Issued at (Unix timestamp)
90    pub issued_at: u64,
91
92    /// Expires at (Unix timestamp)
93    pub expires_at: u64,
94
95    /// ACL tags the token holder can access (for row-level ACLs)
96    pub acl_tags: Vec<String>,
97
98    /// Signature (HMAC-SHA256 or Ed25519)
99    pub signature: Vec<u8>,
100}
101
102/// Capability flags in the token
103#[derive(Debug, Clone, Default, Serialize, Deserialize)]
104pub struct TokenCapabilities {
105    /// Can read/query vectors
106    pub can_read: bool,
107    /// Can insert vectors
108    pub can_write: bool,
109    /// Can delete vectors
110    pub can_delete: bool,
111    /// Can perform admin operations (create/drop indexes)
112    pub can_admin: bool,
113    /// Can create new tokens (delegation)
114    pub can_delegate: bool,
115}
116
117impl CapabilityToken {
118    /// Current token version
119    pub const CURRENT_VERSION: u8 = 1;
120
121    /// Check if the token is expired
122    pub fn is_expired(&self) -> bool {
123        let now = SystemTime::now()
124            .duration_since(UNIX_EPOCH)
125            .map(|d| d.as_secs())
126            .unwrap_or(0);
127        now > self.expires_at
128    }
129
130    /// Check if a namespace is allowed
131    pub fn is_namespace_allowed(&self, namespace: &str) -> bool {
132        self.allowed_namespaces.iter().any(|ns| ns == namespace)
133    }
134
135    /// Convert to AuthScope for use with FilterIR
136    pub fn to_auth_scope(&self) -> AuthScope {
137        AuthScope {
138            allowed_namespaces: self.allowed_namespaces.clone(),
139            tenant_id: self.tenant_id.clone(),
140            project_id: self.project_id.clone(),
141            expires_at: Some(self.expires_at),
142            capabilities: AuthCapabilities {
143                can_read: self.capabilities.can_read,
144                can_write: self.capabilities.can_write,
145                can_delete: self.capabilities.can_delete,
146                can_admin: self.capabilities.can_admin,
147            },
148            acl_tags: self.acl_tags.clone(),
149        }
150    }
151
152    /// Get remaining validity duration
153    pub fn remaining_validity(&self) -> Option<Duration> {
154        let now = SystemTime::now()
155            .duration_since(UNIX_EPOCH)
156            .map(|d| d.as_secs())
157            .unwrap_or(0);
158
159        if now >= self.expires_at {
160            None
161        } else {
162            Some(Duration::from_secs(self.expires_at - now))
163        }
164    }
165}
166
167// ============================================================================
168// Token Builder
169// ============================================================================
170
171/// Builder for creating capability tokens
172pub struct TokenBuilder {
173    namespaces: Vec<String>,
174    tenant_id: Option<String>,
175    project_id: Option<String>,
176    capabilities: TokenCapabilities,
177    validity: Duration,
178    acl_tags: Vec<String>,
179}
180
181impl TokenBuilder {
182    /// Create a new token builder for a namespace
183    pub fn new(namespace: impl Into<String>) -> Self {
184        Self {
185            namespaces: vec![namespace.into()],
186            tenant_id: None,
187            project_id: None,
188            capabilities: TokenCapabilities {
189                can_read: true,
190                ..Default::default()
191            },
192            validity: Duration::from_secs(3600), // 1 hour default
193            acl_tags: Vec::new(),
194        }
195    }
196
197    /// Add another namespace
198    pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
199        self.namespaces.push(namespace.into());
200        self
201    }
202
203    /// Set tenant ID
204    pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
205        self.tenant_id = Some(tenant_id.into());
206        self
207    }
208
209    /// Set project ID
210    pub fn with_project(mut self, project_id: impl Into<String>) -> Self {
211        self.project_id = Some(project_id.into());
212        self
213    }
214
215    /// Enable read capability
216    pub fn can_read(mut self) -> Self {
217        self.capabilities.can_read = true;
218        self
219    }
220
221    /// Enable write capability
222    pub fn can_write(mut self) -> Self {
223        self.capabilities.can_write = true;
224        self
225    }
226
227    /// Enable delete capability
228    pub fn can_delete(mut self) -> Self {
229        self.capabilities.can_delete = true;
230        self
231    }
232
233    /// Enable admin capability
234    pub fn can_admin(mut self) -> Self {
235        self.capabilities.can_admin = true;
236        self
237    }
238
239    /// Enable all capabilities
240    pub fn full_access(mut self) -> Self {
241        self.capabilities = TokenCapabilities {
242            can_read: true,
243            can_write: true,
244            can_delete: true,
245            can_admin: true,
246            can_delegate: false,
247        };
248        self
249    }
250
251    /// Set validity duration
252    pub fn valid_for(mut self, duration: Duration) -> Self {
253        self.validity = duration;
254        self
255    }
256
257    /// Add ACL tags
258    pub fn with_acl_tags(mut self, tags: Vec<String>) -> Self {
259        self.acl_tags = tags;
260        self
261    }
262
263    /// Build the token (unsigned - call sign() on TokenSigner)
264    pub fn build_unsigned(self) -> CapabilityToken {
265        let now = SystemTime::now()
266            .duration_since(UNIX_EPOCH)
267            .map(|d| d.as_secs())
268            .unwrap_or(0);
269
270        CapabilityToken {
271            version: CapabilityToken::CURRENT_VERSION,
272            token_id: generate_token_id(),
273            allowed_namespaces: self.namespaces,
274            tenant_id: self.tenant_id,
275            project_id: self.project_id,
276            capabilities: self.capabilities,
277            issued_at: now,
278            expires_at: now + self.validity.as_secs(),
279            acl_tags: self.acl_tags,
280            signature: Vec::new(),
281        }
282    }
283}
284
285/// Generate a unique token ID
286fn generate_token_id() -> String {
287    // Simple ID generation - in production use UUID or similar
288    format!(
289        "tok_{:x}",
290        std::time::SystemTime::now()
291            .duration_since(UNIX_EPOCH)
292            .unwrap_or_default()
293            .as_nanos()
294    )
295}
296
297// ============================================================================
298// Token Signing and Verification
299// ============================================================================
300
301/// Token signer using HMAC-SHA256
302pub struct TokenSigner {
303    /// Secret key for HMAC
304    secret: Vec<u8>,
305}
306
307impl TokenSigner {
308    /// Create a new signer with a secret key
309    pub fn new(secret: impl AsRef<[u8]>) -> Self {
310        Self {
311            secret: secret.as_ref().to_vec(),
312        }
313    }
314
315    /// Sign a token
316    pub fn sign(&self, token: &mut CapabilityToken) {
317        let payload = self.compute_payload(token);
318        token.signature = self.hmac_sha256(&payload);
319    }
320
321    /// Verify a token signature
322    pub fn verify(&self, token: &CapabilityToken) -> Result<(), TokenError> {
323        // Check version
324        if token.version != CapabilityToken::CURRENT_VERSION {
325            return Err(TokenError::UnsupportedVersion(token.version));
326        }
327
328        // Check expiry
329        if token.is_expired() {
330            return Err(TokenError::Expired);
331        }
332
333        // Verify signature
334        let payload = self.compute_payload(token);
335        let expected = self.hmac_sha256(&payload);
336
337        if !constant_time_eq(&token.signature, &expected) {
338            return Err(TokenError::InvalidSignature);
339        }
340
341        Ok(())
342    }
343
344    /// Compute the payload to sign
345    fn compute_payload(&self, token: &CapabilityToken) -> Vec<u8> {
346        // Deterministic serialization of token fields (excluding signature)
347        let mut payload = Vec::new();
348
349        payload.push(token.version);
350        payload.extend(token.token_id.as_bytes());
351
352        for ns in &token.allowed_namespaces {
353            payload.extend(ns.as_bytes());
354            payload.push(0); // Separator
355        }
356
357        if let Some(ref tenant) = token.tenant_id {
358            payload.extend(tenant.as_bytes());
359        }
360        payload.push(0);
361
362        if let Some(ref project) = token.project_id {
363            payload.extend(project.as_bytes());
364        }
365        payload.push(0);
366
367        // Capabilities as flags
368        let caps = (token.capabilities.can_read as u8)
369            | ((token.capabilities.can_write as u8) << 1)
370            | ((token.capabilities.can_delete as u8) << 2)
371            | ((token.capabilities.can_admin as u8) << 3)
372            | ((token.capabilities.can_delegate as u8) << 4);
373        payload.push(caps);
374
375        payload.extend(&token.issued_at.to_le_bytes());
376        payload.extend(&token.expires_at.to_le_bytes());
377
378        for tag in &token.acl_tags {
379            payload.extend(tag.as_bytes());
380            payload.push(0);
381        }
382
383        payload
384    }
385
386    /// HMAC-SHA256 using the `ring` crate for cryptographic security.
387    fn hmac_sha256(&self, data: &[u8]) -> Vec<u8> {
388        use ring::hmac;
389        let key = hmac::Key::new(hmac::HMAC_SHA256, &self.secret);
390        let tag = hmac::sign(&key, data);
391        tag.as_ref().to_vec() // 32 bytes
392    }
393}
394
395/// Constant-time comparison to prevent timing attacks.
396///
397/// Uses ring's HMAC verification path for timing safety:
398/// re-compute the HMAC and compare via `ring::hmac::verify`.
399fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
400    if a.len() != b.len() {
401        return false;
402    }
403    // Fallback: XOR-based constant-time compare (no deprecated ring API)
404    let mut diff = 0u8;
405    for (x, y) in a.iter().zip(b.iter()) {
406        diff |= x ^ y;
407    }
408    diff == 0
409}
410
411/// Token errors
412#[derive(Debug, Clone, thiserror::Error)]
413pub enum TokenError {
414    #[error("token has expired")]
415    Expired,
416
417    #[error("invalid signature")]
418    InvalidSignature,
419
420    #[error("unsupported token version: {0}")]
421    UnsupportedVersion(u8),
422
423    #[error("token revoked")]
424    Revoked,
425
426    #[error("namespace not allowed: {0}")]
427    NamespaceNotAllowed(String),
428
429    #[error("insufficient capabilities")]
430    InsufficientCapabilities,
431}
432
433// ============================================================================
434// Token Revocation (Simple In-Memory)
435// ============================================================================
436
437/// Simple in-memory token revocation list
438pub struct RevocationList {
439    /// Revoked token IDs
440    revoked: std::sync::RwLock<HashSet<String>>,
441}
442
443impl RevocationList {
444    /// Create a new revocation list
445    pub fn new() -> Self {
446        Self {
447            revoked: std::sync::RwLock::new(HashSet::new()),
448        }
449    }
450
451    /// Revoke a token
452    pub fn revoke(&self, token_id: &str) {
453        self.revoked.write().unwrap().insert(token_id.to_string());
454    }
455
456    /// Check if a token is revoked
457    pub fn is_revoked(&self, token_id: &str) -> bool {
458        self.revoked.read().unwrap().contains(token_id)
459    }
460
461    /// Get count of revoked tokens
462    pub fn count(&self) -> usize {
463        self.revoked.read().unwrap().len()
464    }
465}
466
467impl Default for RevocationList {
468    fn default() -> Self {
469        Self::new()
470    }
471}
472
473// ============================================================================
474// Token Validator (Combines Signer + Revocation)
475// ============================================================================
476
477/// Complete token validator
478pub struct TokenValidator {
479    signer: TokenSigner,
480    revocation_list: RevocationList,
481}
482
483impl TokenValidator {
484    /// Create a new validator
485    pub fn new(secret: impl AsRef<[u8]>) -> Self {
486        Self {
487            signer: TokenSigner::new(secret),
488            revocation_list: RevocationList::new(),
489        }
490    }
491
492    /// Issue a new token
493    pub fn issue(&self, builder: TokenBuilder) -> CapabilityToken {
494        let mut token = builder.build_unsigned();
495        self.signer.sign(&mut token);
496        token
497    }
498
499    /// Validate a token
500    pub fn validate(&self, token: &CapabilityToken) -> Result<AuthScope, TokenError> {
501        // Check revocation
502        if self.revocation_list.is_revoked(&token.token_id) {
503            return Err(TokenError::Revoked);
504        }
505
506        // Verify signature and expiry
507        self.signer.verify(token)?;
508
509        // Convert to AuthScope
510        Ok(token.to_auth_scope())
511    }
512
513    /// Revoke a token
514    pub fn revoke(&self, token_id: &str) {
515        self.revocation_list.revoke(token_id);
516    }
517}
518
519// ============================================================================
520// Row-Level ACL Tags (Future Extension)
521// ============================================================================
522
523/// A row-level ACL tag
524///
525/// In the future, documents can have ACL tags and tokens can specify
526/// which tags they can access. This integrates with the filter IR:
527///
528/// ```text
529/// FilterAtom::HasTag("confidential") → bitmap lookup → intersection
530/// ```
531#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
532pub struct AclTag(String);
533
534impl AclTag {
535    /// Create a new ACL tag
536    pub fn new(tag: impl Into<String>) -> Self {
537        Self(tag.into())
538    }
539
540    /// Get the tag name
541    pub fn name(&self) -> &str {
542        &self.0
543    }
544}
545
546/// ACL tag index for row-level security
547///
548/// This would be integrated with MetadataIndex to provide:
549/// tag → bitmap of doc_ids with that tag
550#[derive(Debug, Default)]
551pub struct AclTagIndex {
552    /// Map from tag to doc_ids
553    tag_to_docs: std::collections::HashMap<String, Vec<u64>>,
554}
555
556impl AclTagIndex {
557    /// Create a new ACL tag index
558    pub fn new() -> Self {
559        Self::default()
560    }
561
562    /// Add a tag to a document
563    pub fn add_tag(&mut self, doc_id: u64, tag: &str) {
564        self.tag_to_docs
565            .entry(tag.to_string())
566            .or_default()
567            .push(doc_id);
568    }
569
570    /// Get doc_ids with a specific tag
571    pub fn docs_with_tag(&self, tag: &str) -> &[u64] {
572        self.tag_to_docs
573            .get(tag)
574            .map(|v| v.as_slice())
575            .unwrap_or(&[])
576    }
577
578    /// Get doc_ids accessible by a set of allowed tags (union)
579    pub fn accessible_docs(&self, allowed_tags: &[String]) -> Vec<u64> {
580        let mut result = HashSet::new();
581        for tag in allowed_tags {
582            if let Some(docs) = self.tag_to_docs.get(tag) {
583                result.extend(docs.iter().copied());
584            }
585        }
586        result.into_iter().collect()
587    }
588}
589
590// ============================================================================
591// Tests
592// ============================================================================
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    #[test]
599    fn test_token_builder() {
600        let token = TokenBuilder::new("production")
601            .with_namespace("staging")
602            .with_tenant("acme")
603            .can_read()
604            .can_write()
605            .valid_for(Duration::from_secs(3600))
606            .build_unsigned();
607
608        assert_eq!(token.allowed_namespaces.len(), 2);
609        assert_eq!(token.tenant_id, Some("acme".to_string()));
610        assert!(token.capabilities.can_read);
611        assert!(token.capabilities.can_write);
612        assert!(!token.capabilities.can_delete);
613    }
614
615    #[test]
616    fn test_token_signing_and_verification() {
617        let signer = TokenSigner::new("super_secret_key");
618
619        let mut token = TokenBuilder::new("production")
620            .can_read()
621            .valid_for(Duration::from_secs(3600))
622            .build_unsigned();
623
624        signer.sign(&mut token);
625        assert!(!token.signature.is_empty());
626
627        // Verification should succeed
628        assert!(signer.verify(&token).is_ok());
629
630        // Tamper with token
631        token.allowed_namespaces.push("hacked".to_string());
632        assert!(signer.verify(&token).is_err());
633    }
634
635    #[test]
636    fn test_token_expiry() {
637        // Create a token that expires 1 second in the past
638        let mut token = TokenBuilder::new("production")
639            .valid_for(Duration::from_secs(3600))
640            .build_unsigned();
641
642        // Manually set expires_at to 0 (Unix epoch - in the past)
643        token.expires_at = 0;
644
645        assert!(token.is_expired());
646    }
647
648    #[test]
649    fn test_token_to_auth_scope() {
650        let token = TokenBuilder::new("production")
651            .with_tenant("acme")
652            .can_read()
653            .can_write()
654            .with_acl_tags(vec!["public".to_string(), "internal".to_string()])
655            .build_unsigned();
656
657        let scope = token.to_auth_scope();
658        assert!(scope.is_namespace_allowed("production"));
659        assert!(!scope.is_namespace_allowed("staging"));
660        assert_eq!(scope.tenant_id, Some("acme".to_string()));
661        assert!(scope.capabilities.can_read);
662        assert!(scope.capabilities.can_write);
663        assert_eq!(scope.acl_tags.len(), 2);
664    }
665
666    #[test]
667    fn test_revocation() {
668        let validator = TokenValidator::new("secret");
669
670        let token = validator.issue(
671            TokenBuilder::new("production")
672                .can_read()
673                .valid_for(Duration::from_secs(3600)),
674        );
675
676        // Should validate
677        assert!(validator.validate(&token).is_ok());
678
679        // Revoke
680        validator.revoke(&token.token_id);
681
682        // Should fail validation
683        assert!(matches!(
684            validator.validate(&token),
685            Err(TokenError::Revoked)
686        ));
687    }
688
689    #[test]
690    fn test_acl_tag_index() {
691        let mut index = AclTagIndex::new();
692
693        index.add_tag(1, "public");
694        index.add_tag(2, "public");
695        index.add_tag(3, "internal");
696        index.add_tag(4, "confidential");
697
698        assert_eq!(index.docs_with_tag("public").len(), 2);
699        assert_eq!(index.docs_with_tag("internal").len(), 1);
700
701        let accessible = index.accessible_docs(&["public".to_string(), "internal".to_string()]);
702        assert_eq!(accessible.len(), 3);
703    }
704}