reifydb_sub_server/protocols/ws/
handler.rs1use std::io::{Read, Write};
5
6use reifydb_core::interface::{Engine, Identity};
7
8use super::{
9 CommandRequest, CommandResponse, QueryRequest, QueryResponse, Request, Response, ResponsePayload,
10 WebSocketConnectionData, WsState,
11};
12use crate::{
13 core::Connection,
14 protocols::{
15 ProtocolError, ProtocolHandler, ProtocolResult,
16 convert::{convert_frames, convert_params},
17 utils::{build_ws_frame, build_ws_response, find_header_end, parse_ws_frame},
18 },
19};
20
21#[derive(Clone)]
22pub struct WebSocketHandler;
23
24impl WebSocketHandler {
25 pub fn new() -> Self {
26 Self
27 }
28}
29
30impl ProtocolHandler for WebSocketHandler {
31 fn name(&self) -> &'static str {
32 "ws"
33 }
34
35 fn can_handle(&self, buffer: &[u8]) -> bool {
36 if buffer.len() < 16 {
38 return false;
39 }
40
41 let request = String::from_utf8_lossy(buffer);
42
43 let request_lower = request.to_lowercase();
44
45 request_lower.contains("get ")
46 && request_lower.contains("http/1.1")
47 && request_lower.contains("upgrade: websocket")
48 && (request_lower.contains("connection: upgrade")
49 || request_lower.contains("connection: keep-alive, upgrade"))
50 }
51
52 fn handle_connection(&self, conn: &mut Connection) -> ProtocolResult<()> {
53 let ws_state = WsState::Handshake(WebSocketConnectionData::new());
55 conn.set_state(crate::core::ConnectionState::WebSocket(ws_state));
56 Ok(())
57 }
58
59 fn handle_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
60 if let crate::core::ConnectionState::WebSocket(ws_state) = conn.state() {
61 match ws_state {
62 WsState::Handshake(_) => self.handle_handshake_read(conn),
63 WsState::Active(_) => self.handle_ws_read(conn),
64 WsState::Closed => Ok(()),
65 }
66 } else {
67 Err(ProtocolError::InvalidFrame)
68 }
69 }
70
71 fn handle_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
72 if let crate::core::ConnectionState::WebSocket(ws_state) = conn.state() {
73 match ws_state {
74 WsState::Handshake(_) => self.handle_handshake_write(conn),
75 WsState::Active(_) => self.handle_ws_write(conn),
76 WsState::Closed => Ok(()),
77 }
78 } else {
79 Err(ProtocolError::InvalidFrame)
80 }
81 }
82
83 fn should_close(&self, conn: &Connection) -> bool {
84 matches!(
85 conn.state(),
86 crate::core::ConnectionState::WebSocket(WsState::Closed) | crate::core::ConnectionState::Closed
87 )
88 }
89}
90
91impl WebSocketHandler {
92 fn handle_handshake_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
93 if !conn.buffer().is_empty() {
96 if let Some(hlen) = find_header_end(conn.buffer()) {
97 let (resp, _key) = build_ws_response(&conn.buffer()[..hlen])
98 .map_err(|e| ProtocolError::Custom(format!("Handshake error: {}", e)))?;
99
100 if let crate::core::ConnectionState::WebSocket(WsState::Handshake(data)) =
102 conn.state_mut()
103 {
104 data.handshake_response = Some(resp);
105 }
106
107 conn.buffer_mut().drain(0..hlen);
109 return Ok(());
110 }
111 }
112
113 let mut buf = [0u8; 2048];
115 loop {
116 match conn.stream().read(&mut buf) {
117 Ok(0) => return Err(ProtocolError::ConnectionClosed),
118 Ok(n) => {
119 conn.buffer_mut().extend_from_slice(&buf[..n]);
120 if let Some(hlen) = find_header_end(conn.buffer()) {
121 let (resp, _key) =
122 build_ws_response(&conn.buffer()[..hlen]).map_err(|e| {
123 ProtocolError::Custom(format!("Handshake error: {}", e))
124 })?;
125
126 if let crate::core::ConnectionState::WebSocket(WsState::Handshake(
128 data,
129 )) = conn.state_mut()
130 {
131 data.handshake_response = Some(resp);
132 }
133
134 conn.buffer_mut().drain(0..hlen);
136 return Ok(());
137 }
138 if conn.buffer().len() > 16 * 1024 {
139 return Err(ProtocolError::BufferOverflow);
140 }
141 if n < buf.len() {
142 break;
143 }
144 }
145 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
146 Err(e) => return Err(ProtocolError::Io(e)),
147 }
148 }
149 Ok(())
150 }
151
152 fn handle_handshake_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
153 let (response, written) =
155 if let crate::core::ConnectionState::WebSocket(WsState::Handshake(data)) = conn.state() {
156 if let Some(ref response) = data.handshake_response {
157 (response.clone(), data.written)
158 } else {
159 return Ok(());
160 }
161 } else {
162 return Ok(());
163 };
164
165 let mut bytes_written = written;
166 loop {
167 if bytes_written >= response.len() {
168 let active_data = WebSocketConnectionData::active();
171 conn.set_state(crate::core::ConnectionState::WebSocket(WsState::Active(active_data)));
172 break;
173 }
174
175 match conn.stream().write(&response[bytes_written..]) {
176 Ok(0) => return Err(ProtocolError::ConnectionClosed),
177 Ok(n) => {
178 bytes_written += n;
179 if let crate::core::ConnectionState::WebSocket(WsState::Handshake(data)) =
181 conn.state_mut()
182 {
183 data.written = bytes_written;
184 }
185 }
186 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
187 Err(e) => return Err(ProtocolError::Io(e)),
188 }
189 }
190 Ok(())
191 }
192
193 fn handle_ws_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
194 let mut buf = [0u8; 8192];
195
196 loop {
197 match conn.stream().read(&mut buf) {
198 Ok(0) => return Err(ProtocolError::ConnectionClosed),
199 Ok(n) => {
200 conn.buffer_mut().extend_from_slice(&buf[..n]);
202
203 self.process_buffered_ws_data(conn)?;
205
206 if n < buf.len() {
207 break;
208 }
209 }
210 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
211 Err(e) => return Err(ProtocolError::Io(e)),
212 }
213 }
214 Ok(())
215 }
216
217 fn handle_ws_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
218 loop {
219 let frame_to_send =
221 if let crate::core::ConnectionState::WebSocket(WsState::Active(data)) = conn.state() {
222 data.outbox.front().cloned()
223 } else {
224 break;
225 };
226
227 if let Some(frame) = frame_to_send {
228 match conn.stream().write(&frame) {
229 Ok(n) => {
230 if n == frame.len() {
231 if let crate::core::ConnectionState::WebSocket(
233 WsState::Active(data),
234 ) = conn.state_mut()
235 {
236 let written_frame = data.outbox.pop_front().unwrap();
237 data.outbox_bytes = data
238 .outbox_bytes
239 .saturating_sub(written_frame.len());
240 }
241 } else {
242 if let crate::core::ConnectionState::WebSocket(
244 WsState::Active(data),
245 ) = conn.state_mut()
246 {
247 let mut remaining = data.outbox.pop_front().unwrap();
248 remaining.drain(0..n);
249 data.outbox.push_front(remaining);
250 data.outbox_bytes = data.outbox_bytes.saturating_sub(n);
251 }
252 break;
253 }
254 }
255 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
256 Err(e) => return Err(ProtocolError::Io(e)),
257 }
258 } else {
259 break;
260 }
261 }
262 Ok(())
263 }
264
265 fn process_buffered_ws_data(&self, conn: &mut Connection) -> ProtocolResult<()> {
266 let mut total_processed = 0;
267
268 loop {
270 let remaining_len = {
271 let buffer = conn.buffer();
272 if total_processed >= buffer.len() {
273 break;
274 }
275 buffer.len() - total_processed
276 };
277
278 if remaining_len == 0 {
279 break;
280 }
281
282 let frame_result = {
284 let buffer = conn.buffer();
285 let remaining_data = &buffer[total_processed..];
286
287 match parse_ws_frame(remaining_data)
288 .map_err(|e| ProtocolError::Custom(format!("Frame parse error: {}", e)))?
289 {
290 Some((opcode, payload)) => {
291 let frame_size = self.calculate_frame_size(remaining_data)?;
292 Some((opcode, payload, frame_size))
293 }
294 None => None, }
296 };
297
298 match frame_result {
299 Some((opcode, payload, frame_size)) => {
300 total_processed += frame_size;
301
302 self.process_ws_frame(conn, opcode, payload)?;
305 }
306 None => {
307 break;
309 }
310 }
311 }
312
313 if total_processed > 0 {
315 conn.buffer_mut().drain(0..total_processed);
316
317 conn.optimize_buffer();
320 }
321
322 Ok(())
323 }
324
325 fn calculate_frame_size(&self, data: &[u8]) -> ProtocolResult<usize> {
326 if data.len() < 2 {
327 return Ok(0);
328 }
329
330 let second_byte = data[1];
331 let masked = (second_byte & 0x80) != 0;
332 let mut payload_len = (second_byte & 0x7F) as usize;
333 let mut pos = 2;
334
335 if payload_len == 126 {
337 if data.len() < pos + 2 {
338 return Ok(0);
339 }
340 payload_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
341 pos += 2;
342 } else if payload_len == 127 {
343 if data.len() < pos + 8 {
344 return Ok(0);
345 }
346 payload_len = u64::from_be_bytes([
347 data[pos],
348 data[pos + 1],
349 data[pos + 2],
350 data[pos + 3],
351 data[pos + 4],
352 data[pos + 5],
353 data[pos + 6],
354 data[pos + 7],
355 ]) as usize;
356 pos += 8;
357 }
358
359 if masked {
361 pos += 4;
362 }
363
364 pos += payload_len;
366
367 Ok(pos)
368 }
369
370 fn process_ws_frame(&self, conn: &mut Connection, opcode: u8, payload: Vec<u8>) -> ProtocolResult<()> {
371 match opcode {
372 1 => {
373 let text = String::from_utf8_lossy(&payload);
376
377 match serde_json::from_str::<Request>(&text) {
378 Ok(request) => {
379 let response_payload = self.handle_request(conn, &request)?;
380 let response = Response {
381 id: request.id,
382 payload: response_payload,
383 };
384 let response_json = serde_json::to_string(&response).map_err(|e| {
385 ProtocolError::Custom(format!("JSON error: {}", e))
386 })?;
387 let response_frame = build_ws_frame(1, response_json.as_bytes());
388 self.send_frame(conn, response_frame)?;
389 }
390 Err(parse_error) => {
391 eprintln!("WebSocket request parse error: {}", parse_error);
394 let error_response = serde_json::json!({
395 "error": "Invalid request format",
396 "message": format!("Failed to parse WebSocket request: {}", parse_error)
397 });
398 let error_json =
399 serde_json::to_string(&error_response).map_err(|e| {
400 ProtocolError::Custom(format!("JSON error: {}", e))
401 })?;
402 let error_frame = build_ws_frame(1, error_json.as_bytes());
403 self.send_frame(conn, error_frame)?;
404 }
405 }
406 }
407 2 => {
408 let response_frame = build_ws_frame(2, &payload);
410 self.send_frame(conn, response_frame)?;
411 }
412 8 => {
413 let close_code = if payload.len() >= 2 {
416 u16::from_be_bytes([payload[0], payload[1]])
417 } else {
418 1000 };
420
421 let _close_reason = if payload.len() > 2 {
422 String::from_utf8_lossy(&payload[2..]).to_string()
423 } else {
424 "Connection closed by client".to_string()
425 };
426
427 let mut close_payload = close_code.to_be_bytes().to_vec();
429 close_payload.extend_from_slice(b"Server closing connection");
430 let close_response = build_ws_frame(8, &close_payload);
431 self.send_frame(conn, close_response)?;
432
433 conn.set_state(crate::core::ConnectionState::WebSocket(WsState::Closed));
435 }
436 9 => {
437 let pong_response = build_ws_frame(10, &payload);
439 self.send_frame(conn, pong_response)?;
440 }
441 10 => {
442 }
445 _ => {
446 }
448 }
449 Ok(())
450 }
451
452 fn handle_request(&self, conn: &mut Connection, request: &Request) -> ProtocolResult<ResponsePayload> {
453 use super::{AuthResponse, RequestPayload};
454
455 match &request.payload {
456 RequestPayload::Auth(_auth_req) => {
457 Ok(ResponsePayload::Auth(AuthResponse {}))
459 }
460 RequestPayload::Command(cmd_req) => self.handle_command_request(conn, cmd_req),
461 RequestPayload::Query(query_req) => self.handle_query_request(conn, query_req),
462 }
463 }
464
465 fn handle_command_request(
466 &self,
467 conn: &mut Connection,
468 cmd_req: &CommandRequest,
469 ) -> ProtocolResult<ResponsePayload> {
470 let mut all_frames = Vec::new();
472
473 for statement in &cmd_req.statements {
474 let params = convert_params(&cmd_req.params)?;
475
476 match conn.engine().command_as(
477 &Identity::System {
478 id: 1,
479 name: "root".to_string(),
480 },
481 statement,
482 params,
483 ) {
484 Ok(result) => {
485 let frames = convert_frames(result)?;
486 all_frames.extend(frames);
487 }
488 Err(e) => {
489 let mut diagnostic = e.diagnostic();
492 diagnostic.with_statement(statement.clone());
493
494 return Ok(ResponsePayload::Err(super::ErrorResponse {
495 diagnostic,
496 }));
497 }
498 }
499 }
500
501 Ok(ResponsePayload::Command(CommandResponse {
502 frames: all_frames,
503 }))
504 }
505
506 fn handle_query_request(
507 &self,
508 conn: &mut Connection,
509 query_req: &QueryRequest,
510 ) -> ProtocolResult<ResponsePayload> {
511 let mut all_frames = Vec::new();
513
514 for statement in &query_req.statements {
515 let params = convert_params(&query_req.params)?;
516
517 match conn.engine().query_as(
518 &Identity::System {
519 id: 1,
520 name: "root".to_string(),
521 },
522 statement,
523 params,
524 ) {
525 Ok(result) => {
526 let frames = convert_frames(result)?;
527 all_frames.extend(frames);
528 }
529 Err(e) => {
530 let mut diagnostic = e.diagnostic();
533 diagnostic.with_statement(statement.clone());
534
535 return Ok(ResponsePayload::Err(super::ErrorResponse {
536 diagnostic,
537 }));
538 }
539 }
540 }
541
542 Ok(ResponsePayload::Query(QueryResponse {
543 frames: all_frames,
544 }))
545 }
546
547 fn send_frame(&self, conn: &mut Connection, frame: Vec<u8>) -> ProtocolResult<()> {
548 if let crate::core::ConnectionState::WebSocket(WsState::Active(data)) = conn.state_mut() {
549 if data.outbox_bytes + frame.len() > data.max_outbox_bytes {
550 return Err(ProtocolError::BufferOverflow);
551 }
552
553 data.outbox_bytes += frame.len();
554 data.outbox.push_back(frame);
555 }
556 Ok(())
557 }
558}