awaken_server_contract/contract/
store_traits.rs1use crate::contract::storage::*;
11use async_trait::async_trait;
12use awaken_runtime_contract::contract::message::{Message, MessageRecord};
13use awaken_runtime_contract::thread::{Thread, normalize_lineage_id};
14
15#[async_trait]
22pub trait ThreadStore: Send + Sync {
23 async fn load_thread(&self, thread_id: &str) -> Result<Option<Thread>, StorageError>;
25
26 async fn save_thread(&self, thread: &Thread) -> Result<(), StorageError>;
32
33 async fn save_thread_validated(&self, thread: &Thread) -> Result<(), StorageError> {
39 self.validate_thread_hierarchy(&thread.id, thread.parent_thread_id.as_deref())
40 .await?;
41 self.save_thread(thread).await
42 }
43
44 async fn delete_thread(&self, thread_id: &str) -> Result<(), StorageError>;
49
50 async fn save_thread_state(
56 &self,
57 thread_id: &str,
58 state: &awaken_runtime_contract::state::PersistedState,
59 ) -> Result<(), StorageError> {
60 let _ = (thread_id, state);
61 Ok(())
62 }
63
64 async fn load_thread_state(
66 &self,
67 thread_id: &str,
68 ) -> Result<Option<awaken_runtime_contract::state::PersistedState>, StorageError> {
69 let _ = thread_id;
70 Ok(None)
71 }
72
73 async fn delete_thread_with_strategy(
80 &self,
81 thread_id: &str,
82 strategy: ChildThreadDeleteStrategy,
83 ) -> Result<(), StorageError> {
84 if self.load_thread(thread_id).await?.is_none() {
85 return Err(StorageError::NotFound(thread_id.to_owned()));
86 }
87
88 match strategy {
89 ChildThreadDeleteStrategy::Reject => {
90 let children = self.list_child_threads(thread_id).await?;
91 if !children.is_empty() {
92 return Err(StorageError::Validation(format!(
93 "thread '{thread_id}' has child threads; choose 'detach' or 'cascade'"
94 )));
95 }
96 self.delete_thread(thread_id).await
97 }
98 ChildThreadDeleteStrategy::Detach => {
99 let mut children = self.list_child_threads(thread_id).await?;
100 let updated_at = crate::now_ms();
101 for child in &mut children {
102 child.parent_thread_id = None;
103 child.metadata.updated_at = Some(updated_at);
104 self.save_thread(child).await?;
105 }
106 self.delete_thread(thread_id).await
107 }
108 ChildThreadDeleteStrategy::Cascade => {
109 let mut visited = std::collections::HashSet::new();
110 let mut stack = vec![(thread_id.to_owned(), false)];
111 let mut delete_order = Vec::new();
112
113 while let Some((current_thread_id, expanded)) = stack.pop() {
114 if expanded {
115 delete_order.push(current_thread_id);
116 continue;
117 }
118
119 if !visited.insert(current_thread_id.clone()) {
120 return Err(StorageError::Validation(format!(
121 "thread hierarchy cycle detected while deleting '{thread_id}'"
122 )));
123 }
124
125 stack.push((current_thread_id.clone(), true));
126 let mut children = self.list_child_threads(¤t_thread_id).await?;
127 children.sort_by(|left, right| left.id.cmp(&right.id));
128 for child in children.into_iter().rev() {
129 stack.push((child.id, false));
130 }
131 }
132
133 for id in delete_order {
134 self.delete_thread(&id).await?;
135 }
136 Ok(())
137 }
138 }
139 }
140
141 async fn list_threads(&self, offset: usize, limit: usize) -> Result<Vec<String>, StorageError>;
143
144 async fn list_threads_query(&self, query: &ThreadQuery) -> Result<ThreadPage, StorageError> {
146 const SCAN_LIMIT: usize = 200;
147
148 let mut offset = 0;
149 let mut threads = Vec::new();
150 loop {
151 let ids = self.list_threads(offset, SCAN_LIMIT).await?;
152 if ids.is_empty() {
153 break;
154 }
155 let count = ids.len();
156 for id in ids {
157 if let Some(thread) = self.load_thread(&id).await? {
158 threads.push(thread);
159 }
160 }
161 if count < SCAN_LIMIT {
162 break;
163 }
164 offset += count;
165 }
166
167 Ok(paginate_threads(threads, query))
168 }
169
170 async fn list_child_threads(
172 &self,
173 parent_thread_id: &str,
174 ) -> Result<Vec<Thread>, StorageError> {
175 const PAGE_LIMIT: usize = 200;
176
177 let mut offset = 0;
178 let mut children = Vec::new();
179 loop {
180 let query = ThreadQuery {
181 offset,
182 limit: PAGE_LIMIT,
183 resource_id: None,
184 parent_filter: ThreadParentFilter::Parent(parent_thread_id.to_owned()),
185 id_prefix: None,
186 };
187 let page = self.list_threads_query(&query).await?;
188 let count = page.items.len();
189 for id in page.items {
190 if let Some(thread) = self.load_thread(&id).await? {
191 children.push(thread);
192 }
193 }
194 if !page.has_more || count == 0 {
195 break;
196 }
197 offset = page
198 .next_cursor
199 .as_deref()
200 .and_then(|cursor| query.decode_cursor(cursor).ok())
201 .unwrap_or(offset.saturating_add(count));
202 }
203 Ok(children)
204 }
205
206 async fn validate_thread_hierarchy(
208 &self,
209 thread_id: &str,
210 parent_thread_id: Option<&str>,
211 ) -> Result<(), StorageError> {
212 let Some(parent_thread_id) = normalize_lineage_id(parent_thread_id) else {
213 return Ok(());
214 };
215 if parent_thread_id == thread_id {
216 return Err(StorageError::Validation(format!(
217 "thread '{thread_id}' cannot parent itself"
218 )));
219 }
220
221 let root_parent_thread_id = parent_thread_id.to_owned();
222 let mut current_thread_id = root_parent_thread_id.clone();
223 let mut visited = std::collections::HashSet::from([thread_id.to_owned()]);
224
225 loop {
226 if !visited.insert(current_thread_id.clone()) {
227 return Err(StorageError::Validation(format!(
228 "thread hierarchy cycle detected at '{current_thread_id}'"
229 )));
230 }
231
232 let Some(thread) = self.load_thread(¤t_thread_id).await? else {
233 let message = if current_thread_id == root_parent_thread_id {
234 format!("parent thread not found: {root_parent_thread_id}")
235 } else {
236 format!("thread hierarchy references missing ancestor '{current_thread_id}'")
237 };
238 return Err(StorageError::Validation(message));
239 };
240
241 let Some(next_parent_thread_id) =
242 normalize_lineage_id(thread.parent_thread_id.as_deref())
243 else {
244 return Ok(());
245 };
246 current_thread_id = next_parent_thread_id;
247 }
248 }
249
250 async fn load_messages(&self, thread_id: &str) -> Result<Option<Vec<Message>>, StorageError>;
251
252 async fn load_committed_messages(
253 &self,
254 thread_id: &str,
255 ) -> Result<Option<Vec<Message>>, StorageError> {
256 self.load_messages(thread_id).await
257 }
258
259 async fn load_message_records(
260 &self,
261 thread_id: &str,
262 ) -> Result<Option<Vec<MessageRecord>>, StorageError> {
263 let Some(messages) = self.load_messages(thread_id).await? else {
264 return Ok(None);
265 };
266 Ok(Some(
267 messages
268 .into_iter()
269 .enumerate()
270 .map(|(index, message)| {
271 MessageRecord::from_message(thread_id.to_string(), index as u64 + 1, message)
272 })
273 .collect(),
274 ))
275 }
276
277 async fn list_message_records(
279 &self,
280 thread_id: &str,
281 query: &MessageQuery,
282 ) -> Result<MessagePage, StorageError> {
283 let Some(records) = self.load_message_records(thread_id).await? else {
284 return Ok(MessagePage::empty());
285 };
286 Ok(paginate_message_records(records, query))
287 }
288
289 async fn append_message_records(
291 &self,
292 thread_id: &str,
293 messages: &[Message],
294 ) -> Result<Vec<MessageRecord>, StorageError> {
295 let mut existing = self
296 .load_committed_messages(thread_id)
297 .await?
298 .unwrap_or_default();
299 message_append::validate_append_only_delta(&existing, messages)?;
300 let start_seq = existing.len() as u64 + 1;
301 existing.extend(messages.iter().cloned());
302 self.save_messages(thread_id, &existing).await?;
303 Ok(messages
304 .iter()
305 .cloned()
306 .enumerate()
307 .map(|(index, message)| {
308 MessageRecord::from_message(
309 thread_id.to_string(),
310 start_seq + index as u64,
311 message,
312 )
313 })
314 .collect())
315 }
316
317 async fn load_message_record(
319 &self,
320 thread_id: &str,
321 message_id: &str,
322 ) -> Result<Option<MessageRecord>, StorageError> {
323 let Some(records) = self.load_message_records(thread_id).await? else {
324 return Ok(None);
325 };
326 Ok(records
327 .into_iter()
328 .find(|record| record.message_id == message_id))
329 }
330
331 async fn load_message_records_range(
333 &self,
334 thread_id: &str,
335 range: MessageSeqRange,
336 ) -> Result<Vec<MessageRecord>, StorageError> {
337 let Some(records) = self.load_message_records(thread_id).await? else {
338 return Ok(Vec::new());
339 };
340 Ok(records
341 .into_iter()
342 .filter(|record| record.seq >= range.from_seq && record.seq <= range.to_seq)
343 .collect())
344 }
345
346 async fn save_messages(
348 &self,
349 thread_id: &str,
350 messages: &[Message],
351 ) -> Result<(), StorageError>;
352
353 async fn delete_messages(&self, thread_id: &str) -> Result<(), StorageError>;
355
356 async fn update_thread_metadata(
359 &self,
360 id: &str,
361 metadata: crate::thread::ThreadMetadata,
362 ) -> Result<(), StorageError>;
363}
364
365#[async_trait]
369pub trait RunStore: Send + Sync {
370 async fn create_run(&self, record: &RunRecord) -> Result<(), StorageError>;
372
373 async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, StorageError>;
375
376 async fn latest_run(&self, thread_id: &str) -> Result<Option<RunRecord>, StorageError>;
378
379 async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, StorageError>;
381}
382
383#[async_trait]
390pub trait ThreadRunStore: ThreadStore + RunStore + Send + Sync {
391 fn thread_run_storage_identity(&self) -> Option<String> {
396 None
397 }
398
399 #[deprecated(since = "0.6.0", note = "use CommitCoordinator (ADR-0038 D7)")]
400 async fn checkpoint(
401 &self,
402 thread_id: &str,
403 messages: &[Message],
404 run: &RunRecord,
405 ) -> Result<(), StorageError>;
406
407 #[allow(deprecated)]
409 async fn checkpoint_append(
410 &self,
411 thread_id: &str,
412 messages: &[Message],
413 expected_version: Option<u64>,
414 run: &RunRecord,
415 ) -> Result<u64, StorageError> {
416 let existing = self
417 .load_committed_messages(thread_id)
418 .await?
419 .unwrap_or_default();
420 let actual = existing.len() as u64;
421 if let Some(expected) = expected_version
422 && expected != actual
423 {
424 return Err(StorageError::VersionConflict { expected, actual });
425 }
426 let mut merged = existing;
427 message_append::merge_checkpoint_append_messages(&mut merged, messages)?;
428 let new_version = merged.len() as u64;
429 self.checkpoint(thread_id, &merged, run).await?;
430 Ok(new_version)
431 }
432
433 async fn load_checkpoint(
440 &self,
441 thread_id: &str,
442 ) -> Result<Option<CheckpointSnapshot>, StorageError> {
443 let committed = ThreadStore::load_committed_messages(self, thread_id).await?;
444 let latest_run = RunStore::latest_run(self, thread_id).await?;
445 if committed.is_none() && latest_run.is_none() {
446 return Ok(None);
447 }
448 let raw = committed.unwrap_or_default();
449 let message_version = raw.len() as u64;
450 let messages =
451 awaken_runtime_contract::contract::message::effective_committed_view(raw, thread_id);
452 let thread_state = ThreadStore::load_thread_state(self, thread_id).await?;
453 Ok(Some(CheckpointSnapshot {
454 messages,
455 message_version,
456 latest_run,
457 thread_state,
458 }))
459 }
460}
461
462pub struct ThreadRunCheckpointStore {
466 inner: std::sync::Arc<dyn ThreadRunStore>,
467}
468
469impl ThreadRunCheckpointStore {
470 pub fn new(inner: std::sync::Arc<dyn ThreadRunStore>) -> Self {
471 Self { inner }
472 }
473}
474
475#[async_trait]
476impl RuntimeCheckpointStore for ThreadRunCheckpointStore {
477 async fn load_thread(&self, thread_id: &str) -> Result<Option<Thread>, StorageError> {
478 ThreadStore::load_thread(self.inner.as_ref(), thread_id).await
479 }
480
481 async fn load_messages(&self, thread_id: &str) -> Result<Option<Vec<Message>>, StorageError> {
482 ThreadStore::load_messages(self.inner.as_ref(), thread_id).await
483 }
484
485 async fn load_committed_messages(
486 &self,
487 thread_id: &str,
488 ) -> Result<Option<Vec<Message>>, StorageError> {
489 ThreadStore::load_committed_messages(self.inner.as_ref(), thread_id).await
490 }
491
492 async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, StorageError> {
493 RunStore::load_run(self.inner.as_ref(), run_id).await
494 }
495
496 async fn latest_run(&self, thread_id: &str) -> Result<Option<RunRecord>, StorageError> {
497 RunStore::latest_run(self.inner.as_ref(), thread_id).await
498 }
499
500 async fn load_thread_state(
501 &self,
502 thread_id: &str,
503 ) -> Result<Option<awaken_runtime_contract::state::PersistedState>, StorageError> {
504 ThreadStore::load_thread_state(self.inner.as_ref(), thread_id).await
505 }
506
507 async fn load_checkpoint(
508 &self,
509 thread_id: &str,
510 ) -> Result<Option<CheckpointSnapshot>, StorageError> {
511 ThreadRunStore::load_checkpoint(self.inner.as_ref(), thread_id).await
513 }
514}
515
516#[cfg(test)]
517#[path = "store_traits_tests.rs"]
518mod tests;