capnweb_server/
runner.rs

1use crate::{RpcTarget, ServerConfig};
2use capnweb_core::{CapId, Op, Plan, RpcError, Source};
3use serde_json::{Map, Value};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8/// Executes IL Plans against capabilities
9pub struct PlanRunner {
10    /// Configuration for the runner
11    #[allow(dead_code)]
12    config: ServerConfig,
13}
14
15impl PlanRunner {
16    /// Create a new plan runner
17    pub fn new(config: ServerConfig) -> Self {
18        Self { config }
19    }
20
21    /// Execute a plan with captured capabilities
22    pub async fn execute(
23        &self,
24        plan: &Plan,
25        params: Option<Value>,
26        captures: &HashMap<u32, Arc<RwLock<dyn RpcTarget>>>,
27    ) -> Result<Value, RpcError> {
28        // Validate the plan
29        plan.validate()
30            .map_err(|e| RpcError::bad_request(format!("Invalid plan: {}", e)))?;
31
32        // Track results from operations
33        let mut results: HashMap<u32, Value> = HashMap::new();
34
35        // Execute operations in order
36        for op in &plan.ops {
37            let result = self
38                .execute_op(op, params.as_ref(), captures, &results)
39                .await?;
40
41            match op {
42                Op::Call { call } => {
43                    results.insert(call.result, result);
44                }
45                Op::Object { object } => {
46                    results.insert(object.result, result);
47                }
48                Op::Array { array } => {
49                    results.insert(array.result, result);
50                }
51            }
52        }
53
54        // Get the final result
55        self.resolve_source(&plan.result, params.as_ref(), captures, &results)
56    }
57
58    /// Execute a single operation
59    async fn execute_op(
60        &self,
61        op: &Op,
62        params: Option<&Value>,
63        captures: &HashMap<u32, Arc<RwLock<dyn RpcTarget>>>,
64        results: &HashMap<u32, Value>,
65    ) -> Result<Value, RpcError> {
66        match op {
67            Op::Call { call } => {
68                let target_value = self.resolve_source(&call.target, params, captures, results)?;
69
70                // Get the capability ID from the target value
71                let cap_id = if let Value::Object(obj) = &target_value {
72                    if let Some(Value::Number(n)) = obj.get("cap") {
73                        CapId::new(
74                            n.as_u64()
75                                .ok_or_else(|| RpcError::bad_request("Invalid capability ID"))?,
76                        )
77                    } else {
78                        return Err(RpcError::bad_request("Target is not a capability"));
79                    }
80                } else {
81                    return Err(RpcError::bad_request("Target is not a capability"));
82                };
83
84                // Get the capability from captures
85                let capability = captures.get(&(cap_id.as_u64() as u32)).ok_or_else(|| {
86                    RpcError::not_found(format!("Capability not found: {:?}", cap_id))
87                })?;
88
89                // Resolve arguments
90                let mut resolved_args = Vec::new();
91                for arg_source in &call.args {
92                    resolved_args.push(self.resolve_source(arg_source, params, captures, results)?);
93                }
94
95                // Call the method on the capability
96                let target = capability.read().await;
97                target.call(&call.member, resolved_args).await
98            }
99
100            Op::Object { object } => {
101                let mut obj = Map::new();
102                for (key, source) in &object.fields {
103                    let value = self.resolve_source(source, params, captures, results)?;
104                    obj.insert(key.clone(), value);
105                }
106                Ok(Value::Object(obj))
107            }
108
109            Op::Array { array } => {
110                let mut arr = Vec::new();
111                for source in &array.items {
112                    arr.push(self.resolve_source(source, params, captures, results)?);
113                }
114                Ok(Value::Array(arr))
115            }
116        }
117    }
118
119    /// Resolve a source to its value
120    fn resolve_source(
121        &self,
122        source: &Source,
123        params: Option<&Value>,
124        captures: &HashMap<u32, Arc<RwLock<dyn RpcTarget>>>,
125        results: &HashMap<u32, Value>,
126    ) -> Result<Value, RpcError> {
127        match source {
128            Source::Capture { capture } => {
129                // Convert capability to a reference value
130                captures
131                    .get(&capture.index)
132                    .map(|_| {
133                        // Return a capability reference
134                        serde_json::json!({ "cap": capture.index })
135                    })
136                    .ok_or_else(|| {
137                        RpcError::not_found(format!("Capture {} not found", capture.index))
138                    })
139            }
140
141            Source::Result { result } => results
142                .get(&result.index)
143                .cloned()
144                .ok_or_else(|| RpcError::not_found(format!("Result {} not found", result.index))),
145
146            Source::Param { param } => {
147                let params =
148                    params.ok_or_else(|| RpcError::bad_request("No parameters provided"))?;
149
150                // Navigate the path through the params
151                let mut current = params;
152                for segment in &param.path {
153                    current = current.get(segment).ok_or_else(|| {
154                        RpcError::bad_request(format!(
155                            "Parameter path not found: {}",
156                            param.path.join(".")
157                        ))
158                    })?;
159                }
160                Ok(current.clone())
161            }
162
163            Source::ByValue { by_value } => Ok(by_value.value.clone()),
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use async_trait::async_trait;
172    use capnweb_core::{Op, Plan, Source};
173
174    /// Test implementation of RpcTarget
175    struct TestTarget {
176        name: String,
177    }
178
179    #[async_trait]
180    impl RpcTarget for TestTarget {
181        async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, RpcError> {
182            match method {
183                "getName" => Ok(Value::String(self.name.clone())),
184                "add" => {
185                    if args.len() != 2 {
186                        return Err(RpcError::bad_request("add requires 2 arguments"));
187                    }
188                    let a = args[0]
189                        .as_f64()
190                        .ok_or_else(|| RpcError::bad_request("First argument must be a number"))?;
191                    let b = args[1]
192                        .as_f64()
193                        .ok_or_else(|| RpcError::bad_request("Second argument must be a number"))?;
194                    Ok(serde_json::json!(a + b))
195                }
196                "echo" => Ok(args.first().cloned().unwrap_or(Value::Null)),
197                _ => Err(RpcError::not_found(format!("Method not found: {}", method))),
198            }
199        }
200    }
201
202    #[tokio::test]
203    async fn test_execute_simple_call() {
204        let runner = PlanRunner::new(ServerConfig::default());
205
206        let plan = Plan::new(
207            vec![CapId::new(1)],
208            vec![Op::call(
209                Source::capture(0),
210                "getName".to_string(),
211                vec![],
212                0,
213            )],
214            Source::result(0),
215        );
216
217        let mut captures = HashMap::new();
218        captures.insert(
219            0,
220            Arc::new(RwLock::new(TestTarget {
221                name: "test".to_string(),
222            })) as Arc<RwLock<dyn RpcTarget>>,
223        );
224
225        let result = runner.execute(&plan, None, &captures).await.unwrap();
226        assert_eq!(result, Value::String("test".to_string()));
227    }
228
229    #[tokio::test]
230    async fn test_execute_with_params() {
231        let runner = PlanRunner::new(ServerConfig::default());
232
233        let plan = Plan::new(
234            vec![CapId::new(1)],
235            vec![Op::call(
236                Source::capture(0),
237                "add".to_string(),
238                vec![
239                    Source::param(vec!["a".to_string()]),
240                    Source::param(vec!["b".to_string()]),
241                ],
242                0,
243            )],
244            Source::result(0),
245        );
246
247        let mut captures = HashMap::new();
248        captures.insert(
249            0,
250            Arc::new(RwLock::new(TestTarget {
251                name: "calculator".to_string(),
252            })) as Arc<RwLock<dyn RpcTarget>>,
253        );
254
255        let params = serde_json::json!({
256            "a": 5,
257            "b": 3
258        });
259
260        let result = runner
261            .execute(&plan, Some(params), &captures)
262            .await
263            .unwrap();
264        assert_eq!(result, serde_json::json!(8.0));
265    }
266
267    #[tokio::test]
268    async fn test_execute_object_construction() {
269        let runner = PlanRunner::new(ServerConfig::default());
270
271        let plan = Plan::new(
272            vec![],
273            vec![Op::object(
274                vec![
275                    (
276                        "name".to_string(),
277                        Source::by_value(Value::String("test".to_string())),
278                    ),
279                    ("value".to_string(), Source::by_value(serde_json::json!(42))),
280                ]
281                .into_iter()
282                .collect(),
283                0,
284            )],
285            Source::result(0),
286        );
287
288        let captures = HashMap::new();
289        let result = runner.execute(&plan, None, &captures).await.unwrap();
290
291        assert_eq!(
292            result,
293            serde_json::json!({
294                "name": "test",
295                "value": 42
296            })
297        );
298    }
299
300    #[tokio::test]
301    async fn test_execute_array_construction() {
302        let runner = PlanRunner::new(ServerConfig::default());
303
304        let plan = Plan::new(
305            vec![],
306            vec![Op::array(
307                vec![
308                    Source::by_value(serde_json::json!(1)),
309                    Source::by_value(serde_json::json!(2)),
310                    Source::by_value(serde_json::json!(3)),
311                ],
312                0,
313            )],
314            Source::result(0),
315        );
316
317        let captures = HashMap::new();
318        let result = runner.execute(&plan, None, &captures).await.unwrap();
319
320        assert_eq!(result, serde_json::json!([1, 2, 3]));
321    }
322
323    #[tokio::test]
324    async fn test_execute_chained_operations() {
325        let runner = PlanRunner::new(ServerConfig::default());
326
327        let plan = Plan::new(
328            vec![CapId::new(1)],
329            vec![
330                Op::call(
331                    Source::capture(0),
332                    "echo".to_string(),
333                    vec![Source::by_value(Value::String("hello".to_string()))],
334                    0,
335                ),
336                Op::object(
337                    vec![
338                        ("message".to_string(), Source::result(0)),
339                        (
340                            "timestamp".to_string(),
341                            Source::by_value(serde_json::json!(12345)),
342                        ),
343                    ]
344                    .into_iter()
345                    .collect(),
346                    1,
347                ),
348            ],
349            Source::result(1),
350        );
351
352        let mut captures = HashMap::new();
353        captures.insert(
354            0,
355            Arc::new(RwLock::new(TestTarget {
356                name: "echo".to_string(),
357            })) as Arc<RwLock<dyn RpcTarget>>,
358        );
359
360        let result = runner.execute(&plan, None, &captures).await.unwrap();
361
362        assert_eq!(
363            result,
364            serde_json::json!({
365                "message": "hello",
366                "timestamp": 12345
367            })
368        );
369    }
370
371    #[tokio::test]
372    async fn test_invalid_plan() {
373        let runner = PlanRunner::new(ServerConfig::default());
374
375        // Invalid plan with forward reference
376        let plan = Plan::new(
377            vec![],
378            vec![Op::call(
379                Source::result(1), // Forward reference
380                "test".to_string(),
381                vec![],
382                0,
383            )],
384            Source::result(0),
385        );
386
387        let captures = HashMap::new();
388        let result = runner.execute(&plan, None, &captures).await;
389
390        assert!(result.is_err());
391    }
392}