ubiquity-core 0.1.1

Core types and traits for Ubiquity consciousness-aware mesh
Documentation
//! Authentication and authorization

use crate::{AuthToken, AuthType, AgentCapability, UbiquityError};
use chrono::{Utc, Duration};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;

/// Authentication service
pub struct AuthService {
    auth_type: AuthType,
    tokens: Arc<RwLock<HashMap<String, AuthToken>>>,
    api_keys: Arc<RwLock<HashMap<String, String>>>, // api_key -> agent_id
    secret_key: String,
}

impl AuthService {
    pub fn new(auth_type: AuthType) -> Self {
        Self {
            auth_type,
            tokens: Arc::new(RwLock::new(HashMap::new())),
            api_keys: Arc::new(RwLock::new(HashMap::new())),
            secret_key: generate_secret_key(),
        }
    }
    
    /// Generate a new authentication token
    pub async fn generate_token(
        &self,
        agent_id: String,
        capabilities: AgentCapability,
        expires_in: Option<Duration>,
    ) -> Result<AuthToken, UbiquityError> {
        let token_string = generate_token_string();
        let expires_at = expires_in.map(|d| Utc::now() + d);
        
        let token = AuthToken {
            token: token_string.clone(),
            agent_id: agent_id.clone(),
            capabilities,
            expires_at,
        };
        
        self.tokens.write().await.insert(token_string, token.clone());
        
        Ok(token)
    }
    
    /// Validate a token
    pub async fn validate_token(&self, token_str: &str) -> Result<AuthToken, UbiquityError> {
        let tokens = self.tokens.read().await;
        
        match tokens.get(token_str) {
            Some(token) => {
                // Check expiration
                if let Some(expires_at) = token.expires_at {
                    if expires_at < Utc::now() {
                        return Err(UbiquityError::ConfigError("Token expired".into()));
                    }
                }
                Ok(token.clone())
            }
            None => Err(UbiquityError::ConfigError("Invalid token".into())),
        }
    }
    
    /// Register an API key
    pub async fn register_api_key(&self, api_key: String, agent_id: String) -> Result<(), UbiquityError> {
        self.api_keys.write().await.insert(api_key, agent_id);
        Ok(())
    }
    
    /// Validate an API key
    pub async fn validate_api_key(&self, api_key: &str) -> Result<String, UbiquityError> {
        self.api_keys
            .read()
            .await
            .get(api_key)
            .cloned()
            .ok_or_else(|| UbiquityError::ConfigError("Invalid API key".into()))
    }
    
    /// Authenticate based on configured method
    pub async fn authenticate(&self, credentials: &str) -> Result<String, UbiquityError> {
        match self.auth_type {
            AuthType::None => Ok("anonymous".to_string()),
            AuthType::Token => {
                let token = self.validate_token(credentials).await?;
                Ok(token.agent_id)
            }
            AuthType::ApiKey => {
                self.validate_api_key(credentials).await
            }
            AuthType::Certificate => {
                // TODO: Implement certificate validation
                Err(UbiquityError::ConfigError("Certificate auth not implemented".into()))
            }
        }
    }
    
    /// Check if an agent has a specific capability
    pub async fn check_capability(
        &self,
        token_str: &str,
        check: impl FnOnce(&AgentCapability) -> bool,
    ) -> Result<bool, UbiquityError> {
        let token = self.validate_token(token_str).await?;
        Ok(check(&token.capabilities))
    }
    
    /// Revoke a token
    pub async fn revoke_token(&self, token_str: &str) -> Result<(), UbiquityError> {
        self.tokens.write().await.remove(token_str);
        Ok(())
    }
    
    /// Clean up expired tokens
    pub async fn cleanup_expired_tokens(&self) {
        let now = Utc::now();
        let mut tokens = self.tokens.write().await;
        
        tokens.retain(|_, token| {
            token.expires_at.map(|exp| exp > now).unwrap_or(true)
        });
    }
}

/// Generate a secure random token string
fn generate_token_string() -> String {
    use rand::Rng;
    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
    const TOKEN_LEN: usize = 32;
    
    let mut rng = rand::thread_rng();
    (0..TOKEN_LEN)
        .map(|_| {
            let idx = rng.gen_range(0..CHARSET.len());
            CHARSET[idx] as char
        })
        .collect()
}

/// Generate a secret key for signing
fn generate_secret_key() -> String {
    generate_token_string()
}

/// Authorization guard
pub struct AuthGuard {
    required_capabilities: Vec<Box<dyn Fn(&AgentCapability) -> bool + Send + Sync>>,
}

impl AuthGuard {
    pub fn new() -> Self {
        Self {
            required_capabilities: Vec::new(),
        }
    }
    
    /// Require write code capability
    pub fn require_write_code(mut self) -> Self {
        self.required_capabilities.push(Box::new(|cap| cap.can_write_code));
        self
    }
    
    /// Require execute commands capability
    pub fn require_execute_commands(mut self) -> Self {
        self.required_capabilities.push(Box::new(|cap| cap.can_execute_commands));
        self
    }
    
    /// Require coordination capability
    pub fn require_coordination(mut self) -> Self {
        self.required_capabilities.push(Box::new(|cap| cap.can_coordinate));
        self
    }
    
    /// Check if token meets all requirements
    pub async fn check(&self, auth_service: &AuthService, token: &str) -> Result<(), UbiquityError> {
        let token_data = auth_service.validate_token(token).await?;
        
        for check in &self.required_capabilities {
            if !check(&token_data.capabilities) {
                return Err(UbiquityError::ConfigError("Insufficient capabilities".into()));
            }
        }
        
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[tokio::test]
    async fn test_token_generation_and_validation() {
        let auth = AuthService::new(AuthType::Token);
        let capabilities = AgentCapability::default();
        
        let token = auth.generate_token(
            "test-agent".to_string(),
            capabilities,
            Some(Duration::hours(1)),
        ).await.unwrap();
        
        let validated = auth.validate_token(&token.token).await.unwrap();
        assert_eq!(validated.agent_id, "test-agent");
    }
    
    #[tokio::test]
    async fn test_expired_token() {
        let auth = AuthService::new(AuthType::Token);
        let capabilities = AgentCapability::default();
        
        let token = auth.generate_token(
            "test-agent".to_string(),
            capabilities,
            Some(Duration::seconds(-1)), // Already expired
        ).await.unwrap();
        
        let result = auth.validate_token(&token.token).await;
        assert!(result.is_err());
    }
    
    #[tokio::test]
    async fn test_api_key() {
        let auth = AuthService::new(AuthType::ApiKey);
        
        auth.register_api_key("test-key".to_string(), "test-agent".to_string())
            .await
            .unwrap();
        
        let agent_id = auth.validate_api_key("test-key").await.unwrap();
        assert_eq!(agent_id, "test-agent");
    }
    
    #[tokio::test]
    async fn test_auth_guard() {
        let auth = AuthService::new(AuthType::Token);
        let mut capabilities = AgentCapability::default();
        capabilities.can_coordinate = true;
        
        let token = auth.generate_token(
            "test-agent".to_string(),
            capabilities,
            None,
        ).await.unwrap();
        
        let guard = AuthGuard::new()
            .require_write_code()
            .require_coordination();
        
        let result = guard.check(&auth, &token.token).await;
        assert!(result.is_ok());
    }
}