Skip to main content

moloch_core/agent/
session.rs

1//! Session types for agent accountability.
2//!
3//! A session is a bounded context for agent operations. Every agent action
4//! occurs within a session, which defines:
5//!
6//! - The principal who initiated the session
7//! - Maximum duration and depth limits
8//! - Session-level capability constraints
9
10use serde::{Deserialize, Serialize};
11use std::fmt;
12use std::time::Duration;
13
14use crate::crypto::{hash, Hash};
15use crate::error::{Error, Result};
16
17use super::capability::CapabilitySet;
18use super::principal::PrincipalId;
19
20/// Unique session identifier.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub struct SessionId(pub [u8; 16]);
23
24impl SessionId {
25    /// Create a new random session ID.
26    pub fn random() -> Self {
27        use rand::Rng;
28        let mut bytes = [0u8; 16];
29        rand::thread_rng().fill(&mut bytes);
30        Self(bytes)
31    }
32
33    /// Create from bytes.
34    pub fn from_bytes(bytes: [u8; 16]) -> Self {
35        Self(bytes)
36    }
37
38    /// Get the raw bytes.
39    pub fn as_bytes(&self) -> &[u8; 16] {
40        &self.0
41    }
42
43    /// Convert to hex string.
44    pub fn to_hex(&self) -> String {
45        hex::encode(self.0)
46    }
47
48    /// Parse from hex string.
49    pub fn from_hex(s: &str) -> Result<Self> {
50        let bytes = hex::decode(s).map_err(|e| Error::invalid_input(e.to_string()))?;
51        if bytes.len() != 16 {
52            return Err(Error::invalid_input(format!(
53                "SessionId must be 16 bytes, got {}",
54                bytes.len()
55            )));
56        }
57        let mut arr = [0u8; 16];
58        arr.copy_from_slice(&bytes);
59        Ok(Self(arr))
60    }
61}
62
63impl fmt::Display for SessionId {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "{}", &self.to_hex()[..8]) // Short display
66    }
67}
68
69/// A bounded context for agent operations.
70///
71/// Sessions provide:
72/// - Traceability: All events link to a session
73/// - Time bounds: Sessions expire after max_duration
74/// - Depth limits: Prevents runaway agent spawning
75/// - Scope: Session-level capability constraints
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct Session {
78    /// Unique session identifier.
79    id: SessionId,
80
81    /// Human principal who initiated the session.
82    principal: PrincipalId,
83
84    /// When the session started (Unix timestamp ms).
85    started_at: i64,
86
87    /// When the session ended (None if active).
88    ended_at: Option<i64>,
89
90    /// Maximum session duration.
91    max_duration: Duration,
92
93    /// Maximum causal depth permitted.
94    max_depth: u32,
95
96    /// Session-level capability constraints (Section 3.2.2).
97    ///
98    /// Defines what the agent is allowed to do within this session.
99    /// None means capabilities are managed externally.
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    capabilities: Option<CapabilitySet>,
102
103    /// Human-readable session purpose.
104    purpose: String,
105
106    /// Total events in this session.
107    event_count: u64,
108
109    /// Total actions taken in this session.
110    action_count: u64,
111}
112
113impl Session {
114    /// Default maximum depth for causal chains.
115    pub const DEFAULT_MAX_DEPTH: u32 = 10;
116
117    /// Default maximum session duration (1 hour).
118    pub const DEFAULT_MAX_DURATION: Duration = Duration::from_secs(3600);
119
120    /// Create a new session builder.
121    pub fn builder() -> SessionBuilder {
122        SessionBuilder::new()
123    }
124
125    /// Get the session ID.
126    pub fn id(&self) -> SessionId {
127        self.id
128    }
129
130    /// Get the principal who initiated this session.
131    pub fn principal(&self) -> &PrincipalId {
132        &self.principal
133    }
134
135    /// Get when the session started.
136    pub fn started_at(&self) -> i64 {
137        self.started_at
138    }
139
140    /// Get when the session ended, if it has.
141    pub fn ended_at(&self) -> Option<i64> {
142        self.ended_at
143    }
144
145    /// Get the maximum allowed duration.
146    pub fn max_duration(&self) -> Duration {
147        self.max_duration
148    }
149
150    /// Get the maximum allowed causal depth.
151    pub fn max_depth(&self) -> u32 {
152        self.max_depth
153    }
154
155    /// Get the session-level capabilities (Section 3.2.2).
156    pub fn capabilities(&self) -> Option<&CapabilitySet> {
157        self.capabilities.as_ref()
158    }
159
160    /// Get the session purpose.
161    pub fn purpose(&self) -> &str {
162        &self.purpose
163    }
164
165    /// Check if the session is still active (not ended).
166    pub fn is_active(&self) -> bool {
167        self.ended_at.is_none()
168    }
169
170    /// Check if the session has expired based on max_duration.
171    pub fn is_expired(&self, current_time: i64) -> bool {
172        let elapsed_ms = current_time.saturating_sub(self.started_at);
173        let max_ms = self.max_duration.as_millis() as i64;
174        elapsed_ms > max_ms
175    }
176
177    /// Get remaining duration before expiry.
178    ///
179    /// Returns None if already expired.
180    pub fn remaining_duration(&self, current_time: i64) -> Option<Duration> {
181        let elapsed_ms = current_time.saturating_sub(self.started_at);
182        let max_ms = self.max_duration.as_millis() as i64;
183        if elapsed_ms >= max_ms {
184            None
185        } else {
186            Some(Duration::from_millis((max_ms - elapsed_ms) as u64))
187        }
188    }
189
190    /// End the session.
191    ///
192    /// # Errors
193    /// Returns error if session is already ended.
194    pub fn end(&mut self, current_time: i64, reason: SessionEndReason) -> Result<SessionSummary> {
195        if self.ended_at.is_some() {
196            return Err(Error::invalid_input("Session already ended"));
197        }
198
199        self.ended_at = Some(current_time);
200
201        Ok(SessionSummary {
202            session_id: self.id,
203            reason,
204            duration: Duration::from_millis((current_time - self.started_at) as u64),
205            event_count: self.event_count,
206            action_count: self.action_count,
207        })
208    }
209
210    /// Record an event in this session.
211    pub fn record_event(&mut self) {
212        self.event_count += 1;
213    }
214
215    /// Record an action in this session.
216    pub fn record_action(&mut self) {
217        self.action_count += 1;
218    }
219
220    /// Compute a unique hash for this session.
221    pub fn hash(&self) -> Hash {
222        let mut data = Vec::new();
223        data.extend_from_slice(&self.id.0);
224        data.extend_from_slice(self.principal.id().as_bytes());
225        data.extend_from_slice(&self.started_at.to_le_bytes());
226        hash(&data)
227    }
228}
229
230/// Builder for creating sessions.
231#[derive(Debug, Default)]
232pub struct SessionBuilder {
233    id: Option<SessionId>,
234    principal: Option<PrincipalId>,
235    started_at: Option<i64>,
236    max_duration: Option<Duration>,
237    max_depth: Option<u32>,
238    capabilities: Option<CapabilitySet>,
239    purpose: Option<String>,
240}
241
242impl SessionBuilder {
243    /// Create a new session builder.
244    pub fn new() -> Self {
245        Self::default()
246    }
247
248    /// Set the session ID (random if not specified).
249    pub fn id(mut self, id: SessionId) -> Self {
250        self.id = Some(id);
251        self
252    }
253
254    /// Set the principal who initiated the session.
255    pub fn principal(mut self, principal: PrincipalId) -> Self {
256        self.principal = Some(principal);
257        self
258    }
259
260    /// Set the start time (current time if not specified).
261    pub fn started_at(mut self, timestamp: i64) -> Self {
262        self.started_at = Some(timestamp);
263        self
264    }
265
266    /// Set the maximum duration.
267    pub fn max_duration(mut self, duration: Duration) -> Self {
268        self.max_duration = Some(duration);
269        self
270    }
271
272    /// Set the maximum causal depth.
273    pub fn max_depth(mut self, depth: u32) -> Self {
274        self.max_depth = Some(depth);
275        self
276    }
277
278    /// Set session-level capabilities (Section 3.2.2).
279    pub fn capabilities(mut self, capabilities: CapabilitySet) -> Self {
280        self.capabilities = Some(capabilities);
281        self
282    }
283
284    /// Set the session purpose.
285    pub fn purpose(mut self, purpose: impl Into<String>) -> Self {
286        self.purpose = Some(purpose.into());
287        self
288    }
289
290    /// Build the session.
291    ///
292    /// # Errors
293    /// Returns error if principal is not set.
294    pub fn build(self) -> Result<Session> {
295        let principal = self
296            .principal
297            .ok_or_else(|| Error::invalid_input("Session requires a principal"))?;
298
299        let started_at = self
300            .started_at
301            .unwrap_or_else(|| chrono::Utc::now().timestamp_millis());
302
303        Ok(Session {
304            id: self.id.unwrap_or_else(SessionId::random),
305            principal,
306            started_at,
307            ended_at: None,
308            max_duration: self.max_duration.unwrap_or(Session::DEFAULT_MAX_DURATION),
309            max_depth: self.max_depth.unwrap_or(Session::DEFAULT_MAX_DEPTH),
310            capabilities: self.capabilities,
311            purpose: self.purpose.unwrap_or_default(),
312            event_count: 0,
313            action_count: 0,
314        })
315    }
316}
317
318/// Reason for session ending.
319#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
320#[serde(rename_all = "snake_case", tag = "reason")]
321pub enum SessionEndReason {
322    /// Session completed normally.
323    Completed,
324
325    /// Session timed out.
326    Timeout,
327
328    /// User terminated the session.
329    UserTerminated,
330
331    /// Session terminated due to error.
332    ErrorTerminated {
333        /// Error description.
334        error: String,
335    },
336
337    /// Session terminated by emergency action.
338    EmergencyTerminated {
339        /// ID of the emergency event.
340        emergency_id: String,
341    },
342}
343
344/// Summary of a completed session.
345#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct SessionSummary {
347    /// Session ID.
348    pub session_id: SessionId,
349
350    /// Reason the session ended.
351    pub reason: SessionEndReason,
352
353    /// Total session duration.
354    pub duration: Duration,
355
356    /// Total events recorded.
357    pub event_count: u64,
358
359    /// Total actions taken.
360    pub action_count: u64,
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    fn test_principal() -> PrincipalId {
368        PrincipalId::user("alice").unwrap()
369    }
370
371    fn now_ms() -> i64 {
372        chrono::Utc::now().timestamp_millis()
373    }
374
375    // === SessionId Tests ===
376
377    #[test]
378    fn session_id_generates_unique() {
379        let id1 = SessionId::random();
380        let id2 = SessionId::random();
381        assert_ne!(id1, id2);
382    }
383
384    #[test]
385    fn session_id_hex_roundtrip() {
386        let id = SessionId::random();
387        let hex = id.to_hex();
388        let parsed = SessionId::from_hex(&hex).unwrap();
389        assert_eq!(id, parsed);
390    }
391
392    #[test]
393    fn session_id_from_bytes() {
394        let bytes = [1u8; 16];
395        let id = SessionId::from_bytes(bytes);
396        assert_eq!(id.as_bytes(), &bytes);
397    }
398
399    // === Session Lifecycle Tests ===
400
401    #[test]
402    fn session_requires_principal() {
403        let result = Session::builder().build();
404        assert!(result.is_err());
405    }
406
407    #[test]
408    fn session_created_with_defaults() {
409        let session = Session::builder()
410            .principal(test_principal())
411            .build()
412            .unwrap();
413
414        assert!(session.is_active());
415        assert_eq!(session.max_depth(), Session::DEFAULT_MAX_DEPTH);
416        assert_eq!(session.max_duration(), Session::DEFAULT_MAX_DURATION);
417    }
418
419    #[test]
420    fn session_tracks_started_at() {
421        let before = now_ms();
422        let session = Session::builder()
423            .principal(test_principal())
424            .build()
425            .unwrap();
426        let after = now_ms();
427
428        assert!(session.started_at() >= before);
429        assert!(session.started_at() <= after);
430    }
431
432    #[test]
433    fn session_ended_at_none_when_active() {
434        let session = Session::builder()
435            .principal(test_principal())
436            .build()
437            .unwrap();
438
439        assert!(session.ended_at().is_none());
440        assert!(session.is_active());
441    }
442
443    #[test]
444    fn session_end_sets_ended_at() {
445        let mut session = Session::builder()
446            .principal(test_principal())
447            .build()
448            .unwrap();
449
450        let end_time = now_ms() + 1000;
451        let summary = session.end(end_time, SessionEndReason::Completed).unwrap();
452
453        assert!(!session.is_active());
454        assert_eq!(session.ended_at(), Some(end_time));
455        assert!(matches!(summary.reason, SessionEndReason::Completed));
456    }
457
458    #[test]
459    fn session_cannot_end_twice() {
460        let mut session = Session::builder()
461            .principal(test_principal())
462            .build()
463            .unwrap();
464
465        session.end(now_ms(), SessionEndReason::Completed).unwrap();
466        let result = session.end(now_ms(), SessionEndReason::Completed);
467
468        assert!(result.is_err());
469    }
470
471    // === Duration Tests ===
472
473    #[test]
474    fn session_is_expired_after_max_duration() {
475        let start = now_ms();
476        let session = Session::builder()
477            .principal(test_principal())
478            .started_at(start)
479            .max_duration(Duration::from_secs(60))
480            .build()
481            .unwrap();
482
483        // Not expired immediately
484        assert!(!session.is_expired(start));
485
486        // Not expired at 59 seconds
487        assert!(!session.is_expired(start + 59_000));
488
489        // Expired at 61 seconds
490        assert!(session.is_expired(start + 61_000));
491    }
492
493    #[test]
494    fn session_remaining_duration_decreases() {
495        let start = now_ms();
496        let session = Session::builder()
497            .principal(test_principal())
498            .started_at(start)
499            .max_duration(Duration::from_secs(60))
500            .build()
501            .unwrap();
502
503        let remaining1 = session.remaining_duration(start).unwrap();
504        let remaining2 = session.remaining_duration(start + 10_000).unwrap();
505
506        assert!(remaining2 < remaining1);
507    }
508
509    #[test]
510    fn session_remaining_duration_none_when_expired() {
511        let start = now_ms();
512        let session = Session::builder()
513            .principal(test_principal())
514            .started_at(start)
515            .max_duration(Duration::from_secs(60))
516            .build()
517            .unwrap();
518
519        let remaining = session.remaining_duration(start + 70_000);
520        assert!(remaining.is_none());
521    }
522
523    // === Depth and Purpose Tests ===
524
525    #[test]
526    fn session_max_depth_configurable() {
527        let session = Session::builder()
528            .principal(test_principal())
529            .max_depth(5)
530            .build()
531            .unwrap();
532
533        assert_eq!(session.max_depth(), 5);
534    }
535
536    #[test]
537    fn session_purpose_recorded() {
538        let session = Session::builder()
539            .principal(test_principal())
540            .purpose("Code review task")
541            .build()
542            .unwrap();
543
544        assert_eq!(session.purpose(), "Code review task");
545    }
546
547    // === Event Counting Tests ===
548
549    #[test]
550    fn session_counts_events() {
551        let mut session = Session::builder()
552            .principal(test_principal())
553            .build()
554            .unwrap();
555
556        session.record_event();
557        session.record_event();
558        session.record_action();
559
560        let summary = session.end(now_ms(), SessionEndReason::Completed).unwrap();
561        assert_eq!(summary.event_count, 2);
562        assert_eq!(summary.action_count, 1);
563    }
564
565    // === Hash Tests ===
566
567    #[test]
568    fn same_session_same_hash() {
569        let id = SessionId::random();
570        let principal = test_principal();
571        let started = now_ms();
572
573        let s1 = Session::builder()
574            .id(id)
575            .principal(principal.clone())
576            .started_at(started)
577            .build()
578            .unwrap();
579
580        let s2 = Session::builder()
581            .id(id)
582            .principal(principal)
583            .started_at(started)
584            .build()
585            .unwrap();
586
587        assert_eq!(s1.hash(), s2.hash());
588    }
589
590    #[test]
591    fn different_sessions_different_hash() {
592        let s1 = Session::builder()
593            .principal(test_principal())
594            .build()
595            .unwrap();
596
597        let s2 = Session::builder()
598            .principal(test_principal())
599            .build()
600            .unwrap();
601
602        // Different IDs (random) means different hashes
603        assert_ne!(s1.hash(), s2.hash());
604    }
605}