capnweb_server/
promise_table.rs1use capnweb_core::{CallId, PendingPromise, PromiseDependencyGraph, PromiseId, RpcError};
2use dashmap::DashMap;
3use serde_json::Value;
4use std::collections::HashSet;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub struct PromiseTable {
10 promises: Arc<DashMap<PromiseId, Arc<RwLock<PendingPromise>>>>,
12 call_to_promise: Arc<DashMap<CallId, PromiseId>>,
14 dependency_graph: Arc<RwLock<PromiseDependencyGraph>>,
16 resolved_promises: Arc<RwLock<HashSet<PromiseId>>>,
18}
19
20impl PromiseTable {
21 pub fn new() -> Self {
22 Self {
23 promises: Arc::new(DashMap::new()),
24 call_to_promise: Arc::new(DashMap::new()),
25 dependency_graph: Arc::new(RwLock::new(PromiseDependencyGraph::new())),
26 resolved_promises: Arc::new(RwLock::new(HashSet::new())),
27 }
28 }
29
30 pub async fn register_promise(&self, promise_id: PromiseId, call_id: CallId) {
32 let promise = Arc::new(RwLock::new(PendingPromise::new(promise_id, call_id)));
33 self.promises.insert(promise_id, promise);
34 self.call_to_promise.insert(call_id, promise_id);
35 }
36
37 pub async fn add_dependency(
39 &self,
40 promise_id: PromiseId,
41 depends_on: PromiseId,
42 ) -> Result<(), RpcError> {
43 let mut graph = self.dependency_graph.write().await;
44
45 if graph.would_create_cycle(promise_id, depends_on) {
47 return Err(RpcError::bad_request(format!(
48 "Adding dependency from {:?} to {:?} would create a cycle",
49 promise_id, depends_on
50 )));
51 }
52
53 graph.add_dependency(promise_id, depends_on);
54
55 if let Some(promise_arc) = self.promises.get(&promise_id) {
57 let mut promise = promise_arc.write().await;
58 promise.add_dependency(depends_on);
59 }
60
61 Ok(())
62 }
63
64 pub async fn resolve_by_call(&self, call_id: CallId, result: Value) -> Option<PromiseId> {
66 if let Some(promise_id) = self.call_to_promise.remove(&call_id) {
67 let promise_id = promise_id.1; self.resolve_promise(promise_id, result).await;
69 Some(promise_id)
70 } else {
71 None
72 }
73 }
74
75 pub async fn resolve_promise(&self, promise_id: PromiseId, result: Value) {
77 if let Some(promise_arc) = self.promises.get(&promise_id) {
78 let mut promise = promise_arc.write().await;
79 promise.resolve(result);
80 }
81
82 let mut resolved = self.resolved_promises.write().await;
83 resolved.insert(promise_id);
84 }
85
86 pub async fn get_ready_promises(&self) -> Vec<PromiseId> {
88 let resolved = self.resolved_promises.read().await;
89 let mut ready = Vec::new();
90
91 for entry in self.promises.iter() {
92 let promise_id = *entry.key();
93 let promise = entry.value().read().await;
94
95 if !promise.resolved && promise.is_ready(&resolved) {
96 ready.push(promise_id);
97 }
98 }
99
100 ready
101 }
102
103 pub async fn get_result(&self, promise_id: &PromiseId) -> Option<Value> {
105 if let Some(promise_arc) = self.promises.get(promise_id) {
106 let promise = promise_arc.read().await;
107 promise.result.clone()
108 } else {
109 None
110 }
111 }
112
113 pub async fn get_execution_order(&self) -> Option<Vec<PromiseId>> {
115 let graph = self.dependency_graph.read().await;
116 graph.topological_sort()
117 }
118
119 pub async fn is_resolved(&self, promise_id: &PromiseId) -> bool {
121 let resolved = self.resolved_promises.read().await;
122 resolved.contains(promise_id)
123 }
124
125 pub async fn clear_resolved(&self) {
127 let resolved = self.resolved_promises.read().await;
128 for promise_id in resolved.iter() {
129 self.promises.remove(promise_id);
130 }
131 drop(resolved);
132
133 let mut resolved = self.resolved_promises.write().await;
134 resolved.clear();
135 }
136
137 pub async fn stats(&self) -> PromiseTableStats {
139 let resolved = self.resolved_promises.read().await;
140 PromiseTableStats {
141 total_promises: self.promises.len(),
142 resolved_promises: resolved.len(),
143 pending_promises: self.promises.len() - resolved.len(),
144 }
145 }
146}
147
148impl Default for PromiseTable {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[derive(Debug, Clone)]
155pub struct PromiseTableStats {
156 pub total_promises: usize,
157 pub resolved_promises: usize,
158 pub pending_promises: usize,
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use serde_json::json;
165
166 #[tokio::test]
167 async fn test_promise_registration() {
168 let table = PromiseTable::new();
169 let promise_id = PromiseId::new(1);
170 let call_id = CallId::new(1);
171
172 table.register_promise(promise_id, call_id).await;
173
174 let stats = table.stats().await;
175 assert_eq!(stats.total_promises, 1);
176 assert_eq!(stats.pending_promises, 1);
177 assert_eq!(stats.resolved_promises, 0);
178 }
179
180 #[tokio::test]
181 async fn test_promise_resolution() {
182 let table = PromiseTable::new();
183 let promise_id = PromiseId::new(1);
184 let call_id = CallId::new(1);
185
186 table.register_promise(promise_id, call_id).await;
187 table
188 .resolve_by_call(call_id, json!({"result": true}))
189 .await;
190
191 assert!(table.is_resolved(&promise_id).await);
192
193 let result = table.get_result(&promise_id).await;
194 assert_eq!(result, Some(json!({"result": true})));
195 }
196
197 #[tokio::test]
198 async fn test_dependencies() {
199 let table = PromiseTable::new();
200
201 let p1 = PromiseId::new(1);
202 let p2 = PromiseId::new(2);
203 let p3 = PromiseId::new(3);
204
205 table.register_promise(p1, CallId::new(1)).await;
206 table.register_promise(p2, CallId::new(2)).await;
207 table.register_promise(p3, CallId::new(3)).await;
208
209 table.add_dependency(p2, p1).await.unwrap();
211 table.add_dependency(p3, p2).await.unwrap();
213
214 let ready = table.get_ready_promises().await;
216 assert_eq!(ready, vec![p1]);
217
218 table.resolve_promise(p1, json!(1)).await;
220 let ready = table.get_ready_promises().await;
221 assert_eq!(ready, vec![p2]);
222
223 table.resolve_promise(p2, json!(2)).await;
225 let ready = table.get_ready_promises().await;
226 assert_eq!(ready, vec![p3]);
227 }
228
229 #[tokio::test]
230 async fn test_cycle_detection() {
231 let table = PromiseTable::new();
232
233 let p1 = PromiseId::new(1);
234 let p2 = PromiseId::new(2);
235
236 table.register_promise(p1, CallId::new(1)).await;
237 table.register_promise(p2, CallId::new(2)).await;
238
239 table.add_dependency(p2, p1).await.unwrap();
241
242 let result = table.add_dependency(p1, p2).await;
244 assert!(result.is_err());
245 }
246
247 #[tokio::test]
248 async fn test_execution_order() {
249 let table = PromiseTable::new();
250
251 let p1 = PromiseId::new(1);
252 let p2 = PromiseId::new(2);
253 let p3 = PromiseId::new(3);
254
255 table.register_promise(p1, CallId::new(1)).await;
256 table.register_promise(p2, CallId::new(2)).await;
257 table.register_promise(p3, CallId::new(3)).await;
258
259 table.add_dependency(p2, p1).await.unwrap();
260 table.add_dependency(p3, p2).await.unwrap();
261
262 let order = table.get_execution_order().await.unwrap();
263 assert_eq!(order, vec![p1, p2, p3]);
264 }
265
266 #[tokio::test]
267 async fn test_clear_resolved() {
268 let table = PromiseTable::new();
269
270 for i in 1..=3 {
271 let pid = PromiseId::new(i);
272 let cid = CallId::new(i);
273 table.register_promise(pid, cid).await;
274 table.resolve_promise(pid, json!(i)).await;
275 }
276
277 let stats = table.stats().await;
278 assert_eq!(stats.resolved_promises, 3);
279
280 table.clear_resolved().await;
281
282 let stats = table.stats().await;
283 assert_eq!(stats.total_promises, 0);
284 assert_eq!(stats.resolved_promises, 0);
285 }
286}