reifydb_sub_server/protocols/ws/
handler.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4use std::io::{Read, Write};
5
6use reifydb_core::interface::{Engine, Identity};
7
8use super::{
9	CommandRequest, CommandResponse, QueryRequest, QueryResponse, Request, Response, ResponsePayload,
10	WebSocketConnectionData, WsState,
11};
12use crate::{
13	core::Connection,
14	protocols::{
15		ProtocolError, ProtocolHandler, ProtocolResult,
16		convert::{convert_frames, convert_params},
17		utils::{build_ws_frame, build_ws_response, find_header_end, parse_ws_frame},
18	},
19};
20
21#[derive(Clone)]
22pub struct WebSocketHandler;
23
24impl WebSocketHandler {
25	pub fn new() -> Self {
26		Self
27	}
28}
29
30impl ProtocolHandler for WebSocketHandler {
31	fn name(&self) -> &'static str {
32		"ws"
33	}
34
35	fn can_handle(&self, buffer: &[u8]) -> bool {
36		// Check for WebSocket handshake signature
37		if buffer.len() < 16 {
38			return false;
39		}
40
41		let request = String::from_utf8_lossy(buffer);
42
43		let request_lower = request.to_lowercase();
44
45		request_lower.contains("get ")
46			&& request_lower.contains("http/1.1")
47			&& request_lower.contains("upgrade: websocket")
48			&& (request_lower.contains("connection: upgrade")
49				|| request_lower.contains("connection: keep-alive, upgrade"))
50	}
51
52	fn handle_connection(&self, conn: &mut Connection) -> ProtocolResult<()> {
53		// Initialize WebSocket state
54		let ws_state = WsState::Handshake(WebSocketConnectionData::new());
55		conn.set_state(crate::core::ConnectionState::WebSocket(ws_state));
56		Ok(())
57	}
58
59	fn handle_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
60		if let crate::core::ConnectionState::WebSocket(ws_state) = conn.state() {
61			match ws_state {
62				WsState::Handshake(_) => self.handle_handshake_read(conn),
63				WsState::Active(_) => self.handle_ws_read(conn),
64				WsState::Closed => Ok(()),
65			}
66		} else {
67			Err(ProtocolError::InvalidFrame)
68		}
69	}
70
71	fn handle_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
72		if let crate::core::ConnectionState::WebSocket(ws_state) = conn.state() {
73			match ws_state {
74				WsState::Handshake(_) => self.handle_handshake_write(conn),
75				WsState::Active(_) => self.handle_ws_write(conn),
76				WsState::Closed => Ok(()),
77			}
78		} else {
79			Err(ProtocolError::InvalidFrame)
80		}
81	}
82
83	fn should_close(&self, conn: &Connection) -> bool {
84		matches!(
85			conn.state(),
86			crate::core::ConnectionState::WebSocket(WsState::Closed) | crate::core::ConnectionState::Closed
87		)
88	}
89}
90
91impl WebSocketHandler {
92	fn handle_handshake_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
93		// First, check if we already have complete headers in the
94		// buffer (from protocol detection)
95		if !conn.buffer().is_empty() {
96			if let Some(hlen) = find_header_end(conn.buffer()) {
97				let (resp, _key) = build_ws_response(&conn.buffer()[..hlen])
98					.map_err(|e| ProtocolError::Custom(format!("Handshake error: {}", e)))?;
99
100				// Update WebSocket state with response
101				if let crate::core::ConnectionState::WebSocket(WsState::Handshake(data)) =
102					conn.state_mut()
103				{
104					data.handshake_response = Some(resp);
105				}
106
107				// Clear the handshake data from buffer
108				conn.buffer_mut().drain(0..hlen);
109				return Ok(());
110			}
111		}
112
113		// If we don't have complete headers yet, read more data
114		let mut buf = [0u8; 2048];
115		loop {
116			match conn.stream().read(&mut buf) {
117				Ok(0) => return Err(ProtocolError::ConnectionClosed),
118				Ok(n) => {
119					conn.buffer_mut().extend_from_slice(&buf[..n]);
120					if let Some(hlen) = find_header_end(conn.buffer()) {
121						let (resp, _key) =
122							build_ws_response(&conn.buffer()[..hlen]).map_err(|e| {
123								ProtocolError::Custom(format!("Handshake error: {}", e))
124							})?;
125
126						// Update WebSocket state with response
127						if let crate::core::ConnectionState::WebSocket(WsState::Handshake(
128							data,
129						)) = conn.state_mut()
130						{
131							data.handshake_response = Some(resp);
132						}
133
134						// Clear the handshake data from buffer
135						conn.buffer_mut().drain(0..hlen);
136						return Ok(());
137					}
138					if conn.buffer().len() > 16 * 1024 {
139						return Err(ProtocolError::BufferOverflow);
140					}
141					if n < buf.len() {
142						break;
143					}
144				}
145				Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
146				Err(e) => return Err(ProtocolError::Io(e)),
147			}
148		}
149		Ok(())
150	}
151
152	fn handle_handshake_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
153		// Extract the necessary data to avoid borrowing issues
154		let (response, written) =
155			if let crate::core::ConnectionState::WebSocket(WsState::Handshake(data)) = conn.state() {
156				if let Some(ref response) = data.handshake_response {
157					(response.clone(), data.written)
158				} else {
159					return Ok(());
160				}
161			} else {
162				return Ok(());
163			};
164
165		let mut bytes_written = written;
166		loop {
167			if bytes_written >= response.len() {
168				// Handshake complete, transition to active
169				// state
170				let active_data = WebSocketConnectionData::active();
171				conn.set_state(crate::core::ConnectionState::WebSocket(WsState::Active(active_data)));
172				break;
173			}
174
175			match conn.stream().write(&response[bytes_written..]) {
176				Ok(0) => return Err(ProtocolError::ConnectionClosed),
177				Ok(n) => {
178					bytes_written += n;
179					// Update the state with the new written count
180					if let crate::core::ConnectionState::WebSocket(WsState::Handshake(data)) =
181						conn.state_mut()
182					{
183						data.written = bytes_written;
184					}
185				}
186				Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
187				Err(e) => return Err(ProtocolError::Io(e)),
188			}
189		}
190		Ok(())
191	}
192
193	fn handle_ws_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
194		let mut buf = [0u8; 8192];
195
196		loop {
197			match conn.stream().read(&mut buf) {
198				Ok(0) => return Err(ProtocolError::ConnectionClosed),
199				Ok(n) => {
200					// Add data to connection buffer
201					conn.buffer_mut().extend_from_slice(&buf[..n]);
202
203					// Process complete frames from buffer
204					self.process_buffered_ws_data(conn)?;
205
206					if n < buf.len() {
207						break;
208					}
209				}
210				Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
211				Err(e) => return Err(ProtocolError::Io(e)),
212			}
213		}
214		Ok(())
215	}
216
217	fn handle_ws_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
218		loop {
219			// Check if there's a frame to send
220			let frame_to_send =
221				if let crate::core::ConnectionState::WebSocket(WsState::Active(data)) = conn.state() {
222					data.outbox.front().cloned()
223				} else {
224					break;
225				};
226
227			if let Some(frame) = frame_to_send {
228				match conn.stream().write(&frame) {
229					Ok(n) => {
230						if n == frame.len() {
231							// Full frame written
232							if let crate::core::ConnectionState::WebSocket(
233								WsState::Active(data),
234							) = conn.state_mut()
235							{
236								let written_frame = data.outbox.pop_front().unwrap();
237								data.outbox_bytes = data
238									.outbox_bytes
239									.saturating_sub(written_frame.len());
240							}
241						} else {
242							// Partial write - update the frame
243							if let crate::core::ConnectionState::WebSocket(
244								WsState::Active(data),
245							) = conn.state_mut()
246							{
247								let mut remaining = data.outbox.pop_front().unwrap();
248								remaining.drain(0..n);
249								data.outbox.push_front(remaining);
250								data.outbox_bytes = data.outbox_bytes.saturating_sub(n);
251							}
252							break;
253						}
254					}
255					Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
256					Err(e) => return Err(ProtocolError::Io(e)),
257				}
258			} else {
259				break;
260			}
261		}
262		Ok(())
263	}
264
265	fn process_buffered_ws_data(&self, conn: &mut Connection) -> ProtocolResult<()> {
266		let mut total_processed = 0;
267
268		// Process frames directly from buffer to avoid copies
269		loop {
270			let remaining_len = {
271				let buffer = conn.buffer();
272				if total_processed >= buffer.len() {
273					break;
274				}
275				buffer.len() - total_processed
276			};
277
278			if remaining_len == 0 {
279				break;
280			}
281
282			// Parse frame directly from buffer slice
283			let frame_result = {
284				let buffer = conn.buffer();
285				let remaining_data = &buffer[total_processed..];
286
287				match parse_ws_frame(remaining_data)
288					.map_err(|e| ProtocolError::Custom(format!("Frame parse error: {}", e)))?
289				{
290					Some((opcode, payload)) => {
291						let frame_size = self.calculate_frame_size(remaining_data)?;
292						Some((opcode, payload, frame_size))
293					}
294					None => None, // Incomplete frame
295				}
296			};
297
298			match frame_result {
299				Some((opcode, payload, frame_size)) => {
300					total_processed += frame_size;
301
302					// Process frame immediately to avoid
303					// storing payload
304					self.process_ws_frame(conn, opcode, payload)?;
305				}
306				None => {
307					// Incomplete frame, wait for more data
308					break;
309				}
310			}
311		}
312
313		// Remove processed data from connection buffer
314		if total_processed > 0 {
315			conn.buffer_mut().drain(0..total_processed);
316
317			// Optimize buffer after processing to manage memory
318			// efficiently
319			conn.optimize_buffer();
320		}
321
322		Ok(())
323	}
324
325	fn calculate_frame_size(&self, data: &[u8]) -> ProtocolResult<usize> {
326		if data.len() < 2 {
327			return Ok(0);
328		}
329
330		let second_byte = data[1];
331		let masked = (second_byte & 0x80) != 0;
332		let mut payload_len = (second_byte & 0x7F) as usize;
333		let mut pos = 2;
334
335		// Extended payload length
336		if payload_len == 126 {
337			if data.len() < pos + 2 {
338				return Ok(0);
339			}
340			payload_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
341			pos += 2;
342		} else if payload_len == 127 {
343			if data.len() < pos + 8 {
344				return Ok(0);
345			}
346			payload_len = u64::from_be_bytes([
347				data[pos],
348				data[pos + 1],
349				data[pos + 2],
350				data[pos + 3],
351				data[pos + 4],
352				data[pos + 5],
353				data[pos + 6],
354				data[pos + 7],
355			]) as usize;
356			pos += 8;
357		}
358
359		// Add masking key size
360		if masked {
361			pos += 4;
362		}
363
364		// Add payload size
365		pos += payload_len;
366
367		Ok(pos)
368	}
369
370	fn process_ws_frame(&self, conn: &mut Connection, opcode: u8, payload: Vec<u8>) -> ProtocolResult<()> {
371		match opcode {
372			1 => {
373				// Text frame - try to parse as WebSocket
374				// Request
375				let text = String::from_utf8_lossy(&payload);
376
377				match serde_json::from_str::<Request>(&text) {
378					Ok(request) => {
379						let response_payload = self.handle_request(conn, &request)?;
380						let response = Response {
381							id: request.id,
382							payload: response_payload,
383						};
384						let response_json = serde_json::to_string(&response).map_err(|e| {
385							ProtocolError::Custom(format!("JSON error: {}", e))
386						})?;
387						let response_frame = build_ws_frame(1, response_json.as_bytes());
388						self.send_frame(conn, response_frame)?;
389					}
390					Err(parse_error) => {
391						// Not a valid WebSocket Request
392						// - send error response
393						eprintln!("WebSocket request parse error: {}", parse_error);
394						let error_response = serde_json::json!({
395							"error": "Invalid request format",
396							"message": format!("Failed to parse WebSocket request: {}", parse_error)
397						});
398						let error_json =
399							serde_json::to_string(&error_response).map_err(|e| {
400								ProtocolError::Custom(format!("JSON error: {}", e))
401							})?;
402						let error_frame = build_ws_frame(1, error_json.as_bytes());
403						self.send_frame(conn, error_frame)?;
404					}
405				}
406			}
407			2 => {
408				// Binary frame - echo it back
409				let response_frame = build_ws_frame(2, &payload);
410				self.send_frame(conn, response_frame)?;
411			}
412			8 => {
413				// Close frame - send close response and mark
414				// connection for closure
415				let close_code = if payload.len() >= 2 {
416					u16::from_be_bytes([payload[0], payload[1]])
417				} else {
418					1000 // Normal closure
419				};
420
421				let _close_reason = if payload.len() > 2 {
422					String::from_utf8_lossy(&payload[2..]).to_string()
423				} else {
424					"Connection closed by client".to_string()
425				};
426
427				// Send close response with same code
428				let mut close_payload = close_code.to_be_bytes().to_vec();
429				close_payload.extend_from_slice(b"Server closing connection");
430				let close_response = build_ws_frame(8, &close_payload);
431				self.send_frame(conn, close_response)?;
432
433				// Mark connection as closed
434				conn.set_state(crate::core::ConnectionState::WebSocket(WsState::Closed));
435			}
436			9 => {
437				// Ping frame - respond with pong
438				let pong_response = build_ws_frame(10, &payload);
439				self.send_frame(conn, pong_response)?;
440			}
441			10 => {
442				// Pong frame - client response to our ping, no
443				// action needed
444			}
445			_ => {
446				// Ignore other opcodes for now
447			}
448		}
449		Ok(())
450	}
451
452	fn handle_request(&self, conn: &mut Connection, request: &Request) -> ProtocolResult<ResponsePayload> {
453		use super::{AuthResponse, RequestPayload};
454
455		match &request.payload {
456			RequestPayload::Auth(_auth_req) => {
457				// For now, always return success for auth
458				Ok(ResponsePayload::Auth(AuthResponse {}))
459			}
460			RequestPayload::Command(cmd_req) => self.handle_command_request(conn, cmd_req),
461			RequestPayload::Query(query_req) => self.handle_query_request(conn, query_req),
462		}
463	}
464
465	fn handle_command_request(
466		&self,
467		conn: &mut Connection,
468		cmd_req: &CommandRequest,
469	) -> ProtocolResult<ResponsePayload> {
470		// Execute each statement and collect results
471		let mut all_frames = Vec::new();
472
473		for statement in &cmd_req.statements {
474			let params = convert_params(&cmd_req.params)?;
475
476			match conn.engine().command_as(
477				&Identity::System {
478					id: 1,
479					name: "root".to_string(),
480				},
481				statement,
482				params,
483			) {
484				Ok(result) => {
485					let frames = convert_frames(result)?;
486					all_frames.extend(frames);
487				}
488				Err(e) => {
489					// Get the diagnostic from the error and
490					// add statement context
491					let mut diagnostic = e.diagnostic();
492					diagnostic.with_statement(statement.clone());
493
494					return Ok(ResponsePayload::Err(super::ErrorResponse {
495						diagnostic,
496					}));
497				}
498			}
499		}
500
501		Ok(ResponsePayload::Command(CommandResponse {
502			frames: all_frames,
503		}))
504	}
505
506	fn handle_query_request(
507		&self,
508		conn: &mut Connection,
509		query_req: &QueryRequest,
510	) -> ProtocolResult<ResponsePayload> {
511		// Execute each statement and collect results
512		let mut all_frames = Vec::new();
513
514		for statement in &query_req.statements {
515			let params = convert_params(&query_req.params)?;
516
517			match conn.engine().query_as(
518				&Identity::System {
519					id: 1,
520					name: "root".to_string(),
521				},
522				statement,
523				params,
524			) {
525				Ok(result) => {
526					let frames = convert_frames(result)?;
527					all_frames.extend(frames);
528				}
529				Err(e) => {
530					// Get the diagnostic from the error and
531					// add statement context
532					let mut diagnostic = e.diagnostic();
533					diagnostic.with_statement(statement.clone());
534
535					return Ok(ResponsePayload::Err(super::ErrorResponse {
536						diagnostic,
537					}));
538				}
539			}
540		}
541
542		Ok(ResponsePayload::Query(QueryResponse {
543			frames: all_frames,
544		}))
545	}
546
547	fn send_frame(&self, conn: &mut Connection, frame: Vec<u8>) -> ProtocolResult<()> {
548		if let crate::core::ConnectionState::WebSocket(WsState::Active(data)) = conn.state_mut() {
549			if data.outbox_bytes + frame.len() > data.max_outbox_bytes {
550				return Err(ProtocolError::BufferOverflow);
551			}
552
553			data.outbox_bytes += frame.len();
554			data.outbox.push_back(frame);
555		}
556		Ok(())
557	}
558}