1use crate::il_executor::ILExecutor;
2use crate::il_extended::{ILContext, ILExpression};
3use crate::promise::PromiseDependencyGraph;
4use crate::protocol::tables::Value as TablesValue;
5use crate::{CallId, PromiseId, RpcError, RpcTarget};
6use dashmap::DashMap;
7use serde_json::Value;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
12pub struct MapOperation {
13 pub source_promise: PromiseId,
15 pub map_function: ILExpression,
17 pub result_promise: PromiseId,
19}
20
21#[derive(Debug, Clone)]
23pub struct PipelinedCall {
24 pub target_promise: PromiseId,
26 pub method: String,
28 pub args: Vec<Value>,
30 pub result_promise: PromiseId,
32 pub call_id: CallId,
34}
35
36pub struct PromiseMapExecutor {
38 map_operations: Arc<DashMap<PromiseId, Vec<MapOperation>>>,
40 pipelined_calls: Arc<DashMap<PromiseId, Vec<PipelinedCall>>>,
42 il_executor: Arc<ILExecutor>,
44 dependency_graph: Arc<tokio::sync::RwLock<PromiseDependencyGraph>>,
46}
47
48impl PromiseMapExecutor {
49 pub fn new() -> Self {
50 Self {
51 map_operations: Arc::new(DashMap::new()),
52 pipelined_calls: Arc::new(DashMap::new()),
53 il_executor: Arc::new(ILExecutor::new()),
54 dependency_graph: Arc::new(tokio::sync::RwLock::new(PromiseDependencyGraph::new())),
55 }
56 }
57
58 pub async fn register_map(
60 &self,
61 source_promise: PromiseId,
62 map_function: ILExpression,
63 result_promise: PromiseId,
64 ) -> Result<(), RpcError> {
65 let operation = MapOperation {
66 source_promise,
67 map_function,
68 result_promise,
69 };
70
71 self.map_operations
72 .entry(source_promise)
73 .or_default()
74 .push(operation);
75
76 let mut graph = self.dependency_graph.write().await;
78 graph.add_dependency(result_promise, source_promise);
79
80 Ok(())
81 }
82
83 pub async fn register_pipelined_call(
85 &self,
86 target_promise: PromiseId,
87 method: String,
88 args: Vec<Value>,
89 result_promise: PromiseId,
90 call_id: CallId,
91 ) -> Result<(), RpcError> {
92 let call = PipelinedCall {
93 target_promise,
94 method,
95 args,
96 result_promise,
97 call_id,
98 };
99
100 self.pipelined_calls
101 .entry(target_promise)
102 .or_default()
103 .push(call);
104
105 let mut graph = self.dependency_graph.write().await;
107 graph.add_dependency(result_promise, target_promise);
108
109 Ok(())
110 }
111
112 pub async fn execute_map_on_resolution(
114 &self,
115 promise_id: PromiseId,
116 resolved_value: Value,
117 ) -> Vec<(PromiseId, Result<Value, RpcError>)> {
118 let mut results = Vec::new();
119
120 if let Some((_, operations)) = self.map_operations.remove(&promise_id) {
122 for operation in operations {
123 let result = self
124 .execute_single_map(&resolved_value, &operation.map_function)
125 .await;
126 results.push((operation.result_promise, result));
127 }
128 }
129
130 results
131 }
132
133 async fn execute_single_map(
135 &self,
136 value: &Value,
137 map_function: &ILExpression,
138 ) -> Result<Value, RpcError> {
139 match value {
140 Value::Array(items) => {
141 let mut mapped_results = Vec::new();
142 let mut context = ILContext::new(vec![]);
143
144 for item in items {
145 context
147 .set_variable(0, item.clone())
148 .map_err(|e| RpcError::internal(format!("IL error: {}", e)))?;
149
150 let result = self
152 .il_executor
153 .execute(map_function, &mut context)
154 .await
155 .map_err(|e| RpcError::internal(format!("Map execution failed: {}", e)))?;
156
157 mapped_results.push(result);
158 }
159
160 Ok(Value::Array(mapped_results))
161 }
162 _ => {
163 let mut context = ILContext::new(vec![]);
165 context
166 .set_variable(0, value.clone())
167 .map_err(|e| RpcError::internal(format!("IL error: {}", e)))?;
168
169 self.il_executor
170 .execute(map_function, &mut context)
171 .await
172 .map_err(|e| RpcError::internal(format!("Map execution failed: {}", e)))
173 }
174 }
175 }
176
177 pub async fn execute_pipelined_calls(
179 &self,
180 promise_id: PromiseId,
181 capability: Arc<dyn RpcTarget>,
182 ) -> Vec<(CallId, PromiseId, Result<Value, RpcError>)> {
183 let mut results = Vec::new();
184
185 if let Some((_, calls)) = self.pipelined_calls.remove(&promise_id) {
187 for call in calls {
188 let converted_args = call.args.into_iter().map(json_to_tables_value).collect();
190
191 let result = capability.call(&call.method, converted_args).await;
192
193 let converted_result = result.map(tables_to_json_value);
195 results.push((call.call_id, call.result_promise, converted_result));
196 }
197 }
198
199 results
200 }
201
202 pub async fn get_dependent_promises(&self, promise_id: PromiseId) -> Vec<PromiseId> {
204 let graph = self.dependency_graph.read().await;
205 graph
206 .dependents_of(&promise_id)
207 .map(|deps| deps.iter().copied().collect())
208 .unwrap_or_default()
209 }
210
211 pub async fn would_create_cycle(&self, promise: PromiseId, depends_on: PromiseId) -> bool {
213 let graph = self.dependency_graph.read().await;
214 graph.would_create_cycle(promise, depends_on)
215 }
216}
217
218impl Default for PromiseMapExecutor {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224fn json_to_tables_value(json: Value) -> TablesValue {
226 match json {
227 Value::Null => TablesValue::Null,
228 Value::Bool(b) => TablesValue::Bool(b),
229 Value::Number(n) => TablesValue::Number(n),
230 Value::String(s) => TablesValue::String(s),
231 Value::Array(arr) => {
232 TablesValue::Array(arr.into_iter().map(json_to_tables_value).collect())
233 }
234 Value::Object(obj) => TablesValue::Object(
235 obj.into_iter()
236 .map(|(k, v)| (k, Box::new(json_to_tables_value(v))))
237 .collect(),
238 ),
239 }
240}
241
242fn tables_to_json_value(value: TablesValue) -> Value {
244 match value {
245 TablesValue::Null => Value::Null,
246 TablesValue::Bool(b) => Value::Bool(b),
247 TablesValue::Number(n) => Value::Number(n),
248 TablesValue::String(s) => Value::String(s),
249 TablesValue::Array(arr) => {
250 Value::Array(arr.into_iter().map(tables_to_json_value).collect())
251 }
252 TablesValue::Object(obj) => Value::Object(
253 obj.into_iter()
254 .map(|(k, v)| (k, tables_to_json_value(*v)))
255 .collect(),
256 ),
257 TablesValue::Date(timestamp) => {
258 serde_json::json!({
260 "_type": "date",
261 "timestamp": timestamp
262 })
263 }
264 TablesValue::Error {
265 error_type,
266 message,
267 stack,
268 } => {
269 serde_json::json!({
271 "_type": "error",
272 "error_type": error_type,
273 "message": message,
274 "stack": stack
275 })
276 }
277 TablesValue::Stub(stub_ref) => {
278 serde_json::json!({
280 "_type": "stub",
281 "id": stub_ref.id
282 })
283 }
284 TablesValue::Promise(promise_ref) => {
285 serde_json::json!({
287 "_type": "promise",
288 "id": promise_ref.id
289 })
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use serde_json::json;
298
299 #[derive(Debug)]
300 struct TestCapability;
301
302 #[async_trait::async_trait]
303 impl RpcTarget for TestCapability {
304 async fn call(
305 &self,
306 method: &str,
307 args: Vec<TablesValue>,
308 ) -> Result<TablesValue, RpcError> {
309 match method {
310 "double" => {
311 if let Some(TablesValue::Number(n)) = args.first() {
312 if let Some(v) = n.as_f64() {
313 return Ok(TablesValue::Number(
314 serde_json::Number::from_f64(v * 2.0).unwrap(),
315 ));
316 }
317 }
318 Err(RpcError::bad_request("Invalid argument"))
319 }
320 "getName" => Ok(TablesValue::String("TestCap".to_string())),
321 _ => Err(RpcError::not_found(format!("Method {} not found", method))),
322 }
323 }
324
325 async fn get_property(&self, _property: &str) -> Result<TablesValue, RpcError> {
326 Ok(TablesValue::Null)
327 }
328 }
329
330 #[tokio::test]
331 async fn test_map_on_array() {
332 let executor = PromiseMapExecutor::new();
333
334 let map_fn = ILExpression::call(ILExpression::var(0), "double".to_string(), vec![]);
336
337 let source_promise = PromiseId::new(1);
338 let result_promise = PromiseId::new(2);
339
340 executor
342 .register_map(source_promise, map_fn, result_promise)
343 .await
344 .unwrap();
345
346 let resolved_value = json!([1, 2, 3, 4, 5]);
348 let results = executor
349 .execute_map_on_resolution(source_promise, resolved_value)
350 .await;
351
352 assert_eq!(results.len(), 1);
353 let (promise_id, result) = &results[0];
354 assert_eq!(*promise_id, result_promise);
355
356 assert!(result.is_err()); }
360
361 #[tokio::test]
362 async fn test_pipelined_call() {
363 let executor = PromiseMapExecutor::new();
364
365 let target_promise = PromiseId::new(1);
366 let result_promise = PromiseId::new(2);
367 let call_id = CallId::new(1);
368
369 executor
371 .register_pipelined_call(
372 target_promise,
373 "getName".to_string(),
374 vec![],
375 result_promise,
376 call_id,
377 )
378 .await
379 .unwrap();
380
381 let capability = Arc::new(TestCapability);
383 let results = executor
384 .execute_pipelined_calls(target_promise, capability)
385 .await;
386
387 assert_eq!(results.len(), 1);
388 let (returned_call_id, returned_promise_id, result) = &results[0];
389 assert_eq!(*returned_call_id, call_id);
390 assert_eq!(*returned_promise_id, result_promise);
391 assert!(result.is_ok());
392 assert_eq!(
393 result.as_ref().unwrap(),
394 &Value::String("TestCap".to_string())
395 );
396 }
397
398 #[tokio::test]
399 async fn test_dependency_tracking() {
400 let executor = PromiseMapExecutor::new();
401
402 let p1 = PromiseId::new(1);
403 let p2 = PromiseId::new(2);
404 let p3 = PromiseId::new(3);
405
406 executor
408 .register_map(p1, ILExpression::var(0), p2)
409 .await
410 .unwrap();
411 executor
412 .register_map(p2, ILExpression::var(0), p3)
413 .await
414 .unwrap();
415
416 let deps_of_p2 = executor.get_dependent_promises(p1).await;
418 assert!(deps_of_p2.contains(&p2));
419
420 let deps_of_p3 = executor.get_dependent_promises(p2).await;
421 assert!(deps_of_p3.contains(&p3));
422
423 assert!(executor.would_create_cycle(p1, p3).await);
425 assert!(!executor.would_create_cycle(p3, p1).await);
426 }
427}