oxios_kernel/space/
conversation_buffer.rs1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11use super::SpaceId;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ConversationTurn {
16 pub user: String,
18 pub agent: String,
20 pub space_id: SpaceId,
22 pub timestamp: DateTime<Utc>,
24}
25
26#[derive(Debug, Clone)]
28pub struct ConversationBuffer {
29 turns: VecDeque<ConversationTurn>,
31 max_turns: usize,
33 turns_since_topic_check: usize,
35 last_space_id: Option<SpaceId>,
37}
38
39impl Default for ConversationBuffer {
40 fn default() -> Self {
41 Self::new(50)
42 }
43}
44
45impl ConversationBuffer {
46 pub fn new(max_turns: usize) -> Self {
48 Self {
49 turns: VecDeque::with_capacity(max_turns),
50 max_turns,
51 turns_since_topic_check: 0,
52 last_space_id: None,
53 }
54 }
55
56 pub fn push_user(&mut self, message: &str) {
58 let turn = ConversationTurn {
59 user: message.to_string(),
60 agent: String::new(), space_id: SpaceId::nil(), timestamp: Utc::now(),
63 };
64
65 if let Some(last) = self.turns.back_mut() {
67 if last.agent.is_empty() && last.space_id == SpaceId::nil() {
68 last.user = message.to_string();
69 last.timestamp = Utc::now();
70 return;
71 }
72 }
73
74 self.turns.push_back(turn);
75
76 while self.turns.len() > self.max_turns {
78 self.turns.pop_front();
79 }
80 }
81
82 pub fn push_agent(&mut self, response: &str, space_id: &SpaceId) {
84 if let Some(last) = self.turns.back_mut() {
85 last.agent = truncate_response(response, 200);
86 last.space_id = *space_id;
87 self.last_space_id = Some(*space_id);
88 }
89 }
90
91 pub fn recent(&self, n: usize) -> Vec<&ConversationTurn> {
93 self.turns.iter().rev().take(n).collect()
94 }
95
96 pub fn turns(&self) -> std::collections::VecDeque<ConversationTurn> {
98 self.turns.clone()
99 }
100
101 pub fn len(&self) -> usize {
103 self.turns.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
108 self.turns.is_empty()
109 }
110
111 pub fn should_check_topic(&self, min_turns: usize) -> bool {
116 self.turns_since_topic_check >= min_turns || self.pattern_changed()
117 }
118
119 pub fn mark_topic_checked(&mut self) {
121 self.turns_since_topic_check = 0;
122 }
123
124 pub fn record_turn(&mut self, min_turns: usize) -> bool {
126 self.turns_since_topic_check += 1;
127 self.should_check_topic(min_turns)
128 }
129
130 pub fn should_check_topic_from_messages(turns: &[ConversationTurn], _min_turns: usize) -> bool {
136 if turns.len() < 4 {
138 return false;
139 }
140
141 let recent = &turns[turns.len() - 2..];
143 let previous = &turns[turns.len() - 4..turns.len() - 2];
144
145 let avg_recent =
146 recent.iter().map(|t| word_count(&t.user)).sum::<usize>() as f64 / recent.len() as f64;
147 let avg_prev = previous.iter().map(|t| word_count(&t.user)).sum::<usize>() as f64
148 / previous.len() as f64;
149
150 let ratio = avg_recent / avg_prev.max(1.0);
151 !(0.5..=2.0).contains(&ratio)
152 }
153
154 pub fn pattern_changed(&self) -> bool {
159 if self.turns.len() < 4 {
160 return false;
161 }
162
163 let all_turns: Vec<_> = self.turns.iter().collect();
164
165 let recent = &all_turns[all_turns.len() - 2..];
167 let previous = &all_turns[all_turns.len() - 4..all_turns.len() - 2];
168
169 let avg_word_count_recent =
170 recent.iter().map(|t| word_count(&t.user)).sum::<usize>() as f64 / recent.len() as f64;
171
172 let avg_word_count_prev = previous.iter().map(|t| word_count(&t.user)).sum::<usize>()
173 as f64
174 / previous.len() as f64;
175
176 let ratio = avg_word_count_recent / avg_word_count_prev.max(1.0);
178 if !(0.5..=2.0).contains(&ratio) {
179 return true;
180 }
181
182 let domain_shift_keywords = [
184 ("code", "food"),
185 ("rust", "요리"),
186 ("bug", "저녁"),
187 ("file", "운동"),
188 ("import", "영화"),
189 ("commit", "음식"),
190 ("function", "게임"),
191 ("Cargo", "장보기"),
192 ];
193
194 let recent_text = recent
195 .iter()
196 .map(|t| t.user.to_lowercase())
197 .collect::<String>();
198 let prev_text = previous
199 .iter()
200 .map(|t| t.user.to_lowercase())
201 .collect::<String>();
202
203 for (prev_kw, recent_kw) in domain_shift_keywords {
204 let has_prev = prev_text.contains(prev_kw);
205 let has_recent = recent_text.contains(recent_kw);
206 if has_prev && !has_recent {
207 return true;
210 }
211 }
212
213 false
214 }
215
216 pub fn space_changed(&self) -> bool {
218 if self.turns.len() < 2 {
219 return false;
220 }
221
222 let all_turns: Vec<_> = self.turns.iter().collect();
223 let last = &all_turns[all_turns.len() - 1];
224 let prev = &all_turns[all_turns.len() - 2];
225
226 last.space_id != prev.space_id
227 }
228
229 pub fn last_space_id(&self) -> Option<SpaceId> {
231 self.last_space_id
232 }
233
234 pub fn clear(&mut self) {
236 self.turns.clear();
237 self.turns_since_topic_check = 0;
238 }
239}
240
241fn word_count(s: &str) -> usize {
243 s.split_whitespace().count()
244}
245
246fn truncate_response(response: &str, max_len: usize) -> String {
248 if response.len() <= max_len {
249 response.to_string()
250 } else {
251 let end = response
253 .char_indices()
254 .take_while(|(idx, _)| *idx < max_len)
255 .last()
256 .map(|(idx, c)| idx + c.len_utf8())
257 .unwrap_or(0);
258 if end == 0 {
259 "...".to_string()
260 } else {
261 format!("{}...", &response[..end])
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_push_user_and_agent() {
272 let mut buf = ConversationBuffer::new(10);
273 assert!(buf.is_empty());
274
275 buf.push_user("Hello, how are you?");
276 assert_eq!(buf.len(), 1);
277 assert_eq!(buf.turns[0].user, "Hello, how are you?");
278 assert!(buf.turns[0].agent.is_empty());
279
280 buf.push_agent("I'm doing well!", &SpaceId::nil());
281 assert_eq!(buf.turns[0].agent, "I'm doing well!");
282 }
283
284 #[test]
285 fn test_max_capacity() {
286 let mut buf = ConversationBuffer::new(3);
287 let space = SpaceId::nil();
288
289 buf.push_user("msg1");
290 buf.push_agent("r1", &space);
291 buf.push_user("msg2");
292 buf.push_agent("r2", &space);
293 buf.push_user("msg3");
294 buf.push_agent("r3", &space);
295 buf.push_user("msg4");
296 buf.push_agent("r4", &space);
297 buf.push_user("msg5");
298 buf.push_agent("r5", &space);
299
300 assert_eq!(buf.len(), 3);
302 assert_eq!(buf.recent(1)[0].user, "msg5");
303 }
304
305 #[test]
306 fn test_should_check_topic() {
307 let mut buf = ConversationBuffer::new(10);
308 assert!(!buf.should_check_topic(3));
309
310 for _ in 0..3 {
311 buf.push_user("test");
312 buf.mark_topic_checked();
313 }
314 assert!(!buf.should_check_topic(3));
316 }
317
318 #[test]
319 fn test_pattern_changed_word_count() {
320 let mut buf = ConversationBuffer::new(10);
321 let space = SpaceId::nil();
322
323 for _ in 0..3 {
325 buf.push_user("hi");
326 buf.push_agent("hi", &space);
327 }
328
329 assert!(!buf.pattern_changed());
330
331 buf.push_user("This is a very long message that contains many many many many many words to trigger the pattern detection");
333 buf.push_agent("ok", &space);
334
335 assert!(buf.pattern_changed());
336 }
337
338 #[test]
339 fn test_truncate_response() {
340 let short = "Hello";
341 assert_eq!(truncate_response(short, 10), "Hello");
342
343 let long = "This is a very long response";
344 let truncated = truncate_response(long, 10);
345 assert_eq!(truncated.len(), 13); assert!(truncated.ends_with("..."));
347 }
348
349 #[test]
350 fn test_recent_turns() {
351 let mut buf = ConversationBuffer::new(10);
352 let space = SpaceId::nil();
353
354 for i in 0..5 {
355 buf.push_user(&format!("msg{}", i));
356 buf.push_agent(&format!("resp{}", i), &space);
357 }
358
359 let recent = buf.recent(3);
360 assert_eq!(recent.len(), 3);
361 assert_eq!(recent[0].user, "msg4"); assert_eq!(recent[2].user, "msg2");
363 }
364}