1use crate::{RpcTarget, ServerConfig};
2use capnweb_core::{CapId, Op, Plan, RpcError, Source};
3use serde_json::{Map, Value};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub struct PlanRunner {
10 #[allow(dead_code)]
12 config: ServerConfig,
13}
14
15impl PlanRunner {
16 pub fn new(config: ServerConfig) -> Self {
18 Self { config }
19 }
20
21 pub async fn execute(
23 &self,
24 plan: &Plan,
25 params: Option<Value>,
26 captures: &HashMap<u32, Arc<RwLock<dyn RpcTarget>>>,
27 ) -> Result<Value, RpcError> {
28 plan.validate()
30 .map_err(|e| RpcError::bad_request(format!("Invalid plan: {}", e)))?;
31
32 let mut results: HashMap<u32, Value> = HashMap::new();
34
35 for op in &plan.ops {
37 let result = self
38 .execute_op(op, params.as_ref(), captures, &results)
39 .await?;
40
41 match op {
42 Op::Call { call } => {
43 results.insert(call.result, result);
44 }
45 Op::Object { object } => {
46 results.insert(object.result, result);
47 }
48 Op::Array { array } => {
49 results.insert(array.result, result);
50 }
51 }
52 }
53
54 self.resolve_source(&plan.result, params.as_ref(), captures, &results)
56 }
57
58 async fn execute_op(
60 &self,
61 op: &Op,
62 params: Option<&Value>,
63 captures: &HashMap<u32, Arc<RwLock<dyn RpcTarget>>>,
64 results: &HashMap<u32, Value>,
65 ) -> Result<Value, RpcError> {
66 match op {
67 Op::Call { call } => {
68 let target_value = self.resolve_source(&call.target, params, captures, results)?;
69
70 let cap_id = if let Value::Object(obj) = &target_value {
72 if let Some(Value::Number(n)) = obj.get("cap") {
73 CapId::new(
74 n.as_u64()
75 .ok_or_else(|| RpcError::bad_request("Invalid capability ID"))?,
76 )
77 } else {
78 return Err(RpcError::bad_request("Target is not a capability"));
79 }
80 } else {
81 return Err(RpcError::bad_request("Target is not a capability"));
82 };
83
84 let capability = captures.get(&(cap_id.as_u64() as u32)).ok_or_else(|| {
86 RpcError::not_found(format!("Capability not found: {:?}", cap_id))
87 })?;
88
89 let mut resolved_args = Vec::new();
91 for arg_source in &call.args {
92 resolved_args.push(self.resolve_source(arg_source, params, captures, results)?);
93 }
94
95 let target = capability.read().await;
97 target.call(&call.member, resolved_args).await
98 }
99
100 Op::Object { object } => {
101 let mut obj = Map::new();
102 for (key, source) in &object.fields {
103 let value = self.resolve_source(source, params, captures, results)?;
104 obj.insert(key.clone(), value);
105 }
106 Ok(Value::Object(obj))
107 }
108
109 Op::Array { array } => {
110 let mut arr = Vec::new();
111 for source in &array.items {
112 arr.push(self.resolve_source(source, params, captures, results)?);
113 }
114 Ok(Value::Array(arr))
115 }
116 }
117 }
118
119 fn resolve_source(
121 &self,
122 source: &Source,
123 params: Option<&Value>,
124 captures: &HashMap<u32, Arc<RwLock<dyn RpcTarget>>>,
125 results: &HashMap<u32, Value>,
126 ) -> Result<Value, RpcError> {
127 match source {
128 Source::Capture { capture } => {
129 captures
131 .get(&capture.index)
132 .map(|_| {
133 serde_json::json!({ "cap": capture.index })
135 })
136 .ok_or_else(|| {
137 RpcError::not_found(format!("Capture {} not found", capture.index))
138 })
139 }
140
141 Source::Result { result } => results
142 .get(&result.index)
143 .cloned()
144 .ok_or_else(|| RpcError::not_found(format!("Result {} not found", result.index))),
145
146 Source::Param { param } => {
147 let params =
148 params.ok_or_else(|| RpcError::bad_request("No parameters provided"))?;
149
150 let mut current = params;
152 for segment in ¶m.path {
153 current = current.get(segment).ok_or_else(|| {
154 RpcError::bad_request(format!(
155 "Parameter path not found: {}",
156 param.path.join(".")
157 ))
158 })?;
159 }
160 Ok(current.clone())
161 }
162
163 Source::ByValue { by_value } => Ok(by_value.value.clone()),
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use async_trait::async_trait;
172 use capnweb_core::{Op, Plan, Source};
173
174 struct TestTarget {
176 name: String,
177 }
178
179 #[async_trait]
180 impl RpcTarget for TestTarget {
181 async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, RpcError> {
182 match method {
183 "getName" => Ok(Value::String(self.name.clone())),
184 "add" => {
185 if args.len() != 2 {
186 return Err(RpcError::bad_request("add requires 2 arguments"));
187 }
188 let a = args[0]
189 .as_f64()
190 .ok_or_else(|| RpcError::bad_request("First argument must be a number"))?;
191 let b = args[1]
192 .as_f64()
193 .ok_or_else(|| RpcError::bad_request("Second argument must be a number"))?;
194 Ok(serde_json::json!(a + b))
195 }
196 "echo" => Ok(args.first().cloned().unwrap_or(Value::Null)),
197 _ => Err(RpcError::not_found(format!("Method not found: {}", method))),
198 }
199 }
200 }
201
202 #[tokio::test]
203 async fn test_execute_simple_call() {
204 let runner = PlanRunner::new(ServerConfig::default());
205
206 let plan = Plan::new(
207 vec![CapId::new(1)],
208 vec![Op::call(
209 Source::capture(0),
210 "getName".to_string(),
211 vec![],
212 0,
213 )],
214 Source::result(0),
215 );
216
217 let mut captures = HashMap::new();
218 captures.insert(
219 0,
220 Arc::new(RwLock::new(TestTarget {
221 name: "test".to_string(),
222 })) as Arc<RwLock<dyn RpcTarget>>,
223 );
224
225 let result = runner.execute(&plan, None, &captures).await.unwrap();
226 assert_eq!(result, Value::String("test".to_string()));
227 }
228
229 #[tokio::test]
230 async fn test_execute_with_params() {
231 let runner = PlanRunner::new(ServerConfig::default());
232
233 let plan = Plan::new(
234 vec![CapId::new(1)],
235 vec![Op::call(
236 Source::capture(0),
237 "add".to_string(),
238 vec![
239 Source::param(vec!["a".to_string()]),
240 Source::param(vec!["b".to_string()]),
241 ],
242 0,
243 )],
244 Source::result(0),
245 );
246
247 let mut captures = HashMap::new();
248 captures.insert(
249 0,
250 Arc::new(RwLock::new(TestTarget {
251 name: "calculator".to_string(),
252 })) as Arc<RwLock<dyn RpcTarget>>,
253 );
254
255 let params = serde_json::json!({
256 "a": 5,
257 "b": 3
258 });
259
260 let result = runner
261 .execute(&plan, Some(params), &captures)
262 .await
263 .unwrap();
264 assert_eq!(result, serde_json::json!(8.0));
265 }
266
267 #[tokio::test]
268 async fn test_execute_object_construction() {
269 let runner = PlanRunner::new(ServerConfig::default());
270
271 let plan = Plan::new(
272 vec![],
273 vec![Op::object(
274 vec![
275 (
276 "name".to_string(),
277 Source::by_value(Value::String("test".to_string())),
278 ),
279 ("value".to_string(), Source::by_value(serde_json::json!(42))),
280 ]
281 .into_iter()
282 .collect(),
283 0,
284 )],
285 Source::result(0),
286 );
287
288 let captures = HashMap::new();
289 let result = runner.execute(&plan, None, &captures).await.unwrap();
290
291 assert_eq!(
292 result,
293 serde_json::json!({
294 "name": "test",
295 "value": 42
296 })
297 );
298 }
299
300 #[tokio::test]
301 async fn test_execute_array_construction() {
302 let runner = PlanRunner::new(ServerConfig::default());
303
304 let plan = Plan::new(
305 vec![],
306 vec![Op::array(
307 vec![
308 Source::by_value(serde_json::json!(1)),
309 Source::by_value(serde_json::json!(2)),
310 Source::by_value(serde_json::json!(3)),
311 ],
312 0,
313 )],
314 Source::result(0),
315 );
316
317 let captures = HashMap::new();
318 let result = runner.execute(&plan, None, &captures).await.unwrap();
319
320 assert_eq!(result, serde_json::json!([1, 2, 3]));
321 }
322
323 #[tokio::test]
324 async fn test_execute_chained_operations() {
325 let runner = PlanRunner::new(ServerConfig::default());
326
327 let plan = Plan::new(
328 vec![CapId::new(1)],
329 vec![
330 Op::call(
331 Source::capture(0),
332 "echo".to_string(),
333 vec![Source::by_value(Value::String("hello".to_string()))],
334 0,
335 ),
336 Op::object(
337 vec![
338 ("message".to_string(), Source::result(0)),
339 (
340 "timestamp".to_string(),
341 Source::by_value(serde_json::json!(12345)),
342 ),
343 ]
344 .into_iter()
345 .collect(),
346 1,
347 ),
348 ],
349 Source::result(1),
350 );
351
352 let mut captures = HashMap::new();
353 captures.insert(
354 0,
355 Arc::new(RwLock::new(TestTarget {
356 name: "echo".to_string(),
357 })) as Arc<RwLock<dyn RpcTarget>>,
358 );
359
360 let result = runner.execute(&plan, None, &captures).await.unwrap();
361
362 assert_eq!(
363 result,
364 serde_json::json!({
365 "message": "hello",
366 "timestamp": 12345
367 })
368 );
369 }
370
371 #[tokio::test]
372 async fn test_invalid_plan() {
373 let runner = PlanRunner::new(ServerConfig::default());
374
375 let plan = Plan::new(
377 vec![],
378 vec![Op::call(
379 Source::result(1), "test".to_string(),
381 vec![],
382 0,
383 )],
384 Source::result(0),
385 );
386
387 let captures = HashMap::new();
388 let result = runner.execute(&plan, None, &captures).await;
389
390 assert!(result.is_err());
391 }
392}