Skip to main content

ditto_os/security/
auth.rs

1use chrono::{DateTime, Utc};
2use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::RwLock;
8use tracing::{debug, info};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub enum Permission {
13    CreateSession,
14    ExecuteCommand,
15    ViewSessions,
16    DeleteSession,
17    ManageAgents,
18    SystemAdmin,
19}
20
21impl Permission {
22    pub fn as_str(&self) -> &'static str {
23        match self {
24            Permission::CreateSession => "create_session",
25            Permission::ExecuteCommand => "execute_command",
26            Permission::ViewSessions => "view_sessions",
27            Permission::DeleteSession => "delete_session",
28            Permission::ManageAgents => "manage_agents",
29            Permission::SystemAdmin => "system_admin",
30        }
31    }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentCredentials {
36    pub id: String,
37    pub name: String,
38    pub api_key: String,
39    pub permissions: Vec<Permission>,
40    pub created_at: DateTime<Utc>,
41    pub last_used: Option<DateTime<Utc>>,
42    pub active: bool,
43    pub metadata: HashMap<String, String>,
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47pub struct JwtClaims {
48    pub sub: String, // agent_id
49    pub name: String,
50    pub permissions: Vec<String>,
51    pub exp: usize,  // expiration time
52    pub iat: usize,  // issued at
53    pub jti: String, // JWT ID
54}
55
56pub struct AuthManager {
57    agents: Arc<RwLock<HashMap<String, AgentCredentials>>>,
58    jwt_secret: String,
59    token_expiry: Duration,
60}
61
62impl AuthManager {
63    pub fn new(jwt_secret: String) -> Self {
64        Self {
65            agents: Arc::new(RwLock::new(HashMap::new())),
66            jwt_secret,
67            token_expiry: Duration::from_secs(24 * 60 * 60), // 24 hours
68        }
69    }
70
71    pub async fn add_agent(&self, agent: AgentCredentials) -> Result<(), anyhow::Error> {
72        let mut agents = self.agents.write().await;
73        if agents.contains_key(&agent.id) {
74            return Err(anyhow::anyhow!("Agent with ID {} already exists", agent.id));
75        }
76
77        agents.insert(agent.id.clone(), agent.clone());
78        info!("Added new agent: {}", agent.name);
79        Ok(())
80    }
81
82    pub async fn remove_agent(&self, agent_id: &str) -> Result<(), anyhow::Error> {
83        let mut agents = self.agents.write().await;
84        if agents.remove(agent_id).is_none() {
85            return Err(anyhow::anyhow!("Agent with ID {} not found", agent_id));
86        }
87
88        info!("Removed agent: {}", agent_id);
89        Ok(())
90    }
91
92    pub async fn get_agent(&self, agent_id: &str) -> Option<AgentCredentials> {
93        let agents = self.agents.read().await;
94        agents.get(agent_id).cloned()
95    }
96
97    pub async fn list_agents(&self) -> Vec<AgentCredentials> {
98        let agents = self.agents.read().await;
99        agents.values().cloned().collect()
100    }
101
102    pub async fn validate_token(&self, token: &str) -> Result<String, anyhow::Error> {
103        let token_data = decode::<JwtClaims>(
104            token,
105            &DecodingKey::from_secret(self.jwt_secret.as_ref()),
106            &Validation::default(),
107        )?;
108
109        let claims = token_data.claims;
110        let agent_id = claims.sub.clone();
111
112        // Check if agent exists and is active
113        let agent = self
114            .get_agent(&agent_id)
115            .await
116            .ok_or_else(|| anyhow::anyhow!("Agent not found: {}", agent_id))?;
117
118        if !agent.active {
119            return Err(anyhow::anyhow!("Agent is inactive: {}", agent_id));
120        }
121
122        debug!("Successfully validated token for agent: {}", agent_id);
123        Ok(agent_id)
124    }
125
126    pub async fn authenticate_api_key(&self, api_key: &str) -> Result<String, anyhow::Error> {
127        let agents = self.agents.read().await;
128        let agent = agents
129            .values()
130            .find(|a| a.api_key == api_key && a.active)
131            .ok_or_else(|| anyhow::anyhow!("Invalid API key"))?;
132
133        info!("Agent authenticated via API key: {}", agent.name);
134        Ok(agent.id.clone())
135    }
136
137    pub async fn generate_token(&self, agent_id: &str) -> Result<String, anyhow::Error> {
138        let agent = self
139            .get_agent(agent_id)
140            .await
141            .ok_or_else(|| anyhow::anyhow!("Agent not found: {}", agent_id))?;
142
143        let now = Utc::now();
144        let exp = now
145            + chrono::Duration::from_std(self.token_expiry)
146                .map_err(|e| anyhow::anyhow!("Invalid expiration duration: {}", e))?;
147
148        let claims = JwtClaims {
149            sub: agent.id.clone(),
150            name: agent.name.clone(),
151            permissions: agent
152                .permissions
153                .iter()
154                .map(|p| p.as_str().to_string())
155                .collect(),
156            exp: exp.timestamp() as usize,
157            iat: now.timestamp() as usize,
158            jti: Uuid::new_v4().to_string(),
159        };
160
161        let token = encode(
162            &Header::default(),
163            &claims,
164            &EncodingKey::from_secret(self.jwt_secret.as_ref()),
165        )?;
166
167        info!("Generated JWT token for agent: {}", agent.name);
168        Ok(token)
169    }
170
171    pub async fn check_permission(&self, agent_id: &str, permission: &Permission) -> bool {
172        let agents = self.agents.read().await;
173        if let Some(agent) = agents.get(agent_id) {
174            agent.permissions.contains(permission)
175                || agent.permissions.contains(&Permission::SystemAdmin)
176        } else {
177            false
178        }
179    }
180
181    pub async fn update_last_used(&self, agent_id: &str) -> Result<(), anyhow::Error> {
182        let mut agents = self.agents.write().await;
183        if let Some(agent) = agents.get_mut(agent_id) {
184            agent.last_used = Some(Utc::now());
185        }
186        Ok(())
187    }
188
189    pub async fn get_agent_stats(&self) -> AgentStats {
190        let agents = self.agents.read().await;
191        let total_agents = agents.len();
192        let active_agents = agents.values().filter(|a| a.active).count();
193        let agents_with_recent_activity = agents
194            .values()
195            .filter(|a| {
196                if let Some(last_used) = a.last_used {
197                    Utc::now() - last_used < chrono::Duration::hours(24)
198                } else {
199                    false
200                }
201            })
202            .count();
203
204        AgentStats {
205            total_agents,
206            active_agents,
207            agents_with_recent_activity,
208        }
209    }
210
211    // Initialize with a default admin agent for development
212    pub async fn init_default_agent(&self) -> Result<String, anyhow::Error> {
213        let admin_agent = AgentCredentials {
214            id: "admin".to_string(),
215            name: "Default Admin Agent".to_string(),
216            api_key: "sk-ditto-admin-2024".to_string(),
217            permissions: vec![
218                Permission::CreateSession,
219                Permission::ExecuteCommand,
220                Permission::ViewSessions,
221                Permission::DeleteSession,
222                Permission::ManageAgents,
223                Permission::SystemAdmin,
224            ],
225            created_at: Utc::now(),
226            last_used: None,
227            active: true,
228            metadata: HashMap::new(),
229        };
230
231        let agent_id = admin_agent.id.clone();
232        self.add_agent(admin_agent).await?;
233        info!("Initialized default admin agent");
234
235        // Generate a JWT token for the admin agent
236        self.generate_token(&agent_id).await
237    }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct AgentStats {
242    pub total_agents: usize,
243    pub active_agents: usize,
244    pub agents_with_recent_activity: usize,
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[tokio::test]
252    async fn test_agent_authentication() {
253        let auth_manager = AuthManager::new("test_secret".to_string());
254
255        // Add test agent
256        let agent = AgentCredentials {
257            id: "test_agent".to_string(),
258            name: "Test Agent".to_string(),
259            api_key: "test_key".to_string(),
260            permissions: vec![Permission::CreateSession],
261            created_at: Utc::now(),
262            last_used: None,
263            active: true,
264            metadata: HashMap::new(),
265        };
266
267        auth_manager.add_agent(agent).await.unwrap();
268
269        // Test API key authentication
270        let result = auth_manager.authenticate_api_key("test_key").await;
271        assert!(result.is_ok());
272        assert_eq!(result.unwrap(), "test_agent");
273
274        // Test invalid API key
275        let result = auth_manager.authenticate_api_key("invalid_key").await;
276        assert!(result.is_err());
277    }
278
279    #[tokio::test]
280    async fn test_jwt_token_generation() {
281        let auth_manager = AuthManager::new("test_secret".to_string());
282
283        // Add test agent
284        let agent = AgentCredentials {
285            id: "test_agent".to_string(),
286            name: "Test Agent".to_string(),
287            api_key: "test_key".to_string(),
288            permissions: vec![Permission::CreateSession],
289            created_at: Utc::now(),
290            last_used: None,
291            active: true,
292            metadata: HashMap::new(),
293        };
294
295        auth_manager.add_agent(agent).await.unwrap();
296
297        // Generate token
298        let token = auth_manager.generate_token("test_agent").await.unwrap();
299        assert!(!token.is_empty());
300
301        // Validate token
302        let agent_id = auth_manager.validate_token(&token).await.unwrap();
303        assert_eq!(agent_id, "test_agent");
304    }
305
306    #[tokio::test]
307    async fn test_permission_checking() {
308        let auth_manager = AuthManager::new("test_secret".to_string());
309
310        let agent = AgentCredentials {
311            id: "test_agent".to_string(),
312            name: "Test Agent".to_string(),
313            api_key: "test_key".to_string(),
314            permissions: vec![Permission::CreateSession],
315            created_at: Utc::now(),
316            last_used: None,
317            active: true,
318            metadata: HashMap::new(),
319        };
320
321        auth_manager.add_agent(agent).await.unwrap();
322
323        // Test valid permission
324        assert!(
325            auth_manager
326                .check_permission("test_agent", &Permission::CreateSession)
327                .await
328        );
329
330        // Test invalid permission
331        assert!(
332            !auth_manager
333                .check_permission("test_agent", &Permission::DeleteSession)
334                .await
335        );
336
337        // Test non-existent agent
338        assert!(
339            !auth_manager
340                .check_permission("non_existent", &Permission::CreateSession)
341                .await
342        );
343    }
344}