use capnweb_core::{CallId, PendingPromise, PromiseDependencyGraph, PromiseId, RpcError};
use dashmap::DashMap;
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct PromiseTable {
promises: Arc<DashMap<PromiseId, Arc<RwLock<PendingPromise>>>>,
call_to_promise: Arc<DashMap<CallId, PromiseId>>,
dependency_graph: Arc<RwLock<PromiseDependencyGraph>>,
resolved_promises: Arc<RwLock<HashSet<PromiseId>>>,
}
impl PromiseTable {
pub fn new() -> Self {
Self {
promises: Arc::new(DashMap::new()),
call_to_promise: Arc::new(DashMap::new()),
dependency_graph: Arc::new(RwLock::new(PromiseDependencyGraph::new())),
resolved_promises: Arc::new(RwLock::new(HashSet::new())),
}
}
pub async fn register_promise(&self, promise_id: PromiseId, call_id: CallId) {
let promise = Arc::new(RwLock::new(PendingPromise::new(promise_id, call_id)));
self.promises.insert(promise_id, promise);
self.call_to_promise.insert(call_id, promise_id);
}
pub async fn add_dependency(
&self,
promise_id: PromiseId,
depends_on: PromiseId,
) -> Result<(), RpcError> {
let mut graph = self.dependency_graph.write().await;
if graph.would_create_cycle(promise_id, depends_on) {
return Err(RpcError::bad_request(format!(
"Adding dependency from {:?} to {:?} would create a cycle",
promise_id, depends_on
)));
}
graph.add_dependency(promise_id, depends_on);
if let Some(promise_arc) = self.promises.get(&promise_id) {
let mut promise = promise_arc.write().await;
promise.add_dependency(depends_on);
}
Ok(())
}
pub async fn resolve_by_call(&self, call_id: CallId, result: Value) -> Option<PromiseId> {
if let Some(promise_id) = self.call_to_promise.remove(&call_id) {
let promise_id = promise_id.1; self.resolve_promise(promise_id, result).await;
Some(promise_id)
} else {
None
}
}
pub async fn resolve_promise(&self, promise_id: PromiseId, result: Value) {
if let Some(promise_arc) = self.promises.get(&promise_id) {
let mut promise = promise_arc.write().await;
promise.resolve(result);
}
let mut resolved = self.resolved_promises.write().await;
resolved.insert(promise_id);
}
pub async fn get_ready_promises(&self) -> Vec<PromiseId> {
let resolved = self.resolved_promises.read().await;
let mut ready = Vec::new();
for entry in self.promises.iter() {
let promise_id = *entry.key();
let promise = entry.value().read().await;
if !promise.resolved && promise.is_ready(&resolved) {
ready.push(promise_id);
}
}
ready
}
pub async fn get_result(&self, promise_id: &PromiseId) -> Option<Value> {
if let Some(promise_arc) = self.promises.get(promise_id) {
let promise = promise_arc.read().await;
promise.result.clone()
} else {
None
}
}
pub async fn get_execution_order(&self) -> Option<Vec<PromiseId>> {
let graph = self.dependency_graph.read().await;
graph.topological_sort()
}
pub async fn is_resolved(&self, promise_id: &PromiseId) -> bool {
let resolved = self.resolved_promises.read().await;
resolved.contains(promise_id)
}
pub async fn clear_resolved(&self) {
let resolved = self.resolved_promises.read().await;
for promise_id in resolved.iter() {
self.promises.remove(promise_id);
}
drop(resolved);
let mut resolved = self.resolved_promises.write().await;
resolved.clear();
}
pub async fn stats(&self) -> PromiseTableStats {
let resolved = self.resolved_promises.read().await;
PromiseTableStats {
total_promises: self.promises.len(),
resolved_promises: resolved.len(),
pending_promises: self.promises.len() - resolved.len(),
}
}
}
impl Default for PromiseTable {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PromiseTableStats {
pub total_promises: usize,
pub resolved_promises: usize,
pub pending_promises: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_promise_registration() {
let table = PromiseTable::new();
let promise_id = PromiseId::new(1);
let call_id = CallId::new(1);
table.register_promise(promise_id, call_id).await;
let stats = table.stats().await;
assert_eq!(stats.total_promises, 1);
assert_eq!(stats.pending_promises, 1);
assert_eq!(stats.resolved_promises, 0);
}
#[tokio::test]
async fn test_promise_resolution() {
let table = PromiseTable::new();
let promise_id = PromiseId::new(1);
let call_id = CallId::new(1);
table.register_promise(promise_id, call_id).await;
table
.resolve_by_call(call_id, json!({"result": true}))
.await;
assert!(table.is_resolved(&promise_id).await);
let result = table.get_result(&promise_id).await;
assert_eq!(result, Some(json!({"result": true})));
}
#[tokio::test]
async fn test_dependencies() {
let table = PromiseTable::new();
let p1 = PromiseId::new(1);
let p2 = PromiseId::new(2);
let p3 = PromiseId::new(3);
table.register_promise(p1, CallId::new(1)).await;
table.register_promise(p2, CallId::new(2)).await;
table.register_promise(p3, CallId::new(3)).await;
table.add_dependency(p2, p1).await.unwrap();
table.add_dependency(p3, p2).await.unwrap();
let ready = table.get_ready_promises().await;
assert_eq!(ready, vec![p1]);
table.resolve_promise(p1, json!(1)).await;
let ready = table.get_ready_promises().await;
assert_eq!(ready, vec![p2]);
table.resolve_promise(p2, json!(2)).await;
let ready = table.get_ready_promises().await;
assert_eq!(ready, vec![p3]);
}
#[tokio::test]
async fn test_cycle_detection() {
let table = PromiseTable::new();
let p1 = PromiseId::new(1);
let p2 = PromiseId::new(2);
table.register_promise(p1, CallId::new(1)).await;
table.register_promise(p2, CallId::new(2)).await;
table.add_dependency(p2, p1).await.unwrap();
let result = table.add_dependency(p1, p2).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_execution_order() {
let table = PromiseTable::new();
let p1 = PromiseId::new(1);
let p2 = PromiseId::new(2);
let p3 = PromiseId::new(3);
table.register_promise(p1, CallId::new(1)).await;
table.register_promise(p2, CallId::new(2)).await;
table.register_promise(p3, CallId::new(3)).await;
table.add_dependency(p2, p1).await.unwrap();
table.add_dependency(p3, p2).await.unwrap();
let order = table.get_execution_order().await.unwrap();
assert_eq!(order, vec![p1, p2, p3]);
}
#[tokio::test]
async fn test_clear_resolved() {
let table = PromiseTable::new();
for i in 1..=3 {
let pid = PromiseId::new(i);
let cid = CallId::new(i);
table.register_promise(pid, cid).await;
table.resolve_promise(pid, json!(i)).await;
}
let stats = table.stats().await;
assert_eq!(stats.resolved_promises, 3);
table.clear_resolved().await;
let stats = table.stats().await;
assert_eq!(stats.total_promises, 0);
assert_eq!(stats.resolved_promises, 0);
}
}