reifydb_client/ws/
client.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT
3
4use std::{
5	io::{Read, Write},
6	net::{SocketAddr, TcpStream, ToSocketAddrs},
7	sync::{Arc, Mutex, mpsc},
8	thread::JoinHandle,
9};
10
11use crate::{
12	Request, Response, ResponseMessage, WsBlockingSession, WsCallbackSession, WsChannelSession,
13	ws::{
14		message::InternalMessage,
15		protocol::{
16			build_ws_frame, calculate_accept_key, calculate_frame_size, find_header_end,
17			generate_websocket_key, parse_ws_frame,
18		},
19		router::RequestRouter,
20		worker,
21	},
22};
23
24/// WebSocket client implementation
25#[derive(Clone)]
26pub struct WsClient {
27	inner: Arc<ClientInner>,
28}
29
30pub(crate) struct ClientInner {
31	pub(crate) command_tx: mpsc::Sender<InternalMessage>,
32	worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
33}
34
35// ============================================================================
36// WsClient Implementation
37// ============================================================================
38
39impl WsClient {
40	/// Create a new WebSocket client from URL string
41	pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
42		// Parse the URL to get a socket address
43		let socket_addr = Self::parse_ws_url(url)?;
44
45		let (command_tx, command_rx) = mpsc::channel();
46		let router = Arc::new(Mutex::new(RequestRouter::new()));
47
48		// Verify connection by creating a test WebSocket client
49		let test_client = WebSocketClient::connect(socket_addr)?;
50		drop(test_client); // Close test connection
51
52		// Start the background worker thread
53		let router_clone = router.clone();
54		let socket_addr_clone = socket_addr;
55		let worker_handle = std::thread::spawn(move || {
56			worker::worker_thread_with_addr(socket_addr_clone, command_rx, router_clone);
57		});
58
59		Ok(Self {
60			inner: Arc::new(ClientInner {
61				command_tx,
62				worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
63			}),
64		})
65	}
66
67	/// Parse a WebSocket URL to extract the socket address
68	fn parse_ws_url(url: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
69		let addr_str = if url.starts_with("ws://") {
70			&url[5..] // Remove "ws://"
71		} else if url.starts_with("wss://") {
72			return Err("WSS (secure WebSocket) is not yet supported".into());
73		} else {
74			url
75		};
76
77		// Parse the address string to SocketAddr
78		// Handle different formats:
79		// - [::1]:8080 (already properly formatted)
80		// - ::1:8080 (needs brackets added)
81		// - localhost:8080 (hostname)
82		// - 127.0.0.1:8080 (IPv4)
83
84		if addr_str.starts_with('[') {
85			// Already has brackets, parse as-is
86			addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
87		} else if addr_str.starts_with("::") {
88			// IPv6 address without brackets
89			// Find the last colon that's likely the port separator
90			// Count colons - if more than 2, it's IPv6
91			let colon_count = addr_str.matches(':').count();
92			if colon_count > 2 {
93				// Definitely IPv6, find the port
94				if let Some(port_start) = addr_str.rfind(':') {
95					// Check if what follows is a port
96					// number
97					if addr_str[port_start + 1..].chars().all(|c| c.is_ascii_digit()) {
98						let ipv6_part = &addr_str[..port_start];
99						let port_part = &addr_str[port_start + 1..];
100						let formatted = format!("[{}]:{}", ipv6_part, port_part);
101						return formatted
102							.to_socket_addrs()?
103							.next()
104							.ok_or_else(|| "Failed to resolve address".into());
105					}
106				}
107			}
108			// Try as-is
109			addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
110		} else {
111			// Regular address (hostname or IPv4)
112			addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
113		}
114	}
115
116	/// Create a new WebSocket client
117	pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
118		// Resolve the address to get the first valid SocketAddr
119		let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
120
121		let (command_tx, command_rx) = mpsc::channel();
122		let router = Arc::new(Mutex::new(RequestRouter::new()));
123
124		// Verify connection by creating a test WebSocket client
125		let test_client = WebSocketClient::connect(socket_addr)?;
126		drop(test_client); // Close test connection
127
128		// Start the background worker thread
129		let router_clone = router.clone();
130		let socket_addr_clone = socket_addr;
131		let worker_handle = std::thread::spawn(move || {
132			worker::worker_thread_with_addr(socket_addr_clone, command_rx, router_clone);
133		});
134
135		Ok(Self {
136			inner: Arc::new(ClientInner {
137				command_tx,
138				worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
139			}),
140		})
141	}
142
143	/// Create a blocking session
144	pub fn blocking_session(&self, token: Option<String>) -> Result<WsBlockingSession, reifydb_type::Error> {
145		WsBlockingSession::new(self.inner.clone(), token)
146	}
147
148	/// Create a callback-based session
149	pub fn callback_session(&self, token: Option<String>) -> Result<WsCallbackSession, reifydb_type::Error> {
150		WsCallbackSession::new(self.inner.clone(), token)
151	}
152
153	/// Create a channel-based session
154	pub fn channel_session(
155		&self,
156		token: Option<String>,
157	) -> Result<(WsChannelSession, mpsc::Receiver<ResponseMessage>), reifydb_type::Error> {
158		WsChannelSession::new(self.inner.clone(), token)
159	}
160
161	/// Close the client connection
162	pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
163		self.inner.command_tx.send(InternalMessage::Close)?;
164
165		// Wait for worker thread to finish
166		if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
167			if let Some(handle) = handle_guard.take() {
168				let _ = handle.join();
169			}
170		}
171		Ok(())
172	}
173}
174
175impl Drop for WsClient {
176	fn drop(&mut self) {
177		let _ = self.inner.command_tx.send(InternalMessage::Close);
178	}
179}
180
181/// WebSocket client implementation
182pub struct WebSocketClient {
183	pub(crate) stream: TcpStream,
184	read_buffer: Vec<u8>,
185	pub(crate) is_connected: bool,
186}
187
188impl WebSocketClient {
189	/// Create a new WebSocket client and connect to the specified address
190	pub fn connect(addr: SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
191		// Connect to the socket address
192		let stream = TcpStream::connect(addr)?;
193		stream.set_nonblocking(true)?;
194
195		let mut client = WebSocketClient {
196			stream,
197			read_buffer: Vec::with_capacity(4096),
198			is_connected: false,
199		};
200
201		// Perform WebSocket handshake
202		client.handshake()?;
203
204		Ok(client)
205	}
206
207	/// Perform WebSocket handshake
208	fn handshake(&mut self) -> Result<(), Box<dyn std::error::Error>> {
209		// Generate WebSocket key
210		let key = generate_websocket_key();
211
212		// Build handshake request
213		let request = format!(
214			"GET / HTTP/1.1\r\n\
215Host: localhost\r\n\
216Upgrade: websocket\r\n\
217Connection: Upgrade\r\n\
218Sec-WebSocket-Key: {}\r\n\
219Sec-WebSocket-Version: 13\r\n\
220\r\n",
221			key
222		);
223
224		// Send handshake
225		self.stream.write_all(request.as_bytes())?;
226		self.stream.flush()?;
227
228		// Read response with timeout
229		let mut response = Vec::new();
230		let mut buffer = [0u8; 1024];
231		let start = std::time::Instant::now();
232		let timeout = std::time::Duration::from_secs(5);
233
234		loop {
235			match self.stream.read(&mut buffer) {
236				Ok(0) => return Err("Connection closed during handshake".into()),
237				Ok(n) => {
238					response.extend_from_slice(&buffer[..n]);
239
240					// Check if we have the complete HTTP
241					// response
242					if let Some(end_pos) = find_header_end(&response) {
243						response.truncate(end_pos);
244						break;
245					}
246				}
247				Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
248					// No data available yet
249					if start.elapsed() > timeout {
250						return Err("Handshake timeout".into());
251					}
252					std::thread::sleep(std::time::Duration::from_millis(10));
253					continue;
254				}
255				Err(e) => return Err(e.into()),
256			}
257		}
258
259		// Verify handshake response
260		let response_str = String::from_utf8_lossy(&response);
261		if !response_str.contains("HTTP/1.1 101") {
262			return Err(format!("Invalid handshake response: {}", response_str).into());
263		}
264
265		// Verify Sec-WebSocket-Accept (case-insensitive header search)
266		let expected_accept = calculate_accept_key(&key);
267		let response_lower = response_str.to_lowercase();
268		let accept_pattern = format!("sec-websocket-accept: {}", expected_accept).to_lowercase();
269		if !response_lower.contains(&accept_pattern) {
270			return Err(format!(
271				"Invalid Sec-WebSocket-Accept. Expected: {}, Response: {}",
272				expected_accept, response_str
273			)
274			.into());
275		}
276
277		self.is_connected = true;
278		Ok(())
279	}
280
281	/// Send a request over the WebSocket connection
282	pub(crate) fn send_request(&mut self, request: &Request) -> Result<(), Box<dyn std::error::Error>> {
283		if !self.is_connected {
284			return Err("Not connected".into());
285		}
286
287		// Serialize request to JSON
288		let json = serde_json::to_string(request)?;
289		let payload = json.as_bytes();
290
291		// Build WebSocket frame (text frame, opcode = 1)
292		let frame = build_ws_frame(0x01, payload, true);
293
294		// Send frame
295		self.stream.write_all(&frame)?;
296		self.stream.flush()?;
297
298		Ok(())
299	}
300
301	/// Receive a response from the WebSocket connection
302	pub fn receive(&mut self) -> Result<Option<Response>, Box<dyn std::error::Error>> {
303		if !self.is_connected {
304			return Err("Not connected".into());
305		}
306
307		// Read data into buffer
308		let mut buf = vec![0u8; 4096];
309		match self.stream.read(&mut buf) {
310			Ok(0) => {
311				self.is_connected = false;
312				return Err("Connection closed".into());
313			}
314			Ok(n) => {
315				self.read_buffer.extend_from_slice(&buf[..n]);
316			}
317			Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
318				// No data available
319				return Ok(None);
320			}
321			Err(e) => return Err(e.into()),
322		}
323
324		// Try to parse WebSocket frame
325		if let Some((opcode, payload)) = parse_ws_frame(&self.read_buffer)? {
326			// Remove parsed frame from buffer
327			let frame_size = calculate_frame_size(&payload, false);
328			self.read_buffer.drain(..frame_size);
329
330			match opcode {
331				0x01 | 0x02 => {
332					// Text or binary frame
333					let response: Response = serde_json::from_slice(&payload)?;
334					return Ok(Some(response));
335				}
336				0x08 => {
337					// Close frame
338					self.is_connected = false;
339					return Err("Connection closed by server".into());
340				}
341				0x09 => {
342					// Ping frame - respond with pong
343					let pong = build_ws_frame(0x0A, &payload, true);
344					self.stream.write_all(&pong)?;
345					self.stream.flush()?;
346				}
347				0x0A => {
348					// Pong frame - ignore
349				}
350				_ => {
351					// Unknown opcode
352					return Err(format!("Unknown opcode: {}", opcode).into());
353				}
354			}
355		}
356
357		Ok(None)
358	}
359
360	/// Close the WebSocket connection
361	pub fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
362		if self.is_connected {
363			// Send close frame
364			let close_frame = build_ws_frame(0x08, &[], true);
365			self.stream.write_all(&close_frame)?;
366			self.stream.flush()?;
367			self.is_connected = false;
368		}
369		Ok(())
370	}
371
372	/// Check if the client is connected
373	pub fn is_connected(&self) -> bool {
374		self.is_connected
375	}
376}
377
378impl Drop for WebSocketClient {
379	fn drop(&mut self) {
380		let _ = self.close();
381	}
382}