use crate::ids::{CallId, CapId, PromiseId};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ArgValue {
CapRef { cap: CapId },
PromiseRef { promise: PromiseId },
PromiseField { promise: PromiseId, field: String },
Value(Value),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ExtendedTarget {
Cap(CapId),
Special(String),
Promise(PromiseId),
PromiseField { promise: PromiseId, field: String },
}
#[derive(Debug, Default)]
pub struct PromiseDependencyGraph {
dependencies: HashMap<PromiseId, HashSet<PromiseId>>,
dependents: HashMap<PromiseId, HashSet<PromiseId>>,
}
impl PromiseDependencyGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_dependency(&mut self, promise: PromiseId, depends_on: PromiseId) {
self.dependencies
.entry(promise)
.or_default()
.insert(depends_on);
self.dependents
.entry(depends_on)
.or_default()
.insert(promise);
}
pub fn dependencies_of(&self, promise: &PromiseId) -> Option<&HashSet<PromiseId>> {
self.dependencies.get(promise)
}
pub fn dependents_of(&self, promise: &PromiseId) -> Option<&HashSet<PromiseId>> {
self.dependents.get(promise)
}
pub fn topological_sort(&self) -> Option<Vec<PromiseId>> {
let mut in_degree = HashMap::new();
let mut queue = Vec::new();
let mut result = Vec::new();
let mut all_nodes = HashSet::new();
for (promise, deps) in &self.dependencies {
all_nodes.insert(*promise);
for dep in deps {
all_nodes.insert(*dep);
}
}
for &node in &all_nodes {
in_degree.insert(node, 0);
}
for (promise, deps) in &self.dependencies {
*in_degree
.get_mut(promise)
.expect("Promise should exist in in_degree map") = deps.len();
}
for (&promise, °ree) in &in_degree {
if degree == 0 {
queue.push(promise);
}
}
while let Some(promise) = queue.pop() {
result.push(promise);
if let Some(dependents) = self.dependents.get(&promise) {
for &dependent in dependents {
let degree = in_degree
.get_mut(&dependent)
.expect("Dependent should exist in in_degree map");
*degree -= 1;
if *degree == 0 {
queue.push(dependent);
}
}
}
}
if result.len() == all_nodes.len() {
Some(result)
} else {
None
}
}
pub fn would_create_cycle(&self, promise: PromiseId, depends_on: PromiseId) -> bool {
let mut visited = HashSet::new();
let mut stack = vec![depends_on];
while let Some(current) = stack.pop() {
if current == promise {
return true;
}
if visited.insert(current) {
if let Some(deps) = self.dependencies.get(¤t) {
stack.extend(deps.iter().copied());
}
}
}
false
}
}
#[derive(Debug)]
pub struct PendingPromise {
pub id: PromiseId,
pub call_id: CallId,
pub dependencies: HashSet<PromiseId>,
pub resolved: bool,
pub result: Option<Value>,
}
impl PendingPromise {
pub fn new(id: PromiseId, call_id: CallId) -> Self {
Self {
id,
call_id,
dependencies: HashSet::new(),
resolved: false,
result: None,
}
}
pub fn add_dependency(&mut self, promise: PromiseId) {
self.dependencies.insert(promise);
}
pub fn resolve(&mut self, result: Value) {
self.resolved = true;
self.result = Some(result);
}
pub fn is_ready(&self, resolved_promises: &HashSet<PromiseId>) -> bool {
self.dependencies.is_subset(resolved_promises)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arg_value_serialization() {
let arg = ArgValue::Value(serde_json::json!(42));
let json = serde_json::to_string(&arg).unwrap();
assert_eq!(json, "42");
let deserialized: ArgValue = serde_json::from_str(&json).unwrap();
assert_eq!(arg, deserialized);
let arg = ArgValue::CapRef { cap: CapId::new(1) };
let json = serde_json::to_string(&arg).unwrap();
assert_eq!(json, r#"{"cap":1}"#);
let arg = ArgValue::PromiseRef {
promise: PromiseId::new(2),
};
let json = serde_json::to_string(&arg).unwrap();
assert_eq!(json, r#"{"promise":2}"#);
let arg = ArgValue::PromiseField {
promise: PromiseId::new(3),
field: "result".to_string(),
};
let json = serde_json::to_string(&arg).unwrap();
assert!(json.contains("\"promise\":3"));
assert!(json.contains("\"field\":\"result\""));
}
#[test]
fn test_dependency_graph() {
let mut graph = PromiseDependencyGraph::new();
let p1 = PromiseId::new(1);
let p2 = PromiseId::new(2);
let p3 = PromiseId::new(3);
graph.add_dependency(p2, p1);
graph.add_dependency(p3, p2);
assert!(graph
.dependencies_of(&p2)
.map(|deps| deps.contains(&p1))
.unwrap_or(false));
assert!(graph
.dependents_of(&p1)
.map(|deps| deps.contains(&p2))
.unwrap_or(false));
println!("Graph dependencies: {:?}", graph.dependencies);
println!("Graph dependents: {:?}", graph.dependents);
let sorted = graph.topological_sort();
assert!(
sorted.is_some(),
"Topological sort should succeed (no cycle)"
);
let sorted = sorted.unwrap();
assert_eq!(sorted.len(), 3);
let p1_index = sorted
.iter()
.position(|&p| p == p1)
.expect("p1 should be in sorted list");
let p2_index = sorted
.iter()
.position(|&p| p == p2)
.expect("p2 should be in sorted list");
let p3_index = sorted
.iter()
.position(|&p| p == p3)
.expect("p3 should be in sorted list");
assert!(p1_index < p2_index, "p1 should come before p2");
assert!(p2_index < p3_index, "p2 should come before p3");
}
#[test]
fn test_cycle_detection() {
let mut graph = PromiseDependencyGraph::new();
let p1 = PromiseId::new(1);
let p2 = PromiseId::new(2);
let p3 = PromiseId::new(3);
graph.add_dependency(p2, p1);
graph.add_dependency(p3, p2);
assert!(graph.would_create_cycle(p1, p3));
assert!(!graph.would_create_cycle(p3, p1));
}
#[test]
fn test_pending_promise() {
let mut promise = PendingPromise::new(PromiseId::new(1), CallId::new(1));
promise.add_dependency(PromiseId::new(2));
promise.add_dependency(PromiseId::new(3));
let mut resolved = HashSet::new();
assert!(!promise.is_ready(&resolved));
resolved.insert(PromiseId::new(2));
assert!(!promise.is_ready(&resolved));
resolved.insert(PromiseId::new(3));
assert!(promise.is_ready(&resolved));
promise.resolve(serde_json::json!({"result": true}));
assert!(promise.resolved);
assert!(promise.result.is_some());
}
}