tirea-contract 0.5.0

Agent runtime contracts: 8-phase plugin lifecycle, typed tool traits, and state scope system
Documentation
//! Thread model and persistent history primitives.
//!
//! `Thread` (formerly `AgentState`) represents persisted agent state with
//! message history and patches.

use crate::thread::message::Message;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use tirea_state::{apply_patches, TireaError, TireaResult, TrackedPatch};

/// Persisted thread state with messages and state history.
///
/// `Thread` uses an owned builder pattern: `with_*` methods consume `self`
/// and return a new `Thread` (e.g., `thread.with_message(msg)`).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread {
    /// Unique thread identifier.
    pub id: String,
    /// Owner/resource identifier (e.g., user_id, org_id).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub resource_id: Option<String>,
    /// Parent thread identifier (links child → parent for sub-agent lineage).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub parent_thread_id: Option<String>,
    /// Messages (Arc-wrapped for efficient cloning).
    pub messages: Vec<Arc<Message>>,
    /// Initial/snapshot state.
    pub state: Value,
    /// Patches applied since the last snapshot.
    pub patches: Vec<TrackedPatch>,
    /// Metadata.
    #[serde(default)]
    pub metadata: ThreadMetadata,
}

/// Thread metadata.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ThreadMetadata {
    /// Creation timestamp (unix millis).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub created_at: Option<u64>,
    /// Last update timestamp (unix millis).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub updated_at: Option<u64>,
    /// Persisted state cursor version.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub version: Option<u64>,
    /// Timestamp of the latest committed version.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub version_timestamp: Option<u64>,
    /// Custom metadata.
    #[serde(flatten)]
    pub extra: serde_json::Map<String, Value>,
}

impl Thread {
    /// Create a new thread with the given ID.
    pub fn new(id: impl Into<String>) -> Self {
        Self {
            id: id.into(),
            resource_id: None,
            parent_thread_id: None,
            messages: Vec::new(),
            state: Value::Object(serde_json::Map::new()),
            patches: Vec::new(),
            metadata: ThreadMetadata::default(),
        }
    }

    /// Create a new thread with initial state.
    pub fn with_initial_state(id: impl Into<String>, state: Value) -> Self {
        Self {
            id: id.into(),
            resource_id: None,
            parent_thread_id: None,
            messages: Vec::new(),
            state,
            patches: Vec::new(),
            metadata: ThreadMetadata::default(),
        }
    }

    /// Set the resource_id (pure function, returns new Thread).
    #[must_use]
    pub fn with_resource_id(mut self, resource_id: impl Into<String>) -> Self {
        self.resource_id = Some(resource_id.into());
        self
    }

    /// Set the parent_thread_id (pure function, returns new Thread).
    #[must_use]
    pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
        self.parent_thread_id = Some(parent_thread_id.into());
        self
    }

    /// Add a message to the thread (pure function, returns new Thread).
    ///
    /// Messages are Arc-wrapped for efficient cloning during agent loops.
    #[must_use]
    pub fn with_message(mut self, msg: Message) -> Self {
        self.messages.push(Arc::new(msg));
        self
    }

    /// Add multiple messages (pure function, returns new Thread).
    #[must_use]
    pub fn with_messages(mut self, msgs: impl IntoIterator<Item = Message>) -> Self {
        let arcs: Vec<Arc<Message>> = msgs.into_iter().map(Arc::new).collect();
        self.messages.extend(arcs);
        self
    }

    /// Add a patch to the thread (pure function, returns new Thread).
    #[must_use]
    pub fn with_patch(mut self, patch: TrackedPatch) -> Self {
        self.patches.push(patch);
        self
    }

    /// Add multiple patches (pure function, returns new Thread).
    #[must_use]
    pub fn with_patches(mut self, patches: impl IntoIterator<Item = TrackedPatch>) -> Self {
        self.patches.extend(patches);
        self
    }

    /// Rebuild the current state (base + thread patches).
    pub fn rebuild_state(&self) -> TireaResult<Value> {
        if self.patches.is_empty() {
            return Ok(self.state.clone());
        }
        apply_patches(&self.state, self.patches.iter().map(|p| p.patch()))
    }

    /// Replay state to a specific patch index (0-based).
    ///
    /// - `patch_index = 0`: Returns state after applying the first patch only
    /// - `patch_index = n`: Returns state after applying patches 0..=n
    /// - `patch_index >= patch_count`: Returns error
    ///
    /// This enables time-travel debugging by accessing any historical state point.
    pub fn replay_to(&self, patch_index: usize) -> TireaResult<Value> {
        if patch_index >= self.patches.len() {
            return Err(TireaError::invalid_operation(format!(
                "replay index {patch_index} out of bounds (history len: {})",
                self.patches.len()
            )));
        }

        apply_patches(
            &self.state,
            self.patches[..=patch_index].iter().map(|p| p.patch()),
        )
    }

    /// Create a snapshot, collapsing patches into the base state.
    ///
    /// Returns a new Thread with the current state as base and empty patches.
    pub fn snapshot(self) -> TireaResult<Self> {
        let current_state = self.rebuild_state()?;
        Ok(Self {
            id: self.id,
            resource_id: self.resource_id,
            parent_thread_id: self.parent_thread_id,
            messages: self.messages,
            state: current_state,
            patches: Vec::new(),
            metadata: self.metadata,
        })
    }

    /// Check if a snapshot is needed (e.g., too many patches).
    pub fn needs_snapshot(&self, threshold: usize) -> bool {
        self.patches.len() >= threshold
    }

    /// Get the number of messages.
    pub fn message_count(&self) -> usize {
        self.messages.len()
    }

    /// Get the number of patches.
    pub fn patch_count(&self) -> usize {
        self.patches.len()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;
    use tirea_state::{path, Op, Patch};

    // Tests use Thread directly (the canonical name).

    #[test]
    fn test_thread_with_messages_batch() {
        let msgs = vec![
            Message::user("a"),
            Message::assistant("b"),
            Message::user("c"),
        ];
        let thread = Thread::new("t-1").with_messages(msgs);
        assert_eq!(thread.messages.len(), 3);
        assert_eq!(thread.messages[0].content, "a");
        assert_eq!(thread.messages[2].content, "c");
    }

    #[test]
    fn test_thread_with_patches_batch() {
        let thread = Thread::new("t-1").with_patches(vec![
            TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1)))),
            TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
            TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
        ]);
        assert_eq!(thread.patches.len(), 3);
    }

    #[test]
    fn test_thread_new() {
        let thread = Thread::new("test-1");
        assert_eq!(thread.id, "test-1");
        assert!(thread.resource_id.is_none());
        assert!(thread.messages.is_empty());
        assert!(thread.patches.is_empty());
    }

    #[test]
    fn test_thread_with_resource_id() {
        let thread = Thread::new("t-1").with_resource_id("user-123");
        assert_eq!(thread.resource_id.as_deref(), Some("user-123"));
    }

    #[test]
    fn test_thread_with_initial_state() {
        let state = json!({"counter": 0});
        let thread = Thread::with_initial_state("test-1", state.clone());
        assert_eq!(thread.state, state);
    }

    #[test]
    fn test_thread_with_message() {
        let thread = Thread::new("test-1")
            .with_message(Message::user("Hello"))
            .with_message(Message::assistant("Hi!"));

        assert_eq!(thread.message_count(), 2);
        assert_eq!(thread.messages[0].content, "Hello");
        assert_eq!(thread.messages[1].content, "Hi!");
    }

    #[test]
    fn test_thread_with_patch() {
        let thread = Thread::new("test-1");
        let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1))));

        let thread = thread.with_patch(patch);
        assert_eq!(thread.patch_count(), 1);
    }

    #[test]
    fn test_thread_rebuild_state_empty() {
        let state = json!({"counter": 0});
        let thread = Thread::with_initial_state("test-1", state.clone());

        let rebuilt = thread.rebuild_state().unwrap();
        assert_eq!(rebuilt, state);
    }

    #[test]
    fn test_thread_rebuild_state_with_patches() {
        let state = json!({"counter": 0});
        let thread = Thread::with_initial_state("test-1", state)
            .with_patch(TrackedPatch::new(
                Patch::new().with_op(Op::set(path!("counter"), json!(1))),
            ))
            .with_patch(TrackedPatch::new(
                Patch::new().with_op(Op::set(path!("name"), json!("test"))),
            ));

        let rebuilt = thread.rebuild_state().unwrap();
        assert_eq!(rebuilt["counter"], 1);
        assert_eq!(rebuilt["name"], "test");
    }

    #[test]
    fn test_thread_snapshot() {
        let state = json!({"counter": 0});
        let thread = Thread::with_initial_state("test-1", state).with_patch(TrackedPatch::new(
            Patch::new().with_op(Op::set(path!("counter"), json!(5))),
        ));

        assert_eq!(thread.patch_count(), 1);

        let snapshotted = thread.snapshot().unwrap();
        assert_eq!(snapshotted.patch_count(), 0);
        assert_eq!(snapshotted.state["counter"], 5);
    }

    #[test]
    fn test_thread_needs_snapshot() {
        let thread = Thread::new("test-1");
        assert!(!thread.needs_snapshot(10));

        let thread = (0..10).fold(thread, |s, i| {
            s.with_patch(TrackedPatch::new(
                Patch::new().with_op(Op::set(path!("field").key(i.to_string()), json!(i))),
            ))
        });

        assert!(thread.needs_snapshot(10));
        assert!(!thread.needs_snapshot(20));
    }

    #[test]
    fn test_thread_serialization() {
        let thread = Thread::new("test-1").with_message(Message::user("Hello"));

        let json_str = serde_json::to_string(&thread).unwrap();
        let restored: Thread = serde_json::from_str(&json_str).unwrap();

        assert_eq!(restored.id, "test-1");
        assert_eq!(restored.message_count(), 1);
    }

    #[test]
    fn test_state_persists_after_serialization() {
        let thread = Thread::with_initial_state("test-1", json!({"counter": 0})).with_patch(
            TrackedPatch::new(Patch::new().with_op(Op::set(path!("counter"), json!(5)))),
        );

        let json_str = serde_json::to_string(&thread).unwrap();
        let restored: Thread = serde_json::from_str(&json_str).unwrap();

        let rebuilt = restored.rebuild_state().unwrap();
        assert_eq!(
            rebuilt["counter"], 5,
            "persisted state should survive serialization"
        );
    }

    #[test]
    fn test_thread_serialization_includes_resource_id() {
        let thread = Thread::new("t-1").with_resource_id("org-42");
        let json_str = serde_json::to_string(&thread).unwrap();
        assert!(json_str.contains("org-42"));

        let restored: Thread = serde_json::from_str(&json_str).unwrap();
        assert_eq!(restored.resource_id.as_deref(), Some("org-42"));
    }

    #[test]
    fn test_thread_replay_to() {
        let state = json!({"counter": 0});
        let thread = Thread::with_initial_state("test-1", state)
            .with_patch(TrackedPatch::new(
                Patch::new().with_op(Op::set(path!("counter"), json!(10))),
            ))
            .with_patch(TrackedPatch::new(
                Patch::new().with_op(Op::set(path!("counter"), json!(20))),
            ))
            .with_patch(TrackedPatch::new(
                Patch::new().with_op(Op::set(path!("counter"), json!(30))),
            ));

        let state_at_0 = thread.replay_to(0).unwrap();
        assert_eq!(state_at_0["counter"], 10);

        let state_at_1 = thread.replay_to(1).unwrap();
        assert_eq!(state_at_1["counter"], 20);

        let state_at_2 = thread.replay_to(2).unwrap();
        assert_eq!(state_at_2["counter"], 30);

        let err = thread.replay_to(100).unwrap_err();
        assert!(err
            .to_string()
            .contains("replay index 100 out of bounds (history len: 3)"));
    }

    #[test]
    fn test_thread_replay_to_empty() {
        let state = json!({"counter": 0});
        let thread = Thread::with_initial_state("test-1", state.clone());

        let err = thread.replay_to(0).unwrap_err();
        assert!(err
            .to_string()
            .contains("replay index 0 out of bounds (history len: 0)"));
    }
}