use crate::thread::message::Message;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use tirea_state::{apply_patches, TireaError, TireaResult, TrackedPatch};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub resource_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_thread_id: Option<String>,
pub messages: Vec<Arc<Message>>,
pub state: Value,
pub patches: Vec<TrackedPatch>,
#[serde(default)]
pub metadata: ThreadMetadata,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ThreadMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub created_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub updated_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub version: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub version_timestamp: Option<u64>,
#[serde(flatten)]
pub extra: serde_json::Map<String, Value>,
}
impl Thread {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
resource_id: None,
parent_thread_id: None,
messages: Vec::new(),
state: Value::Object(serde_json::Map::new()),
patches: Vec::new(),
metadata: ThreadMetadata::default(),
}
}
pub fn with_initial_state(id: impl Into<String>, state: Value) -> Self {
Self {
id: id.into(),
resource_id: None,
parent_thread_id: None,
messages: Vec::new(),
state,
patches: Vec::new(),
metadata: ThreadMetadata::default(),
}
}
#[must_use]
pub fn with_resource_id(mut self, resource_id: impl Into<String>) -> Self {
self.resource_id = Some(resource_id.into());
self
}
#[must_use]
pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
self.parent_thread_id = Some(parent_thread_id.into());
self
}
#[must_use]
pub fn with_message(mut self, msg: Message) -> Self {
self.messages.push(Arc::new(msg));
self
}
#[must_use]
pub fn with_messages(mut self, msgs: impl IntoIterator<Item = Message>) -> Self {
let arcs: Vec<Arc<Message>> = msgs.into_iter().map(Arc::new).collect();
self.messages.extend(arcs);
self
}
#[must_use]
pub fn with_patch(mut self, patch: TrackedPatch) -> Self {
self.patches.push(patch);
self
}
#[must_use]
pub fn with_patches(mut self, patches: impl IntoIterator<Item = TrackedPatch>) -> Self {
self.patches.extend(patches);
self
}
pub fn rebuild_state(&self) -> TireaResult<Value> {
if self.patches.is_empty() {
return Ok(self.state.clone());
}
apply_patches(&self.state, self.patches.iter().map(|p| p.patch()))
}
pub fn replay_to(&self, patch_index: usize) -> TireaResult<Value> {
if patch_index >= self.patches.len() {
return Err(TireaError::invalid_operation(format!(
"replay index {patch_index} out of bounds (history len: {})",
self.patches.len()
)));
}
apply_patches(
&self.state,
self.patches[..=patch_index].iter().map(|p| p.patch()),
)
}
pub fn snapshot(self) -> TireaResult<Self> {
let current_state = self.rebuild_state()?;
Ok(Self {
id: self.id,
resource_id: self.resource_id,
parent_thread_id: self.parent_thread_id,
messages: self.messages,
state: current_state,
patches: Vec::new(),
metadata: self.metadata,
})
}
pub fn needs_snapshot(&self, threshold: usize) -> bool {
self.patches.len() >= threshold
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn patch_count(&self) -> usize {
self.patches.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use tirea_state::{path, Op, Patch};
#[test]
fn test_thread_with_messages_batch() {
let msgs = vec![
Message::user("a"),
Message::assistant("b"),
Message::user("c"),
];
let thread = Thread::new("t-1").with_messages(msgs);
assert_eq!(thread.messages.len(), 3);
assert_eq!(thread.messages[0].content, "a");
assert_eq!(thread.messages[2].content, "c");
}
#[test]
fn test_thread_with_patches_batch() {
let thread = Thread::new("t-1").with_patches(vec![
TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1)))),
TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
]);
assert_eq!(thread.patches.len(), 3);
}
#[test]
fn test_thread_new() {
let thread = Thread::new("test-1");
assert_eq!(thread.id, "test-1");
assert!(thread.resource_id.is_none());
assert!(thread.messages.is_empty());
assert!(thread.patches.is_empty());
}
#[test]
fn test_thread_with_resource_id() {
let thread = Thread::new("t-1").with_resource_id("user-123");
assert_eq!(thread.resource_id.as_deref(), Some("user-123"));
}
#[test]
fn test_thread_with_initial_state() {
let state = json!({"counter": 0});
let thread = Thread::with_initial_state("test-1", state.clone());
assert_eq!(thread.state, state);
}
#[test]
fn test_thread_with_message() {
let thread = Thread::new("test-1")
.with_message(Message::user("Hello"))
.with_message(Message::assistant("Hi!"));
assert_eq!(thread.message_count(), 2);
assert_eq!(thread.messages[0].content, "Hello");
assert_eq!(thread.messages[1].content, "Hi!");
}
#[test]
fn test_thread_with_patch() {
let thread = Thread::new("test-1");
let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1))));
let thread = thread.with_patch(patch);
assert_eq!(thread.patch_count(), 1);
}
#[test]
fn test_thread_rebuild_state_empty() {
let state = json!({"counter": 0});
let thread = Thread::with_initial_state("test-1", state.clone());
let rebuilt = thread.rebuild_state().unwrap();
assert_eq!(rebuilt, state);
}
#[test]
fn test_thread_rebuild_state_with_patches() {
let state = json!({"counter": 0});
let thread = Thread::with_initial_state("test-1", state)
.with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("counter"), json!(1))),
))
.with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("name"), json!("test"))),
));
let rebuilt = thread.rebuild_state().unwrap();
assert_eq!(rebuilt["counter"], 1);
assert_eq!(rebuilt["name"], "test");
}
#[test]
fn test_thread_snapshot() {
let state = json!({"counter": 0});
let thread = Thread::with_initial_state("test-1", state).with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("counter"), json!(5))),
));
assert_eq!(thread.patch_count(), 1);
let snapshotted = thread.snapshot().unwrap();
assert_eq!(snapshotted.patch_count(), 0);
assert_eq!(snapshotted.state["counter"], 5);
}
#[test]
fn test_thread_needs_snapshot() {
let thread = Thread::new("test-1");
assert!(!thread.needs_snapshot(10));
let thread = (0..10).fold(thread, |s, i| {
s.with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("field").key(i.to_string()), json!(i))),
))
});
assert!(thread.needs_snapshot(10));
assert!(!thread.needs_snapshot(20));
}
#[test]
fn test_thread_serialization() {
let thread = Thread::new("test-1").with_message(Message::user("Hello"));
let json_str = serde_json::to_string(&thread).unwrap();
let restored: Thread = serde_json::from_str(&json_str).unwrap();
assert_eq!(restored.id, "test-1");
assert_eq!(restored.message_count(), 1);
}
#[test]
fn test_state_persists_after_serialization() {
let thread = Thread::with_initial_state("test-1", json!({"counter": 0})).with_patch(
TrackedPatch::new(Patch::new().with_op(Op::set(path!("counter"), json!(5)))),
);
let json_str = serde_json::to_string(&thread).unwrap();
let restored: Thread = serde_json::from_str(&json_str).unwrap();
let rebuilt = restored.rebuild_state().unwrap();
assert_eq!(
rebuilt["counter"], 5,
"persisted state should survive serialization"
);
}
#[test]
fn test_thread_serialization_includes_resource_id() {
let thread = Thread::new("t-1").with_resource_id("org-42");
let json_str = serde_json::to_string(&thread).unwrap();
assert!(json_str.contains("org-42"));
let restored: Thread = serde_json::from_str(&json_str).unwrap();
assert_eq!(restored.resource_id.as_deref(), Some("org-42"));
}
#[test]
fn test_thread_replay_to() {
let state = json!({"counter": 0});
let thread = Thread::with_initial_state("test-1", state)
.with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("counter"), json!(10))),
))
.with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("counter"), json!(20))),
))
.with_patch(TrackedPatch::new(
Patch::new().with_op(Op::set(path!("counter"), json!(30))),
));
let state_at_0 = thread.replay_to(0).unwrap();
assert_eq!(state_at_0["counter"], 10);
let state_at_1 = thread.replay_to(1).unwrap();
assert_eq!(state_at_1["counter"], 20);
let state_at_2 = thread.replay_to(2).unwrap();
assert_eq!(state_at_2["counter"], 30);
let err = thread.replay_to(100).unwrap_err();
assert!(err
.to_string()
.contains("replay index 100 out of bounds (history len: 3)"));
}
#[test]
fn test_thread_replay_to_empty() {
let state = json!({"counter": 0});
let thread = Thread::with_initial_state("test-1", state.clone());
let err = thread.replay_to(0).unwrap_err();
assert!(err
.to_string()
.contains("replay index 0 out of bounds (history len: 0)"));
}
}