capnweb_server/
promise_table.rs

1use capnweb_core::{CallId, PendingPromise, PromiseDependencyGraph, PromiseId, RpcError};
2use dashmap::DashMap;
3use serde_json::Value;
4use std::collections::HashSet;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8/// Table for tracking pending promises and their resolution
9pub struct PromiseTable {
10    /// Maps PromiseId to PendingPromise
11    promises: Arc<DashMap<PromiseId, Arc<RwLock<PendingPromise>>>>,
12    /// Maps CallId to PromiseId for result correlation
13    call_to_promise: Arc<DashMap<CallId, PromiseId>>,
14    /// Dependency graph for topological sorting
15    dependency_graph: Arc<RwLock<PromiseDependencyGraph>>,
16    /// Set of resolved promise IDs
17    resolved_promises: Arc<RwLock<HashSet<PromiseId>>>,
18}
19
20impl PromiseTable {
21    pub fn new() -> Self {
22        Self {
23            promises: Arc::new(DashMap::new()),
24            call_to_promise: Arc::new(DashMap::new()),
25            dependency_graph: Arc::new(RwLock::new(PromiseDependencyGraph::new())),
26            resolved_promises: Arc::new(RwLock::new(HashSet::new())),
27        }
28    }
29
30    /// Register a new pending promise
31    pub async fn register_promise(&self, promise_id: PromiseId, call_id: CallId) {
32        let promise = Arc::new(RwLock::new(PendingPromise::new(promise_id, call_id)));
33        self.promises.insert(promise_id, promise);
34        self.call_to_promise.insert(call_id, promise_id);
35    }
36
37    /// Add a dependency between promises
38    pub async fn add_dependency(
39        &self,
40        promise_id: PromiseId,
41        depends_on: PromiseId,
42    ) -> Result<(), RpcError> {
43        let mut graph = self.dependency_graph.write().await;
44
45        // Check for cycles
46        if graph.would_create_cycle(promise_id, depends_on) {
47            return Err(RpcError::bad_request(format!(
48                "Adding dependency from {:?} to {:?} would create a cycle",
49                promise_id, depends_on
50            )));
51        }
52
53        graph.add_dependency(promise_id, depends_on);
54
55        // Update the promise's dependencies
56        if let Some(promise_arc) = self.promises.get(&promise_id) {
57            let mut promise = promise_arc.write().await;
58            promise.add_dependency(depends_on);
59        }
60
61        Ok(())
62    }
63
64    /// Resolve a promise by call ID
65    pub async fn resolve_by_call(&self, call_id: CallId, result: Value) -> Option<PromiseId> {
66        if let Some(promise_id) = self.call_to_promise.remove(&call_id) {
67            let promise_id = promise_id.1; // Extract value from DashMap entry
68            self.resolve_promise(promise_id, result).await;
69            Some(promise_id)
70        } else {
71            None
72        }
73    }
74
75    /// Resolve a promise directly
76    pub async fn resolve_promise(&self, promise_id: PromiseId, result: Value) {
77        if let Some(promise_arc) = self.promises.get(&promise_id) {
78            let mut promise = promise_arc.write().await;
79            promise.resolve(result);
80        }
81
82        let mut resolved = self.resolved_promises.write().await;
83        resolved.insert(promise_id);
84    }
85
86    /// Get promises that are ready to execute (all dependencies resolved)
87    pub async fn get_ready_promises(&self) -> Vec<PromiseId> {
88        let resolved = self.resolved_promises.read().await;
89        let mut ready = Vec::new();
90
91        for entry in self.promises.iter() {
92            let promise_id = *entry.key();
93            let promise = entry.value().read().await;
94
95            if !promise.resolved && promise.is_ready(&resolved) {
96                ready.push(promise_id);
97            }
98        }
99
100        ready
101    }
102
103    /// Get the result of a resolved promise
104    pub async fn get_result(&self, promise_id: &PromiseId) -> Option<Value> {
105        if let Some(promise_arc) = self.promises.get(promise_id) {
106            let promise = promise_arc.read().await;
107            promise.result.clone()
108        } else {
109            None
110        }
111    }
112
113    /// Get all promises in topologically sorted order
114    pub async fn get_execution_order(&self) -> Option<Vec<PromiseId>> {
115        let graph = self.dependency_graph.read().await;
116        graph.topological_sort()
117    }
118
119    /// Check if a promise is resolved
120    pub async fn is_resolved(&self, promise_id: &PromiseId) -> bool {
121        let resolved = self.resolved_promises.read().await;
122        resolved.contains(promise_id)
123    }
124
125    /// Clear all resolved promises (cleanup)
126    pub async fn clear_resolved(&self) {
127        let resolved = self.resolved_promises.read().await;
128        for promise_id in resolved.iter() {
129            self.promises.remove(promise_id);
130        }
131        drop(resolved);
132
133        let mut resolved = self.resolved_promises.write().await;
134        resolved.clear();
135    }
136
137    /// Get statistics about the promise table
138    pub async fn stats(&self) -> PromiseTableStats {
139        let resolved = self.resolved_promises.read().await;
140        PromiseTableStats {
141            total_promises: self.promises.len(),
142            resolved_promises: resolved.len(),
143            pending_promises: self.promises.len() - resolved.len(),
144        }
145    }
146}
147
148impl Default for PromiseTable {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154#[derive(Debug, Clone)]
155pub struct PromiseTableStats {
156    pub total_promises: usize,
157    pub resolved_promises: usize,
158    pub pending_promises: usize,
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use serde_json::json;
165
166    #[tokio::test]
167    async fn test_promise_registration() {
168        let table = PromiseTable::new();
169        let promise_id = PromiseId::new(1);
170        let call_id = CallId::new(1);
171
172        table.register_promise(promise_id, call_id).await;
173
174        let stats = table.stats().await;
175        assert_eq!(stats.total_promises, 1);
176        assert_eq!(stats.pending_promises, 1);
177        assert_eq!(stats.resolved_promises, 0);
178    }
179
180    #[tokio::test]
181    async fn test_promise_resolution() {
182        let table = PromiseTable::new();
183        let promise_id = PromiseId::new(1);
184        let call_id = CallId::new(1);
185
186        table.register_promise(promise_id, call_id).await;
187        table
188            .resolve_by_call(call_id, json!({"result": true}))
189            .await;
190
191        assert!(table.is_resolved(&promise_id).await);
192
193        let result = table.get_result(&promise_id).await;
194        assert_eq!(result, Some(json!({"result": true})));
195    }
196
197    #[tokio::test]
198    async fn test_dependencies() {
199        let table = PromiseTable::new();
200
201        let p1 = PromiseId::new(1);
202        let p2 = PromiseId::new(2);
203        let p3 = PromiseId::new(3);
204
205        table.register_promise(p1, CallId::new(1)).await;
206        table.register_promise(p2, CallId::new(2)).await;
207        table.register_promise(p3, CallId::new(3)).await;
208
209        // p2 depends on p1
210        table.add_dependency(p2, p1).await.unwrap();
211        // p3 depends on p2
212        table.add_dependency(p3, p2).await.unwrap();
213
214        // Initially, only p1 should be ready
215        let ready = table.get_ready_promises().await;
216        assert_eq!(ready, vec![p1]);
217
218        // After resolving p1, p2 should be ready
219        table.resolve_promise(p1, json!(1)).await;
220        let ready = table.get_ready_promises().await;
221        assert_eq!(ready, vec![p2]);
222
223        // After resolving p2, p3 should be ready
224        table.resolve_promise(p2, json!(2)).await;
225        let ready = table.get_ready_promises().await;
226        assert_eq!(ready, vec![p3]);
227    }
228
229    #[tokio::test]
230    async fn test_cycle_detection() {
231        let table = PromiseTable::new();
232
233        let p1 = PromiseId::new(1);
234        let p2 = PromiseId::new(2);
235
236        table.register_promise(p1, CallId::new(1)).await;
237        table.register_promise(p2, CallId::new(2)).await;
238
239        // p2 depends on p1
240        table.add_dependency(p2, p1).await.unwrap();
241
242        // p1 depends on p2 would create a cycle
243        let result = table.add_dependency(p1, p2).await;
244        assert!(result.is_err());
245    }
246
247    #[tokio::test]
248    async fn test_execution_order() {
249        let table = PromiseTable::new();
250
251        let p1 = PromiseId::new(1);
252        let p2 = PromiseId::new(2);
253        let p3 = PromiseId::new(3);
254
255        table.register_promise(p1, CallId::new(1)).await;
256        table.register_promise(p2, CallId::new(2)).await;
257        table.register_promise(p3, CallId::new(3)).await;
258
259        table.add_dependency(p2, p1).await.unwrap();
260        table.add_dependency(p3, p2).await.unwrap();
261
262        let order = table.get_execution_order().await.unwrap();
263        assert_eq!(order, vec![p1, p2, p3]);
264    }
265
266    #[tokio::test]
267    async fn test_clear_resolved() {
268        let table = PromiseTable::new();
269
270        for i in 1..=3 {
271            let pid = PromiseId::new(i);
272            let cid = CallId::new(i);
273            table.register_promise(pid, cid).await;
274            table.resolve_promise(pid, json!(i)).await;
275        }
276
277        let stats = table.stats().await;
278        assert_eq!(stats.resolved_promises, 3);
279
280        table.clear_resolved().await;
281
282        let stats = table.stats().await;
283        assert_eq!(stats.total_promises, 0);
284        assert_eq!(stats.resolved_promises, 0);
285    }
286}