capnweb_core/
promise_map.rs

1use crate::il_executor::ILExecutor;
2use crate::il_extended::{ILContext, ILExpression};
3use crate::promise::PromiseDependencyGraph;
4use crate::protocol::tables::Value as TablesValue;
5use crate::{CallId, PromiseId, RpcError, RpcTarget};
6use dashmap::DashMap;
7use serde_json::Value;
8use std::sync::Arc;
9
10/// MapOperation represents a .map() call on a promise
11#[derive(Debug, Clone)]
12pub struct MapOperation {
13    /// The promise we're mapping over
14    pub source_promise: PromiseId,
15    /// The IL expression to apply to each element
16    pub map_function: ILExpression,
17    /// The resulting promise ID
18    pub result_promise: PromiseId,
19}
20
21/// PipelinedCall represents a method call on an unresolved promise
22#[derive(Debug, Clone)]
23pub struct PipelinedCall {
24    /// The promise we're calling a method on
25    pub target_promise: PromiseId,
26    /// The method name
27    pub method: String,
28    /// The arguments (which may themselves reference promises)
29    pub args: Vec<Value>,
30    /// The resulting promise ID
31    pub result_promise: PromiseId,
32    /// The call ID for tracking
33    pub call_id: CallId,
34}
35
36/// PromiseMapExecutor handles .map() operations and promise pipelining
37pub struct PromiseMapExecutor {
38    /// Pending map operations indexed by source promise
39    map_operations: Arc<DashMap<PromiseId, Vec<MapOperation>>>,
40    /// Pending pipelined calls indexed by target promise
41    pipelined_calls: Arc<DashMap<PromiseId, Vec<PipelinedCall>>>,
42    /// The IL executor for running map functions
43    il_executor: Arc<ILExecutor>,
44    /// Dependency graph for tracking promise dependencies
45    dependency_graph: Arc<tokio::sync::RwLock<PromiseDependencyGraph>>,
46}
47
48impl PromiseMapExecutor {
49    pub fn new() -> Self {
50        Self {
51            map_operations: Arc::new(DashMap::new()),
52            pipelined_calls: Arc::new(DashMap::new()),
53            il_executor: Arc::new(ILExecutor::new()),
54            dependency_graph: Arc::new(tokio::sync::RwLock::new(PromiseDependencyGraph::new())),
55        }
56    }
57
58    /// Register a .map() operation on a promise
59    pub async fn register_map(
60        &self,
61        source_promise: PromiseId,
62        map_function: ILExpression,
63        result_promise: PromiseId,
64    ) -> Result<(), RpcError> {
65        let operation = MapOperation {
66            source_promise,
67            map_function,
68            result_promise,
69        };
70
71        self.map_operations
72            .entry(source_promise)
73            .or_default()
74            .push(operation);
75
76        // Add dependency
77        let mut graph = self.dependency_graph.write().await;
78        graph.add_dependency(result_promise, source_promise);
79
80        Ok(())
81    }
82
83    /// Register a pipelined method call on a promise
84    pub async fn register_pipelined_call(
85        &self,
86        target_promise: PromiseId,
87        method: String,
88        args: Vec<Value>,
89        result_promise: PromiseId,
90        call_id: CallId,
91    ) -> Result<(), RpcError> {
92        let call = PipelinedCall {
93            target_promise,
94            method,
95            args,
96            result_promise,
97            call_id,
98        };
99
100        self.pipelined_calls
101            .entry(target_promise)
102            .or_default()
103            .push(call);
104
105        // Add dependency
106        let mut graph = self.dependency_graph.write().await;
107        graph.add_dependency(result_promise, target_promise);
108
109        Ok(())
110    }
111
112    /// Execute map operations when a promise resolves
113    pub async fn execute_map_on_resolution(
114        &self,
115        promise_id: PromiseId,
116        resolved_value: Value,
117    ) -> Vec<(PromiseId, Result<Value, RpcError>)> {
118        let mut results = Vec::new();
119
120        // Check for map operations on this promise
121        if let Some((_, operations)) = self.map_operations.remove(&promise_id) {
122            for operation in operations {
123                let result = self
124                    .execute_single_map(&resolved_value, &operation.map_function)
125                    .await;
126                results.push((operation.result_promise, result));
127            }
128        }
129
130        results
131    }
132
133    /// Execute a single map operation
134    async fn execute_single_map(
135        &self,
136        value: &Value,
137        map_function: &ILExpression,
138    ) -> Result<Value, RpcError> {
139        match value {
140            Value::Array(items) => {
141                let mut mapped_results = Vec::new();
142                let mut context = ILContext::new(vec![]);
143
144                for item in items {
145                    // Set the current item as a variable in the context
146                    context
147                        .set_variable(0, item.clone())
148                        .map_err(|e| RpcError::internal(format!("IL error: {}", e)))?;
149
150                    // Execute the map function
151                    let result = self
152                        .il_executor
153                        .execute(map_function, &mut context)
154                        .await
155                        .map_err(|e| RpcError::internal(format!("Map execution failed: {}", e)))?;
156
157                    mapped_results.push(result);
158                }
159
160                Ok(Value::Array(mapped_results))
161            }
162            _ => {
163                // For non-arrays, apply the function directly
164                let mut context = ILContext::new(vec![]);
165                context
166                    .set_variable(0, value.clone())
167                    .map_err(|e| RpcError::internal(format!("IL error: {}", e)))?;
168
169                self.il_executor
170                    .execute(map_function, &mut context)
171                    .await
172                    .map_err(|e| RpcError::internal(format!("Map execution failed: {}", e)))
173            }
174        }
175    }
176
177    /// Execute pipelined calls when a promise resolves to a capability
178    pub async fn execute_pipelined_calls(
179        &self,
180        promise_id: PromiseId,
181        capability: Arc<dyn RpcTarget>,
182    ) -> Vec<(CallId, PromiseId, Result<Value, RpcError>)> {
183        let mut results = Vec::new();
184
185        // Check for pipelined calls on this promise
186        if let Some((_, calls)) = self.pipelined_calls.remove(&promise_id) {
187            for call in calls {
188                // Convert serde_json::Value args to tables::Value
189                let converted_args = call.args.into_iter().map(json_to_tables_value).collect();
190
191                let result = capability.call(&call.method, converted_args).await;
192
193                // Convert result back to serde_json::Value
194                let converted_result = result.map(tables_to_json_value);
195                results.push((call.call_id, call.result_promise, converted_result));
196            }
197        }
198
199        results
200    }
201
202    /// Get all promises that depend on a given promise
203    pub async fn get_dependent_promises(&self, promise_id: PromiseId) -> Vec<PromiseId> {
204        let graph = self.dependency_graph.read().await;
205        graph
206            .dependents_of(&promise_id)
207            .map(|deps| deps.iter().copied().collect())
208            .unwrap_or_default()
209    }
210
211    /// Check if there would be a cycle when adding a dependency
212    pub async fn would_create_cycle(&self, promise: PromiseId, depends_on: PromiseId) -> bool {
213        let graph = self.dependency_graph.read().await;
214        graph.would_create_cycle(promise, depends_on)
215    }
216}
217
218impl Default for PromiseMapExecutor {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224/// Convert serde_json::Value to tables::Value
225fn json_to_tables_value(json: Value) -> TablesValue {
226    match json {
227        Value::Null => TablesValue::Null,
228        Value::Bool(b) => TablesValue::Bool(b),
229        Value::Number(n) => TablesValue::Number(n),
230        Value::String(s) => TablesValue::String(s),
231        Value::Array(arr) => {
232            TablesValue::Array(arr.into_iter().map(json_to_tables_value).collect())
233        }
234        Value::Object(obj) => TablesValue::Object(
235            obj.into_iter()
236                .map(|(k, v)| (k, Box::new(json_to_tables_value(v))))
237                .collect(),
238        ),
239    }
240}
241
242/// Convert tables::Value to serde_json::Value
243fn tables_to_json_value(value: TablesValue) -> Value {
244    match value {
245        TablesValue::Null => Value::Null,
246        TablesValue::Bool(b) => Value::Bool(b),
247        TablesValue::Number(n) => Value::Number(n),
248        TablesValue::String(s) => Value::String(s),
249        TablesValue::Array(arr) => {
250            Value::Array(arr.into_iter().map(tables_to_json_value).collect())
251        }
252        TablesValue::Object(obj) => Value::Object(
253            obj.into_iter()
254                .map(|(k, v)| (k, tables_to_json_value(*v)))
255                .collect(),
256        ),
257        TablesValue::Date(timestamp) => {
258            // Convert Date to a JSON object representation
259            serde_json::json!({
260                "_type": "date",
261                "timestamp": timestamp
262            })
263        }
264        TablesValue::Error {
265            error_type,
266            message,
267            stack,
268        } => {
269            // Convert Error to a JSON object representation
270            serde_json::json!({
271                "_type": "error",
272                "error_type": error_type,
273                "message": message,
274                "stack": stack
275            })
276        }
277        TablesValue::Stub(stub_ref) => {
278            // Convert Stub to a JSON object representation
279            serde_json::json!({
280                "_type": "stub",
281                "id": stub_ref.id
282            })
283        }
284        TablesValue::Promise(promise_ref) => {
285            // Convert Promise to a JSON object representation
286            serde_json::json!({
287                "_type": "promise",
288                "id": promise_ref.id
289            })
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use serde_json::json;
298
299    #[derive(Debug)]
300    struct TestCapability;
301
302    #[async_trait::async_trait]
303    impl RpcTarget for TestCapability {
304        async fn call(
305            &self,
306            method: &str,
307            args: Vec<TablesValue>,
308        ) -> Result<TablesValue, RpcError> {
309            match method {
310                "double" => {
311                    if let Some(TablesValue::Number(n)) = args.first() {
312                        if let Some(v) = n.as_f64() {
313                            return Ok(TablesValue::Number(
314                                serde_json::Number::from_f64(v * 2.0).unwrap(),
315                            ));
316                        }
317                    }
318                    Err(RpcError::bad_request("Invalid argument"))
319                }
320                "getName" => Ok(TablesValue::String("TestCap".to_string())),
321                _ => Err(RpcError::not_found(format!("Method {} not found", method))),
322            }
323        }
324
325        async fn get_property(&self, _property: &str) -> Result<TablesValue, RpcError> {
326            Ok(TablesValue::Null)
327        }
328    }
329
330    #[tokio::test]
331    async fn test_map_on_array() {
332        let executor = PromiseMapExecutor::new();
333
334        // Create a map function that doubles values
335        let map_fn = ILExpression::call(ILExpression::var(0), "double".to_string(), vec![]);
336
337        let source_promise = PromiseId::new(1);
338        let result_promise = PromiseId::new(2);
339
340        // Register the map operation
341        executor
342            .register_map(source_promise, map_fn, result_promise)
343            .await
344            .unwrap();
345
346        // Simulate promise resolution with an array
347        let resolved_value = json!([1, 2, 3, 4, 5]);
348        let results = executor
349            .execute_map_on_resolution(source_promise, resolved_value)
350            .await;
351
352        assert_eq!(results.len(), 1);
353        let (promise_id, result) = &results[0];
354        assert_eq!(*promise_id, result_promise);
355
356        // Note: This test would need proper capability resolution to work fully
357        // For now, it demonstrates the structure
358        assert!(result.is_err()); // Expected since we don't have capability resolution yet
359    }
360
361    #[tokio::test]
362    async fn test_pipelined_call() {
363        let executor = PromiseMapExecutor::new();
364
365        let target_promise = PromiseId::new(1);
366        let result_promise = PromiseId::new(2);
367        let call_id = CallId::new(1);
368
369        // Register a pipelined call
370        executor
371            .register_pipelined_call(
372                target_promise,
373                "getName".to_string(),
374                vec![],
375                result_promise,
376                call_id,
377            )
378            .await
379            .unwrap();
380
381        // Simulate promise resolution to a capability
382        let capability = Arc::new(TestCapability);
383        let results = executor
384            .execute_pipelined_calls(target_promise, capability)
385            .await;
386
387        assert_eq!(results.len(), 1);
388        let (returned_call_id, returned_promise_id, result) = &results[0];
389        assert_eq!(*returned_call_id, call_id);
390        assert_eq!(*returned_promise_id, result_promise);
391        assert!(result.is_ok());
392        assert_eq!(
393            result.as_ref().unwrap(),
394            &Value::String("TestCap".to_string())
395        );
396    }
397
398    #[tokio::test]
399    async fn test_dependency_tracking() {
400        let executor = PromiseMapExecutor::new();
401
402        let p1 = PromiseId::new(1);
403        let p2 = PromiseId::new(2);
404        let p3 = PromiseId::new(3);
405
406        // Register map operations to create dependencies
407        executor
408            .register_map(p1, ILExpression::var(0), p2)
409            .await
410            .unwrap();
411        executor
412            .register_map(p2, ILExpression::var(0), p3)
413            .await
414            .unwrap();
415
416        // Check dependencies
417        let deps_of_p2 = executor.get_dependent_promises(p1).await;
418        assert!(deps_of_p2.contains(&p2));
419
420        let deps_of_p3 = executor.get_dependent_promises(p2).await;
421        assert!(deps_of_p3.contains(&p3));
422
423        // Check cycle detection
424        assert!(executor.would_create_cycle(p1, p3).await);
425        assert!(!executor.would_create_cycle(p3, p1).await);
426    }
427}