1use crate::ids::{CallId, CapId, PromiseId};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8#[serde(untagged)]
9pub enum ArgValue {
10 CapRef { cap: CapId },
12 PromiseRef { promise: PromiseId },
14 PromiseField { promise: PromiseId, field: String },
16 Value(Value),
18}
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22#[serde(untagged)]
23pub enum ExtendedTarget {
24 Cap(CapId),
26 Special(String),
28 Promise(PromiseId),
30 PromiseField { promise: PromiseId, field: String },
32}
33
34#[derive(Debug, Default)]
36pub struct PromiseDependencyGraph {
37 dependencies: HashMap<PromiseId, HashSet<PromiseId>>,
39 dependents: HashMap<PromiseId, HashSet<PromiseId>>,
41}
42
43impl PromiseDependencyGraph {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn add_dependency(&mut self, promise: PromiseId, depends_on: PromiseId) {
50 self.dependencies
51 .entry(promise)
52 .or_default()
53 .insert(depends_on);
54
55 self.dependents
56 .entry(depends_on)
57 .or_default()
58 .insert(promise);
59 }
60
61 pub fn dependencies_of(&self, promise: &PromiseId) -> Option<&HashSet<PromiseId>> {
63 self.dependencies.get(promise)
64 }
65
66 pub fn dependents_of(&self, promise: &PromiseId) -> Option<&HashSet<PromiseId>> {
68 self.dependents.get(promise)
69 }
70
71 pub fn topological_sort(&self) -> Option<Vec<PromiseId>> {
74 let mut in_degree = HashMap::new();
75 let mut queue = Vec::new();
76 let mut result = Vec::new();
77
78 let mut all_nodes = HashSet::new();
80 for (promise, deps) in &self.dependencies {
81 all_nodes.insert(*promise);
82 for dep in deps {
83 all_nodes.insert(*dep);
84 }
85 }
86
87 for &node in &all_nodes {
89 in_degree.insert(node, 0);
90 }
91
92 for (promise, deps) in &self.dependencies {
95 *in_degree
96 .get_mut(promise)
97 .expect("Promise should exist in in_degree map") = deps.len();
98 }
99
100 for (&promise, °ree) in &in_degree {
102 if degree == 0 {
103 queue.push(promise);
104 }
105 }
106
107 while let Some(promise) = queue.pop() {
109 result.push(promise);
110
111 if let Some(dependents) = self.dependents.get(&promise) {
113 for &dependent in dependents {
114 let degree = in_degree
115 .get_mut(&dependent)
116 .expect("Dependent should exist in in_degree map");
117 *degree -= 1;
118 if *degree == 0 {
119 queue.push(dependent);
120 }
121 }
122 }
123 }
124
125 if result.len() == all_nodes.len() {
127 Some(result)
128 } else {
129 None
130 }
131 }
132
133 pub fn would_create_cycle(&self, promise: PromiseId, depends_on: PromiseId) -> bool {
135 let mut visited = HashSet::new();
137 let mut stack = vec![depends_on];
138
139 while let Some(current) = stack.pop() {
140 if current == promise {
141 return true;
142 }
143
144 if visited.insert(current) {
145 if let Some(deps) = self.dependencies.get(¤t) {
146 stack.extend(deps.iter().copied());
147 }
148 }
149 }
150
151 false
152 }
153}
154
155#[derive(Debug)]
157pub struct PendingPromise {
158 pub id: PromiseId,
159 pub call_id: CallId,
160 pub dependencies: HashSet<PromiseId>,
161 pub resolved: bool,
162 pub result: Option<Value>,
163}
164
165impl PendingPromise {
166 pub fn new(id: PromiseId, call_id: CallId) -> Self {
167 Self {
168 id,
169 call_id,
170 dependencies: HashSet::new(),
171 resolved: false,
172 result: None,
173 }
174 }
175
176 pub fn add_dependency(&mut self, promise: PromiseId) {
177 self.dependencies.insert(promise);
178 }
179
180 pub fn resolve(&mut self, result: Value) {
181 self.resolved = true;
182 self.result = Some(result);
183 }
184
185 pub fn is_ready(&self, resolved_promises: &HashSet<PromiseId>) -> bool {
186 self.dependencies.is_subset(resolved_promises)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_arg_value_serialization() {
196 let arg = ArgValue::Value(serde_json::json!(42));
198 let json = serde_json::to_string(&arg).unwrap();
199 assert_eq!(json, "42");
200 let deserialized: ArgValue = serde_json::from_str(&json).unwrap();
201 assert_eq!(arg, deserialized);
202
203 let arg = ArgValue::CapRef { cap: CapId::new(1) };
205 let json = serde_json::to_string(&arg).unwrap();
206 assert_eq!(json, r#"{"cap":1}"#);
207
208 let arg = ArgValue::PromiseRef {
210 promise: PromiseId::new(2),
211 };
212 let json = serde_json::to_string(&arg).unwrap();
213 assert_eq!(json, r#"{"promise":2}"#);
214
215 let arg = ArgValue::PromiseField {
217 promise: PromiseId::new(3),
218 field: "result".to_string(),
219 };
220 let json = serde_json::to_string(&arg).unwrap();
221 assert!(json.contains("\"promise\":3"));
223 assert!(json.contains("\"field\":\"result\""));
224 }
225
226 #[test]
227 fn test_dependency_graph() {
228 let mut graph = PromiseDependencyGraph::new();
229
230 let p1 = PromiseId::new(1);
231 let p2 = PromiseId::new(2);
232 let p3 = PromiseId::new(3);
233
234 graph.add_dependency(p2, p1);
237 graph.add_dependency(p3, p2);
238
239 assert!(graph
241 .dependencies_of(&p2)
242 .map(|deps| deps.contains(&p1))
243 .unwrap_or(false));
244 assert!(graph
245 .dependents_of(&p1)
246 .map(|deps| deps.contains(&p2))
247 .unwrap_or(false));
248
249 println!("Graph dependencies: {:?}", graph.dependencies);
251 println!("Graph dependents: {:?}", graph.dependents);
252 let sorted = graph.topological_sort();
253 assert!(
254 sorted.is_some(),
255 "Topological sort should succeed (no cycle)"
256 );
257 let sorted = sorted.unwrap();
258
259 assert_eq!(sorted.len(), 3);
261
262 let p1_index = sorted
264 .iter()
265 .position(|&p| p == p1)
266 .expect("p1 should be in sorted list");
267 let p2_index = sorted
268 .iter()
269 .position(|&p| p == p2)
270 .expect("p2 should be in sorted list");
271 let p3_index = sorted
272 .iter()
273 .position(|&p| p == p3)
274 .expect("p3 should be in sorted list");
275
276 assert!(p1_index < p2_index, "p1 should come before p2");
277 assert!(p2_index < p3_index, "p2 should come before p3");
278 }
279
280 #[test]
281 fn test_cycle_detection() {
282 let mut graph = PromiseDependencyGraph::new();
283
284 let p1 = PromiseId::new(1);
285 let p2 = PromiseId::new(2);
286 let p3 = PromiseId::new(3);
287
288 graph.add_dependency(p2, p1);
289 graph.add_dependency(p3, p2);
290
291 assert!(graph.would_create_cycle(p1, p3));
293
294 assert!(!graph.would_create_cycle(p3, p1));
296 }
297
298 #[test]
299 fn test_pending_promise() {
300 let mut promise = PendingPromise::new(PromiseId::new(1), CallId::new(1));
301
302 promise.add_dependency(PromiseId::new(2));
303 promise.add_dependency(PromiseId::new(3));
304
305 let mut resolved = HashSet::new();
306 assert!(!promise.is_ready(&resolved));
307
308 resolved.insert(PromiseId::new(2));
309 assert!(!promise.is_ready(&resolved));
310
311 resolved.insert(PromiseId::new(3));
312 assert!(promise.is_ready(&resolved));
313
314 promise.resolve(serde_json::json!({"result": true}));
315 assert!(promise.resolved);
316 assert!(promise.result.is_some());
317 }
318}