Skip to main content

tirea_contract/thread/
model.rs

1//! Thread model and persistent history primitives.
2//!
3//! `Thread` (formerly `AgentState`) represents persisted agent state with
4//! message history and patches.
5
6use crate::thread::message::Message;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::sync::Arc;
10use tirea_state::{apply_patches, TireaError, TireaResult, TrackedPatch};
11
12/// Persisted thread state with messages and state history.
13///
14/// `Thread` uses an owned builder pattern: `with_*` methods consume `self`
15/// and return a new `Thread` (e.g., `thread.with_message(msg)`).
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Thread {
18    /// Unique thread identifier.
19    pub id: String,
20    /// Owner/resource identifier (e.g., user_id, org_id).
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub resource_id: Option<String>,
23    /// Parent thread identifier (links child → parent for sub-agent lineage).
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub parent_thread_id: Option<String>,
26    /// Messages (Arc-wrapped for efficient cloning).
27    pub messages: Vec<Arc<Message>>,
28    /// Initial/snapshot state.
29    pub state: Value,
30    /// Patches applied since the last snapshot.
31    pub patches: Vec<TrackedPatch>,
32    /// Metadata.
33    #[serde(default)]
34    pub metadata: ThreadMetadata,
35}
36
37/// Thread metadata.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct ThreadMetadata {
40    /// Creation timestamp (unix millis).
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub created_at: Option<u64>,
43    /// Last update timestamp (unix millis).
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub updated_at: Option<u64>,
46    /// Persisted state cursor version.
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub version: Option<u64>,
49    /// Timestamp of the latest committed version.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub version_timestamp: Option<u64>,
52    /// Custom metadata.
53    #[serde(flatten)]
54    pub extra: serde_json::Map<String, Value>,
55}
56
57impl Thread {
58    /// Create a new thread with the given ID.
59    pub fn new(id: impl Into<String>) -> Self {
60        Self {
61            id: id.into(),
62            resource_id: None,
63            parent_thread_id: None,
64            messages: Vec::new(),
65            state: Value::Object(serde_json::Map::new()),
66            patches: Vec::new(),
67            metadata: ThreadMetadata::default(),
68        }
69    }
70
71    /// Create a new thread with initial state.
72    pub fn with_initial_state(id: impl Into<String>, state: Value) -> Self {
73        Self {
74            id: id.into(),
75            resource_id: None,
76            parent_thread_id: None,
77            messages: Vec::new(),
78            state,
79            patches: Vec::new(),
80            metadata: ThreadMetadata::default(),
81        }
82    }
83
84    /// Set the resource_id (pure function, returns new Thread).
85    #[must_use]
86    pub fn with_resource_id(mut self, resource_id: impl Into<String>) -> Self {
87        self.resource_id = Some(resource_id.into());
88        self
89    }
90
91    /// Set the parent_thread_id (pure function, returns new Thread).
92    #[must_use]
93    pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
94        self.parent_thread_id = Some(parent_thread_id.into());
95        self
96    }
97
98    /// Add a message to the thread (pure function, returns new Thread).
99    ///
100    /// Messages are Arc-wrapped for efficient cloning during agent loops.
101    #[must_use]
102    pub fn with_message(mut self, msg: Message) -> Self {
103        self.messages.push(Arc::new(msg));
104        self
105    }
106
107    /// Add multiple messages (pure function, returns new Thread).
108    #[must_use]
109    pub fn with_messages(mut self, msgs: impl IntoIterator<Item = Message>) -> Self {
110        let arcs: Vec<Arc<Message>> = msgs.into_iter().map(Arc::new).collect();
111        self.messages.extend(arcs);
112        self
113    }
114
115    /// Add a patch to the thread (pure function, returns new Thread).
116    #[must_use]
117    pub fn with_patch(mut self, patch: TrackedPatch) -> Self {
118        self.patches.push(patch);
119        self
120    }
121
122    /// Add multiple patches (pure function, returns new Thread).
123    #[must_use]
124    pub fn with_patches(mut self, patches: impl IntoIterator<Item = TrackedPatch>) -> Self {
125        self.patches.extend(patches);
126        self
127    }
128
129    /// Rebuild the current state (base + thread patches).
130    pub fn rebuild_state(&self) -> TireaResult<Value> {
131        if self.patches.is_empty() {
132            return Ok(self.state.clone());
133        }
134        apply_patches(&self.state, self.patches.iter().map(|p| p.patch()))
135    }
136
137    /// Replay state to a specific patch index (0-based).
138    ///
139    /// - `patch_index = 0`: Returns state after applying the first patch only
140    /// - `patch_index = n`: Returns state after applying patches 0..=n
141    /// - `patch_index >= patch_count`: Returns error
142    ///
143    /// This enables time-travel debugging by accessing any historical state point.
144    pub fn replay_to(&self, patch_index: usize) -> TireaResult<Value> {
145        if patch_index >= self.patches.len() {
146            return Err(TireaError::invalid_operation(format!(
147                "replay index {patch_index} out of bounds (history len: {})",
148                self.patches.len()
149            )));
150        }
151
152        apply_patches(
153            &self.state,
154            self.patches[..=patch_index].iter().map(|p| p.patch()),
155        )
156    }
157
158    /// Create a snapshot, collapsing patches into the base state.
159    ///
160    /// Returns a new Thread with the current state as base and empty patches.
161    pub fn snapshot(self) -> TireaResult<Self> {
162        let current_state = self.rebuild_state()?;
163        Ok(Self {
164            id: self.id,
165            resource_id: self.resource_id,
166            parent_thread_id: self.parent_thread_id,
167            messages: self.messages,
168            state: current_state,
169            patches: Vec::new(),
170            metadata: self.metadata,
171        })
172    }
173
174    /// Check if a snapshot is needed (e.g., too many patches).
175    pub fn needs_snapshot(&self, threshold: usize) -> bool {
176        self.patches.len() >= threshold
177    }
178
179    /// Get the number of messages.
180    pub fn message_count(&self) -> usize {
181        self.messages.len()
182    }
183
184    /// Get the number of patches.
185    pub fn patch_count(&self) -> usize {
186        self.patches.len()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use serde_json::json;
194    use tirea_state::{path, Op, Patch};
195
196    // Tests use Thread directly (the canonical name).
197
198    #[test]
199    fn test_thread_with_messages_batch() {
200        let msgs = vec![
201            Message::user("a"),
202            Message::assistant("b"),
203            Message::user("c"),
204        ];
205        let thread = Thread::new("t-1").with_messages(msgs);
206        assert_eq!(thread.messages.len(), 3);
207        assert_eq!(thread.messages[0].content, "a");
208        assert_eq!(thread.messages[2].content, "c");
209    }
210
211    #[test]
212    fn test_thread_with_patches_batch() {
213        let thread = Thread::new("t-1").with_patches(vec![
214            TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1)))),
215            TrackedPatch::new(Patch::new().with_op(Op::set(path!("b"), json!(2)))),
216            TrackedPatch::new(Patch::new().with_op(Op::set(path!("c"), json!(3)))),
217        ]);
218        assert_eq!(thread.patches.len(), 3);
219    }
220
221    #[test]
222    fn test_thread_new() {
223        let thread = Thread::new("test-1");
224        assert_eq!(thread.id, "test-1");
225        assert!(thread.resource_id.is_none());
226        assert!(thread.messages.is_empty());
227        assert!(thread.patches.is_empty());
228    }
229
230    #[test]
231    fn test_thread_with_resource_id() {
232        let thread = Thread::new("t-1").with_resource_id("user-123");
233        assert_eq!(thread.resource_id.as_deref(), Some("user-123"));
234    }
235
236    #[test]
237    fn test_thread_with_initial_state() {
238        let state = json!({"counter": 0});
239        let thread = Thread::with_initial_state("test-1", state.clone());
240        assert_eq!(thread.state, state);
241    }
242
243    #[test]
244    fn test_thread_with_message() {
245        let thread = Thread::new("test-1")
246            .with_message(Message::user("Hello"))
247            .with_message(Message::assistant("Hi!"));
248
249        assert_eq!(thread.message_count(), 2);
250        assert_eq!(thread.messages[0].content, "Hello");
251        assert_eq!(thread.messages[1].content, "Hi!");
252    }
253
254    #[test]
255    fn test_thread_with_patch() {
256        let thread = Thread::new("test-1");
257        let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(1))));
258
259        let thread = thread.with_patch(patch);
260        assert_eq!(thread.patch_count(), 1);
261    }
262
263    #[test]
264    fn test_thread_rebuild_state_empty() {
265        let state = json!({"counter": 0});
266        let thread = Thread::with_initial_state("test-1", state.clone());
267
268        let rebuilt = thread.rebuild_state().unwrap();
269        assert_eq!(rebuilt, state);
270    }
271
272    #[test]
273    fn test_thread_rebuild_state_with_patches() {
274        let state = json!({"counter": 0});
275        let thread = Thread::with_initial_state("test-1", state)
276            .with_patch(TrackedPatch::new(
277                Patch::new().with_op(Op::set(path!("counter"), json!(1))),
278            ))
279            .with_patch(TrackedPatch::new(
280                Patch::new().with_op(Op::set(path!("name"), json!("test"))),
281            ));
282
283        let rebuilt = thread.rebuild_state().unwrap();
284        assert_eq!(rebuilt["counter"], 1);
285        assert_eq!(rebuilt["name"], "test");
286    }
287
288    #[test]
289    fn test_thread_snapshot() {
290        let state = json!({"counter": 0});
291        let thread = Thread::with_initial_state("test-1", state).with_patch(TrackedPatch::new(
292            Patch::new().with_op(Op::set(path!("counter"), json!(5))),
293        ));
294
295        assert_eq!(thread.patch_count(), 1);
296
297        let snapshotted = thread.snapshot().unwrap();
298        assert_eq!(snapshotted.patch_count(), 0);
299        assert_eq!(snapshotted.state["counter"], 5);
300    }
301
302    #[test]
303    fn test_thread_needs_snapshot() {
304        let thread = Thread::new("test-1");
305        assert!(!thread.needs_snapshot(10));
306
307        let thread = (0..10).fold(thread, |s, i| {
308            s.with_patch(TrackedPatch::new(
309                Patch::new().with_op(Op::set(path!("field").key(i.to_string()), json!(i))),
310            ))
311        });
312
313        assert!(thread.needs_snapshot(10));
314        assert!(!thread.needs_snapshot(20));
315    }
316
317    #[test]
318    fn test_thread_serialization() {
319        let thread = Thread::new("test-1").with_message(Message::user("Hello"));
320
321        let json_str = serde_json::to_string(&thread).unwrap();
322        let restored: Thread = serde_json::from_str(&json_str).unwrap();
323
324        assert_eq!(restored.id, "test-1");
325        assert_eq!(restored.message_count(), 1);
326    }
327
328    #[test]
329    fn test_state_persists_after_serialization() {
330        let thread = Thread::with_initial_state("test-1", json!({"counter": 0})).with_patch(
331            TrackedPatch::new(Patch::new().with_op(Op::set(path!("counter"), json!(5)))),
332        );
333
334        let json_str = serde_json::to_string(&thread).unwrap();
335        let restored: Thread = serde_json::from_str(&json_str).unwrap();
336
337        let rebuilt = restored.rebuild_state().unwrap();
338        assert_eq!(
339            rebuilt["counter"], 5,
340            "persisted state should survive serialization"
341        );
342    }
343
344    #[test]
345    fn test_thread_serialization_includes_resource_id() {
346        let thread = Thread::new("t-1").with_resource_id("org-42");
347        let json_str = serde_json::to_string(&thread).unwrap();
348        assert!(json_str.contains("org-42"));
349
350        let restored: Thread = serde_json::from_str(&json_str).unwrap();
351        assert_eq!(restored.resource_id.as_deref(), Some("org-42"));
352    }
353
354    #[test]
355    fn test_thread_replay_to() {
356        let state = json!({"counter": 0});
357        let thread = Thread::with_initial_state("test-1", state)
358            .with_patch(TrackedPatch::new(
359                Patch::new().with_op(Op::set(path!("counter"), json!(10))),
360            ))
361            .with_patch(TrackedPatch::new(
362                Patch::new().with_op(Op::set(path!("counter"), json!(20))),
363            ))
364            .with_patch(TrackedPatch::new(
365                Patch::new().with_op(Op::set(path!("counter"), json!(30))),
366            ));
367
368        let state_at_0 = thread.replay_to(0).unwrap();
369        assert_eq!(state_at_0["counter"], 10);
370
371        let state_at_1 = thread.replay_to(1).unwrap();
372        assert_eq!(state_at_1["counter"], 20);
373
374        let state_at_2 = thread.replay_to(2).unwrap();
375        assert_eq!(state_at_2["counter"], 30);
376
377        let err = thread.replay_to(100).unwrap_err();
378        assert!(err
379            .to_string()
380            .contains("replay index 100 out of bounds (history len: 3)"));
381    }
382
383    #[test]
384    fn test_thread_replay_to_empty() {
385        let state = json!({"counter": 0});
386        let thread = Thread::with_initial_state("test-1", state.clone());
387
388        let err = thread.replay_to(0).unwrap_err();
389        assert!(err
390            .to_string()
391            .contains("replay index 0 out of bounds (history len: 0)"));
392    }
393}