1use anyhow::{Result, anyhow};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::{RwLock, broadcast};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SharedResult {
17 pub key: String,
19
20 pub producer_id: String,
22
23 pub value: Value,
25
26 pub schema: ResultSchema,
28
29 pub published_at: chrono::DateTime<chrono::Utc>,
31
32 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
34
35 pub tags: Vec<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ResultSchema {
42 pub type_name: String,
44
45 pub description: Option<String>,
47
48 pub fields: Option<HashMap<String, String>>,
50}
51
52impl ResultSchema {
53 pub fn from_value(value: &Value) -> Self {
55 let type_name = match value {
56 Value::Null => "null".to_string(),
57 Value::Bool(_) => "boolean".to_string(),
58 Value::Number(n) => {
59 if n.is_i64() || n.is_u64() {
60 "integer".to_string()
61 } else {
62 "number".to_string()
63 }
64 }
65 Value::String(_) => "string".to_string(),
66 Value::Array(_) => "array".to_string(),
67 Value::Object(_) => "object".to_string(),
68 };
69
70 let fields = if let Value::Object(obj) = value {
71 let mut field_types = HashMap::new();
72 for (key, val) in obj {
73 field_types.insert(key.clone(), Self::from_value(val).type_name);
74 }
75 Some(field_types)
76 } else {
77 None
78 };
79
80 Self {
81 type_name,
82 description: None,
83 fields,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub enum SubscriptionPattern {
91 Exact(String),
93 Prefix(String),
95 Tag(Vec<String>),
97 Producer(String),
99 All,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ResultNotification {
106 pub key: String,
108
109 pub producer_id: String,
111
112 pub tags: Vec<String>,
114}
115
116pub struct ResultStore {
118 results: RwLock<HashMap<String, SharedResult>>,
120
121 notification_tx: broadcast::Sender<ResultNotification>,
123
124 subscriptions: RwLock<HashMap<String, Vec<SubscriptionPattern>>>,
126}
127
128impl ResultStore {
129 pub fn new() -> Self {
131 let (notification_tx, _) = broadcast::channel(1000);
132 Self {
133 results: RwLock::new(HashMap::new()),
134 notification_tx,
135 subscriptions: RwLock::new(HashMap::new()),
136 }
137 }
138
139 pub fn new_arc() -> Arc<Self> {
141 Arc::new(Self::new())
142 }
143
144 pub async fn publish(
146 &self,
147 key: impl Into<String>,
148 producer_id: impl Into<String>,
149 value: impl Serialize,
150 tags: Vec<String>,
151 expires_at: Option<chrono::DateTime<chrono::Utc>>,
152 ) -> Result<SharedResult> {
153 let key = key.into();
154 let producer_id = producer_id.into();
155
156 let value = serde_json::to_value(value)?;
157 let schema = ResultSchema::from_value(&value);
158
159 let result = SharedResult {
160 key: key.clone(),
161 producer_id: producer_id.clone(),
162 value,
163 schema,
164 published_at: chrono::Utc::now(),
165 expires_at,
166 tags: tags.clone(),
167 };
168
169 {
171 let mut results = self.results.write().await;
172 results.insert(key.clone(), result.clone());
173 }
174
175 let notification = ResultNotification {
177 key: key.clone(),
178 producer_id,
179 tags,
180 };
181
182 let _ = self.notification_tx.send(notification);
184
185 tracing::info!(key = %key, "Published shared result");
186
187 Ok(result)
188 }
189
190 pub async fn get(&self, key: &str) -> Option<SharedResult> {
192 let results = self.results.read().await;
193 results.get(key).cloned()
194 }
195
196 pub async fn get_typed<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T> {
198 let result = self
199 .get(key)
200 .await
201 .ok_or_else(|| anyhow!("Result not found: {}", key))?;
202
203 serde_json::from_value(result.value)
204 .map_err(|e| anyhow!("Failed to deserialize result: {}", e))
205 }
206
207 pub async fn query_by_tags(&self, tags: &[String]) -> Vec<SharedResult> {
209 let results = self.results.read().await;
210 results
211 .values()
212 .filter(|r| tags.iter().any(|t| r.tags.contains(t)))
213 .cloned()
214 .collect()
215 }
216
217 pub async fn query_by_producer(&self, producer_id: &str) -> Vec<SharedResult> {
219 let results = self.results.read().await;
220 results
221 .values()
222 .filter(|r| r.producer_id == producer_id)
223 .cloned()
224 .collect()
225 }
226
227 pub async fn query_by_prefix(&self, prefix: &str) -> Vec<SharedResult> {
229 let results = self.results.read().await;
230 results
231 .values()
232 .filter(|r| r.key.starts_with(prefix))
233 .cloned()
234 .collect()
235 }
236
237 pub fn subscribe(&self) -> broadcast::Receiver<ResultNotification> {
239 self.notification_tx.subscribe()
240 }
241
242 pub async fn register_subscription(
244 &self,
245 subtask_id: impl Into<String>,
246 pattern: SubscriptionPattern,
247 ) {
248 let subtask_id = subtask_id.into();
249 let mut subscriptions = self.subscriptions.write().await;
250 subscriptions.entry(subtask_id).or_default().push(pattern);
251 }
252
253 pub async fn unregister_subscriptions(&self, subtask_id: &str) {
255 let mut subscriptions = self.subscriptions.write().await;
256 subscriptions.remove(subtask_id);
257 }
258
259 pub fn matches_pattern(result: &SharedResult, pattern: &SubscriptionPattern) -> bool {
261 match pattern {
262 SubscriptionPattern::Exact(key) => result.key == *key,
263 SubscriptionPattern::Prefix(prefix) => result.key.starts_with(prefix),
264 SubscriptionPattern::Tag(tags) => tags.iter().any(|t| result.tags.contains(t)),
265 SubscriptionPattern::Producer(producer) => result.producer_id == *producer,
266 SubscriptionPattern::All => true,
267 }
268 }
269
270 pub async fn get_all(&self) -> Vec<SharedResult> {
272 let results = self.results.read().await;
273 results.values().cloned().collect()
274 }
275
276 pub async fn cleanup_expired(&self) -> usize {
278 let now = chrono::Utc::now();
279 let mut results = self.results.write().await;
280 let keys_to_remove: Vec<String> = results
281 .values()
282 .filter(|r| r.expires_at.map(|exp| exp <= now).unwrap_or(false))
283 .map(|r| r.key.clone())
284 .collect();
285
286 for key in &keys_to_remove {
287 results.remove(key);
288 }
289
290 keys_to_remove.len()
291 }
292
293 pub async fn clear(&self) {
295 let mut results = self.results.write().await;
296 results.clear();
297 }
298}
299
300impl Default for ResultStore {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306pub trait ResultStoreContext {
308 fn publish_result(
310 &self,
311 key: impl Into<String> + Send,
312 value: impl Serialize + Send,
313 tags: Vec<String>,
314 ) -> impl std::future::Future<Output = Result<SharedResult>> + Send;
315
316 fn get_result(
318 &self,
319 key: &str,
320 ) -> impl std::future::Future<Output = Option<SharedResult>> + Send;
321
322 fn get_result_typed<T: for<'de> Deserialize<'de>>(
324 &self,
325 key: &str,
326 ) -> impl std::future::Future<Output = Result<T>> + Send;
327}
328
329pub struct SubTaskStoreHandle {
331 store: Arc<ResultStore>,
332 subtask_id: String,
333}
334
335impl SubTaskStoreHandle {
336 pub fn new(store: Arc<ResultStore>, subtask_id: impl Into<String>) -> Self {
338 Self {
339 store,
340 subtask_id: subtask_id.into(),
341 }
342 }
343}
344
345impl ResultStoreContext for SubTaskStoreHandle {
346 async fn publish_result(
347 &self,
348 key: impl Into<String> + Send,
349 value: impl Serialize + Send,
350 tags: Vec<String>,
351 ) -> Result<SharedResult> {
352 self.store
353 .publish(key, &self.subtask_id, value, tags, None)
354 .await
355 }
356
357 async fn get_result(&self, key: &str) -> Option<SharedResult> {
358 self.store.get(key).await
359 }
360
361 async fn get_result_typed<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T> {
362 self.store.get_typed(key).await
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[tokio::test]
371 async fn test_publish_and_get() {
372 let store = ResultStore::new();
373
374 store
375 .publish("test-key", "task-1", "hello world", vec![], None)
376 .await
377 .unwrap();
378
379 let result = store.get("test-key").await.unwrap();
380 assert_eq!(result.key, "test-key");
381 assert_eq!(result.producer_id, "task-1");
382 assert_eq!(result.value, Value::String("hello world".to_string()));
383 }
384
385 #[tokio::test]
386 async fn test_get_typed() {
387 let store = ResultStore::new();
388
389 #[derive(Debug, Serialize, Deserialize, PartialEq)]
390 struct TestData {
391 name: String,
392 count: i32,
393 }
394
395 let data = TestData {
396 name: "test".to_string(),
397 count: 42,
398 };
399
400 store
401 .publish("typed-key", "task-1", &data, vec![], None)
402 .await
403 .unwrap();
404
405 let retrieved: TestData = store.get_typed("typed-key").await.unwrap();
406 assert_eq!(retrieved, data);
407 }
408
409 #[tokio::test]
410 async fn test_query_by_tags() {
411 let store = ResultStore::new();
412
413 store
414 .publish(
415 "key-1",
416 "task-1",
417 "value-1",
418 vec!["tag-a".to_string()],
419 None,
420 )
421 .await
422 .unwrap();
423
424 store
425 .publish(
426 "key-2",
427 "task-2",
428 "value-2",
429 vec!["tag-b".to_string()],
430 None,
431 )
432 .await
433 .unwrap();
434
435 store
436 .publish(
437 "key-3",
438 "task-1",
439 "value-3",
440 vec!["tag-a".to_string(), "tag-c".to_string()],
441 None,
442 )
443 .await
444 .unwrap();
445
446 let results = store.query_by_tags(&["tag-a".to_string()]).await;
447 assert_eq!(results.len(), 2);
448 }
449
450 #[tokio::test]
451 async fn test_query_by_prefix() {
452 let store = ResultStore::new();
453
454 store
455 .publish("prefix/key-1", "task-1", "value-1", vec![], None)
456 .await
457 .unwrap();
458
459 store
460 .publish("prefix/key-2", "task-2", "value-2", vec![], None)
461 .await
462 .unwrap();
463
464 store
465 .publish("other/key-3", "task-1", "value-3", vec![], None)
466 .await
467 .unwrap();
468
469 let results = store.query_by_prefix("prefix/").await;
470 assert_eq!(results.len(), 2);
471 }
472
473 #[tokio::test]
474 async fn test_subscription_notifications() {
475 let store = ResultStore::new();
476 let mut rx = store.subscribe();
477
478 store
479 .publish(
480 "notify-key",
481 "task-1",
482 "value",
483 vec!["tag-1".to_string()],
484 None,
485 )
486 .await
487 .unwrap();
488
489 let notification = rx.try_recv().unwrap();
490 assert_eq!(notification.key, "notify-key");
491 assert_eq!(notification.producer_id, "task-1");
492 }
493
494 #[tokio::test]
495 async fn test_matches_pattern() {
496 let result = SharedResult {
497 key: "test/key".to_string(),
498 producer_id: "task-1".to_string(),
499 value: Value::Null,
500 schema: ResultSchema::from_value(&Value::Null),
501 published_at: chrono::Utc::now(),
502 expires_at: None,
503 tags: vec!["tag-a".to_string()],
504 };
505
506 assert!(ResultStore::matches_pattern(
507 &result,
508 &SubscriptionPattern::Exact("test/key".to_string())
509 ));
510 assert!(!ResultStore::matches_pattern(
511 &result,
512 &SubscriptionPattern::Exact("other".to_string())
513 ));
514 assert!(ResultStore::matches_pattern(
515 &result,
516 &SubscriptionPattern::Prefix("test/".to_string())
517 ));
518 assert!(ResultStore::matches_pattern(
519 &result,
520 &SubscriptionPattern::Tag(vec!["tag-a".to_string()])
521 ));
522 assert!(ResultStore::matches_pattern(
523 &result,
524 &SubscriptionPattern::Producer("task-1".to_string())
525 ));
526 assert!(ResultStore::matches_pattern(
527 &result,
528 &SubscriptionPattern::All
529 ));
530 }
531}