Skip to main content

deepstrike_core/memory/
session.rs

1use std::collections::VecDeque;
2
3use crate::types::message::Message;
4
5#[derive(Debug, Clone, Default)]
6pub enum RestorePolicy {
7    #[default]
8    None,
9    TranscriptOnly,
10    Window,
11    RuntimeState,
12}
13
14#[derive(Debug, Clone)]
15pub struct RestoreConfig {
16    pub max_messages: usize,
17    pub max_chars: usize,
18    pub include_context: bool,
19    pub include_events: bool,
20}
21
22impl Default for RestoreConfig {
23    fn default() -> Self {
24        Self {
25            max_messages: 20,
26            max_chars: 8000,
27            include_context: true,
28            include_events: false,
29        }
30    }
31}
32
33/// Apply `policy` to `messages` and return the subset to inject at run start.
34pub fn restore(
35    policy: &RestorePolicy,
36    config: &RestoreConfig,
37    messages: &[Message],
38) -> Vec<Message> {
39    match policy {
40        RestorePolicy::None => vec![],
41        RestorePolicy::TranscriptOnly => messages.to_vec(),
42        RestorePolicy::Window => {
43            let mut result: Vec<Message> = Vec::new();
44            let mut chars = 0usize;
45            for msg in messages.iter().rev() {
46                let len = msg.content.text_len();
47                if result.len() >= config.max_messages || chars + len > config.max_chars {
48                    break;
49                }
50                result.push(msg.clone());
51                chars += len;
52            }
53            result.reverse();
54            result
55        }
56        RestorePolicy::RuntimeState => messages.to_vec(),
57    }
58}
59
60/// Session memory: message history that persists across runs within a session.
61#[derive(Debug)]
62pub struct SessionMemory {
63    messages: VecDeque<Message>,
64    pub max_messages: usize,
65    pub max_tokens: u32,
66    current_tokens: u32,
67}
68
69impl Default for SessionMemory {
70    fn default() -> Self {
71        Self::new(100, u32::MAX)
72    }
73}
74
75impl SessionMemory {
76    pub fn new(max_messages: usize, max_tokens: u32) -> Self {
77        Self {
78            messages: VecDeque::new(),
79            max_messages,
80            max_tokens,
81            current_tokens: 0,
82        }
83    }
84
85    pub fn push(&mut self, msg: Message) {
86        let tokens = msg.token_count.unwrap_or(0);
87        self.messages.push_back(msg);
88        self.current_tokens += tokens;
89        while (self.messages.len() > self.max_messages || self.current_tokens > self.max_tokens)
90            && !self.messages.is_empty()
91        {
92            let removed = self.messages.pop_front().unwrap();
93            self.current_tokens = self
94                .current_tokens
95                .saturating_sub(removed.token_count.unwrap_or(0));
96        }
97    }
98
99    pub fn token_count(&self) -> u32 {
100        self.current_tokens
101    }
102
103    /// Returns messages as a slice for read-only access.
104    pub fn as_slice(&self) -> Vec<&Message> {
105        self.messages.iter().collect()
106    }
107
108    pub fn recent(&self, n: usize) -> Vec<&Message> {
109        let start = self.messages.len().saturating_sub(n);
110        self.messages.iter().skip(start).collect()
111    }
112
113    pub fn to_vec(&self) -> Vec<Message> {
114        self.messages.iter().cloned().collect()
115    }
116
117    pub fn len(&self) -> usize {
118        self.messages.len()
119    }
120    pub fn is_empty(&self) -> bool {
121        self.messages.is_empty()
122    }
123
124    pub fn clear(&mut self) {
125        self.messages.clear();
126        self.current_tokens = 0;
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn evicts_oldest_when_full() {
136        let mut mem = SessionMemory::new(2, 10000);
137        let mut m1 = Message::user("first");
138        m1.token_count = Some(10);
139        let mut m2 = Message::user("second");
140        m2.token_count = Some(10);
141        let mut m3 = Message::user("third");
142        m3.token_count = Some(10);
143
144        mem.push(m1);
145        mem.push(m2);
146        mem.push(m3);
147
148        let msgs = mem.to_vec();
149        assert_eq!(msgs.len(), 2);
150        assert_eq!(msgs[0].content.as_text().unwrap(), "second");
151    }
152}