Skip to main content

codetether_agent/swarm/
result_store.rs

1//! Shared result store for sub-agent result sharing
2//!
3//! This module provides a mechanism for sub-agents to share intermediate results
4//! during execution, allowing dependent sub-agents to access results without
5//! waiting for the entire swarm to finish.
6
7use anyhow::{Result, anyhow};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::{RwLock, broadcast};
13
14/// A typed result entry in the shared store
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SharedResult {
17    /// Unique key for this result
18    pub key: String,
19
20    /// The subtask ID that produced this result
21    pub producer_id: String,
22
23    /// The result value (JSON)
24    pub value: Value,
25
26    /// Schema/type information for the value
27    pub schema: ResultSchema,
28
29    /// Timestamp when the result was published
30    pub published_at: chrono::DateTime<chrono::Utc>,
31
32    /// Optional expiration time
33    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
34
35    /// Tags for categorization and filtering
36    pub tags: Vec<String>,
37}
38
39/// Schema information for a shared result
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ResultSchema {
42    /// Type name (e.g., "string", "number", "object", "array")
43    pub type_name: String,
44
45    /// Optional description of the result
46    pub description: Option<String>,
47
48    /// For objects, the expected field types
49    pub fields: Option<HashMap<String, String>>,
50}
51
52impl ResultSchema {
53    /// Create a schema from a serde_json::Value
54    pub fn from_value(value: &Value) -> Self {
55        let type_name = match value {
56            Value::Null => "null".to_string(),
57            Value::Bool(_) => "boolean".to_string(),
58            Value::Number(n) => {
59                if n.is_i64() || n.is_u64() {
60                    "integer".to_string()
61                } else {
62                    "number".to_string()
63                }
64            }
65            Value::String(_) => "string".to_string(),
66            Value::Array(_) => "array".to_string(),
67            Value::Object(_) => "object".to_string(),
68        };
69
70        let fields = if let Value::Object(obj) = value {
71            let mut field_types = HashMap::new();
72            for (key, val) in obj {
73                field_types.insert(key.clone(), Self::from_value(val).type_name);
74            }
75            Some(field_types)
76        } else {
77            None
78        };
79
80        Self {
81            type_name,
82            description: None,
83            fields,
84        }
85    }
86}
87
88/// Subscription pattern for result notifications
89#[derive(Debug, Clone)]
90pub enum SubscriptionPattern {
91    /// Exact key match
92    Exact(String),
93    /// Prefix match (key starts with)
94    Prefix(String),
95    /// Tag match (result has any of these tags)
96    Tag(Vec<String>),
97    /// Producer match (results from specific subtask)
98    Producer(String),
99    /// All results
100    All,
101}
102
103/// Notification sent when a result is published
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ResultNotification {
106    /// The key of the published result
107    pub key: String,
108
109    /// The producer subtask ID
110    pub producer_id: String,
111
112    /// Tags associated with the result
113    pub tags: Vec<String>,
114}
115
116/// Shared result store for sub-agent communication
117pub struct ResultStore {
118    /// All stored results by key
119    results: RwLock<HashMap<String, SharedResult>>,
120
121    /// Broadcast channel for result notifications
122    notification_tx: broadcast::Sender<ResultNotification>,
123
124    /// Subscriptions by pattern (for efficient filtering)
125    subscriptions: RwLock<HashMap<String, Vec<SubscriptionPattern>>>,
126}
127
128impl ResultStore {
129    /// Create a new result store
130    pub fn new() -> Self {
131        let (notification_tx, _) = broadcast::channel(1000);
132        Self {
133            results: RwLock::new(HashMap::new()),
134            notification_tx,
135            subscriptions: RwLock::new(HashMap::new()),
136        }
137    }
138
139    /// Create a new result store wrapped in an Arc
140    pub fn new_arc() -> Arc<Self> {
141        Arc::new(Self::new())
142    }
143
144    /// Publish a result to the store
145    pub async fn publish(
146        &self,
147        key: impl Into<String>,
148        producer_id: impl Into<String>,
149        value: impl Serialize,
150        tags: Vec<String>,
151        expires_at: Option<chrono::DateTime<chrono::Utc>>,
152    ) -> Result<SharedResult> {
153        let key = key.into();
154        let producer_id = producer_id.into();
155
156        let value = serde_json::to_value(value)?;
157        let schema = ResultSchema::from_value(&value);
158
159        let result = SharedResult {
160            key: key.clone(),
161            producer_id: producer_id.clone(),
162            value,
163            schema,
164            published_at: chrono::Utc::now(),
165            expires_at,
166            tags: tags.clone(),
167        };
168
169        // Store the result
170        {
171            let mut results = self.results.write().await;
172            results.insert(key.clone(), result.clone());
173        }
174
175        // Notify subscribers
176        let notification = ResultNotification {
177            key: key.clone(),
178            producer_id,
179            tags,
180        };
181
182        // Broadcast to all listeners (they filter based on their subscription)
183        let _ = self.notification_tx.send(notification);
184
185        tracing::info!(key = %key, "Published shared result");
186
187        Ok(result)
188    }
189
190    /// Get a result by key
191    pub async fn get(&self, key: &str) -> Option<SharedResult> {
192        let results = self.results.read().await;
193        results.get(key).cloned()
194    }
195
196    /// Get a result and deserialize it to a specific type
197    pub async fn get_typed<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T> {
198        let result = self
199            .get(key)
200            .await
201            .ok_or_else(|| anyhow!("Result not found: {}", key))?;
202
203        serde_json::from_value(result.value)
204            .map_err(|e| anyhow!("Failed to deserialize result: {}", e))
205    }
206
207    /// Query results by tags
208    pub async fn query_by_tags(&self, tags: &[String]) -> Vec<SharedResult> {
209        let results = self.results.read().await;
210        results
211            .values()
212            .filter(|r| tags.iter().any(|t| r.tags.contains(t)))
213            .cloned()
214            .collect()
215    }
216
217    /// Query results by producer
218    pub async fn query_by_producer(&self, producer_id: &str) -> Vec<SharedResult> {
219        let results = self.results.read().await;
220        results
221            .values()
222            .filter(|r| r.producer_id == producer_id)
223            .cloned()
224            .collect()
225    }
226
227    /// Query results by key prefix
228    pub async fn query_by_prefix(&self, prefix: &str) -> Vec<SharedResult> {
229        let results = self.results.read().await;
230        results
231            .values()
232            .filter(|r| r.key.starts_with(prefix))
233            .cloned()
234            .collect()
235    }
236
237    /// Subscribe to result notifications
238    pub fn subscribe(&self) -> broadcast::Receiver<ResultNotification> {
239        self.notification_tx.subscribe()
240    }
241
242    /// Register a subscription pattern for a subtask
243    pub async fn register_subscription(
244        &self,
245        subtask_id: impl Into<String>,
246        pattern: SubscriptionPattern,
247    ) {
248        let subtask_id = subtask_id.into();
249        let mut subscriptions = self.subscriptions.write().await;
250        subscriptions.entry(subtask_id).or_default().push(pattern);
251    }
252
253    /// Unregister all subscriptions for a subtask
254    pub async fn unregister_subscriptions(&self, subtask_id: &str) {
255        let mut subscriptions = self.subscriptions.write().await;
256        subscriptions.remove(subtask_id);
257    }
258
259    /// Check if a result matches a subscription pattern
260    pub fn matches_pattern(result: &SharedResult, pattern: &SubscriptionPattern) -> bool {
261        match pattern {
262            SubscriptionPattern::Exact(key) => result.key == *key,
263            SubscriptionPattern::Prefix(prefix) => result.key.starts_with(prefix),
264            SubscriptionPattern::Tag(tags) => tags.iter().any(|t| result.tags.contains(t)),
265            SubscriptionPattern::Producer(producer) => result.producer_id == *producer,
266            SubscriptionPattern::All => true,
267        }
268    }
269
270    /// Get all results (for debugging/inspection)
271    pub async fn get_all(&self) -> Vec<SharedResult> {
272        let results = self.results.read().await;
273        results.values().cloned().collect()
274    }
275
276    /// Remove expired results
277    pub async fn cleanup_expired(&self) -> usize {
278        let now = chrono::Utc::now();
279        let mut results = self.results.write().await;
280        let keys_to_remove: Vec<String> = results
281            .values()
282            .filter(|r| r.expires_at.map(|exp| exp <= now).unwrap_or(false))
283            .map(|r| r.key.clone())
284            .collect();
285
286        for key in &keys_to_remove {
287            results.remove(key);
288        }
289
290        keys_to_remove.len()
291    }
292
293    /// Clear all results
294    pub async fn clear(&self) {
295        let mut results = self.results.write().await;
296        results.clear();
297    }
298}
299
300impl Default for ResultStore {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306/// Extension trait for SubTaskContext to integrate with ResultStore
307pub trait ResultStoreContext {
308    /// Publish a result to the shared store
309    fn publish_result(
310        &self,
311        key: impl Into<String> + Send,
312        value: impl Serialize + Send,
313        tags: Vec<String>,
314    ) -> impl std::future::Future<Output = Result<SharedResult>> + Send;
315
316    /// Get a result from the shared store
317    fn get_result(
318        &self,
319        key: &str,
320    ) -> impl std::future::Future<Output = Option<SharedResult>> + Send;
321
322    /// Get a typed result from the shared store
323    fn get_result_typed<T: for<'de> Deserialize<'de>>(
324        &self,
325        key: &str,
326    ) -> impl std::future::Future<Output = Result<T>> + Send;
327}
328
329/// Handle for a sub-agent to interact with the shared result store
330pub struct SubTaskStoreHandle {
331    store: Arc<ResultStore>,
332    subtask_id: String,
333}
334
335impl SubTaskStoreHandle {
336    /// Create a new handle for a specific subtask
337    pub fn new(store: Arc<ResultStore>, subtask_id: impl Into<String>) -> Self {
338        Self {
339            store,
340            subtask_id: subtask_id.into(),
341        }
342    }
343}
344
345impl ResultStoreContext for SubTaskStoreHandle {
346    async fn publish_result(
347        &self,
348        key: impl Into<String> + Send,
349        value: impl Serialize + Send,
350        tags: Vec<String>,
351    ) -> Result<SharedResult> {
352        self.store
353            .publish(key, &self.subtask_id, value, tags, None)
354            .await
355    }
356
357    async fn get_result(&self, key: &str) -> Option<SharedResult> {
358        self.store.get(key).await
359    }
360
361    async fn get_result_typed<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T> {
362        self.store.get_typed(key).await
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[tokio::test]
371    async fn test_publish_and_get() {
372        let store = ResultStore::new();
373
374        store
375            .publish("test-key", "task-1", "hello world", vec![], None)
376            .await
377            .unwrap();
378
379        let result = store.get("test-key").await.unwrap();
380        assert_eq!(result.key, "test-key");
381        assert_eq!(result.producer_id, "task-1");
382        assert_eq!(result.value, Value::String("hello world".to_string()));
383    }
384
385    #[tokio::test]
386    async fn test_get_typed() {
387        let store = ResultStore::new();
388
389        #[derive(Debug, Serialize, Deserialize, PartialEq)]
390        struct TestData {
391            name: String,
392            count: i32,
393        }
394
395        let data = TestData {
396            name: "test".to_string(),
397            count: 42,
398        };
399
400        store
401            .publish("typed-key", "task-1", &data, vec![], None)
402            .await
403            .unwrap();
404
405        let retrieved: TestData = store.get_typed("typed-key").await.unwrap();
406        assert_eq!(retrieved, data);
407    }
408
409    #[tokio::test]
410    async fn test_query_by_tags() {
411        let store = ResultStore::new();
412
413        store
414            .publish(
415                "key-1",
416                "task-1",
417                "value-1",
418                vec!["tag-a".to_string()],
419                None,
420            )
421            .await
422            .unwrap();
423
424        store
425            .publish(
426                "key-2",
427                "task-2",
428                "value-2",
429                vec!["tag-b".to_string()],
430                None,
431            )
432            .await
433            .unwrap();
434
435        store
436            .publish(
437                "key-3",
438                "task-1",
439                "value-3",
440                vec!["tag-a".to_string(), "tag-c".to_string()],
441                None,
442            )
443            .await
444            .unwrap();
445
446        let results = store.query_by_tags(&["tag-a".to_string()]).await;
447        assert_eq!(results.len(), 2);
448    }
449
450    #[tokio::test]
451    async fn test_query_by_prefix() {
452        let store = ResultStore::new();
453
454        store
455            .publish("prefix/key-1", "task-1", "value-1", vec![], None)
456            .await
457            .unwrap();
458
459        store
460            .publish("prefix/key-2", "task-2", "value-2", vec![], None)
461            .await
462            .unwrap();
463
464        store
465            .publish("other/key-3", "task-1", "value-3", vec![], None)
466            .await
467            .unwrap();
468
469        let results = store.query_by_prefix("prefix/").await;
470        assert_eq!(results.len(), 2);
471    }
472
473    #[tokio::test]
474    async fn test_subscription_notifications() {
475        let store = ResultStore::new();
476        let mut rx = store.subscribe();
477
478        store
479            .publish(
480                "notify-key",
481                "task-1",
482                "value",
483                vec!["tag-1".to_string()],
484                None,
485            )
486            .await
487            .unwrap();
488
489        let notification = rx.try_recv().unwrap();
490        assert_eq!(notification.key, "notify-key");
491        assert_eq!(notification.producer_id, "task-1");
492    }
493
494    #[tokio::test]
495    async fn test_matches_pattern() {
496        let result = SharedResult {
497            key: "test/key".to_string(),
498            producer_id: "task-1".to_string(),
499            value: Value::Null,
500            schema: ResultSchema::from_value(&Value::Null),
501            published_at: chrono::Utc::now(),
502            expires_at: None,
503            tags: vec!["tag-a".to_string()],
504        };
505
506        assert!(ResultStore::matches_pattern(
507            &result,
508            &SubscriptionPattern::Exact("test/key".to_string())
509        ));
510        assert!(!ResultStore::matches_pattern(
511            &result,
512            &SubscriptionPattern::Exact("other".to_string())
513        ));
514        assert!(ResultStore::matches_pattern(
515            &result,
516            &SubscriptionPattern::Prefix("test/".to_string())
517        ));
518        assert!(ResultStore::matches_pattern(
519            &result,
520            &SubscriptionPattern::Tag(vec!["tag-a".to_string()])
521        ));
522        assert!(ResultStore::matches_pattern(
523            &result,
524            &SubscriptionPattern::Producer("task-1".to_string())
525        ));
526        assert!(ResultStore::matches_pattern(
527            &result,
528            &SubscriptionPattern::All
529        ));
530    }
531}