capnweb_server/
wire_server.rs

1// Official Cap'n Web Wire Protocol Server
2// Implements the official Cap'n Web protocol using newline-delimited arrays
3
4use async_trait::async_trait;
5use axum::{
6    body::Bytes,
7    extract::State,
8    http::{HeaderMap, StatusCode},
9    response::IntoResponse,
10    routing::{get, post},
11    Json, Router,
12};
13use capnweb_core::{
14    parse_wire_batch, serialize_wire_batch, PropertyKey, RpcError, WireExpression, WireMessage,
15};
16use dashmap::DashMap;
17use serde_json::Value;
18use std::sync::Arc;
19use tokio::net::TcpListener;
20use tracing::{debug, error, info, warn};
21
22#[async_trait]
23pub trait WireCapability: Send + Sync {
24    async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, RpcError>;
25}
26
27#[derive(Clone)]
28pub struct WireServer {
29    config: WireServerConfig,
30    capabilities: Arc<DashMap<i64, Arc<dyn WireCapability>>>,
31    next_export_id: Arc<std::sync::atomic::AtomicI64>,
32}
33
34#[derive(Clone)]
35pub struct WireServerConfig {
36    pub port: u16,
37    pub host: String,
38    pub max_batch_size: usize,
39}
40
41impl Default for WireServerConfig {
42    fn default() -> Self {
43        WireServerConfig {
44            port: 8080,
45            host: "127.0.0.1".to_string(),
46            max_batch_size: 100,
47        }
48    }
49}
50
51impl WireServer {
52    pub fn new(config: WireServerConfig) -> Self {
53        WireServer {
54            config,
55            capabilities: Arc::new(DashMap::new()),
56            next_export_id: Arc::new(std::sync::atomic::AtomicI64::new(-1)),
57        }
58    }
59
60    pub fn register_capability(&self, id: i64, capability: Arc<dyn WireCapability>) {
61        info!("Registering capability with ID: {}", id);
62        self.capabilities.insert(id, capability);
63    }
64
65    pub async fn run(self) -> Result<(), std::io::Error> {
66        let addr = format!("{}:{}", self.config.host, self.config.port);
67
68        let app = Router::new()
69            .route("/rpc/batch", post(handle_wire_batch))
70            .route("/health", get(handle_health))
71            .with_state(Arc::new(self));
72        let listener = TcpListener::bind(&addr).await?;
73
74        info!("🚀 Cap'n Web server listening on {}", addr);
75        info!("  HTTP Batch endpoint: http://{}/rpc/batch", addr);
76        info!("  Health endpoint: http://{}/health", addr);
77
78        axum::serve(listener, app).await?;
79
80        Ok(())
81    }
82
83    async fn process_wire_message(&self, message: WireMessage) -> Vec<WireMessage> {
84        debug!("Processing wire message: {:?}", message);
85
86        match message {
87            WireMessage::Push(expr) => {
88                info!("Processing PUSH message");
89                self.handle_push_expression(expr).await
90            }
91
92            WireMessage::Pull(import_id) => {
93                info!("Processing PULL message for import ID: {}", import_id);
94                // For now, just return an error since we don't have promise resolution implemented
95                vec![WireMessage::Reject(
96                    -1, // Use a generic export ID
97                    WireExpression::Error {
98                        error_type: "not_implemented".to_string(),
99                        message: "Promise resolution not yet implemented".to_string(),
100                        stack: None,
101                    },
102                )]
103            }
104
105            WireMessage::Release(import_ids) => {
106                info!("Processing RELEASE message for IDs: {:?}", import_ids);
107                // Release capabilities
108                for id in import_ids {
109                    self.capabilities.remove(&id);
110                }
111                vec![] // No response for release
112            }
113
114            _ => {
115                warn!("Unhandled message type: {:?}", message);
116                vec![]
117            }
118        }
119    }
120
121    async fn handle_push_expression(&self, expr: WireExpression) -> Vec<WireMessage> {
122        match expr {
123            WireExpression::Pipeline {
124                import_id,
125                property_path,
126                args,
127            } => {
128                info!(
129                    "Handling pipeline expression: import_id={}, property_path={:?}",
130                    import_id, property_path
131                );
132
133                // For now, treat import_id as a capability ID to call
134                if let Some(capability) = self.capabilities.get(&import_id) {
135                    // Extract method name from property path
136                    if let Some(property_path) = property_path {
137                        if let Some(PropertyKey::String(method)) = property_path.first() {
138                            // Convert args from WireExpression to serde_json::Value
139                            let json_args = if let Some(args_expr) = args {
140                                self.wire_expression_to_json_args(*args_expr)
141                            } else {
142                                vec![]
143                            };
144
145                            // Call the capability
146                            match capability.call(method, json_args).await {
147                                Ok(result) => {
148                                    let export_id = self
149                                        .next_export_id
150                                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
151
152                                    vec![WireMessage::Resolve(
153                                        export_id,
154                                        self.json_to_wire_expression(result),
155                                    )]
156                                }
157                                Err(err) => {
158                                    let export_id = self
159                                        .next_export_id
160                                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
161
162                                    vec![WireMessage::Reject(
163                                        export_id,
164                                        WireExpression::Error {
165                                            error_type: err.code.to_string(),
166                                            message: err.message.to_string(),
167                                            stack: None,
168                                        },
169                                    )]
170                                }
171                            }
172                        } else {
173                            let export_id = self
174                                .next_export_id
175                                .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
176
177                            vec![WireMessage::Reject(
178                                export_id,
179                                WireExpression::Error {
180                                    error_type: "bad_request".to_string(),
181                                    message: "Invalid property path".to_string(),
182                                    stack: None,
183                                },
184                            )]
185                        }
186                    } else {
187                        let export_id = self
188                            .next_export_id
189                            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
190
191                        vec![WireMessage::Reject(
192                            export_id,
193                            WireExpression::Error {
194                                error_type: "bad_request".to_string(),
195                                message: "Missing property path".to_string(),
196                                stack: None,
197                            },
198                        )]
199                    }
200                } else {
201                    let export_id = self
202                        .next_export_id
203                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
204
205                    vec![WireMessage::Reject(
206                        export_id,
207                        WireExpression::Error {
208                            error_type: "not_found".to_string(),
209                            message: format!("Capability {} not found", import_id),
210                            stack: None,
211                        },
212                    )]
213                }
214            }
215
216            other => {
217                warn!("Unhandled push expression: {:?}", other);
218                let export_id = self
219                    .next_export_id
220                    .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
221
222                vec![WireMessage::Reject(
223                    export_id,
224                    WireExpression::Error {
225                        error_type: "not_implemented".to_string(),
226                        message: "Expression type not implemented".to_string(),
227                        stack: None,
228                    },
229                )]
230            }
231        }
232    }
233
234    fn wire_expression_to_json_args(&self, expr: WireExpression) -> Vec<Value> {
235        match expr {
236            WireExpression::Array(items) => items
237                .into_iter()
238                .map(|item| self.wire_expression_to_json_value(item))
239                .collect(),
240            single => vec![self.wire_expression_to_json_value(single)],
241        }
242    }
243
244    #[allow(clippy::only_used_in_recursion)]
245    fn wire_expression_to_json_value(&self, expr: WireExpression) -> Value {
246        match expr {
247            WireExpression::Null => Value::Null,
248            WireExpression::Bool(b) => Value::Bool(b),
249            WireExpression::Number(n) => Value::Number(n),
250            WireExpression::String(s) => Value::String(s),
251            WireExpression::Array(items) => Value::Array(
252                items
253                    .into_iter()
254                    .map(|item| self.wire_expression_to_json_value(item))
255                    .collect(),
256            ),
257            WireExpression::Object(map) => Value::Object(
258                map.into_iter()
259                    .map(|(k, v)| (k, self.wire_expression_to_json_value(v)))
260                    .collect(),
261            ),
262            _ => Value::String(format!("Unsupported expression: {:?}", expr)),
263        }
264    }
265
266    #[allow(clippy::only_used_in_recursion)]
267    fn json_to_wire_expression(&self, value: Value) -> WireExpression {
268        match value {
269            Value::Null => WireExpression::Null,
270            Value::Bool(b) => WireExpression::Bool(b),
271            Value::Number(n) => WireExpression::Number(n),
272            Value::String(s) => WireExpression::String(s),
273            Value::Array(items) => WireExpression::Array(
274                items
275                    .into_iter()
276                    .map(|item| self.json_to_wire_expression(item))
277                    .collect(),
278            ),
279            Value::Object(map) => WireExpression::Object(
280                map.into_iter()
281                    .map(|(k, v)| (k, self.json_to_wire_expression(v)))
282                    .collect(),
283            ),
284        }
285    }
286}
287
288async fn handle_wire_batch(
289    State(server): State<Arc<WireServer>>,
290    headers: HeaderMap,
291    body: Bytes,
292) -> impl IntoResponse {
293    info!("=== WIRE PROTOCOL REQUEST ===");
294    info!("Headers: {:?}", headers);
295    info!("Body size: {} bytes", body.len());
296
297    let body_str = String::from_utf8_lossy(&body);
298    info!("Raw body: {}", body_str);
299
300    // Parse wire protocol messages
301    let wire_messages = match parse_wire_batch(&body_str) {
302        Ok(messages) => {
303            info!("Successfully parsed {} wire messages", messages.len());
304            for (i, msg) in messages.iter().enumerate() {
305                debug!("Message {}: {:?}", i, msg);
306            }
307            messages
308        }
309        Err(e) => {
310            error!("Failed to parse wire protocol: {}", e);
311            let error_response = WireMessage::Reject(
312                -1,
313                WireExpression::Error {
314                    error_type: "bad_request".to_string(),
315                    message: format!("Invalid wire protocol: {}", e),
316                    stack: None,
317                },
318            );
319            let response = serialize_wire_batch(&[error_response]);
320            return (
321                StatusCode::BAD_REQUEST,
322                [("Content-Type", "text/plain")],
323                response,
324            );
325        }
326    };
327
328    // Check batch size
329    if wire_messages.len() > server.config.max_batch_size {
330        let error_response = WireMessage::Reject(
331            -1,
332            WireExpression::Error {
333                error_type: "bad_request".to_string(),
334                message: format!(
335                    "Batch size {} exceeds maximum {}",
336                    wire_messages.len(),
337                    server.config.max_batch_size
338                ),
339                stack: None,
340            },
341        );
342        let response = serialize_wire_batch(&[error_response]);
343        return (
344            StatusCode::BAD_REQUEST,
345            [("Content-Type", "text/plain")],
346            response,
347        );
348    }
349
350    // Process each wire message
351    let mut response_messages = Vec::new();
352    for message in wire_messages {
353        let responses = server.process_wire_message(message).await;
354        response_messages.extend(responses);
355    }
356
357    // Serialize response
358    let response_body = serialize_wire_batch(&response_messages);
359    info!("Response: {}", response_body);
360
361    (
362        StatusCode::OK,
363        [("Content-Type", "text/plain")],
364        response_body,
365    )
366}
367
368async fn handle_health() -> impl IntoResponse {
369    Json(serde_json::json!({
370        "status": "healthy",
371        "server": "capnweb-rust",
372        "version": "0.1.0",
373        "protocol": "cap'n web wire protocol",
374        "endpoints": {
375            "batch": "/rpc/batch",
376            "health": "/health"
377        }
378    }))
379}
380
381// Adapter for existing RpcTarget trait
382pub struct RpcTargetAdapter<T: crate::RpcTarget> {
383    inner: T,
384}
385
386impl<T: crate::RpcTarget> RpcTargetAdapter<T> {
387    pub fn new(inner: T) -> Self {
388        RpcTargetAdapter { inner }
389    }
390}
391
392#[async_trait]
393impl<T: crate::RpcTarget> WireCapability for RpcTargetAdapter<T> {
394    async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, RpcError> {
395        self.inner.call(method, args).await
396    }
397}