1pub use awaken_runtime_contract::contract::storage::{
12 CheckpointSnapshot, MessageSeqRange, RunMessageInput, RunMessageOutput, RunOutcome, RunRecord,
13 RunRequestOrigin, RunRequestSnapshot, RunResumeDecision, RunWaitingState, RunWaitingTicket,
14 RuntimeCheckpointStore, StorageError, WaitingReason, message_append,
15};
16pub use super::store_traits::{RunStore, ThreadRunCheckpointStore, ThreadRunStore, ThreadStore};
18
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use awaken_runtime_contract::contract::lifecycle::RunStatus;
23use awaken_runtime_contract::contract::message::{Message, MessageRecord, Visibility};
24use awaken_runtime_contract::thread::{Thread, ThreadMetadata, normalize_lineage_id};
25use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
26use serde::de::DeserializeOwned;
27use serde::{Deserialize, Serialize};
28
29use crate::contract::scope::{ScopeId, scoped_key, unscoped_key};
30
31const MESSAGE_CURSOR_PREFIX: &str = "msg_";
32const THREAD_CURSOR_PREFIX: &str = "thr_";
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub struct MessageQuery {
39 pub offset: usize,
41 pub limit: usize,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub after: Option<u64>,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub before: Option<u64>,
49 #[serde(default)]
51 pub order: MessageOrder,
52 #[serde(default)]
54 pub visibility: MessageVisibilityFilter,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub run_id: Option<String>,
58}
59
60impl Default for MessageQuery {
61 fn default() -> Self {
62 Self {
63 offset: 0,
64 limit: 50,
65 after: None,
66 before: None,
67 order: MessageOrder::Asc,
68 visibility: MessageVisibilityFilter::Any,
69 run_id: None,
70 }
71 }
72}
73
74impl MessageQuery {
75 #[must_use]
77 pub fn normalized(&self) -> Self {
78 Self {
79 offset: self.offset,
80 limit: self.limit.min(200),
81 after: self.after,
82 before: self.before,
83 order: self.order,
84 visibility: self.visibility,
85 run_id: self.run_id.clone(),
86 }
87 }
88
89 #[must_use]
91 pub fn encode_cursor(&self, offset: usize) -> String {
92 let normalized = self.normalized();
93 encode_cursor_token(
94 MESSAGE_CURSOR_PREFIX,
95 &MessageCursorToken {
96 offset,
97 after: normalized.after,
98 before: normalized.before,
99 order: normalized.order,
100 visibility: normalized.visibility,
101 run_id: normalized.run_id,
102 },
103 )
104 }
105
106 pub fn decode_cursor(&self, cursor: &str) -> Result<usize, String> {
108 if let Ok(offset) = cursor.parse::<usize>() {
109 return Ok(offset);
110 }
111
112 let normalized = self.normalized();
113 let token: MessageCursorToken = decode_cursor_token(MESSAGE_CURSOR_PREFIX, cursor)?;
114 if token.after != normalized.after
115 || token.before != normalized.before
116 || token.order != normalized.order
117 || token.visibility != normalized.visibility
118 || token.run_id != normalized.run_id
119 {
120 return Err("cursor does not match message query filters".to_string());
121 }
122 Ok(token.offset)
123 }
124
125 #[must_use]
127 pub fn matches_record(&self, record: &MessageRecord) -> bool {
128 if self.after.is_some_and(|after| record.seq <= after) {
129 return false;
130 }
131 if self.before.is_some_and(|before| record.seq >= before) {
132 return false;
133 }
134 if self
135 .run_id
136 .as_deref()
137 .is_some_and(|run_id| record.produced_by_run_id.as_deref() != Some(run_id))
138 {
139 return false;
140 }
141 match self.visibility {
142 MessageVisibilityFilter::Any => true,
143 MessageVisibilityFilter::External => record.message.visibility != Visibility::Internal,
144 MessageVisibilityFilter::Internal => record.message.visibility == Visibility::Internal,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
151#[serde(rename_all = "snake_case")]
152pub enum MessageOrder {
153 #[default]
155 Asc,
156 Desc,
158}
159
160#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(rename_all = "snake_case")]
163pub enum MessageVisibilityFilter {
164 #[default]
166 Any,
167 External,
169 Internal,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct MessagePage {
176 pub records: Vec<MessageRecord>,
177 pub total: usize,
178 pub has_more: bool,
179 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub next_cursor: Option<String>,
181 #[serde(default, skip_serializing_if = "Option::is_none")]
182 pub prev_cursor: Option<String>,
183}
184
185impl MessagePage {
186 #[must_use]
188 pub fn empty() -> Self {
189 Self {
190 records: Vec::new(),
191 total: 0,
192 has_more: false,
193 next_cursor: None,
194 prev_cursor: None,
195 }
196 }
197}
198
199#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
201#[serde(rename_all = "snake_case")]
202pub enum ThreadParentFilter {
203 #[default]
205 Any,
206 Root,
208 Parent(String),
210}
211
212impl ThreadParentFilter {
213 #[must_use]
214 pub fn is_any(&self) -> bool {
215 matches!(self, Self::Any)
216 }
217
218 #[must_use]
219 pub fn normalized(&self) -> Self {
220 match self {
221 Self::Any => Self::Any,
222 Self::Root => Self::Root,
223 Self::Parent(parent_thread_id) => normalize_lineage_id(Some(parent_thread_id))
224 .map(Self::Parent)
225 .unwrap_or(Self::Any),
226 }
227 }
228}
229
230#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
231struct MessageCursorToken {
232 offset: usize,
233 after: Option<u64>,
234 before: Option<u64>,
235 order: MessageOrder,
236 visibility: MessageVisibilityFilter,
237 run_id: Option<String>,
238}
239
240#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
241struct ThreadCursorToken {
242 offset: usize,
243 resource_id: Option<String>,
244 parent_filter: ThreadParentFilter,
245 id_prefix: Option<String>,
246}
247
248#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
250pub struct ThreadQuery {
251 pub offset: usize,
253 pub limit: usize,
255 #[serde(default, skip_serializing_if = "Option::is_none")]
257 pub resource_id: Option<String>,
258 #[serde(default, skip_serializing_if = "ThreadParentFilter::is_any")]
260 pub parent_filter: ThreadParentFilter,
261 #[serde(default, skip_serializing_if = "Option::is_none")]
267 pub id_prefix: Option<String>,
268}
269
270impl Default for ThreadQuery {
271 fn default() -> Self {
272 Self {
273 offset: 0,
274 limit: 50,
275 resource_id: None,
276 parent_filter: ThreadParentFilter::Any,
277 id_prefix: None,
278 }
279 }
280}
281
282impl ThreadQuery {
283 #[must_use]
285 pub fn has_filters(&self) -> bool {
286 normalize_lineage_id(self.resource_id.as_deref()).is_some()
287 || !self.parent_filter.is_any()
288 || self.id_prefix.is_some()
289 }
290
291 #[must_use]
293 pub fn normalized(&self) -> Self {
294 Self {
295 offset: self.offset,
296 limit: self.limit.min(200),
297 resource_id: normalize_lineage_id(self.resource_id.as_deref()),
298 parent_filter: self.parent_filter.normalized(),
299 id_prefix: self.id_prefix.clone(),
300 }
301 }
302
303 #[must_use]
305 pub fn encode_cursor(&self, offset: usize) -> String {
306 let normalized = self.normalized();
307 encode_cursor_token(
308 THREAD_CURSOR_PREFIX,
309 &ThreadCursorToken {
310 offset,
311 resource_id: normalized.resource_id,
312 parent_filter: normalized.parent_filter,
313 id_prefix: normalized.id_prefix,
314 },
315 )
316 }
317
318 pub fn decode_cursor(&self, cursor: &str) -> Result<usize, String> {
320 let normalized = self.normalized();
321 if let Ok(offset) = cursor.parse::<usize>() {
322 return if normalized.has_filters() {
323 Err("cursor does not match thread query filters".to_string())
324 } else {
325 Ok(offset)
326 };
327 }
328
329 let token: ThreadCursorToken = decode_cursor_token(THREAD_CURSOR_PREFIX, cursor)?;
330 if token.resource_id != normalized.resource_id
331 || token.parent_filter != normalized.parent_filter
332 || token.id_prefix != normalized.id_prefix
333 {
334 return Err("cursor does not match thread query filters".to_string());
335 }
336 Ok(token.offset)
337 }
338
339 #[must_use]
341 pub fn matches_thread(&self, thread: &Thread) -> bool {
342 let normalized = self.normalized();
343 if normalized
344 .id_prefix
345 .as_deref()
346 .is_some_and(|prefix| !thread.id.starts_with(prefix))
347 {
348 return false;
349 }
350 if normalized
351 .resource_id
352 .as_deref()
353 .is_some_and(|resource_id| {
354 normalize_lineage_id(thread.resource_id.as_deref()).as_deref() != Some(resource_id)
355 })
356 {
357 return false;
358 }
359 match &normalized.parent_filter {
360 ThreadParentFilter::Any => {}
361 ThreadParentFilter::Root => {
362 if normalize_lineage_id(thread.parent_thread_id.as_deref()).is_some() {
363 return false;
364 }
365 }
366 ThreadParentFilter::Parent(parent_thread_id) => {
367 if normalize_lineage_id(thread.parent_thread_id.as_deref()).as_deref()
368 != Some(parent_thread_id.as_str())
369 {
370 return false;
371 }
372 }
373 }
374 true
375 }
376}
377
378#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
380pub struct ThreadPage {
381 pub items: Vec<String>,
382 pub total: usize,
383 pub has_more: bool,
384 #[serde(default, skip_serializing_if = "Option::is_none")]
385 pub next_cursor: Option<String>,
386 #[serde(default, skip_serializing_if = "Option::is_none")]
387 pub prev_cursor: Option<String>,
388}
389
390impl ThreadPage {
391 #[must_use]
393 pub fn empty() -> Self {
394 Self {
395 items: Vec::new(),
396 total: 0,
397 has_more: false,
398 next_cursor: None,
399 prev_cursor: None,
400 }
401 }
402}
403
404#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
406#[serde(rename_all = "snake_case")]
407pub enum ChildThreadDeleteStrategy {
408 Reject,
410 #[default]
412 Detach,
413 Cascade,
415}
416
417#[must_use]
419pub fn checkpoint_parent_thread_id<'a>(
420 existing_thread: Option<&'a Thread>,
421 run: &'a RunRecord,
422) -> Option<&'a str> {
423 existing_thread
424 .and_then(|thread| thread.parent_thread_id.as_deref())
425 .or_else(|| {
426 run.request
427 .as_ref()
428 .and_then(|request| request.parent_thread_id.as_deref())
429 })
430}
431
432pub fn sort_threads_by_recent_activity(threads: &mut [Thread]) {
434 threads.sort_by(|a, b| {
435 let a_updated = a.metadata.updated_at.or(a.metadata.created_at).unwrap_or(0);
436 let b_updated = b.metadata.updated_at.or(b.metadata.created_at).unwrap_or(0);
437 b_updated.cmp(&a_updated).then_with(|| a.id.cmp(&b.id))
438 });
439}
440
441#[must_use]
443pub fn paginate_threads(mut threads: Vec<Thread>, query: &ThreadQuery) -> ThreadPage {
444 let query = query.normalized();
445 sort_threads_by_recent_activity(&mut threads);
446 let filtered: Vec<Thread> = threads
447 .into_iter()
448 .filter(|thread| query.matches_thread(thread))
449 .collect();
450 let total = filtered.len();
451 let start = query.offset.min(total);
452 let items: Vec<String> = filtered
453 .into_iter()
454 .skip(start)
455 .take(query.limit)
456 .map(|thread| thread.id)
457 .collect();
458 let next_offset = start + items.len();
459 let has_more = query.limit > 0 && next_offset < total;
460 ThreadPage {
461 items,
462 total,
463 has_more,
464 next_cursor: has_more.then(|| query.encode_cursor(next_offset)),
465 prev_cursor: (query.limit > 0 && start > 0)
466 .then(|| query.encode_cursor(start.saturating_sub(query.limit))),
467 }
468}
469
470#[must_use]
472pub fn paginate_message_records(
473 mut records: Vec<MessageRecord>,
474 query: &MessageQuery,
475) -> MessagePage {
476 let query = query.normalized();
477 records.retain(|record| query.matches_record(record));
478 match query.order {
479 MessageOrder::Asc => records.sort_by_key(|record| record.seq),
480 MessageOrder::Desc => records.sort_by(|a, b| b.seq.cmp(&a.seq)),
481 }
482 let total = records.len();
483 let start = query.offset.min(total);
484 let page_records: Vec<MessageRecord> =
485 records.into_iter().skip(start).take(query.limit).collect();
486 let next_offset = start + page_records.len();
487 let has_more = query.limit > 0 && next_offset < total;
488 MessagePage {
489 records: page_records,
490 total,
491 has_more,
492 next_cursor: has_more.then(|| query.encode_cursor(next_offset)),
493 prev_cursor: (query.limit > 0 && start > 0)
494 .then(|| query.encode_cursor(start.saturating_sub(query.limit))),
495 }
496}
497
498fn encode_cursor_token<T: Serialize>(prefix: &str, token: &T) -> String {
499 let bytes = serde_json::to_vec(token).expect("cursor token serialization should succeed");
500 format!("{prefix}{}", URL_SAFE_NO_PAD.encode(bytes))
501}
502
503fn decode_cursor_token<T: DeserializeOwned>(prefix: &str, cursor: &str) -> Result<T, String> {
504 let payload = cursor
505 .strip_prefix(prefix)
506 .ok_or_else(|| "cursor must be a valid pagination token".to_string())?;
507 let decoded = URL_SAFE_NO_PAD
508 .decode(payload)
509 .map_err(|_| "cursor must be a valid pagination token".to_string())?;
510 serde_json::from_slice(&decoded)
511 .map_err(|_| "cursor must be a valid pagination token".to_string())
512}
513
514#[derive(Debug, Clone)]
516pub struct RunQuery {
517 pub offset: usize,
519 pub limit: usize,
521 pub thread_id: Option<String>,
523 pub status: Option<RunStatus>,
525 pub id_prefix: Option<String>,
530}
531
532impl RunQuery {
533 #[must_use]
535 pub fn matches_id_prefix(&self, thread_id: &str) -> bool {
536 self.id_prefix
537 .as_deref()
538 .is_none_or(|prefix| thread_id.starts_with(prefix))
539 }
540}
541
542impl Default for RunQuery {
543 fn default() -> Self {
544 Self {
545 offset: 0,
546 limit: 50,
547 thread_id: None,
548 status: None,
549 id_prefix: None,
550 }
551 }
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct RunPage {
557 pub items: Vec<RunRecord>,
558 pub total: usize,
559 pub has_more: bool,
560}
561
562#[derive(Clone)]
563pub struct ScopedThreadRunStore {
564 inner: Arc<dyn ThreadRunStore>,
565 scope_id: ScopeId,
566}
567
568impl ScopedThreadRunStore {
569 pub fn new(inner: Arc<dyn ThreadRunStore>, scope_id: ScopeId) -> Self {
570 Self { inner, scope_id }
571 }
572
573 pub fn scope_id(&self) -> &ScopeId {
574 &self.scope_id
575 }
576
577 pub fn inner(&self) -> &dyn ThreadRunStore {
578 self.inner.as_ref()
579 }
580
581 fn scoped(&self, id: &str) -> String {
582 scoped_key(&self.scope_id, id)
583 }
584
585 fn scope_prefix(&self) -> String {
588 scoped_key(&self.scope_id, "")
589 }
590
591 fn unscoped<'a>(&self, id: &'a str) -> Option<&'a str> {
592 unscoped_key(&self.scope_id, id)
593 }
594
595 fn encode_thread(&self, thread: &Thread) -> Thread {
596 let mut thread = thread.clone();
597 thread.id = self.scoped(&thread.id);
598 thread.parent_thread_id = thread.parent_thread_id.as_deref().map(|id| self.scoped(id));
599 thread
600 }
601
602 fn decode_thread(&self, mut thread: Thread) -> Option<Thread> {
603 thread.id = self.unscoped(&thread.id)?.to_string();
604 thread.parent_thread_id = match thread.parent_thread_id.as_deref() {
605 Some(id) => Some(self.unscoped(id)?.to_string()),
606 None => None,
607 };
608 Some(thread)
609 }
610
611 fn encode_run(&self, run: &RunRecord) -> RunRecord {
612 let mut run = run.clone();
613 run.run_id = self.scoped(&run.run_id);
614 run.thread_id = self.scoped(&run.thread_id);
615 run.parent_run_id = run.parent_run_id.as_deref().map(|id| self.scoped(id));
616 if let Some(input) = run.input.as_mut() {
617 input.thread_id = self.scoped(&input.thread_id);
618 }
619 if let Some(output) = run.output.as_mut() {
620 output.thread_id = self.scoped(&output.thread_id);
621 }
622 if let Some(request) = run.request.as_mut() {
623 request.parent_thread_id = request
624 .parent_thread_id
625 .as_deref()
626 .map(|id| self.scoped(id));
627 }
628 run
629 }
630
631 fn decode_run(&self, mut run: RunRecord) -> Option<RunRecord> {
632 run.run_id = self.unscoped(&run.run_id)?.to_string();
633 run.thread_id = self.unscoped(&run.thread_id)?.to_string();
634 run.parent_run_id = match run.parent_run_id.as_deref() {
635 Some(id) => Some(self.unscoped(id)?.to_string()),
636 None => None,
637 };
638 if let Some(input) = run.input.as_mut() {
639 input.thread_id = self.unscoped(&input.thread_id)?.to_string();
640 }
641 if let Some(output) = run.output.as_mut() {
642 output.thread_id = self.unscoped(&output.thread_id)?.to_string();
643 }
644 if let Some(request) = run.request.as_mut() {
645 request.parent_thread_id = match request.parent_thread_id.as_deref() {
646 Some(id) => Some(self.unscoped(id)?.to_string()),
647 None => None,
648 };
649 }
650 Some(run)
651 }
652
653 fn decode_message_record(&self, mut record: MessageRecord) -> Option<MessageRecord> {
654 record.thread_id = self.unscoped(&record.thread_id)?.to_string();
655 if let Some(run_id) = record.produced_by_run_id.as_deref()
656 && let Some(unscoped) = self.unscoped(run_id)
657 {
658 record.produced_by_run_id = Some(unscoped.to_string());
659 }
660 Some(record)
661 }
662
663 fn encode_message_query(&self, query: &MessageQuery) -> MessageQuery {
664 let mut query = query.clone();
665 query.run_id = query.run_id.as_deref().map(|id| self.scoped(id));
666 query
667 }
668}
669
670#[async_trait]
671impl ThreadStore for ScopedThreadRunStore {
672 async fn load_thread(&self, thread_id: &str) -> Result<Option<Thread>, StorageError> {
673 Ok(self
674 .inner
675 .load_thread(&self.scoped(thread_id))
676 .await?
677 .and_then(|thread| self.decode_thread(thread)))
678 }
679
680 async fn save_thread(&self, thread: &Thread) -> Result<(), StorageError> {
681 self.inner.save_thread(&self.encode_thread(thread)).await
682 }
683
684 async fn delete_thread(&self, thread_id: &str) -> Result<(), StorageError> {
685 self.inner.delete_thread(&self.scoped(thread_id)).await
686 }
687
688 async fn save_thread_state(
689 &self,
690 thread_id: &str,
691 state: &awaken_runtime_contract::state::PersistedState,
692 ) -> Result<(), StorageError> {
693 self.inner
694 .save_thread_state(&self.scoped(thread_id), state)
695 .await
696 }
697
698 async fn load_thread_state(
699 &self,
700 thread_id: &str,
701 ) -> Result<Option<awaken_runtime_contract::state::PersistedState>, StorageError> {
702 self.inner.load_thread_state(&self.scoped(thread_id)).await
703 }
704
705 async fn list_threads(&self, offset: usize, limit: usize) -> Result<Vec<String>, StorageError> {
706 if limit == 0 {
707 return Ok(Vec::new());
708 }
709
710 let scope_prefix = self.scope_prefix();
715 let mut next_offset = offset;
716 let mut items = Vec::with_capacity(limit.min(200));
717 while items.len() < limit {
718 let page_limit = (limit - items.len()).min(200);
719 let page = self
720 .inner
721 .list_threads_query(&ThreadQuery {
722 offset: next_offset,
723 limit: page_limit,
724 resource_id: None,
725 parent_filter: ThreadParentFilter::Any,
726 id_prefix: Some(scope_prefix.clone()),
727 })
728 .await?;
729 let page_len = page.items.len();
730 items.extend(
731 page.items
732 .into_iter()
733 .filter_map(|id| self.unscoped(&id).map(str::to_string)),
734 );
735 if !page.has_more || page_len == 0 {
736 break;
737 }
738 next_offset = next_offset.saturating_add(page_len);
739 }
740 Ok(items)
741 }
742
743 async fn load_messages(&self, thread_id: &str) -> Result<Option<Vec<Message>>, StorageError> {
744 self.inner.load_messages(&self.scoped(thread_id)).await
745 }
746
747 async fn load_committed_messages(
748 &self,
749 thread_id: &str,
750 ) -> Result<Option<Vec<Message>>, StorageError> {
751 self.inner
752 .load_committed_messages(&self.scoped(thread_id))
753 .await
754 }
755
756 async fn load_message_records(
757 &self,
758 thread_id: &str,
759 ) -> Result<Option<Vec<MessageRecord>>, StorageError> {
760 Ok(self
761 .inner
762 .load_message_records(&self.scoped(thread_id))
763 .await?
764 .map(|records| {
765 records
766 .into_iter()
767 .filter_map(|record| self.decode_message_record(record))
768 .collect()
769 }))
770 }
771
772 async fn list_message_records(
773 &self,
774 thread_id: &str,
775 query: &MessageQuery,
776 ) -> Result<MessagePage, StorageError> {
777 let query = self.encode_message_query(query);
778 let mut page = self
779 .inner
780 .list_message_records(&self.scoped(thread_id), &query)
781 .await?;
782 page.records = page
783 .records
784 .into_iter()
785 .filter_map(|record| self.decode_message_record(record))
786 .collect();
787 Ok(page)
788 }
789
790 async fn append_message_records(
791 &self,
792 thread_id: &str,
793 messages: &[Message],
794 ) -> Result<Vec<MessageRecord>, StorageError> {
795 Ok(self
796 .inner
797 .append_message_records(&self.scoped(thread_id), messages)
798 .await?
799 .into_iter()
800 .filter_map(|record| self.decode_message_record(record))
801 .collect())
802 }
803
804 async fn save_messages(
805 &self,
806 thread_id: &str,
807 messages: &[Message],
808 ) -> Result<(), StorageError> {
809 self.inner
810 .save_messages(&self.scoped(thread_id), messages)
811 .await
812 }
813
814 async fn delete_messages(&self, thread_id: &str) -> Result<(), StorageError> {
815 self.inner.delete_messages(&self.scoped(thread_id)).await
816 }
817
818 async fn update_thread_metadata(
819 &self,
820 id: &str,
821 metadata: ThreadMetadata,
822 ) -> Result<(), StorageError> {
823 self.inner
824 .update_thread_metadata(&self.scoped(id), metadata)
825 .await
826 }
827}
828
829#[async_trait]
830impl RunStore for ScopedThreadRunStore {
831 async fn create_run(&self, record: &RunRecord) -> Result<(), StorageError> {
832 self.inner.create_run(&self.encode_run(record)).await
833 }
834
835 async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, StorageError> {
836 Ok(self
837 .inner
838 .load_run(&self.scoped(run_id))
839 .await?
840 .and_then(|record| self.decode_run(record)))
841 }
842
843 async fn latest_run(&self, thread_id: &str) -> Result<Option<RunRecord>, StorageError> {
844 Ok(self
845 .inner
846 .latest_run(&self.scoped(thread_id))
847 .await?
848 .and_then(|record| self.decode_run(record)))
849 }
850
851 async fn list_runs(&self, query: &RunQuery) -> Result<RunPage, StorageError> {
852 if let Some(thread_id) = query.thread_id.as_deref() {
857 let inner_page = self
858 .inner
859 .list_runs(&RunQuery {
860 offset: query.offset,
861 limit: query.limit,
862 thread_id: Some(self.scoped(thread_id)),
863 status: query.status,
864 id_prefix: None,
865 })
866 .await?;
867 let items = inner_page
868 .items
869 .into_iter()
870 .filter_map(|record| self.decode_run(record))
871 .collect();
872 return Ok(RunPage {
873 items,
874 total: inner_page.total,
875 has_more: inner_page.has_more,
876 });
877 }
878
879 let inner_page = self
884 .inner
885 .list_runs(&RunQuery {
886 offset: query.offset,
887 limit: query.limit,
888 thread_id: None,
889 status: query.status,
890 id_prefix: Some(self.scope_prefix()),
891 })
892 .await?;
893 let items = inner_page
894 .items
895 .into_iter()
896 .filter_map(|record| self.decode_run(record))
897 .collect();
898 Ok(RunPage {
899 items,
900 total: inner_page.total,
901 has_more: inner_page.has_more,
902 })
903 }
904}
905
906#[async_trait]
907impl ThreadRunStore for ScopedThreadRunStore {
908 #[allow(deprecated)]
909 async fn checkpoint(
910 &self,
911 thread_id: &str,
912 messages: &[Message],
913 run: &RunRecord,
914 ) -> Result<(), StorageError> {
915 self.inner
916 .checkpoint(&self.scoped(thread_id), messages, &self.encode_run(run))
917 .await
918 }
919
920 async fn checkpoint_append(
921 &self,
922 thread_id: &str,
923 messages: &[Message],
924 expected_version: Option<u64>,
925 run: &RunRecord,
926 ) -> Result<u64, StorageError> {
927 self.inner
928 .checkpoint_append(
929 &self.scoped(thread_id),
930 messages,
931 expected_version,
932 &self.encode_run(run),
933 )
934 .await
935 }
936}
937
938#[cfg(test)]
939mod query_tests {
940 use super::*;
941 use awaken_runtime_contract::contract::lifecycle::RunStatus;
942
943 #[test]
944 fn run_page_with_multiple_records_roundtrips() {
945 let record = |run_id: &str, status: RunStatus, parent: Option<&str>| RunRecord {
946 run_id: run_id.into(),
947 thread_id: "t-1".into(),
948 agent_id: "a-1".into(),
949 parent_run_id: parent.map(str::to_string),
950 resolution_id: None,
951 activation: None,
952 request: None,
953 input: None,
954 output: None,
955 status,
956 termination_reason: None,
957 final_output: None,
958 error_payload: None,
959 dispatch_id: None,
960 session_id: None,
961 transport_request_id: None,
962 waiting: None,
963 outcome: None,
964 created_at: 100,
965 started_at: None,
966 finished_at: None,
967 updated_at: 200,
968 steps: 1,
969 input_tokens: 0,
970 output_tokens: 0,
971 state: None,
972 };
973 let page = RunPage {
974 items: vec![
975 record("r-1", RunStatus::Done, None),
976 record("r-2", RunStatus::Running, Some("r-1")),
977 ],
978 total: 5,
979 has_more: true,
980 };
981
982 let json = serde_json::to_string(&page).unwrap();
983 let parsed: RunPage = serde_json::from_str(&json).unwrap();
984 assert_eq!(parsed.items.len(), 2);
985 assert_eq!(parsed.total, 5);
986 assert!(parsed.has_more);
987 }
988
989 #[test]
990 fn query_defaults_are_sensible() {
991 let mq = MessageQuery::default();
992 assert_eq!(mq.offset, 0);
993 assert_eq!(mq.limit, 50);
994
995 let rq = RunQuery::default();
996 assert_eq!(rq.offset, 0);
997 assert_eq!(rq.limit, 50);
998 assert!(rq.thread_id.is_none());
999 assert!(rq.status.is_none());
1000 }
1001}