Skip to main content

awaken_server_contract/contract/
store_traits.rs

1//! Server/store-owned thread & run persistence traits.
2//!
3//! Moved out of `awaken-runtime-contract`: the runtime perceives durable
4//! storage only through the narrow `RuntimeCheckpointStore` read port and the
5//! `CommitCoordinator` write boundary. The full CRUD + query surface
6//! (`ThreadStore`, `RunStore`, `ThreadRunStore`) is a server/store concern and
7//! lives here. Storage data types, query/page types, and pagination helpers
8//! stay in runtime-contract and are pulled in via the glob below.
9
10use crate::contract::storage::*;
11use async_trait::async_trait;
12use awaken_runtime_contract::contract::message::{Message, MessageRecord};
13use awaken_runtime_contract::thread::{Thread, normalize_lineage_id};
14
15// ── ThreadStore ─────────────────────────────────────────────────────
16
17/// Thread read/write persistence.
18///
19/// Thread metadata and messages are stored separately. Messages have a
20/// single source of truth through `load_messages` / `save_messages`.
21#[async_trait]
22pub trait ThreadStore: Send + Sync {
23    /// Load a thread by ID. Returns `None` if not found.
24    async fn load_thread(&self, thread_id: &str) -> Result<Option<Thread>, StorageError>;
25
26    /// Persist a thread (create or overwrite).
27    ///
28    /// This is a low-level persistence primitive. Callers that change
29    /// parent-child relationships should use [`ThreadStore::save_thread_validated`]
30    /// so hierarchy invariants are checked against current store state.
31    async fn save_thread(&self, thread: &Thread) -> Result<(), StorageError>;
32
33    /// Persist a thread after validating parent-child hierarchy invariants.
34    ///
35    /// The default implementation validates and then delegates to
36    /// [`ThreadStore::save_thread`]. It is not atomic across those steps.
37    /// with a backend-native atomic or fenced implementation.
38    async fn save_thread_validated(&self, thread: &Thread) -> Result<(), StorageError> {
39        self.validate_thread_hierarchy(&thread.id, thread.parent_thread_id.as_deref())
40            .await?;
41        self.save_thread(thread).await
42    }
43
44    /// Delete a thread and its associated messages.
45    ///
46    /// This is a low-level delete primitive. Callers that need hierarchy-aware
47    /// child handling should use [`ThreadStore::delete_thread_with_strategy`].
48    async fn delete_thread(&self, thread_id: &str) -> Result<(), StorageError>;
49
50    /// Persist thread-scoped state for `thread_id` (overwrite the prior value).
51    ///
52    /// Default is a no-op for stores that do not yet persist thread-scoped
53    /// state; the runtime then keeps state on the run record only. Production
54    /// stores override this and pair it with [`ThreadStore::load_thread_state`].
55    async fn save_thread_state(
56        &self,
57        thread_id: &str,
58        state: &awaken_runtime_contract::state::PersistedState,
59    ) -> Result<(), StorageError> {
60        let _ = (thread_id, state);
61        Ok(())
62    }
63
64    /// Load thread-scoped state for `thread_id`, if any. Default `None`.
65    async fn load_thread_state(
66        &self,
67        thread_id: &str,
68    ) -> Result<Option<awaken_runtime_contract::state::PersistedState>, StorageError> {
69        let _ = thread_id;
70        Ok(None)
71    }
72
73    /// Delete a thread while managing direct and transitive children.
74    ///
75    /// The default implementation performs multiple low-level writes and is
76    /// not atomic across child updates and the final delete. Production stores
77    /// with concurrent writers should override this method with a transactional
78    /// or otherwise fenced implementation.
79    async fn delete_thread_with_strategy(
80        &self,
81        thread_id: &str,
82        strategy: ChildThreadDeleteStrategy,
83    ) -> Result<(), StorageError> {
84        if self.load_thread(thread_id).await?.is_none() {
85            return Err(StorageError::NotFound(thread_id.to_owned()));
86        }
87
88        match strategy {
89            ChildThreadDeleteStrategy::Reject => {
90                let children = self.list_child_threads(thread_id).await?;
91                if !children.is_empty() {
92                    return Err(StorageError::Validation(format!(
93                        "thread '{thread_id}' has child threads; choose 'detach' or 'cascade'"
94                    )));
95                }
96                self.delete_thread(thread_id).await
97            }
98            ChildThreadDeleteStrategy::Detach => {
99                let mut children = self.list_child_threads(thread_id).await?;
100                let updated_at = crate::now_ms();
101                for child in &mut children {
102                    child.parent_thread_id = None;
103                    child.metadata.updated_at = Some(updated_at);
104                    self.save_thread(child).await?;
105                }
106                self.delete_thread(thread_id).await
107            }
108            ChildThreadDeleteStrategy::Cascade => {
109                let mut visited = std::collections::HashSet::new();
110                let mut stack = vec![(thread_id.to_owned(), false)];
111                let mut delete_order = Vec::new();
112
113                while let Some((current_thread_id, expanded)) = stack.pop() {
114                    if expanded {
115                        delete_order.push(current_thread_id);
116                        continue;
117                    }
118
119                    if !visited.insert(current_thread_id.clone()) {
120                        return Err(StorageError::Validation(format!(
121                            "thread hierarchy cycle detected while deleting '{thread_id}'"
122                        )));
123                    }
124
125                    stack.push((current_thread_id.clone(), true));
126                    let mut children = self.list_child_threads(&current_thread_id).await?;
127                    children.sort_by(|left, right| left.id.cmp(&right.id));
128                    for child in children.into_iter().rev() {
129                        stack.push((child.id, false));
130                    }
131                }
132
133                for id in delete_order {
134                    self.delete_thread(&id).await?;
135                }
136                Ok(())
137            }
138        }
139    }
140
141    /// List thread IDs with pagination.
142    async fn list_threads(&self, offset: usize, limit: usize) -> Result<Vec<String>, StorageError>;
143
144    /// List thread IDs with first-class filters and page metadata.
145    async fn list_threads_query(&self, query: &ThreadQuery) -> Result<ThreadPage, StorageError> {
146        const SCAN_LIMIT: usize = 200;
147
148        let mut offset = 0;
149        let mut threads = Vec::new();
150        loop {
151            let ids = self.list_threads(offset, SCAN_LIMIT).await?;
152            if ids.is_empty() {
153                break;
154            }
155            let count = ids.len();
156            for id in ids {
157                if let Some(thread) = self.load_thread(&id).await? {
158                    threads.push(thread);
159                }
160            }
161            if count < SCAN_LIMIT {
162                break;
163            }
164            offset += count;
165        }
166
167        Ok(paginate_threads(threads, query))
168    }
169
170    /// Load all direct child threads for a given parent thread.
171    async fn list_child_threads(
172        &self,
173        parent_thread_id: &str,
174    ) -> Result<Vec<Thread>, StorageError> {
175        const PAGE_LIMIT: usize = 200;
176
177        let mut offset = 0;
178        let mut children = Vec::new();
179        loop {
180            let query = ThreadQuery {
181                offset,
182                limit: PAGE_LIMIT,
183                resource_id: None,
184                parent_filter: ThreadParentFilter::Parent(parent_thread_id.to_owned()),
185                id_prefix: None,
186            };
187            let page = self.list_threads_query(&query).await?;
188            let count = page.items.len();
189            for id in page.items {
190                if let Some(thread) = self.load_thread(&id).await? {
191                    children.push(thread);
192                }
193            }
194            if !page.has_more || count == 0 {
195                break;
196            }
197            offset = page
198                .next_cursor
199                .as_deref()
200                .and_then(|cursor| query.decode_cursor(cursor).ok())
201                .unwrap_or(offset.saturating_add(count));
202        }
203        Ok(children)
204    }
205
206    /// Validate parent-child hierarchy invariants for a thread.
207    async fn validate_thread_hierarchy(
208        &self,
209        thread_id: &str,
210        parent_thread_id: Option<&str>,
211    ) -> Result<(), StorageError> {
212        let Some(parent_thread_id) = normalize_lineage_id(parent_thread_id) else {
213            return Ok(());
214        };
215        if parent_thread_id == thread_id {
216            return Err(StorageError::Validation(format!(
217                "thread '{thread_id}' cannot parent itself"
218            )));
219        }
220
221        let root_parent_thread_id = parent_thread_id.to_owned();
222        let mut current_thread_id = root_parent_thread_id.clone();
223        let mut visited = std::collections::HashSet::from([thread_id.to_owned()]);
224
225        loop {
226            if !visited.insert(current_thread_id.clone()) {
227                return Err(StorageError::Validation(format!(
228                    "thread hierarchy cycle detected at '{current_thread_id}'"
229                )));
230            }
231
232            let Some(thread) = self.load_thread(&current_thread_id).await? else {
233                let message = if current_thread_id == root_parent_thread_id {
234                    format!("parent thread not found: {root_parent_thread_id}")
235                } else {
236                    format!("thread hierarchy references missing ancestor '{current_thread_id}'")
237                };
238                return Err(StorageError::Validation(message));
239            };
240
241            let Some(next_parent_thread_id) =
242                normalize_lineage_id(thread.parent_thread_id.as_deref())
243            else {
244                return Ok(());
245            };
246            current_thread_id = next_parent_thread_id;
247        }
248    }
249
250    async fn load_messages(&self, thread_id: &str) -> Result<Option<Vec<Message>>, StorageError>;
251
252    async fn load_committed_messages(
253        &self,
254        thread_id: &str,
255    ) -> Result<Option<Vec<Message>>, StorageError> {
256        self.load_messages(thread_id).await
257    }
258
259    async fn load_message_records(
260        &self,
261        thread_id: &str,
262    ) -> Result<Option<Vec<MessageRecord>>, StorageError> {
263        let Some(messages) = self.load_messages(thread_id).await? else {
264            return Ok(None);
265        };
266        Ok(Some(
267            messages
268                .into_iter()
269                .enumerate()
270                .map(|(index, message)| {
271                    MessageRecord::from_message(thread_id.to_string(), index as u64 + 1, message)
272                })
273                .collect(),
274        ))
275    }
276
277    /// List thread-owned message records with filtering and page metadata.
278    async fn list_message_records(
279        &self,
280        thread_id: &str,
281        query: &MessageQuery,
282    ) -> Result<MessagePage, StorageError> {
283        let Some(records) = self.load_message_records(thread_id).await? else {
284            return Ok(MessagePage::empty());
285        };
286        Ok(paginate_message_records(records, query))
287    }
288
289    /// Append messages to a thread's durable log and return their records.
290    async fn append_message_records(
291        &self,
292        thread_id: &str,
293        messages: &[Message],
294    ) -> Result<Vec<MessageRecord>, StorageError> {
295        let mut existing = self
296            .load_committed_messages(thread_id)
297            .await?
298            .unwrap_or_default();
299        message_append::validate_append_only_delta(&existing, messages)?;
300        let start_seq = existing.len() as u64 + 1;
301        existing.extend(messages.iter().cloned());
302        self.save_messages(thread_id, &existing).await?;
303        Ok(messages
304            .iter()
305            .cloned()
306            .enumerate()
307            .map(|(index, message)| {
308                MessageRecord::from_message(
309                    thread_id.to_string(),
310                    start_seq + index as u64,
311                    message,
312                )
313            })
314            .collect())
315    }
316
317    /// Load one message record by message ID.
318    async fn load_message_record(
319        &self,
320        thread_id: &str,
321        message_id: &str,
322    ) -> Result<Option<MessageRecord>, StorageError> {
323        let Some(records) = self.load_message_records(thread_id).await? else {
324            return Ok(None);
325        };
326        Ok(records
327            .into_iter()
328            .find(|record| record.message_id == message_id))
329    }
330
331    /// Load message records by inclusive sequence range.
332    async fn load_message_records_range(
333        &self,
334        thread_id: &str,
335        range: MessageSeqRange,
336    ) -> Result<Vec<MessageRecord>, StorageError> {
337        let Some(records) = self.load_message_records(thread_id).await? else {
338            return Ok(Vec::new());
339        };
340        Ok(records
341            .into_iter()
342            .filter(|record| record.seq >= range.from_seq && record.seq <= range.to_seq)
343            .collect())
344    }
345
346    /// Persist messages for a thread (full overwrite).
347    async fn save_messages(
348        &self,
349        thread_id: &str,
350        messages: &[Message],
351    ) -> Result<(), StorageError>;
352
353    /// Delete all messages for a thread. Returns `NotFound` if the thread does not exist.
354    async fn delete_messages(&self, thread_id: &str) -> Result<(), StorageError>;
355
356    /// Update only the metadata of an existing thread.
357    /// Returns `NotFound` if the thread does not exist.
358    async fn update_thread_metadata(
359        &self,
360        id: &str,
361        metadata: crate::thread::ThreadMetadata,
362    ) -> Result<(), StorageError>;
363}
364
365// ── RunStore ────────────────────────────────────────────────────────
366
367/// Run record persistence.
368#[async_trait]
369pub trait RunStore: Send + Sync {
370    /// Create a new run record.
371    async fn create_run(&self, record: &RunRecord) -> Result<(), StorageError>;
372
373    /// Load a run record by `run_id`.
374    async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, StorageError>;
375
376    /// Find the latest run for a thread (by `updated_at`).
377    async fn latest_run(&self, thread_id: &str) -> Result<Option<RunRecord>, StorageError>;
378
379    /// List runs with optional filtering and pagination.
380    async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, StorageError>;
381}
382
383// ── ThreadRunStore (convenience) ────────────────────────────────────
384
385/// Atomic thread+run checkpoint persistence. ADR-0038 D7: prefer
386/// [`CommitCoordinator::commit_checkpoint`](super::commit_coordinator::CommitCoordinator::commit_checkpoint)
387/// for production writes; `checkpoint` is retained for conformance tests
388/// and coordinator-internal use.
389#[async_trait]
390pub trait ThreadRunStore: ThreadStore + RunStore + Send + Sync {
391    /// Return an identity for the backing thread/run store, when the
392    /// implementation can prove it. This is intentionally narrower than a
393    /// coordinator transaction scope: it only identifies the thread/run read
394    /// and write backend used by mailbox/server code.
395    fn thread_run_storage_identity(&self) -> Option<String> {
396        None
397    }
398
399    #[deprecated(since = "0.6.0", note = "use CommitCoordinator (ADR-0038 D7)")]
400    async fn checkpoint(
401        &self,
402        thread_id: &str,
403        messages: &[Message],
404        run: &RunRecord,
405    ) -> Result<(), StorageError>;
406
407    /// Append to the committed log and persist `run`, guarded by message count.
408    #[allow(deprecated)]
409    async fn checkpoint_append(
410        &self,
411        thread_id: &str,
412        messages: &[Message],
413        expected_version: Option<u64>,
414        run: &RunRecord,
415    ) -> Result<u64, StorageError> {
416        let existing = self
417            .load_committed_messages(thread_id)
418            .await?
419            .unwrap_or_default();
420        let actual = existing.len() as u64;
421        if let Some(expected) = expected_version
422            && expected != actual
423        {
424            return Err(StorageError::VersionConflict { expected, actual });
425        }
426        let mut merged = existing;
427        message_append::merge_checkpoint_append_messages(&mut merged, messages)?;
428        let new_version = merged.len() as u64;
429        self.checkpoint(thread_id, &merged, run).await?;
430        Ok(new_version)
431    }
432
433    /// Read a consistent [`CheckpointSnapshot`] for resume (ADR-0038 C5).
434    ///
435    /// The default composes the committed-message, latest-run, and
436    /// thread-state reads and applies the committed-history view filter.
437    /// Backends that can read atomically (a transaction or lock spanning all
438    /// three) override this to avoid torn reads against a concurrent commit.
439    async fn load_checkpoint(
440        &self,
441        thread_id: &str,
442    ) -> Result<Option<CheckpointSnapshot>, StorageError> {
443        let committed = ThreadStore::load_committed_messages(self, thread_id).await?;
444        let latest_run = RunStore::latest_run(self, thread_id).await?;
445        if committed.is_none() && latest_run.is_none() {
446            return Ok(None);
447        }
448        let raw = committed.unwrap_or_default();
449        let message_version = raw.len() as u64;
450        let messages =
451            awaken_runtime_contract::contract::message::effective_committed_view(raw, thread_id);
452        let thread_state = ThreadStore::load_thread_state(self, thread_id).await?;
453        Ok(Some(CheckpointSnapshot {
454            messages,
455            message_version,
456            latest_run,
457            thread_state,
458        }))
459    }
460}
461
462/// Adapts a [`ThreadRunStore`] into a [`RuntimeCheckpointStore`] for the agent
463/// loop. Exposed so embedders/tests can supply a checkpoint reader backed by
464/// any `ThreadRunStore`.
465pub struct ThreadRunCheckpointStore {
466    inner: std::sync::Arc<dyn ThreadRunStore>,
467}
468
469impl ThreadRunCheckpointStore {
470    pub fn new(inner: std::sync::Arc<dyn ThreadRunStore>) -> Self {
471        Self { inner }
472    }
473}
474
475#[async_trait]
476impl RuntimeCheckpointStore for ThreadRunCheckpointStore {
477    async fn load_thread(&self, thread_id: &str) -> Result<Option<Thread>, StorageError> {
478        ThreadStore::load_thread(self.inner.as_ref(), thread_id).await
479    }
480
481    async fn load_messages(&self, thread_id: &str) -> Result<Option<Vec<Message>>, StorageError> {
482        ThreadStore::load_messages(self.inner.as_ref(), thread_id).await
483    }
484
485    async fn load_committed_messages(
486        &self,
487        thread_id: &str,
488    ) -> Result<Option<Vec<Message>>, StorageError> {
489        ThreadStore::load_committed_messages(self.inner.as_ref(), thread_id).await
490    }
491
492    async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, StorageError> {
493        RunStore::load_run(self.inner.as_ref(), run_id).await
494    }
495
496    async fn latest_run(&self, thread_id: &str) -> Result<Option<RunRecord>, StorageError> {
497        RunStore::latest_run(self.inner.as_ref(), thread_id).await
498    }
499
500    async fn load_thread_state(
501        &self,
502        thread_id: &str,
503    ) -> Result<Option<awaken_runtime_contract::state::PersistedState>, StorageError> {
504        ThreadStore::load_thread_state(self.inner.as_ref(), thread_id).await
505    }
506
507    async fn load_checkpoint(
508        &self,
509        thread_id: &str,
510    ) -> Result<Option<CheckpointSnapshot>, StorageError> {
511        // Delegate to the store's (possibly atomic) consistent read.
512        ThreadRunStore::load_checkpoint(self.inner.as_ref(), thread_id).await
513    }
514}
515
516#[cfg(test)]
517#[path = "store_traits_tests.rs"]
518mod tests;