capnweb_server/server.rs
1use crate::CapTable;
2use async_trait::async_trait;
3use axum::{
4 body::Bytes,
5 extract::State,
6 http::{HeaderMap, StatusCode},
7 response::IntoResponse,
8 routing::{get, post},
9 Json, Router,
10};
11use capnweb_core::{
12 parse_wire_batch, serialize_wire_batch, CapId, PropertyKey, RpcError, WireExpression,
13 WireMessage,
14};
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::net::TcpListener;
19
20// Using wire protocol helper functions from the server_wire_handler module
21use crate::server_wire_handler::{value_to_wire_expr, wire_expr_to_values_with_evaluation};
22
23#[async_trait]
24pub trait RpcTarget: Send + Sync {
25 async fn call(&self, member: &str, args: Vec<Value>) -> Result<Value, RpcError>;
26}
27
28#[derive(Clone)]
29pub struct ServerConfig {
30 pub port: u16,
31 pub host: String,
32 pub max_batch_size: usize,
33}
34
35impl Default for ServerConfig {
36 fn default() -> Self {
37 ServerConfig {
38 port: 8080,
39 host: "127.0.0.1".to_string(),
40 max_batch_size: 100,
41 }
42 }
43}
44
45#[derive(Clone)]
46pub struct Server {
47 config: ServerConfig,
48 cap_table: Arc<CapTable>,
49}
50
51impl Server {
52 pub fn new(config: ServerConfig) -> Self {
53 Server {
54 config,
55 cap_table: Arc::new(CapTable::new()),
56 }
57 }
58
59 pub fn cap_table(&self) -> &Arc<CapTable> {
60 &self.cap_table
61 }
62
63 pub fn register_capability(&self, id: CapId, target: Arc<dyn RpcTarget>) {
64 self.cap_table.insert(id, target);
65 }
66
67 pub async fn run(self) -> Result<(), std::io::Error> {
68 let mut app = Router::new()
69 .route("/rpc/batch", post(handle_batch))
70 .route("/health", get(handle_health));
71
72 // Add WebSocket support if the feature is enabled
73 #[cfg(feature = "all-transports")]
74 {
75 // Use the new wire protocol WebSocket handler
76 app = app.route("/rpc/ws", get(crate::ws_wire::websocket_wire_handler));
77 }
78
79 let app = app.with_state(Arc::new(self.clone()));
80
81 let addr = format!("{}:{}", self.config.host, self.config.port);
82 let listener = TcpListener::bind(&addr).await?;
83
84 println!("Server listening on {}", addr);
85 println!(" HTTP Batch endpoint: http://{}/rpc/batch", addr);
86
87 #[cfg(feature = "all-transports")]
88 println!(" WebSocket endpoint: ws://{}/rpc/ws", addr);
89
90 println!(" Health endpoint: http://{}/health", addr);
91
92 axum::serve(listener, app).await?;
93
94 Ok(())
95 }
96
97 // Remove legacy message processing - we only support wire protocol now
98}
99
100// Session state to track push/pull flow per HTTP batch request
101struct BatchSession {
102 next_import_id: i64,
103 // Map import IDs to their pushed expressions
104 pushed_expressions: HashMap<i64, WireExpression>,
105 // Map import IDs to their computed results
106 results: HashMap<i64, WireExpression>,
107}
108
109async fn handle_batch(
110 State(server): State<Arc<Server>>,
111 headers: HeaderMap,
112 body: Bytes,
113) -> impl IntoResponse {
114 // Enhanced tracing for debugging
115 tracing::debug!("=== INCOMING CAP'N WEB WIRE PROTOCOL REQUEST ===");
116 tracing::debug!("Headers: {:?}", headers);
117 tracing::debug!("Body size: {} bytes", body.len());
118
119 // Convert body to string
120 let body_str = String::from_utf8_lossy(&body);
121 tracing::debug!(
122 "Raw body (first 500 chars): {}",
123 &body_str.chars().take(500).collect::<String>()
124 );
125
126 // Create session state for this batch request
127 let mut session = BatchSession {
128 next_import_id: 1, // Start from 1 per protocol spec
129 pushed_expressions: HashMap::new(),
130 results: HashMap::new(),
131 };
132
133 // Parse the official Cap'n Web wire protocol (newline-delimited arrays ONLY)
134 match parse_wire_batch(&body_str) {
135 Ok(wire_messages) => {
136 tracing::info!(
137 "✅ Successfully parsed {} wire messages",
138 wire_messages.len()
139 );
140 tracing::trace!("Messages: {:#?}", wire_messages);
141
142 let mut responses = Vec::new();
143
144 for (i, msg) in wire_messages.iter().enumerate() {
145 tracing::debug!(
146 "Processing wire message {}/{}: {:?}",
147 i + 1,
148 wire_messages.len(),
149 msg
150 );
151
152 match msg {
153 WireMessage::Push(expr) => {
154 tracing::trace!(" PUSH expression details: {:#?}", expr);
155
156 // Assign the next import ID to this push
157 let assigned_import_id = session.next_import_id;
158 session.next_import_id += 1;
159
160 tracing::info!(" PUSH assigned import ID: {}", assigned_import_id);
161
162 // Store the expression for later evaluation
163 session
164 .pushed_expressions
165 .insert(assigned_import_id, expr.clone());
166
167 // Evaluate the expression immediately and store result
168 match expr {
169 WireExpression::Pipeline {
170 import_id,
171 property_path,
172 args,
173 } => {
174 tracing::info!(
175 " Pipeline call: import_id={}, path={:?}",
176 import_id,
177 property_path
178 );
179 tracing::info!(" Pipeline args raw wire expression: {:#?}", args);
180
181 // Validate and map import_id to capability
182 // Official protocol: import_id 0 is the main capability/bootstrap interface
183 // All import_ids map directly to their corresponding capability IDs
184
185 // Check for negative import_id values
186 if *import_id < 0 {
187 tracing::error!(
188 "Invalid negative import_id: {}. Import IDs must be non-negative.",
189 import_id
190 );
191 session.results.insert(
192 assigned_import_id,
193 WireExpression::Error {
194 error_type: "bad_request".to_string(),
195 message: format!(
196 "Invalid import_id: {}. Import IDs must be non-negative",
197 import_id
198 ),
199 stack: None,
200 },
201 );
202 continue;
203 }
204
205 // Safe to convert to u64 now that we've validated it's non-negative
206 let cap_id = CapId::new(*import_id as u64);
207
208 tracing::debug!(
209 " Mapped import_id {} to capability {}",
210 import_id,
211 cap_id
212 );
213
214 if let Some(capability) = server.cap_table.lookup(&cap_id) {
215 if let Some(path) = property_path {
216 if let Some(PropertyKey::String(method)) = path.first() {
217 tracing::info!(
218 " Calling method '{}' on capability {}",
219 method,
220 cap_id
221 );
222
223 // Convert args from WireExpression to Value (with pipeline evaluation)
224 let json_args = if let Some(args_expr) = args {
225 wire_expr_to_values_with_evaluation(
226 args_expr,
227 &session.results,
228 )
229 } else {
230 vec![]
231 };
232
233 tracing::info!(
234 " Method args (converted): {:?}",
235 json_args
236 );
237
238 match capability.call(method, json_args).await {
239 Ok(result) => {
240 tracing::info!(
241 " ✅ Method '{}' succeeded",
242 method
243 );
244 tracing::trace!(" Result: {:?}", result);
245
246 // Store the result for this import ID
247 session.results.insert(
248 assigned_import_id,
249 value_to_wire_expr(result),
250 );
251 }
252 Err(err) => {
253 tracing::error!(
254 " ❌ Method '{}' failed: {:?}",
255 method,
256 err
257 );
258
259 // Store the error for this import ID
260 session.results.insert(
261 assigned_import_id,
262 WireExpression::Error {
263 error_type: err.code.to_string(),
264 message: err.message.clone(),
265 stack: None,
266 },
267 );
268 }
269 }
270 } else {
271 tracing::warn!(
272 " No method name in property path: {:?}",
273 path
274 );
275 session.results.insert(
276 assigned_import_id,
277 WireExpression::Error {
278 error_type: "bad_request".to_string(),
279 message: "No method specified".to_string(),
280 stack: None,
281 },
282 );
283 }
284 } else {
285 tracing::warn!(" No property path in pipeline expression");
286 session.results.insert(
287 assigned_import_id,
288 WireExpression::Error {
289 error_type: "bad_request".to_string(),
290 message: "No property path in pipeline".to_string(),
291 stack: None,
292 },
293 );
294 }
295 } else {
296 tracing::error!(
297 " Capability {} not found in cap_table",
298 cap_id
299 );
300 session.results.insert(
301 assigned_import_id,
302 WireExpression::Error {
303 error_type: "not_found".to_string(),
304 message: format!("Capability {} not found", import_id),
305 stack: None,
306 },
307 );
308 }
309 }
310 WireExpression::Call {
311 cap_id,
312 property_path,
313 args,
314 } => {
315 tracing::info!(
316 " Call: cap_id={}, path={:?}",
317 cap_id,
318 property_path
319 );
320 tracing::trace!(" Call args: {:#?}", args);
321
322 let cap_id = CapId::new(*cap_id as u64);
323
324 if let Some(capability) = server.cap_table.lookup(&cap_id) {
325 if let Some(PropertyKey::String(method)) = property_path.first()
326 {
327 tracing::info!(
328 " Calling method '{}' on capability {}",
329 method,
330 cap_id
331 );
332
333 // Convert args from WireExpression to Value (with pipeline evaluation)
334 let json_args = wire_expr_to_values_with_evaluation(
335 args,
336 &session.results,
337 );
338
339 tracing::trace!(
340 " Method args (converted): {:?}",
341 json_args
342 );
343
344 match capability.call(method, json_args).await {
345 Ok(result) => {
346 tracing::info!(
347 " ✅ Method '{}' succeeded",
348 method
349 );
350 tracing::trace!(" Result: {:?}", result);
351
352 // Store the result for this import ID
353 session.results.insert(
354 assigned_import_id,
355 value_to_wire_expr(result),
356 );
357 }
358 Err(err) => {
359 tracing::error!(
360 " ❌ Method '{}' failed: {:?}",
361 method,
362 err
363 );
364
365 // Store the error for this import ID
366 session.results.insert(
367 assigned_import_id,
368 WireExpression::Error {
369 error_type: err.code.to_string(),
370 message: err.message.clone(),
371 stack: None,
372 },
373 );
374 }
375 }
376 } else {
377 tracing::warn!(
378 " No method name in property path: {:?}",
379 property_path
380 );
381 session.results.insert(
382 assigned_import_id,
383 WireExpression::Error {
384 error_type: "bad_request".to_string(),
385 message: "No method specified".to_string(),
386 stack: None,
387 },
388 );
389 }
390 } else {
391 tracing::error!(
392 " Capability {} not found in cap_table",
393 cap_id
394 );
395 session.results.insert(
396 assigned_import_id,
397 WireExpression::Error {
398 error_type: "not_found".to_string(),
399 message: format!("Capability {} not found", cap_id),
400 stack: None,
401 },
402 );
403 }
404 }
405 _ => {
406 tracing::warn!(" Push expression is not a pipeline or call (unsupported): {:?}", expr);
407 session.results.insert(
408 assigned_import_id,
409 WireExpression::Error {
410 error_type: "not_implemented".to_string(),
411 message: "Only pipeline and call expressions are supported"
412 .to_string(),
413 stack: None,
414 },
415 );
416 }
417 }
418 }
419
420 WireMessage::Pull(import_id) => {
421 tracing::debug!(" PULL for import_id: {}", import_id);
422
423 // Look up the result for this import ID
424 if let Some(result) = session.results.get(import_id) {
425 tracing::info!(" Found result for import ID {}", import_id);
426
427 // Check if it's an error
428 if let WireExpression::Error { .. } = result {
429 // Use the import ID as the export ID (per protocol spec)
430 responses.push(WireMessage::Reject(
431 *import_id, // Use import ID as export ID
432 result.clone(),
433 ));
434 } else {
435 // Use the import ID as the export ID (per protocol spec)
436 responses.push(WireMessage::Resolve(
437 *import_id, // Use import ID as export ID
438 result.clone(),
439 ));
440 }
441 } else {
442 tracing::warn!(" No result found for import ID {}", import_id);
443 responses.push(WireMessage::Reject(
444 *import_id,
445 WireExpression::Error {
446 error_type: "not_found".to_string(),
447 message: format!("No result for import ID {}", import_id),
448 stack: None,
449 },
450 ));
451 }
452 }
453
454 WireMessage::Release(ids) => {
455 tracing::info!(" RELEASE capabilities: {:?}", ids);
456 // Handle capability disposal
457 for id in ids {
458 let cap_id = CapId::new(*id as u64);
459 server.cap_table.remove(&cap_id);
460 }
461 }
462
463 other => {
464 tracing::warn!(
465 " Unhandled message type (not yet implemented): {:?}",
466 other
467 );
468 }
469 }
470 }
471
472 // Serialize responses using official wire protocol (newline-delimited)
473 let response_body = serialize_wire_batch(&responses);
474 tracing::info!("📤 Sending {} response messages", responses.len());
475 tracing::debug!(
476 "Response body (first 500 chars): {}",
477 &response_body.chars().take(500).collect::<String>()
478 );
479
480 (
481 StatusCode::OK,
482 [("content-type", "text/plain")],
483 response_body,
484 )
485 }
486 Err(e) => {
487 tracing::error!("Failed to parse wire protocol: {}", e);
488 tracing::debug!(
489 "Invalid input was: {}",
490 &body_str.chars().take(1000).collect::<String>()
491 );
492 let error_response = WireMessage::Reject(
493 -1,
494 WireExpression::Error {
495 error_type: "bad_request".to_string(),
496 message: format!("Invalid wire protocol: {}", e),
497 stack: None,
498 },
499 );
500 let response = serialize_wire_batch(&[error_response]);
501 (
502 StatusCode::BAD_REQUEST,
503 [("content-type", "text/plain")],
504 response,
505 )
506 }
507 }
508}
509
510async fn handle_health(State(server): State<Arc<Server>>) -> impl IntoResponse {
511 let capability_count = server.cap_table.len();
512
513 let mut endpoints = serde_json::json!({
514 "batch": "/rpc/batch",
515 "health": "/health"
516 });
517
518 // Add WebSocket endpoint if available
519 #[cfg(feature = "all-transports")]
520 {
521 endpoints["websocket"] = serde_json::json!("/rpc/ws");
522 }
523
524 let health_response = serde_json::json!({
525 "status": "healthy",
526 "server": "capnweb-rust",
527 "version": env!("CARGO_PKG_VERSION"),
528 "capabilities": capability_count,
529 "max_batch_size": server.config.max_batch_size,
530 "features": {
531 "websocket": cfg!(feature = "all-transports"),
532 "h3": cfg!(feature = "h3-server")
533 },
534 "endpoints": endpoints
535 });
536
537 (StatusCode::OK, Json(health_response))
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use serde_json::json;
544
545 struct TestTarget;
546
547 #[async_trait]
548 impl RpcTarget for TestTarget {
549 async fn call(&self, member: &str, args: Vec<Value>) -> Result<Value, RpcError> {
550 match member {
551 "echo" => Ok(args.first().cloned().unwrap_or(Value::Null)),
552 "add" => {
553 if args.len() != 2 {
554 return Err(RpcError::bad_request("add requires 2 arguments"));
555 }
556 let a = args[0]
557 .as_f64()
558 .ok_or_else(|| RpcError::bad_request("First arg must be number"))?;
559 let b = args[1]
560 .as_f64()
561 .ok_or_else(|| RpcError::bad_request("Second arg must be number"))?;
562 Ok(json!(a + b))
563 }
564 _ => Err(RpcError::not_found(format!(
565 "Method '{}' not found",
566 member
567 ))),
568 }
569 }
570 }
571
572 #[tokio::test]
573 async fn test_server_creation() {
574 let config = ServerConfig::default();
575 let server = Server::new(config);
576 assert_eq!(server.config.port, 8080);
577 }
578
579 #[tokio::test]
580 async fn test_register_capability() {
581 let server = Server::new(ServerConfig::default());
582 let cap_id = CapId::new(42);
583 let target = Arc::new(TestTarget);
584
585 server.register_capability(cap_id, target);
586 assert!(server.cap_table.lookup(&cap_id).is_some());
587 }
588
589 #[tokio::test]
590 async fn test_wire_protocol_push() {
591 let server = Server::new(ServerConfig::default());
592 let cap_id = CapId::new(1);
593 server.register_capability(cap_id, Arc::new(TestTarget));
594
595 // Simulate wire protocol push message for "echo" method
596 // This would be the wire format:
597 // let _wire_messages = vec![WireMessage::Push(WireExpression::Pipeline {
598 // import_id: 1, // Map to CapId(1)
599 // property_path: Some(vec![PropertyKey::String("echo".to_string())]),
600 // args: Some(Box::new(WireExpression::Array(vec![
601 // WireExpression::String("hello".to_string()),
602 // ]))),
603 // })];
604
605 // Process directly (simulating what handle_batch would do)
606 let capability = server.cap_table.lookup(&cap_id).unwrap();
607 let result = capability.call("echo", vec![json!("hello")]).await.unwrap();
608 assert_eq!(result, json!("hello"));
609 }
610
611 #[tokio::test]
612 async fn test_wire_protocol_release() {
613 let server = Server::new(ServerConfig::default());
614 let cap_id = CapId::new(1);
615 server.register_capability(cap_id, Arc::new(TestTarget));
616
617 assert!(server.cap_table.lookup(&cap_id).is_some());
618
619 // Simulate wire protocol release message
620 server.cap_table.remove(&cap_id);
621
622 assert!(server.cap_table.lookup(&cap_id).is_none());
623 }
624
625 #[tokio::test]
626 async fn test_wire_protocol_unknown_capability() {
627 let server = Server::new(ServerConfig::default());
628 let cap_id = CapId::new(999);
629
630 // Try to look up non-existent capability
631 assert!(server.cap_table.lookup(&cap_id).is_none());
632 }
633}