use crate::agent::dialogue::message::SentAgents;
use super::message::{DialogueMessage, MessageId, MessageOrigin, Speaker};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct MessageStore {
messages_by_id: HashMap<MessageId, DialogueMessage>,
message_order: Vec<MessageId>,
}
impl MessageStore {
pub fn new() -> Self {
Self {
messages_by_id: HashMap::new(),
message_order: Vec::new(),
}
}
pub fn push(&mut self, message: DialogueMessage) {
let id = message.id;
self.messages_by_id.insert(id, message);
self.message_order.push(id);
}
pub fn get(&self, id: MessageId) -> Option<&DialogueMessage> {
self.messages_by_id.get(&id)
}
pub fn all_messages(&self) -> Vec<&DialogueMessage> {
self.message_order
.iter()
.filter_map(|id| self.messages_by_id.get(id))
.collect()
}
pub fn messages_for_turn(&self, turn: usize) -> Vec<&DialogueMessage> {
self.all_messages()
.into_iter()
.filter(|msg| msg.turn == turn)
.collect()
}
pub fn latest_turn(&self) -> usize {
self.all_messages()
.iter()
.map(|msg| msg.turn)
.max()
.unwrap_or(0)
}
pub fn len(&self) -> usize {
self.message_order.len()
}
pub fn is_empty(&self) -> bool {
self.message_order.is_empty()
}
pub fn clear(&mut self) {
self.messages_by_id.clear();
self.message_order.clear();
}
pub fn unsent_messages(&self) -> Vec<&DialogueMessage> {
self.all_messages()
.into_iter()
.filter(|msg| {
!msg.sent_to_agents()
&& (matches!(msg.speaker, Speaker::Agent { .. })
|| matches!(msg.speaker, Speaker::System))
})
.collect()
}
pub fn unsent_messages_with_origin(&self, origin: MessageOrigin) -> Vec<&DialogueMessage> {
self.all_messages()
.into_iter()
.filter(|msg| !msg.sent_to_agents() && msg.metadata.origin() == Some(origin))
.collect()
}
pub fn mark_as_sent(&mut self, id: MessageId, agent: Speaker) {
if let Some(msg) = self.messages_by_id.get_mut(&id) {
msg.sent(agent);
}
}
pub fn mark_as_sent_all_for(&mut self, agent: Speaker) {
for id in self.message_order.clone() {
self.mark_as_sent(id, agent.clone());
}
}
pub fn mark_as_sent_all_agents(&mut self, id: MessageId) {
if let Some(msg) = self.messages_by_id.get_mut(&id) {
msg.sent_agents = SentAgents::All;
}
}
pub fn mark_all_as_sent(&mut self, ids: &[MessageId]) {
for id in ids {
self.mark_as_sent_all_agents(*id);
}
}
}
impl Default for MessageStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::dialogue::{MessageMetadata, message::Speaker};
#[test]
fn test_message_store_basic_operations() {
let mut store = MessageStore::new();
assert_eq!(store.len(), 0);
assert!(store.is_empty());
let msg1 = DialogueMessage::new(1, Speaker::System, "Hello".to_string());
let msg1_id = msg1.id;
store.push(msg1);
assert_eq!(store.len(), 1);
assert!(!store.is_empty());
let retrieved = store.get(msg1_id);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "Hello");
}
#[test]
fn test_chronological_order() {
let mut store = MessageStore::new();
let msg1 = DialogueMessage::new(1, Speaker::System, "First".to_string());
let msg2 = DialogueMessage::new(1, Speaker::agent("A", "Role"), "Second".to_string());
let msg3 = DialogueMessage::new(2, Speaker::System, "Third".to_string());
store.push(msg1);
store.push(msg2);
store.push(msg3);
let all = store.all_messages();
assert_eq!(all.len(), 3);
assert_eq!(all[0].content, "First");
assert_eq!(all[1].content, "Second");
assert_eq!(all[2].content, "Third");
}
#[test]
fn test_messages_for_turn() {
let mut store = MessageStore::new();
let msg1 = DialogueMessage::new(1, Speaker::System, "Turn 1".to_string());
let msg2 = DialogueMessage::new(1, Speaker::agent("A", "Role"), "Response 1".to_string());
let msg3 = DialogueMessage::new(2, Speaker::System, "Turn 2".to_string());
let msg4 = DialogueMessage::new(2, Speaker::agent("B", "Role"), "Response 2".to_string());
store.push(msg1);
store.push(msg2);
store.push(msg3);
store.push(msg4);
let turn1 = store.messages_for_turn(1);
assert_eq!(turn1.len(), 2);
assert_eq!(turn1[0].content, "Turn 1");
assert_eq!(turn1[1].content, "Response 1");
let turn2 = store.messages_for_turn(2);
assert_eq!(turn2.len(), 2);
assert_eq!(turn2[0].content, "Turn 2");
assert_eq!(turn2[1].content, "Response 2");
}
#[test]
fn test_latest_turn() {
let mut store = MessageStore::new();
assert_eq!(store.latest_turn(), 0);
store.push(DialogueMessage::new(
1,
Speaker::System,
"Prompt 1".to_string(),
));
assert_eq!(store.latest_turn(), 1);
store.push(DialogueMessage::new(
1,
Speaker::agent("A", "Role"),
"Response".to_string(),
));
assert_eq!(store.latest_turn(), 1);
store.push(DialogueMessage::new(
2,
Speaker::System,
"Prompt 2".to_string(),
));
assert_eq!(store.latest_turn(), 2);
}
#[test]
fn test_clear() {
let mut store = MessageStore::new();
store.push(DialogueMessage::new(1, Speaker::System, "Test".to_string()));
assert_eq!(store.len(), 1);
store.clear();
assert_eq!(store.len(), 0);
assert!(store.is_empty());
}
#[test]
fn test_unsent_messages_returns_agent_and_system_messages() {
let mut store = MessageStore::new();
let msg1 = DialogueMessage::new(1, Speaker::System, "System prompt".to_string());
let msg2 =
DialogueMessage::new(1, Speaker::user("User", "Human"), "User input".to_string());
let msg3 = DialogueMessage::new(
1,
Speaker::agent("Alice", "Engineer"),
"Alice response".to_string(),
);
let msg4 = DialogueMessage::new(
1,
Speaker::agent("Bob", "Designer"),
"Bob response".to_string(),
);
store.push(msg1);
store.push(msg2);
store.push(msg3);
store.push(msg4);
let unsent = store.unsent_messages();
assert_eq!(unsent.len(), 3);
assert_eq!(unsent[0].content, "System prompt");
assert_eq!(unsent[1].content, "Alice response");
assert_eq!(unsent[2].content, "Bob response");
}
#[test]
fn test_unsent_messages_respects_sent_flag() {
let mut store = MessageStore::new();
let msg1 = DialogueMessage::new(
1,
Speaker::agent("Alice", "Engineer"),
"Alice response".to_string(),
);
let msg1_id = msg1.id;
let msg2 = DialogueMessage::new(
1,
Speaker::agent("Bob", "Designer"),
"Bob response".to_string(),
);
let msg2_id = msg2.id;
store.push(msg1);
store.push(msg2);
assert_eq!(store.unsent_messages().len(), 2);
store.mark_as_sent(msg1_id, Speaker::agent("Bob", "Designer"));
let unsent = store.unsent_messages();
assert_eq!(unsent.len(), 1);
assert_eq!(unsent[0].content, "Bob response");
store.mark_as_sent(msg2_id, Speaker::agent("Alice", "Engineer"));
assert_eq!(store.unsent_messages().len(), 0);
}
#[test]
fn test_mark_all_as_sent() {
let mut store = MessageStore::new();
let msg1 = DialogueMessage::new(
1,
Speaker::agent("Alice", "Engineer"),
"Alice response".to_string(),
);
let msg1_id = msg1.id;
let msg2 = DialogueMessage::new(
1,
Speaker::agent("Bob", "Designer"),
"Bob response".to_string(),
);
let msg2_id = msg2.id;
let msg3 = DialogueMessage::new(
1,
Speaker::agent("Charlie", "Manager"),
"Charlie response".to_string(),
);
let msg3_id = msg3.id;
store.push(msg1);
store.push(msg2);
store.push(msg3);
assert_eq!(store.unsent_messages().len(), 3);
store.mark_all_as_sent(&[msg1_id, msg2_id]);
let unsent = store.unsent_messages();
assert_eq!(unsent.len(), 1);
assert_eq!(unsent[0].content, "Charlie response");
assert!(store.get(msg1_id).unwrap().sent_to_agents());
assert!(store.get(msg2_id).unwrap().sent_to_agents());
assert!(!store.get(msg3_id).unwrap().sent_to_agents());
}
#[test]
fn test_unsent_messages_excludes_user_only() {
let mut store = MessageStore::new();
store.push(DialogueMessage::new(
1,
Speaker::System,
"Turn 1 prompt".to_string(),
));
store.push(DialogueMessage::new(
1,
Speaker::agent("Alice", "Engineer"),
"Turn 1 Alice".to_string(),
));
store.push(DialogueMessage::new(
2,
Speaker::user("User", "Human"),
"Turn 2 user input".to_string(),
));
store.push(DialogueMessage::new(
2,
Speaker::agent("Bob", "Designer"),
"Turn 2 Bob".to_string(),
));
let unsent = store.unsent_messages();
assert_eq!(unsent.len(), 3);
assert_eq!(unsent[0].content, "Turn 1 prompt");
assert!(matches!(unsent[0].speaker, Speaker::System));
assert_eq!(unsent[1].content, "Turn 1 Alice");
assert!(matches!(unsent[1].speaker, Speaker::Agent { .. }));
assert_eq!(unsent[2].content, "Turn 2 Bob");
assert!(matches!(unsent[2].speaker, Speaker::Agent { .. }));
}
#[test]
fn test_mark_as_sent_nonexistent_id() {
let mut store = MessageStore::new();
let msg = DialogueMessage::new(
1,
Speaker::agent("Alice", "Engineer"),
"Response".to_string(),
);
let msg_id = msg.id;
store.push(msg);
let fake_id = MessageId::new();
store.mark_as_sent(fake_id, Speaker::agent("Alice", "Engineer"));
assert_eq!(store.unsent_messages().len(), 1);
assert!(!store.get(msg_id).unwrap().sent_to_agents());
}
#[test]
fn test_unsent_messages_with_origin_filters_results() {
let mut store = MessageStore::new();
let mut payload_msg =
DialogueMessage::new(1, Speaker::System, "Payload system".to_string());
let payload_metadata = MessageMetadata::new().with_origin(MessageOrigin::IncomingPayload);
payload_msg = payload_msg.with_metadata(&payload_metadata);
let mut agent_msg = DialogueMessage::new(
1,
Speaker::agent("Agent", "Role"),
"Agent output".to_string(),
);
let agent_metadata = MessageMetadata::new().with_origin(MessageOrigin::AgentGenerated);
agent_msg = agent_msg.with_metadata(&agent_metadata);
store.push(payload_msg);
store.push(agent_msg);
let payload_results = store.unsent_messages_with_origin(MessageOrigin::IncomingPayload);
assert_eq!(payload_results.len(), 1);
assert_eq!(payload_results[0].content, "Payload system");
let agent_results = store.unsent_messages_with_origin(MessageOrigin::AgentGenerated);
assert_eq!(agent_results.len(), 1);
assert_eq!(agent_results[0].content, "Agent output");
}
}