use crate::{AuthToken, AuthType, AgentCapability, UbiquityError};
use chrono::{Utc, Duration};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct AuthService {
auth_type: AuthType,
tokens: Arc<RwLock<HashMap<String, AuthToken>>>,
api_keys: Arc<RwLock<HashMap<String, String>>>, secret_key: String,
}
impl AuthService {
pub fn new(auth_type: AuthType) -> Self {
Self {
auth_type,
tokens: Arc::new(RwLock::new(HashMap::new())),
api_keys: Arc::new(RwLock::new(HashMap::new())),
secret_key: generate_secret_key(),
}
}
pub async fn generate_token(
&self,
agent_id: String,
capabilities: AgentCapability,
expires_in: Option<Duration>,
) -> Result<AuthToken, UbiquityError> {
let token_string = generate_token_string();
let expires_at = expires_in.map(|d| Utc::now() + d);
let token = AuthToken {
token: token_string.clone(),
agent_id: agent_id.clone(),
capabilities,
expires_at,
};
self.tokens.write().await.insert(token_string, token.clone());
Ok(token)
}
pub async fn validate_token(&self, token_str: &str) -> Result<AuthToken, UbiquityError> {
let tokens = self.tokens.read().await;
match tokens.get(token_str) {
Some(token) => {
if let Some(expires_at) = token.expires_at {
if expires_at < Utc::now() {
return Err(UbiquityError::ConfigError("Token expired".into()));
}
}
Ok(token.clone())
}
None => Err(UbiquityError::ConfigError("Invalid token".into())),
}
}
pub async fn register_api_key(&self, api_key: String, agent_id: String) -> Result<(), UbiquityError> {
self.api_keys.write().await.insert(api_key, agent_id);
Ok(())
}
pub async fn validate_api_key(&self, api_key: &str) -> Result<String, UbiquityError> {
self.api_keys
.read()
.await
.get(api_key)
.cloned()
.ok_or_else(|| UbiquityError::ConfigError("Invalid API key".into()))
}
pub async fn authenticate(&self, credentials: &str) -> Result<String, UbiquityError> {
match self.auth_type {
AuthType::None => Ok("anonymous".to_string()),
AuthType::Token => {
let token = self.validate_token(credentials).await?;
Ok(token.agent_id)
}
AuthType::ApiKey => {
self.validate_api_key(credentials).await
}
AuthType::Certificate => {
Err(UbiquityError::ConfigError("Certificate auth not implemented".into()))
}
}
}
pub async fn check_capability(
&self,
token_str: &str,
check: impl FnOnce(&AgentCapability) -> bool,
) -> Result<bool, UbiquityError> {
let token = self.validate_token(token_str).await?;
Ok(check(&token.capabilities))
}
pub async fn revoke_token(&self, token_str: &str) -> Result<(), UbiquityError> {
self.tokens.write().await.remove(token_str);
Ok(())
}
pub async fn cleanup_expired_tokens(&self) {
let now = Utc::now();
let mut tokens = self.tokens.write().await;
tokens.retain(|_, token| {
token.expires_at.map(|exp| exp > now).unwrap_or(true)
});
}
}
fn generate_token_string() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
const TOKEN_LEN: usize = 32;
let mut rng = rand::thread_rng();
(0..TOKEN_LEN)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
fn generate_secret_key() -> String {
generate_token_string()
}
pub struct AuthGuard {
required_capabilities: Vec<Box<dyn Fn(&AgentCapability) -> bool + Send + Sync>>,
}
impl AuthGuard {
pub fn new() -> Self {
Self {
required_capabilities: Vec::new(),
}
}
pub fn require_write_code(mut self) -> Self {
self.required_capabilities.push(Box::new(|cap| cap.can_write_code));
self
}
pub fn require_execute_commands(mut self) -> Self {
self.required_capabilities.push(Box::new(|cap| cap.can_execute_commands));
self
}
pub fn require_coordination(mut self) -> Self {
self.required_capabilities.push(Box::new(|cap| cap.can_coordinate));
self
}
pub async fn check(&self, auth_service: &AuthService, token: &str) -> Result<(), UbiquityError> {
let token_data = auth_service.validate_token(token).await?;
for check in &self.required_capabilities {
if !check(&token_data.capabilities) {
return Err(UbiquityError::ConfigError("Insufficient capabilities".into()));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_token_generation_and_validation() {
let auth = AuthService::new(AuthType::Token);
let capabilities = AgentCapability::default();
let token = auth.generate_token(
"test-agent".to_string(),
capabilities,
Some(Duration::hours(1)),
).await.unwrap();
let validated = auth.validate_token(&token.token).await.unwrap();
assert_eq!(validated.agent_id, "test-agent");
}
#[tokio::test]
async fn test_expired_token() {
let auth = AuthService::new(AuthType::Token);
let capabilities = AgentCapability::default();
let token = auth.generate_token(
"test-agent".to_string(),
capabilities,
Some(Duration::seconds(-1)), ).await.unwrap();
let result = auth.validate_token(&token.token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_api_key() {
let auth = AuthService::new(AuthType::ApiKey);
auth.register_api_key("test-key".to_string(), "test-agent".to_string())
.await
.unwrap();
let agent_id = auth.validate_api_key("test-key").await.unwrap();
assert_eq!(agent_id, "test-agent");
}
#[tokio::test]
async fn test_auth_guard() {
let auth = AuthService::new(AuthType::Token);
let mut capabilities = AgentCapability::default();
capabilities.can_coordinate = true;
let token = auth.generate_token(
"test-agent".to_string(),
capabilities,
None,
).await.unwrap();
let guard = AuthGuard::new()
.require_write_code()
.require_coordination();
let result = guard.check(&auth, &token.token).await;
assert!(result.is_ok());
}
}