capnweb_server/
ws_wire.rs

1use crate::server_wire_handler::{value_to_wire_expr, wire_expr_to_values};
2use crate::Server;
3use axum::{
4    extract::{
5        ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
6        State,
7    },
8    response::Response,
9};
10use capnweb_core::{
11    parse_wire_batch, serialize_wire_batch, CapId, PropertyKey, WireExpression, WireMessage,
12};
13use futures::{SinkExt, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18/// WebSocket session state that persists across messages
19struct WsSession {
20    #[allow(dead_code)]
21    session_id: String,
22    next_import_id: i64,
23    #[allow(dead_code)]
24    next_export_id: i64,
25    // Map import IDs to their expressions
26    imports: HashMap<i64, WireExpression>,
27    // Map export IDs to their values
28    #[allow(dead_code)]
29    exports: HashMap<i64, WireExpression>,
30}
31
32impl WsSession {
33    fn new(session_id: String) -> Self {
34        Self {
35            session_id,
36            next_import_id: 1,  // Client imports start at 1
37            next_export_id: -1, // Server exports start at -1
38            imports: HashMap::new(),
39            exports: HashMap::new(),
40        }
41    }
42}
43
44/// WebSocket handler for Cap'n Web wire protocol
45pub async fn websocket_wire_handler(
46    ws: WebSocketUpgrade,
47    State(server): State<Arc<Server>>,
48) -> Response {
49    ws.on_upgrade(move |socket| handle_wire_socket(socket, server))
50}
51
52async fn handle_wire_socket(socket: WebSocket, server: Arc<Server>) {
53    let session_id = uuid::Uuid::new_v4().to_string();
54    tracing::info!(
55        "WebSocket wire protocol connection established: {}",
56        session_id
57    );
58
59    let session = Arc::new(RwLock::new(WsSession::new(session_id.clone())));
60    let (mut sender, mut receiver) = socket.split();
61
62    // Handle incoming messages
63    while let Some(result) = receiver.next().await {
64        match result {
65            Ok(msg) => {
66                match msg {
67                    WsMessage::Text(text) => {
68                        tracing::debug!("WS received: {}", text);
69
70                        // Parse wire protocol messages
71                        match parse_wire_batch(&text) {
72                            Ok(messages) => {
73                                let mut responses = Vec::new();
74                                let mut session_guard = session.write().await;
75
76                                for msg in messages {
77                                    tracing::debug!("Processing WS message: {:?}", msg);
78
79                                    match msg {
80                                        WireMessage::Push(expr) => {
81                                            // Assign import ID
82                                            let import_id = session_guard.next_import_id;
83                                            session_guard.next_import_id += 1;
84
85                                            tracing::info!(
86                                                "WS Push assigned import ID: {}",
87                                                import_id
88                                            );
89                                            session_guard.imports.insert(import_id, expr.clone());
90
91                                            // Process pipeline expression
92                                            if let WireExpression::Pipeline {
93                                                import_id: target_id,
94                                                property_path,
95                                                args,
96                                            } = expr
97                                            {
98                                                let cap_id = if target_id == 0 {
99                                                    CapId::new(1) // Main capability
100                                                } else {
101                                                    CapId::new(target_id as u64)
102                                                };
103
104                                                if let Some(capability) =
105                                                    server.cap_table().lookup(&cap_id)
106                                                {
107                                                    if let Some(path) = property_path {
108                                                        if let Some(PropertyKey::String(method)) =
109                                                            path.first()
110                                                        {
111                                                            let json_args = args
112                                                                .as_ref()
113                                                                .map(|a| wire_expr_to_values(a))
114                                                                .unwrap_or_else(Vec::new);
115
116                                                            match capability
117                                                                .call(method, json_args)
118                                                                .await
119                                                            {
120                                                                Ok(result) => {
121                                                                    let result_expr =
122                                                                        value_to_wire_expr(result);
123                                                                    session_guard.imports.insert(
124                                                                        import_id,
125                                                                        result_expr,
126                                                                    );
127                                                                }
128                                                                Err(err) => {
129                                                                    session_guard.imports.insert(
130                                                                        import_id,
131                                                                        WireExpression::Error {
132                                                                            error_type: err
133                                                                                .code
134                                                                                .to_string(),
135                                                                            message: err.message,
136                                                                            stack: None,
137                                                                        },
138                                                                    );
139                                                                }
140                                                            }
141                                                        }
142                                                    }
143                                                } else {
144                                                    session_guard.imports.insert(
145                                                        import_id,
146                                                        WireExpression::Error {
147                                                            error_type: "not_found".to_string(),
148                                                            message: format!(
149                                                                "Capability {} not found",
150                                                                target_id
151                                                            ),
152                                                            stack: None,
153                                                        },
154                                                    );
155                                                }
156                                            }
157                                        }
158
159                                        WireMessage::Pull(import_id) => {
160                                            tracing::debug!("WS Pull for import_id: {}", import_id);
161
162                                            if let Some(result) =
163                                                session_guard.imports.get(&import_id)
164                                            {
165                                                if let WireExpression::Error { .. } = result {
166                                                    responses.push(WireMessage::Reject(
167                                                        import_id,
168                                                        result.clone(),
169                                                    ));
170                                                } else {
171                                                    responses.push(WireMessage::Resolve(
172                                                        import_id,
173                                                        result.clone(),
174                                                    ));
175                                                }
176                                            } else {
177                                                responses.push(WireMessage::Reject(
178                                                    import_id,
179                                                    WireExpression::Error {
180                                                        error_type: "not_found".to_string(),
181                                                        message: format!(
182                                                            "No result for import ID {}",
183                                                            import_id
184                                                        ),
185                                                        stack: None,
186                                                    },
187                                                ));
188                                            }
189                                        }
190
191                                        WireMessage::Release(ids) => {
192                                            tracing::info!("WS Release for IDs: {:?}", ids);
193                                            // Remove released imports
194                                            for id in ids {
195                                                session_guard.imports.remove(&id);
196                                            }
197                                        }
198
199                                        _ => {
200                                            tracing::warn!("WS unhandled message type: {:?}", msg);
201                                        }
202                                    }
203                                }
204
205                                // Send responses
206                                if !responses.is_empty() {
207                                    let response_text = serialize_wire_batch(&responses);
208                                    tracing::debug!("WS sending: {}", response_text);
209
210                                    if let Err(e) =
211                                        sender.send(WsMessage::Text(response_text.into())).await
212                                    {
213                                        tracing::error!("Failed to send WS response: {}", e);
214                                        break;
215                                    }
216                                }
217                            }
218                            Err(e) => {
219                                tracing::error!("Failed to parse WS wire protocol: {}", e);
220                                let error_response = WireMessage::Reject(
221                                    -1,
222                                    WireExpression::Error {
223                                        error_type: "bad_request".to_string(),
224                                        message: format!("Invalid wire protocol: {}", e),
225                                        stack: None,
226                                    },
227                                );
228                                let response_text = serialize_wire_batch(&[error_response]);
229                                if let Err(e) =
230                                    sender.send(WsMessage::Text(response_text.into())).await
231                                {
232                                    tracing::error!("Failed to send error response: {}", e);
233                                    break;
234                                }
235                            }
236                        }
237                    }
238                    WsMessage::Binary(data) => {
239                        tracing::warn!("Received binary WS message, trying as UTF-8");
240                        if let Ok(_text) = String::from_utf8(data.to_vec()) {
241                            // Process as text
242                            continue;
243                        }
244                    }
245                    WsMessage::Close(frame) => {
246                        tracing::info!("WebSocket closing: {} (reason: {:?})", session_id, frame);
247                        break;
248                    }
249                    _ => {}
250                }
251            }
252            Err(e) => {
253                tracing::error!("WebSocket error: {}", e);
254                break;
255            }
256        }
257    }
258
259    tracing::info!("WebSocket disconnected: {}", session_id);
260}