use std::collections::HashMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use crate::kernel::state::KernelState;
use crate::schemas::messages::Message;
pub trait State: Clone + Send + Sync + Serialize + DeserializeOwned {
fn merge(&self, other: &Self) -> Self;
}
pub type StateUpdate = HashMap<String, Value>;
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct MessagesState {
pub messages: Vec<Message>,
}
impl MessagesState {
pub fn new() -> Self {
Self::default()
}
pub fn with_messages(messages: Vec<Message>) -> Self {
Self { messages }
}
}
impl State for MessagesState {
fn merge(&self, other: &Self) -> Self {
let mut messages = self.messages.clone();
messages.extend(other.messages.clone());
Self { messages }
}
}
impl KernelState for MessagesState {
fn version(&self) -> u32 {
1
}
}
pub fn messages_state_update(messages: Vec<Message>) -> StateUpdate {
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(messages).unwrap_or(Value::Array(vec![])),
);
update
}
pub fn extract_messages_from_update(update: &StateUpdate) -> Vec<Message> {
update
.get("messages")
.and_then(|v| {
if let Some(arr) = v.as_array() {
arr.iter()
.filter_map(|item| serde_json::from_value::<Message>(item.clone()).ok())
.collect::<Vec<_>>()
.into()
} else {
None
}
})
.unwrap_or_default()
}
pub fn apply_update_to_messages_state(
state: &MessagesState,
update: &StateUpdate,
) -> MessagesState {
let mut new_state = state.clone();
if let Some(messages_value) = update.get("messages") {
if let Ok(new_messages) = serde_json::from_value::<Vec<Message>>(messages_value.clone()) {
new_state.messages.extend(new_messages);
}
}
new_state
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_messages_state_merge() {
let state1 = MessagesState {
messages: vec![Message::new_human_message("Hello")],
};
let state2 = MessagesState {
messages: vec![Message::new_ai_message("Hi there!")],
};
let merged = state1.merge(&state2);
assert_eq!(merged.messages.len(), 2);
assert_eq!(merged.messages[0].content, "Hello");
assert_eq!(merged.messages[1].content, "Hi there!");
}
#[test]
fn test_messages_state_update() {
let messages = vec![
Message::new_human_message("Hello"),
Message::new_ai_message("Hi!"),
];
let update = messages_state_update(messages.clone());
assert!(update.contains_key("messages"));
let extracted = extract_messages_from_update(&update);
assert_eq!(extracted.len(), 2);
}
#[test]
fn test_apply_update() {
let state = MessagesState {
messages: vec![Message::new_human_message("Hello")],
};
let update = messages_state_update(vec![Message::new_ai_message("Hi!")]);
let new_state = apply_update_to_messages_state(&state, &update);
assert_eq!(new_state.messages.len(), 2);
}
}