reifydb_client/ws/
client.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT
3
4use std::{
5	io::{Read, Write},
6	net::{SocketAddr, TcpStream, ToSocketAddrs},
7	sync::{Arc, Mutex, mpsc},
8	thread::JoinHandle,
9};
10
11use crate::{
12	Request, Response, ResponseMessage, WsBlockingSession, WsCallbackSession, WsChannelSession,
13	ws::{
14		message::InternalMessage,
15		protocol::{
16			build_ws_frame, calculate_accept_key, calculate_frame_size, find_header_end,
17			generate_websocket_key, parse_ws_frame,
18		},
19		router::RequestRouter,
20		worker,
21	},
22};
23
24/// WebSocket client implementation
25#[derive(Clone)]
26pub struct WsClient {
27	inner: Arc<ClientInner>,
28}
29
30pub(crate) struct ClientInner {
31	pub(crate) command_tx: mpsc::Sender<InternalMessage>,
32	worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
33}
34
35// ============================================================================
36// WsClient Implementation
37// ============================================================================
38
39impl WsClient {
40	/// Create a new WebSocket client from URL string
41	pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
42		// Parse the URL to get a socket address
43		let socket_addr = Self::parse_ws_url(url)?;
44
45		let (command_tx, command_rx) = mpsc::channel();
46		let router = Arc::new(Mutex::new(RequestRouter::new()));
47
48		// Verify connection by creating a test WebSocket client
49		let test_client = WebSocketClient::connect(socket_addr)?;
50		drop(test_client); // Close test connection
51
52		// Start the background worker thread
53		let router_clone = router.clone();
54		let socket_addr_clone = socket_addr;
55		let worker_handle = std::thread::spawn(move || {
56			worker::worker_thread_with_addr(socket_addr_clone, command_rx, router_clone);
57		});
58
59		Ok(Self {
60			inner: Arc::new(ClientInner {
61				command_tx,
62				worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
63			}),
64		})
65	}
66
67	/// Parse a WebSocket URL to extract the socket address
68	fn parse_ws_url(url: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
69		let addr_str = if url.starts_with("ws://") {
70			&url[5..] // Remove "ws://"
71		} else if url.starts_with("wss://") {
72			return Err("WSS (secure WebSocket) is not yet supported".into());
73		} else {
74			url
75		};
76
77		// Parse the address string to SocketAddr
78		// Handle different formats:
79		// - [::1]:8080 (already properly formatted)
80		// - ::1:8080 (needs brackets added)
81		// - localhost:8080 (hostname)
82		// - 127.0.0.1:8080 (IPv4)
83
84		if addr_str.starts_with('[') {
85			// Already has brackets, parse as-is
86			addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
87		} else if addr_str.starts_with("::") {
88			// IPv6 address without brackets
89			// Find the last colon that's likely the port separator
90			// Count colons - if more than 2, it's IPv6
91			let colon_count = addr_str.matches(':').count();
92			if colon_count > 2 {
93				// Definitely IPv6, find the port
94				if let Some(port_start) = addr_str.rfind(':') {
95					// Check if what follows is a port
96					// number
97					if addr_str[port_start + 1..].chars().all(|c| c.is_ascii_digit()) {
98						let ipv6_part = &addr_str[..port_start];
99						let port_part = &addr_str[port_start + 1..];
100						let formatted = format!("[{}]:{}", ipv6_part, port_part);
101						return formatted
102							.to_socket_addrs()?
103							.next()
104							.ok_or_else(|| "Failed to resolve address".into());
105					}
106				}
107			}
108			// Try as-is
109			addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
110		} else {
111			// Regular address (hostname or IPv4)
112			addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
113		}
114	}
115
116	/// Create a new WebSocket client
117	pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
118		// Resolve the address to get the first valid SocketAddr
119		let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
120
121		let (command_tx, command_rx) = mpsc::channel();
122		let router = Arc::new(Mutex::new(RequestRouter::new()));
123
124		// Verify connection by creating a test WebSocket client
125		let test_client = WebSocketClient::connect(socket_addr)?;
126		drop(test_client); // Close test connection
127
128		// Start the background worker thread
129		let router_clone = router.clone();
130		let socket_addr_clone = socket_addr;
131		let worker_handle = std::thread::spawn(move || {
132			worker::worker_thread_with_addr(socket_addr_clone, command_rx, router_clone);
133		});
134
135		Ok(Self {
136			inner: Arc::new(ClientInner {
137				command_tx,
138				worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
139			}),
140		})
141	}
142
143	/// Create a blocking session
144	pub fn blocking_session(&self, token: Option<String>) -> Result<WsBlockingSession, reifydb_type::Error> {
145		WsBlockingSession::new(self.inner.clone(), token)
146	}
147
148	/// Create a callback-based session
149	pub fn callback_session(&self, token: Option<String>) -> Result<WsCallbackSession, reifydb_type::Error> {
150		WsCallbackSession::new(self.inner.clone(), token)
151	}
152
153	/// Create a channel-based session
154	pub fn channel_session(
155		&self,
156		token: Option<String>,
157	) -> Result<(WsChannelSession, mpsc::Receiver<ResponseMessage>), reifydb_type::Error> {
158		WsChannelSession::new(self.inner.clone(), token)
159	}
160
161	/// Close the client connection
162	pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
163		self.inner.command_tx.send(InternalMessage::Close)?;
164
165		// Wait for worker thread to finish
166		if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
167			if let Some(handle) = handle_guard.take() {
168				let _ = handle.join();
169			}
170		}
171		Ok(())
172	}
173}
174
175impl Drop for WsClient {
176	fn drop(&mut self) {
177		let _ = self.inner.command_tx.send(InternalMessage::Close);
178	}
179}
180
181/// WebSocket client implementation
182pub struct WebSocketClient {
183	pub(crate) stream: TcpStream,
184	read_buffer: Vec<u8>,
185	pub(crate) is_connected: bool,
186}
187
188impl WebSocketClient {
189	/// Create a new WebSocket client and connect to the specified address
190	pub fn connect(addr: SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
191		// Connect to the socket address
192		let stream = TcpStream::connect(addr)?;
193		stream.set_nonblocking(true)?;
194
195		let mut client = WebSocketClient {
196			stream,
197			read_buffer: Vec::with_capacity(4096),
198			is_connected: false,
199		};
200
201		// Perform WebSocket handshake
202		client.handshake()?;
203
204		Ok(client)
205	}
206
207	/// Perform WebSocket handshake
208	fn handshake(&mut self) -> Result<(), Box<dyn std::error::Error>> {
209		// Generate WebSocket key
210		let key = generate_websocket_key();
211
212		// Build handshake request
213		let request = format!(
214			"GET / HTTP/1.1\r\n\
215             Host: localhost\r\n\
216             Upgrade: websocket\r\n\
217             Connection: Upgrade\r\n\
218             Sec-WebSocket-Key: {}\r\n\
219             Sec-WebSocket-Version: 13\r\n\
220             \r\n",
221			key
222		);
223
224		// Send handshake
225		self.stream.write_all(request.as_bytes())?;
226		self.stream.flush()?;
227
228		// Read response with timeout
229		let mut response = Vec::new();
230		let mut buffer = [0u8; 1024];
231		let start = std::time::Instant::now();
232		let timeout = std::time::Duration::from_secs(5);
233
234		loop {
235			match self.stream.read(&mut buffer) {
236				Ok(0) => return Err("Connection closed during handshake".into()),
237				Ok(n) => {
238					response.extend_from_slice(&buffer[..n]);
239
240					// Check if we have the complete HTTP
241					// response
242					if let Some(end_pos) = find_header_end(&response) {
243						response.truncate(end_pos);
244						break;
245					}
246				}
247				Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
248					// No data available yet
249					if start.elapsed() > timeout {
250						return Err("Handshake timeout".into());
251					}
252					std::thread::sleep(std::time::Duration::from_millis(10));
253					continue;
254				}
255				Err(e) => return Err(e.into()),
256			}
257		}
258
259		// Verify handshake response
260		let response_str = String::from_utf8_lossy(&response);
261		if !response_str.contains("HTTP/1.1 101") {
262			return Err(format!("Invalid handshake response: {}", response_str).into());
263		}
264
265		// Verify Sec-WebSocket-Accept
266		let expected_accept = calculate_accept_key(&key);
267		if !response_str.contains(&format!("Sec-WebSocket-Accept: {}", expected_accept)) {
268			return Err("Invalid Sec-WebSocket-Accept".into());
269		}
270
271		self.is_connected = true;
272		Ok(())
273	}
274
275	/// Send a request over the WebSocket connection
276	pub(crate) fn send_request(&mut self, request: &Request) -> Result<(), Box<dyn std::error::Error>> {
277		if !self.is_connected {
278			return Err("Not connected".into());
279		}
280
281		// Serialize request to JSON
282		let json = serde_json::to_string(request)?;
283		let payload = json.as_bytes();
284
285		// Build WebSocket frame (text frame, opcode = 1)
286		let frame = build_ws_frame(0x01, payload, true);
287
288		// Send frame
289		self.stream.write_all(&frame)?;
290		self.stream.flush()?;
291
292		Ok(())
293	}
294
295	/// Receive a response from the WebSocket connection
296	pub fn receive(&mut self) -> Result<Option<Response>, Box<dyn std::error::Error>> {
297		if !self.is_connected {
298			return Err("Not connected".into());
299		}
300
301		// Read data into buffer
302		let mut buf = vec![0u8; 4096];
303		match self.stream.read(&mut buf) {
304			Ok(0) => {
305				self.is_connected = false;
306				return Err("Connection closed".into());
307			}
308			Ok(n) => {
309				self.read_buffer.extend_from_slice(&buf[..n]);
310			}
311			Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
312				// No data available
313				return Ok(None);
314			}
315			Err(e) => return Err(e.into()),
316		}
317
318		// Try to parse WebSocket frame
319		if let Some((opcode, payload)) = parse_ws_frame(&self.read_buffer)? {
320			// Remove parsed frame from buffer
321			let frame_size = calculate_frame_size(&payload, false);
322			self.read_buffer.drain(..frame_size);
323
324			match opcode {
325				0x01 | 0x02 => {
326					// Text or binary frame
327					let response: Response = serde_json::from_slice(&payload)?;
328					return Ok(Some(response));
329				}
330				0x08 => {
331					// Close frame
332					self.is_connected = false;
333					return Err("Connection closed by server".into());
334				}
335				0x09 => {
336					// Ping frame - respond with pong
337					let pong = build_ws_frame(0x0A, &payload, true);
338					self.stream.write_all(&pong)?;
339					self.stream.flush()?;
340				}
341				0x0A => {
342					// Pong frame - ignore
343				}
344				_ => {
345					// Unknown opcode
346					return Err(format!("Unknown opcode: {}", opcode).into());
347				}
348			}
349		}
350
351		Ok(None)
352	}
353
354	/// Close the WebSocket connection
355	pub fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
356		if self.is_connected {
357			// Send close frame
358			let close_frame = build_ws_frame(0x08, &[], true);
359			self.stream.write_all(&close_frame)?;
360			self.stream.flush()?;
361			self.is_connected = false;
362		}
363		Ok(())
364	}
365
366	/// Check if the client is connected
367	pub fn is_connected(&self) -> bool {
368		self.is_connected
369	}
370}
371
372impl Drop for WebSocketClient {
373	fn drop(&mut self) {
374		let _ = self.close();
375	}
376}