use std::collections::HashMap;
use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::ser::{SerializeStruct, Serializer};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
System,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub id: String,
pub role: MessageRole,
pub content: String,
pub metadata: Option<Value>,
}
#[derive(Debug, Clone, Default)]
pub struct MessagesValue {
messages: Vec<ChatMessage>,
id_index: HashMap<String, usize>,
}
impl MessagesValue {
pub fn new() -> Self {
Self { messages: Vec::new(), id_index: HashMap::new() }
}
pub fn push(&mut self, message: ChatMessage) {
if let Some(&existing_idx) = self.id_index.get(&message.id) {
self.messages[existing_idx] = message;
} else {
let idx = self.messages.len();
self.id_index.insert(message.id.clone(), idx);
self.messages.push(message);
}
}
pub fn extend(&mut self, messages: impl IntoIterator<Item = ChatMessage>) {
for message in messages {
self.push(message);
}
}
pub fn iter(&self) -> impl Iterator<Item = &ChatMessage> {
self.messages.iter()
}
pub fn by_role(&self, role: MessageRole) -> Vec<&ChatMessage> {
self.messages.iter().filter(|m| m.role == role).collect()
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
fn rebuild_index(&mut self) {
self.id_index.clear();
for (idx, msg) in self.messages.iter().enumerate() {
self.id_index.insert(msg.id.clone(), idx);
}
}
}
impl Serialize for MessagesValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("MessagesValue", 1)?;
state.serialize_field("messages", &self.messages)?;
state.end()
}
}
impl<'de> Deserialize<'de> for MessagesValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Messages,
}
struct MessagesValueVisitor;
impl<'de> Visitor<'de> for MessagesValueVisitor {
type Value = MessagesValue;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct MessagesValue")
}
fn visit_seq<V>(self, mut seq: V) -> Result<MessagesValue, V::Error>
where
V: SeqAccess<'de>,
{
let messages: Vec<ChatMessage> = seq.next_element()?.unwrap_or_default();
let mut value = MessagesValue { messages, id_index: HashMap::new() };
value.rebuild_index();
Ok(value)
}
fn visit_map<V>(self, mut map: V) -> Result<MessagesValue, V::Error>
where
V: MapAccess<'de>,
{
let mut messages: Option<Vec<ChatMessage>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Messages => {
if messages.is_some() {
return Err(de::Error::duplicate_field("messages"));
}
messages = Some(map.next_value()?);
}
}
}
let messages = messages.unwrap_or_default();
let mut value = MessagesValue { messages, id_index: HashMap::new() };
value.rebuild_index();
Ok(value)
}
}
const FIELDS: &[&str] = &["messages"];
deserializer.deserialize_struct("MessagesValue", FIELDS, MessagesValueVisitor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_push_new_message() {
let mut mv = MessagesValue::new();
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Hello".to_string(),
metadata: None,
});
assert_eq!(mv.len(), 1);
assert!(!mv.is_empty());
}
#[test]
fn test_push_upsert_replaces() {
let mut mv = MessagesValue::new();
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Original".to_string(),
metadata: None,
});
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Updated".to_string(),
metadata: None,
});
assert_eq!(mv.len(), 1);
assert_eq!(mv.iter().next().unwrap().content, "Updated");
}
#[test]
fn test_extend_with_upsert() {
let mut mv = MessagesValue::new();
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "First".to_string(),
metadata: None,
});
mv.extend(vec![
ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Updated first".to_string(),
metadata: None,
},
ChatMessage {
id: "msg_2".to_string(),
role: MessageRole::Assistant,
content: "Second".to_string(),
metadata: None,
},
]);
assert_eq!(mv.len(), 2);
let msgs: Vec<_> = mv.iter().collect();
assert_eq!(msgs[0].content, "Updated first");
assert_eq!(msgs[1].content, "Second");
}
#[test]
fn test_by_role() {
let mut mv = MessagesValue::new();
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "User msg".to_string(),
metadata: None,
});
mv.push(ChatMessage {
id: "msg_2".to_string(),
role: MessageRole::Assistant,
content: "Assistant msg".to_string(),
metadata: None,
});
mv.push(ChatMessage {
id: "msg_3".to_string(),
role: MessageRole::User,
content: "Another user msg".to_string(),
metadata: None,
});
let user_msgs = mv.by_role(MessageRole::User);
assert_eq!(user_msgs.len(), 2);
let assistant_msgs = mv.by_role(MessageRole::Assistant);
assert_eq!(assistant_msgs.len(), 1);
let system_msgs = mv.by_role(MessageRole::System);
assert_eq!(system_msgs.is_empty(), true);
}
#[test]
fn test_serialization_round_trip() {
let mut mv = MessagesValue::new();
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Hello".to_string(),
metadata: Some(serde_json::json!({"key": "value"})),
});
mv.push(ChatMessage {
id: "msg_2".to_string(),
role: MessageRole::Assistant,
content: "Hi there".to_string(),
metadata: None,
});
let serialized = serde_json::to_string(&mv).unwrap();
let deserialized: MessagesValue = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.len(), 2);
let msgs: Vec<_> = deserialized.iter().collect();
assert_eq!(msgs[0].id, "msg_1");
assert_eq!(msgs[0].content, "Hello");
assert_eq!(msgs[1].id, "msg_2");
assert_eq!(msgs[1].content, "Hi there");
}
#[test]
fn test_dedup_after_deserialization() {
let mut mv = MessagesValue::new();
mv.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Hello".to_string(),
metadata: None,
});
let serialized = serde_json::to_string(&mv).unwrap();
let mut deserialized: MessagesValue = serde_json::from_str(&serialized).unwrap();
deserialized.push(ChatMessage {
id: "msg_1".to_string(),
role: MessageRole::User,
content: "Updated after deser".to_string(),
metadata: None,
});
assert_eq!(deserialized.len(), 1);
assert_eq!(deserialized.iter().next().unwrap().content, "Updated after deser");
}
#[test]
fn test_empty_messages_value() {
let mv = MessagesValue::new();
assert_eq!(mv.len(), 0);
assert!(mv.is_empty());
assert_eq!(mv.iter().count(), 0);
assert_eq!(mv.by_role(MessageRole::User).len(), 0);
}
}