1use async_trait::async_trait;
5use axum::{
6 body::Bytes,
7 extract::State,
8 http::{HeaderMap, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use capnweb_core::{
14 parse_wire_batch, serialize_wire_batch, PropertyKey, RpcError, WireExpression, WireMessage,
15};
16use dashmap::DashMap;
17use serde_json::Value;
18use std::sync::Arc;
19use tokio::net::TcpListener;
20use tracing::{debug, error, info, warn};
21
22#[async_trait]
23pub trait WireCapability: Send + Sync {
24 async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, RpcError>;
25}
26
27#[derive(Clone)]
28pub struct WireServer {
29 config: WireServerConfig,
30 capabilities: Arc<DashMap<i64, Arc<dyn WireCapability>>>,
31 next_export_id: Arc<std::sync::atomic::AtomicI64>,
32}
33
34#[derive(Clone)]
35pub struct WireServerConfig {
36 pub port: u16,
37 pub host: String,
38 pub max_batch_size: usize,
39}
40
41impl Default for WireServerConfig {
42 fn default() -> Self {
43 WireServerConfig {
44 port: 8080,
45 host: "127.0.0.1".to_string(),
46 max_batch_size: 100,
47 }
48 }
49}
50
51impl WireServer {
52 pub fn new(config: WireServerConfig) -> Self {
53 WireServer {
54 config,
55 capabilities: Arc::new(DashMap::new()),
56 next_export_id: Arc::new(std::sync::atomic::AtomicI64::new(-1)),
57 }
58 }
59
60 pub fn register_capability(&self, id: i64, capability: Arc<dyn WireCapability>) {
61 info!("Registering capability with ID: {}", id);
62 self.capabilities.insert(id, capability);
63 }
64
65 pub async fn run(self) -> Result<(), std::io::Error> {
66 let addr = format!("{}:{}", self.config.host, self.config.port);
67
68 let app = Router::new()
69 .route("/rpc/batch", post(handle_wire_batch))
70 .route("/health", get(handle_health))
71 .with_state(Arc::new(self));
72 let listener = TcpListener::bind(&addr).await?;
73
74 info!("🚀 Cap'n Web server listening on {}", addr);
75 info!(" HTTP Batch endpoint: http://{}/rpc/batch", addr);
76 info!(" Health endpoint: http://{}/health", addr);
77
78 axum::serve(listener, app).await?;
79
80 Ok(())
81 }
82
83 async fn process_wire_message(&self, message: WireMessage) -> Vec<WireMessage> {
84 debug!("Processing wire message: {:?}", message);
85
86 match message {
87 WireMessage::Push(expr) => {
88 info!("Processing PUSH message");
89 self.handle_push_expression(expr).await
90 }
91
92 WireMessage::Pull(import_id) => {
93 info!("Processing PULL message for import ID: {}", import_id);
94 vec![WireMessage::Reject(
96 -1, WireExpression::Error {
98 error_type: "not_implemented".to_string(),
99 message: "Promise resolution not yet implemented".to_string(),
100 stack: None,
101 },
102 )]
103 }
104
105 WireMessage::Release(import_ids) => {
106 info!("Processing RELEASE message for IDs: {:?}", import_ids);
107 for id in import_ids {
109 self.capabilities.remove(&id);
110 }
111 vec![] }
113
114 _ => {
115 warn!("Unhandled message type: {:?}", message);
116 vec![]
117 }
118 }
119 }
120
121 async fn handle_push_expression(&self, expr: WireExpression) -> Vec<WireMessage> {
122 match expr {
123 WireExpression::Pipeline {
124 import_id,
125 property_path,
126 args,
127 } => {
128 info!(
129 "Handling pipeline expression: import_id={}, property_path={:?}",
130 import_id, property_path
131 );
132
133 if let Some(capability) = self.capabilities.get(&import_id) {
135 if let Some(property_path) = property_path {
137 if let Some(PropertyKey::String(method)) = property_path.first() {
138 let json_args = if let Some(args_expr) = args {
140 self.wire_expression_to_json_args(*args_expr)
141 } else {
142 vec![]
143 };
144
145 match capability.call(method, json_args).await {
147 Ok(result) => {
148 let export_id = self
149 .next_export_id
150 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
151
152 vec![WireMessage::Resolve(
153 export_id,
154 self.json_to_wire_expression(result),
155 )]
156 }
157 Err(err) => {
158 let export_id = self
159 .next_export_id
160 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
161
162 vec![WireMessage::Reject(
163 export_id,
164 WireExpression::Error {
165 error_type: err.code.to_string(),
166 message: err.message.to_string(),
167 stack: None,
168 },
169 )]
170 }
171 }
172 } else {
173 let export_id = self
174 .next_export_id
175 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
176
177 vec![WireMessage::Reject(
178 export_id,
179 WireExpression::Error {
180 error_type: "bad_request".to_string(),
181 message: "Invalid property path".to_string(),
182 stack: None,
183 },
184 )]
185 }
186 } else {
187 let export_id = self
188 .next_export_id
189 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
190
191 vec![WireMessage::Reject(
192 export_id,
193 WireExpression::Error {
194 error_type: "bad_request".to_string(),
195 message: "Missing property path".to_string(),
196 stack: None,
197 },
198 )]
199 }
200 } else {
201 let export_id = self
202 .next_export_id
203 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
204
205 vec![WireMessage::Reject(
206 export_id,
207 WireExpression::Error {
208 error_type: "not_found".to_string(),
209 message: format!("Capability {} not found", import_id),
210 stack: None,
211 },
212 )]
213 }
214 }
215
216 other => {
217 warn!("Unhandled push expression: {:?}", other);
218 let export_id = self
219 .next_export_id
220 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
221
222 vec![WireMessage::Reject(
223 export_id,
224 WireExpression::Error {
225 error_type: "not_implemented".to_string(),
226 message: "Expression type not implemented".to_string(),
227 stack: None,
228 },
229 )]
230 }
231 }
232 }
233
234 fn wire_expression_to_json_args(&self, expr: WireExpression) -> Vec<Value> {
235 match expr {
236 WireExpression::Array(items) => items
237 .into_iter()
238 .map(|item| self.wire_expression_to_json_value(item))
239 .collect(),
240 single => vec![self.wire_expression_to_json_value(single)],
241 }
242 }
243
244 #[allow(clippy::only_used_in_recursion)]
245 fn wire_expression_to_json_value(&self, expr: WireExpression) -> Value {
246 match expr {
247 WireExpression::Null => Value::Null,
248 WireExpression::Bool(b) => Value::Bool(b),
249 WireExpression::Number(n) => Value::Number(n),
250 WireExpression::String(s) => Value::String(s),
251 WireExpression::Array(items) => Value::Array(
252 items
253 .into_iter()
254 .map(|item| self.wire_expression_to_json_value(item))
255 .collect(),
256 ),
257 WireExpression::Object(map) => Value::Object(
258 map.into_iter()
259 .map(|(k, v)| (k, self.wire_expression_to_json_value(v)))
260 .collect(),
261 ),
262 _ => Value::String(format!("Unsupported expression: {:?}", expr)),
263 }
264 }
265
266 #[allow(clippy::only_used_in_recursion)]
267 fn json_to_wire_expression(&self, value: Value) -> WireExpression {
268 match value {
269 Value::Null => WireExpression::Null,
270 Value::Bool(b) => WireExpression::Bool(b),
271 Value::Number(n) => WireExpression::Number(n),
272 Value::String(s) => WireExpression::String(s),
273 Value::Array(items) => WireExpression::Array(
274 items
275 .into_iter()
276 .map(|item| self.json_to_wire_expression(item))
277 .collect(),
278 ),
279 Value::Object(map) => WireExpression::Object(
280 map.into_iter()
281 .map(|(k, v)| (k, self.json_to_wire_expression(v)))
282 .collect(),
283 ),
284 }
285 }
286}
287
288async fn handle_wire_batch(
289 State(server): State<Arc<WireServer>>,
290 headers: HeaderMap,
291 body: Bytes,
292) -> impl IntoResponse {
293 info!("=== WIRE PROTOCOL REQUEST ===");
294 info!("Headers: {:?}", headers);
295 info!("Body size: {} bytes", body.len());
296
297 let body_str = String::from_utf8_lossy(&body);
298 info!("Raw body: {}", body_str);
299
300 let wire_messages = match parse_wire_batch(&body_str) {
302 Ok(messages) => {
303 info!("Successfully parsed {} wire messages", messages.len());
304 for (i, msg) in messages.iter().enumerate() {
305 debug!("Message {}: {:?}", i, msg);
306 }
307 messages
308 }
309 Err(e) => {
310 error!("Failed to parse wire protocol: {}", e);
311 let error_response = WireMessage::Reject(
312 -1,
313 WireExpression::Error {
314 error_type: "bad_request".to_string(),
315 message: format!("Invalid wire protocol: {}", e),
316 stack: None,
317 },
318 );
319 let response = serialize_wire_batch(&[error_response]);
320 return (
321 StatusCode::BAD_REQUEST,
322 [("Content-Type", "text/plain")],
323 response,
324 );
325 }
326 };
327
328 if wire_messages.len() > server.config.max_batch_size {
330 let error_response = WireMessage::Reject(
331 -1,
332 WireExpression::Error {
333 error_type: "bad_request".to_string(),
334 message: format!(
335 "Batch size {} exceeds maximum {}",
336 wire_messages.len(),
337 server.config.max_batch_size
338 ),
339 stack: None,
340 },
341 );
342 let response = serialize_wire_batch(&[error_response]);
343 return (
344 StatusCode::BAD_REQUEST,
345 [("Content-Type", "text/plain")],
346 response,
347 );
348 }
349
350 let mut response_messages = Vec::new();
352 for message in wire_messages {
353 let responses = server.process_wire_message(message).await;
354 response_messages.extend(responses);
355 }
356
357 let response_body = serialize_wire_batch(&response_messages);
359 info!("Response: {}", response_body);
360
361 (
362 StatusCode::OK,
363 [("Content-Type", "text/plain")],
364 response_body,
365 )
366}
367
368async fn handle_health() -> impl IntoResponse {
369 Json(serde_json::json!({
370 "status": "healthy",
371 "server": "capnweb-rust",
372 "version": "0.1.0",
373 "protocol": "cap'n web wire protocol",
374 "endpoints": {
375 "batch": "/rpc/batch",
376 "health": "/health"
377 }
378 }))
379}
380
381pub struct RpcTargetAdapter<T: crate::RpcTarget> {
383 inner: T,
384}
385
386impl<T: crate::RpcTarget> RpcTargetAdapter<T> {
387 pub fn new(inner: T) -> Self {
388 RpcTargetAdapter { inner }
389 }
390}
391
392#[async_trait]
393impl<T: crate::RpcTarget> WireCapability for RpcTargetAdapter<T> {
394 async fn call(&self, method: &str, args: Vec<Value>) -> Result<Value, RpcError> {
395 self.inner.call(method, args).await
396 }
397}