use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub id: String,
pub role: Role,
pub content: Content,
pub tool_calls: Vec<ToolCall>,
pub tool_call_id: Option<String>,
pub name: Option<String>,
pub usage: Option<TokenUsage>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum Role {
System,
Human,
Ai,
Tool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Content {
Text(String),
MultiPart(Vec<ContentPart>),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ContentPart {
Text { text: String },
Image(ImageData),
Thinking {
text: String,
signature: Option<String>,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ImageData {
pub media_type: String,
pub source: ImageSource,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ImageSource {
Base64(String),
Url(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
pub const REMOVE_ALL_MESSAGES: &str = "__remove_all__";
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct MessagesState {
pub messages: Vec<Message>,
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct MessagesStateUpdate {
pub messages: Option<Vec<Message>>,
}
impl crate::State for MessagesState {
type Update = MessagesStateUpdate;
type FieldVersions = crate::state::FieldVersions;
fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
let mut changed = crate::FieldsChanged(0);
if let Some(messages) = update.messages {
messages_reducer(&mut self.messages, messages);
changed.0 |= 1 << 0;
}
changed
}
fn reset_ephemeral(&mut self) {
}
}
impl MessagesState {
pub fn try_apply_messages(
&mut self,
update: MessagesStateUpdate,
) -> Result<crate::FieldsChanged, crate::error::InvalidUpdateError> {
Ok(crate::State::apply(self, update))
}
}
pub fn messages_reducer(current: &mut Vec<Message>, incoming: Vec<Message>) {
for msg in incoming {
if msg.id == REMOVE_ALL_MESSAGES {
current.clear();
} else if msg.id.starts_with("__remove__:") {
let target_id = &msg.id["__remove__:".len()..];
current.retain(|m| m.id != target_id);
} else if let Some(existing) = current.iter_mut().find(|m| m.id == msg.id) {
*existing = msg;
} else {
current.push(msg);
}
}
}
impl Message {
pub fn human(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: Role::Human,
content: Content::Text(content.into()),
tool_calls: vec![],
tool_call_id: None,
name: None,
usage: None,
}
}
pub fn ai(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: Role::Ai,
content: Content::Text(content.into()),
tool_calls: vec![],
tool_call_id: None,
name: None,
usage: None,
}
}
pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: Role::Ai,
content: Content::Text(content.into()),
tool_calls,
tool_call_id: None,
name: None,
usage: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: Role::System,
content: Content::Text(content.into()),
tool_calls: vec![],
tool_call_id: None,
name: None,
usage: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: Role::Tool,
content: Content::Text(content.into()),
tool_calls: vec![],
tool_call_id: Some(tool_call_id.into()),
name: None,
usage: None,
}
}
#[must_use]
pub const fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
#[must_use]
pub fn content_text(&self) -> &str {
match &self.content {
Content::Text(s) => s,
Content::MultiPart(parts) => parts
.iter()
.find_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.unwrap_or(""),
}
}
#[must_use]
pub fn remove(id: impl Into<String>) -> Self {
let id = id.into();
Self {
id: format!("__remove__:{id}"),
role: Role::System,
content: Content::Text(String::new()),
tool_calls: vec![],
tool_call_id: None,
name: None,
usage: None,
}
}
#[must_use]
pub fn remove_all() -> Self {
Self {
id: REMOVE_ALL_MESSAGES.to_string(),
role: Role::System,
content: Content::Text(String::new()),
tool_calls: vec![],
tool_call_id: None,
name: None,
usage: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::trait_::State;
#[test]
fn test_messages_state_default() {
let state = MessagesState::default();
assert!(state.messages.is_empty());
}
#[test]
fn test_messages_state_apply() {
let mut state = MessagesState::default();
let update = MessagesStateUpdate {
messages: Some(vec![Message::human("Hello")]),
};
let changed = state.apply(update);
assert_eq!(state.messages.len(), 1);
assert!(!changed.is_empty());
assert!(changed.has_field(0));
}
#[test]
fn test_messages_state_apply_merge() {
let mut state = MessagesState {
messages: vec![Message::human("Hello")],
};
let update = MessagesStateUpdate {
messages: Some(vec![Message::ai("Hi there!")]),
};
state.apply(update);
assert_eq!(state.messages.len(), 2);
}
#[test]
fn test_messages_state_apply_none() {
let mut state = MessagesState {
messages: vec![Message::human("Hello")],
};
let update = MessagesStateUpdate { messages: None };
let changed = state.apply(update);
assert_eq!(state.messages.len(), 1);
assert!(changed.is_empty());
}
#[test]
fn test_messages_state_reset_ephemeral() {
let mut state = MessagesState {
messages: vec![Message::human("Hello")],
};
state.reset_ephemeral();
assert_eq!(state.messages.len(), 1);
}
#[test]
fn test_messages_state_serialization() {
let state = MessagesState {
messages: vec![Message::human("Hello")],
};
let json = serde_json::to_string(&state).unwrap();
let deserialized: MessagesState = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.messages.len(), 1);
assert_eq!(deserialized.messages[0].role, Role::Human);
}
#[test]
fn test_messages_state_update_serialization() {
let update = MessagesStateUpdate {
messages: Some(vec![Message::ai("Hi!")]),
};
let json = serde_json::to_string(&update).unwrap();
let deserialized: MessagesStateUpdate = serde_json::from_str(&json).unwrap();
assert!(deserialized.messages.is_some());
assert_eq!(deserialized.messages.unwrap().len(), 1);
}
}