capnweb_server/
server.rs

1use crate::CapTable;
2use async_trait::async_trait;
3use axum::{
4    body::Bytes,
5    extract::State,
6    http::{HeaderMap, StatusCode},
7    response::IntoResponse,
8    routing::{get, post},
9    Json, Router,
10};
11use capnweb_core::{
12    parse_wire_batch, serialize_wire_batch, CapId, PropertyKey, RpcError, WireExpression,
13    WireMessage,
14};
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::net::TcpListener;
19
20// Using wire protocol helper functions from the server_wire_handler module
21use crate::server_wire_handler::{value_to_wire_expr, wire_expr_to_values_with_evaluation};
22
23#[async_trait]
24pub trait RpcTarget: Send + Sync {
25    async fn call(&self, member: &str, args: Vec<Value>) -> Result<Value, RpcError>;
26}
27
28#[derive(Clone)]
29pub struct ServerConfig {
30    pub port: u16,
31    pub host: String,
32    pub max_batch_size: usize,
33}
34
35impl Default for ServerConfig {
36    fn default() -> Self {
37        ServerConfig {
38            port: 8080,
39            host: "127.0.0.1".to_string(),
40            max_batch_size: 100,
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct Server {
47    config: ServerConfig,
48    cap_table: Arc<CapTable>,
49}
50
51impl Server {
52    pub fn new(config: ServerConfig) -> Self {
53        Server {
54            config,
55            cap_table: Arc::new(CapTable::new()),
56        }
57    }
58
59    pub fn cap_table(&self) -> &Arc<CapTable> {
60        &self.cap_table
61    }
62
63    pub fn register_capability(&self, id: CapId, target: Arc<dyn RpcTarget>) {
64        self.cap_table.insert(id, target);
65    }
66
67    pub async fn run(self) -> Result<(), std::io::Error> {
68        let mut app = Router::new()
69            .route("/rpc/batch", post(handle_batch))
70            .route("/health", get(handle_health));
71
72        // Add WebSocket support if the feature is enabled
73        #[cfg(feature = "all-transports")]
74        {
75            // Use the new wire protocol WebSocket handler
76            app = app.route("/rpc/ws", get(crate::ws_wire::websocket_wire_handler));
77        }
78
79        let app = app.with_state(Arc::new(self.clone()));
80
81        let addr = format!("{}:{}", self.config.host, self.config.port);
82        let listener = TcpListener::bind(&addr).await?;
83
84        println!("Server listening on {}", addr);
85        println!("  HTTP Batch endpoint: http://{}/rpc/batch", addr);
86
87        #[cfg(feature = "all-transports")]
88        println!("  WebSocket endpoint: ws://{}/rpc/ws", addr);
89
90        println!("  Health endpoint: http://{}/health", addr);
91
92        axum::serve(listener, app).await?;
93
94        Ok(())
95    }
96
97    // Remove legacy message processing - we only support wire protocol now
98}
99
100// Session state to track push/pull flow per HTTP batch request
101struct BatchSession {
102    next_import_id: i64,
103    // Map import IDs to their pushed expressions
104    pushed_expressions: HashMap<i64, WireExpression>,
105    // Map import IDs to their computed results
106    results: HashMap<i64, WireExpression>,
107}
108
109async fn handle_batch(
110    State(server): State<Arc<Server>>,
111    headers: HeaderMap,
112    body: Bytes,
113) -> impl IntoResponse {
114    // Enhanced tracing for debugging
115    tracing::debug!("=== INCOMING CAP'N WEB WIRE PROTOCOL REQUEST ===");
116    tracing::debug!("Headers: {:?}", headers);
117    tracing::debug!("Body size: {} bytes", body.len());
118
119    // Convert body to string
120    let body_str = String::from_utf8_lossy(&body);
121    tracing::debug!(
122        "Raw body (first 500 chars): {}",
123        &body_str.chars().take(500).collect::<String>()
124    );
125
126    // Create session state for this batch request
127    let mut session = BatchSession {
128        next_import_id: 1, // Start from 1 per protocol spec
129        pushed_expressions: HashMap::new(),
130        results: HashMap::new(),
131    };
132
133    // Parse the official Cap'n Web wire protocol (newline-delimited arrays ONLY)
134    match parse_wire_batch(&body_str) {
135        Ok(wire_messages) => {
136            tracing::info!(
137                "✅ Successfully parsed {} wire messages",
138                wire_messages.len()
139            );
140            tracing::trace!("Messages: {:#?}", wire_messages);
141
142            let mut responses = Vec::new();
143
144            for (i, msg) in wire_messages.iter().enumerate() {
145                tracing::debug!(
146                    "Processing wire message {}/{}: {:?}",
147                    i + 1,
148                    wire_messages.len(),
149                    msg
150                );
151
152                match msg {
153                    WireMessage::Push(expr) => {
154                        tracing::trace!("  PUSH expression details: {:#?}", expr);
155
156                        // Assign the next import ID to this push
157                        let assigned_import_id = session.next_import_id;
158                        session.next_import_id += 1;
159
160                        tracing::info!("  PUSH assigned import ID: {}", assigned_import_id);
161
162                        // Store the expression for later evaluation
163                        session
164                            .pushed_expressions
165                            .insert(assigned_import_id, expr.clone());
166
167                        // Evaluate the expression immediately and store result
168                        match expr {
169                            WireExpression::Pipeline {
170                                import_id,
171                                property_path,
172                                args,
173                            } => {
174                                tracing::info!(
175                                    "  Pipeline call: import_id={}, path={:?}",
176                                    import_id,
177                                    property_path
178                                );
179                                tracing::info!("  Pipeline args raw wire expression: {:#?}", args);
180
181                                // Validate and map import_id to capability
182                                // Official protocol: import_id 0 is the main capability/bootstrap interface
183                                // All import_ids map directly to their corresponding capability IDs
184
185                                // Check for negative import_id values
186                                if *import_id < 0 {
187                                    tracing::error!(
188                                        "Invalid negative import_id: {}. Import IDs must be non-negative.",
189                                        import_id
190                                    );
191                                    session.results.insert(
192                                        assigned_import_id,
193                                        WireExpression::Error {
194                                            error_type: "bad_request".to_string(),
195                                            message: format!(
196                                                "Invalid import_id: {}. Import IDs must be non-negative",
197                                                import_id
198                                            ),
199                                            stack: None,
200                                        },
201                                    );
202                                    continue;
203                                }
204
205                                // Safe to convert to u64 now that we've validated it's non-negative
206                                let cap_id = CapId::new(*import_id as u64);
207
208                                tracing::debug!(
209                                    "  Mapped import_id {} to capability {}",
210                                    import_id,
211                                    cap_id
212                                );
213
214                                if let Some(capability) = server.cap_table.lookup(&cap_id) {
215                                    if let Some(path) = property_path {
216                                        if let Some(PropertyKey::String(method)) = path.first() {
217                                            tracing::info!(
218                                                "  Calling method '{}' on capability {}",
219                                                method,
220                                                cap_id
221                                            );
222
223                                            // Convert args from WireExpression to Value (with pipeline evaluation)
224                                            let json_args = if let Some(args_expr) = args {
225                                                wire_expr_to_values_with_evaluation(
226                                                    args_expr,
227                                                    &session.results,
228                                                )
229                                            } else {
230                                                vec![]
231                                            };
232
233                                            tracing::info!(
234                                                "  Method args (converted): {:?}",
235                                                json_args
236                                            );
237
238                                            match capability.call(method, json_args).await {
239                                                Ok(result) => {
240                                                    tracing::info!(
241                                                        "  ✅ Method '{}' succeeded",
242                                                        method
243                                                    );
244                                                    tracing::trace!("  Result: {:?}", result);
245
246                                                    // Store the result for this import ID
247                                                    session.results.insert(
248                                                        assigned_import_id,
249                                                        value_to_wire_expr(result),
250                                                    );
251                                                }
252                                                Err(err) => {
253                                                    tracing::error!(
254                                                        "  ❌ Method '{}' failed: {:?}",
255                                                        method,
256                                                        err
257                                                    );
258
259                                                    // Store the error for this import ID
260                                                    session.results.insert(
261                                                        assigned_import_id,
262                                                        WireExpression::Error {
263                                                            error_type: err.code.to_string(),
264                                                            message: err.message.clone(),
265                                                            stack: None,
266                                                        },
267                                                    );
268                                                }
269                                            }
270                                        } else {
271                                            tracing::warn!(
272                                                "  No method name in property path: {:?}",
273                                                path
274                                            );
275                                            session.results.insert(
276                                                assigned_import_id,
277                                                WireExpression::Error {
278                                                    error_type: "bad_request".to_string(),
279                                                    message: "No method specified".to_string(),
280                                                    stack: None,
281                                                },
282                                            );
283                                        }
284                                    } else {
285                                        tracing::warn!("  No property path in pipeline expression");
286                                        session.results.insert(
287                                            assigned_import_id,
288                                            WireExpression::Error {
289                                                error_type: "bad_request".to_string(),
290                                                message: "No property path in pipeline".to_string(),
291                                                stack: None,
292                                            },
293                                        );
294                                    }
295                                } else {
296                                    tracing::error!(
297                                        "  Capability {} not found in cap_table",
298                                        cap_id
299                                    );
300                                    session.results.insert(
301                                        assigned_import_id,
302                                        WireExpression::Error {
303                                            error_type: "not_found".to_string(),
304                                            message: format!("Capability {} not found", import_id),
305                                            stack: None,
306                                        },
307                                    );
308                                }
309                            }
310                            WireExpression::Call {
311                                cap_id,
312                                property_path,
313                                args,
314                            } => {
315                                tracing::info!(
316                                    "  Call: cap_id={}, path={:?}",
317                                    cap_id,
318                                    property_path
319                                );
320                                tracing::trace!("  Call args: {:#?}", args);
321
322                                let cap_id = CapId::new(*cap_id as u64);
323
324                                if let Some(capability) = server.cap_table.lookup(&cap_id) {
325                                    if let Some(PropertyKey::String(method)) = property_path.first()
326                                    {
327                                        tracing::info!(
328                                            "  Calling method '{}' on capability {}",
329                                            method,
330                                            cap_id
331                                        );
332
333                                        // Convert args from WireExpression to Value (with pipeline evaluation)
334                                        let json_args = wire_expr_to_values_with_evaluation(
335                                            args,
336                                            &session.results,
337                                        );
338
339                                        tracing::trace!(
340                                            "  Method args (converted): {:?}",
341                                            json_args
342                                        );
343
344                                        match capability.call(method, json_args).await {
345                                            Ok(result) => {
346                                                tracing::info!(
347                                                    "  ✅ Method '{}' succeeded",
348                                                    method
349                                                );
350                                                tracing::trace!("  Result: {:?}", result);
351
352                                                // Store the result for this import ID
353                                                session.results.insert(
354                                                    assigned_import_id,
355                                                    value_to_wire_expr(result),
356                                                );
357                                            }
358                                            Err(err) => {
359                                                tracing::error!(
360                                                    "  ❌ Method '{}' failed: {:?}",
361                                                    method,
362                                                    err
363                                                );
364
365                                                // Store the error for this import ID
366                                                session.results.insert(
367                                                    assigned_import_id,
368                                                    WireExpression::Error {
369                                                        error_type: err.code.to_string(),
370                                                        message: err.message.clone(),
371                                                        stack: None,
372                                                    },
373                                                );
374                                            }
375                                        }
376                                    } else {
377                                        tracing::warn!(
378                                            "  No method name in property path: {:?}",
379                                            property_path
380                                        );
381                                        session.results.insert(
382                                            assigned_import_id,
383                                            WireExpression::Error {
384                                                error_type: "bad_request".to_string(),
385                                                message: "No method specified".to_string(),
386                                                stack: None,
387                                            },
388                                        );
389                                    }
390                                } else {
391                                    tracing::error!(
392                                        "  Capability {} not found in cap_table",
393                                        cap_id
394                                    );
395                                    session.results.insert(
396                                        assigned_import_id,
397                                        WireExpression::Error {
398                                            error_type: "not_found".to_string(),
399                                            message: format!("Capability {} not found", cap_id),
400                                            stack: None,
401                                        },
402                                    );
403                                }
404                            }
405                            _ => {
406                                tracing::warn!("  Push expression is not a pipeline or call (unsupported): {:?}", expr);
407                                session.results.insert(
408                                    assigned_import_id,
409                                    WireExpression::Error {
410                                        error_type: "not_implemented".to_string(),
411                                        message: "Only pipeline and call expressions are supported"
412                                            .to_string(),
413                                        stack: None,
414                                    },
415                                );
416                            }
417                        }
418                    }
419
420                    WireMessage::Pull(import_id) => {
421                        tracing::debug!("  PULL for import_id: {}", import_id);
422
423                        // Look up the result for this import ID
424                        if let Some(result) = session.results.get(import_id) {
425                            tracing::info!("  Found result for import ID {}", import_id);
426
427                            // Check if it's an error
428                            if let WireExpression::Error { .. } = result {
429                                // Use the import ID as the export ID (per protocol spec)
430                                responses.push(WireMessage::Reject(
431                                    *import_id, // Use import ID as export ID
432                                    result.clone(),
433                                ));
434                            } else {
435                                // Use the import ID as the export ID (per protocol spec)
436                                responses.push(WireMessage::Resolve(
437                                    *import_id, // Use import ID as export ID
438                                    result.clone(),
439                                ));
440                            }
441                        } else {
442                            tracing::warn!("  No result found for import ID {}", import_id);
443                            responses.push(WireMessage::Reject(
444                                *import_id,
445                                WireExpression::Error {
446                                    error_type: "not_found".to_string(),
447                                    message: format!("No result for import ID {}", import_id),
448                                    stack: None,
449                                },
450                            ));
451                        }
452                    }
453
454                    WireMessage::Release(ids) => {
455                        tracing::info!("  RELEASE capabilities: {:?}", ids);
456                        // Handle capability disposal
457                        for id in ids {
458                            let cap_id = CapId::new(*id as u64);
459                            server.cap_table.remove(&cap_id);
460                        }
461                    }
462
463                    other => {
464                        tracing::warn!(
465                            "  Unhandled message type (not yet implemented): {:?}",
466                            other
467                        );
468                    }
469                }
470            }
471
472            // Serialize responses using official wire protocol (newline-delimited)
473            let response_body = serialize_wire_batch(&responses);
474            tracing::info!("📤 Sending {} response messages", responses.len());
475            tracing::debug!(
476                "Response body (first 500 chars): {}",
477                &response_body.chars().take(500).collect::<String>()
478            );
479
480            (
481                StatusCode::OK,
482                [("content-type", "text/plain")],
483                response_body,
484            )
485        }
486        Err(e) => {
487            tracing::error!("Failed to parse wire protocol: {}", e);
488            tracing::debug!(
489                "Invalid input was: {}",
490                &body_str.chars().take(1000).collect::<String>()
491            );
492            let error_response = WireMessage::Reject(
493                -1,
494                WireExpression::Error {
495                    error_type: "bad_request".to_string(),
496                    message: format!("Invalid wire protocol: {}", e),
497                    stack: None,
498                },
499            );
500            let response = serialize_wire_batch(&[error_response]);
501            (
502                StatusCode::BAD_REQUEST,
503                [("content-type", "text/plain")],
504                response,
505            )
506        }
507    }
508}
509
510async fn handle_health(State(server): State<Arc<Server>>) -> impl IntoResponse {
511    let capability_count = server.cap_table.len();
512
513    let mut endpoints = serde_json::json!({
514        "batch": "/rpc/batch",
515        "health": "/health"
516    });
517
518    // Add WebSocket endpoint if available
519    #[cfg(feature = "all-transports")]
520    {
521        endpoints["websocket"] = serde_json::json!("/rpc/ws");
522    }
523
524    let health_response = serde_json::json!({
525        "status": "healthy",
526        "server": "capnweb-rust",
527        "version": env!("CARGO_PKG_VERSION"),
528        "capabilities": capability_count,
529        "max_batch_size": server.config.max_batch_size,
530        "features": {
531            "websocket": cfg!(feature = "all-transports"),
532            "h3": cfg!(feature = "h3-server")
533        },
534        "endpoints": endpoints
535    });
536
537    (StatusCode::OK, Json(health_response))
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use serde_json::json;
544
545    struct TestTarget;
546
547    #[async_trait]
548    impl RpcTarget for TestTarget {
549        async fn call(&self, member: &str, args: Vec<Value>) -> Result<Value, RpcError> {
550            match member {
551                "echo" => Ok(args.first().cloned().unwrap_or(Value::Null)),
552                "add" => {
553                    if args.len() != 2 {
554                        return Err(RpcError::bad_request("add requires 2 arguments"));
555                    }
556                    let a = args[0]
557                        .as_f64()
558                        .ok_or_else(|| RpcError::bad_request("First arg must be number"))?;
559                    let b = args[1]
560                        .as_f64()
561                        .ok_or_else(|| RpcError::bad_request("Second arg must be number"))?;
562                    Ok(json!(a + b))
563                }
564                _ => Err(RpcError::not_found(format!(
565                    "Method '{}' not found",
566                    member
567                ))),
568            }
569        }
570    }
571
572    #[tokio::test]
573    async fn test_server_creation() {
574        let config = ServerConfig::default();
575        let server = Server::new(config);
576        assert_eq!(server.config.port, 8080);
577    }
578
579    #[tokio::test]
580    async fn test_register_capability() {
581        let server = Server::new(ServerConfig::default());
582        let cap_id = CapId::new(42);
583        let target = Arc::new(TestTarget);
584
585        server.register_capability(cap_id, target);
586        assert!(server.cap_table.lookup(&cap_id).is_some());
587    }
588
589    #[tokio::test]
590    async fn test_wire_protocol_push() {
591        let server = Server::new(ServerConfig::default());
592        let cap_id = CapId::new(1);
593        server.register_capability(cap_id, Arc::new(TestTarget));
594
595        // Simulate wire protocol push message for "echo" method
596        // This would be the wire format:
597        // let _wire_messages = vec![WireMessage::Push(WireExpression::Pipeline {
598        //     import_id: 1, // Map to CapId(1)
599        //     property_path: Some(vec![PropertyKey::String("echo".to_string())]),
600        //     args: Some(Box::new(WireExpression::Array(vec![
601        //         WireExpression::String("hello".to_string()),
602        //     ]))),
603        // })];
604
605        // Process directly (simulating what handle_batch would do)
606        let capability = server.cap_table.lookup(&cap_id).unwrap();
607        let result = capability.call("echo", vec![json!("hello")]).await.unwrap();
608        assert_eq!(result, json!("hello"));
609    }
610
611    #[tokio::test]
612    async fn test_wire_protocol_release() {
613        let server = Server::new(ServerConfig::default());
614        let cap_id = CapId::new(1);
615        server.register_capability(cap_id, Arc::new(TestTarget));
616
617        assert!(server.cap_table.lookup(&cap_id).is_some());
618
619        // Simulate wire protocol release message
620        server.cap_table.remove(&cap_id);
621
622        assert!(server.cap_table.lookup(&cap_id).is_none());
623    }
624
625    #[tokio::test]
626    async fn test_wire_protocol_unknown_capability() {
627        let server = Server::new(ServerConfig::default());
628        let cap_id = CapId::new(999);
629
630        // Try to look up non-existent capability
631        assert!(server.cap_table.lookup(&cap_id).is_none());
632    }
633}