capnweb_server/ws_wire.rs
1use crate::server_wire_handler::{value_to_wire_expr, wire_expr_to_values};
2use crate::Server;
3use axum::{
4 extract::{
5 ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
6 State,
7 },
8 response::Response,
9};
10use capnweb_core::{
11 parse_wire_batch, serialize_wire_batch, CapId, PropertyKey, WireExpression, WireMessage,
12};
13use futures::{SinkExt, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18/// WebSocket session state that persists across messages
19struct WsSession {
20 #[allow(dead_code)]
21 session_id: String,
22 next_import_id: i64,
23 #[allow(dead_code)]
24 next_export_id: i64,
25 // Map import IDs to their expressions
26 imports: HashMap<i64, WireExpression>,
27 // Map export IDs to their values
28 #[allow(dead_code)]
29 exports: HashMap<i64, WireExpression>,
30}
31
32impl WsSession {
33 fn new(session_id: String) -> Self {
34 Self {
35 session_id,
36 next_import_id: 1, // Client imports start at 1
37 next_export_id: -1, // Server exports start at -1
38 imports: HashMap::new(),
39 exports: HashMap::new(),
40 }
41 }
42}
43
44/// WebSocket handler for Cap'n Web wire protocol
45pub async fn websocket_wire_handler(
46 ws: WebSocketUpgrade,
47 State(server): State<Arc<Server>>,
48) -> Response {
49 ws.on_upgrade(move |socket| handle_wire_socket(socket, server))
50}
51
52async fn handle_wire_socket(socket: WebSocket, server: Arc<Server>) {
53 let session_id = uuid::Uuid::new_v4().to_string();
54 tracing::info!(
55 "WebSocket wire protocol connection established: {}",
56 session_id
57 );
58
59 let session = Arc::new(RwLock::new(WsSession::new(session_id.clone())));
60 let (mut sender, mut receiver) = socket.split();
61
62 // Handle incoming messages
63 while let Some(result) = receiver.next().await {
64 match result {
65 Ok(msg) => {
66 match msg {
67 WsMessage::Text(text) => {
68 tracing::debug!("WS received: {}", text);
69
70 // Parse wire protocol messages
71 match parse_wire_batch(&text) {
72 Ok(messages) => {
73 let mut responses = Vec::new();
74 let mut session_guard = session.write().await;
75
76 for msg in messages {
77 tracing::debug!("Processing WS message: {:?}", msg);
78
79 match msg {
80 WireMessage::Push(expr) => {
81 // Assign import ID
82 let import_id = session_guard.next_import_id;
83 session_guard.next_import_id += 1;
84
85 tracing::info!(
86 "WS Push assigned import ID: {}",
87 import_id
88 );
89 session_guard.imports.insert(import_id, expr.clone());
90
91 // Process pipeline expression
92 if let WireExpression::Pipeline {
93 import_id: target_id,
94 property_path,
95 args,
96 } = expr
97 {
98 let cap_id = if target_id == 0 {
99 CapId::new(1) // Main capability
100 } else {
101 CapId::new(target_id as u64)
102 };
103
104 if let Some(capability) =
105 server.cap_table().lookup(&cap_id)
106 {
107 if let Some(path) = property_path {
108 if let Some(PropertyKey::String(method)) =
109 path.first()
110 {
111 let json_args = args
112 .as_ref()
113 .map(|a| wire_expr_to_values(a))
114 .unwrap_or_else(Vec::new);
115
116 match capability
117 .call(method, json_args)
118 .await
119 {
120 Ok(result) => {
121 let result_expr =
122 value_to_wire_expr(result);
123 session_guard.imports.insert(
124 import_id,
125 result_expr,
126 );
127 }
128 Err(err) => {
129 session_guard.imports.insert(
130 import_id,
131 WireExpression::Error {
132 error_type: err
133 .code
134 .to_string(),
135 message: err.message,
136 stack: None,
137 },
138 );
139 }
140 }
141 }
142 }
143 } else {
144 session_guard.imports.insert(
145 import_id,
146 WireExpression::Error {
147 error_type: "not_found".to_string(),
148 message: format!(
149 "Capability {} not found",
150 target_id
151 ),
152 stack: None,
153 },
154 );
155 }
156 }
157 }
158
159 WireMessage::Pull(import_id) => {
160 tracing::debug!("WS Pull for import_id: {}", import_id);
161
162 if let Some(result) =
163 session_guard.imports.get(&import_id)
164 {
165 if let WireExpression::Error { .. } = result {
166 responses.push(WireMessage::Reject(
167 import_id,
168 result.clone(),
169 ));
170 } else {
171 responses.push(WireMessage::Resolve(
172 import_id,
173 result.clone(),
174 ));
175 }
176 } else {
177 responses.push(WireMessage::Reject(
178 import_id,
179 WireExpression::Error {
180 error_type: "not_found".to_string(),
181 message: format!(
182 "No result for import ID {}",
183 import_id
184 ),
185 stack: None,
186 },
187 ));
188 }
189 }
190
191 WireMessage::Release(ids) => {
192 tracing::info!("WS Release for IDs: {:?}", ids);
193 // Remove released imports
194 for id in ids {
195 session_guard.imports.remove(&id);
196 }
197 }
198
199 _ => {
200 tracing::warn!("WS unhandled message type: {:?}", msg);
201 }
202 }
203 }
204
205 // Send responses
206 if !responses.is_empty() {
207 let response_text = serialize_wire_batch(&responses);
208 tracing::debug!("WS sending: {}", response_text);
209
210 if let Err(e) =
211 sender.send(WsMessage::Text(response_text.into())).await
212 {
213 tracing::error!("Failed to send WS response: {}", e);
214 break;
215 }
216 }
217 }
218 Err(e) => {
219 tracing::error!("Failed to parse WS wire protocol: {}", e);
220 let error_response = WireMessage::Reject(
221 -1,
222 WireExpression::Error {
223 error_type: "bad_request".to_string(),
224 message: format!("Invalid wire protocol: {}", e),
225 stack: None,
226 },
227 );
228 let response_text = serialize_wire_batch(&[error_response]);
229 if let Err(e) =
230 sender.send(WsMessage::Text(response_text.into())).await
231 {
232 tracing::error!("Failed to send error response: {}", e);
233 break;
234 }
235 }
236 }
237 }
238 WsMessage::Binary(data) => {
239 tracing::warn!("Received binary WS message, trying as UTF-8");
240 if let Ok(_text) = String::from_utf8(data.to_vec()) {
241 // Process as text
242 continue;
243 }
244 }
245 WsMessage::Close(frame) => {
246 tracing::info!("WebSocket closing: {} (reason: {:?})", session_id, frame);
247 break;
248 }
249 _ => {}
250 }
251 }
252 Err(e) => {
253 tracing::error!("WebSocket error: {}", e);
254 break;
255 }
256 }
257 }
258
259 tracing::info!("WebSocket disconnected: {}", session_id);
260}