1use std::{
5 io::{Read, Write},
6 net::{SocketAddr, TcpStream, ToSocketAddrs},
7 sync::{Arc, Mutex, mpsc},
8 thread::JoinHandle,
9};
10
11use crate::{
12 Request, Response, ResponseMessage, WsBlockingSession, WsCallbackSession, WsChannelSession,
13 ws::{
14 message::InternalMessage,
15 protocol::{
16 build_ws_frame, calculate_accept_key, calculate_frame_size, find_header_end,
17 generate_websocket_key, parse_ws_frame,
18 },
19 router::RequestRouter,
20 worker,
21 },
22};
23
24#[derive(Clone)]
26pub struct WsClient {
27 inner: Arc<ClientInner>,
28}
29
30pub(crate) struct ClientInner {
31 pub(crate) command_tx: mpsc::Sender<InternalMessage>,
32 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
33}
34
35impl WsClient {
40 pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
42 let socket_addr = Self::parse_ws_url(url)?;
44
45 let (command_tx, command_rx) = mpsc::channel();
46 let router = Arc::new(Mutex::new(RequestRouter::new()));
47
48 let test_client = WebSocketClient::connect(socket_addr)?;
50 drop(test_client); let router_clone = router.clone();
54 let socket_addr_clone = socket_addr;
55 let worker_handle = std::thread::spawn(move || {
56 worker::worker_thread_with_addr(socket_addr_clone, command_rx, router_clone);
57 });
58
59 Ok(Self {
60 inner: Arc::new(ClientInner {
61 command_tx,
62 worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
63 }),
64 })
65 }
66
67 fn parse_ws_url(url: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
69 let addr_str = if url.starts_with("ws://") {
70 &url[5..] } else if url.starts_with("wss://") {
72 return Err("WSS (secure WebSocket) is not yet supported".into());
73 } else {
74 url
75 };
76
77 if addr_str.starts_with('[') {
85 addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
87 } else if addr_str.starts_with("::") {
88 let colon_count = addr_str.matches(':').count();
92 if colon_count > 2 {
93 if let Some(port_start) = addr_str.rfind(':') {
95 if addr_str[port_start + 1..].chars().all(|c| c.is_ascii_digit()) {
98 let ipv6_part = &addr_str[..port_start];
99 let port_part = &addr_str[port_start + 1..];
100 let formatted = format!("[{}]:{}", ipv6_part, port_part);
101 return formatted
102 .to_socket_addrs()?
103 .next()
104 .ok_or_else(|| "Failed to resolve address".into());
105 }
106 }
107 }
108 addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
110 } else {
111 addr_str.to_socket_addrs()?.next().ok_or_else(|| "Failed to resolve address".into())
113 }
114 }
115
116 pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
118 let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
120
121 let (command_tx, command_rx) = mpsc::channel();
122 let router = Arc::new(Mutex::new(RequestRouter::new()));
123
124 let test_client = WebSocketClient::connect(socket_addr)?;
126 drop(test_client); let router_clone = router.clone();
130 let socket_addr_clone = socket_addr;
131 let worker_handle = std::thread::spawn(move || {
132 worker::worker_thread_with_addr(socket_addr_clone, command_rx, router_clone);
133 });
134
135 Ok(Self {
136 inner: Arc::new(ClientInner {
137 command_tx,
138 worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
139 }),
140 })
141 }
142
143 pub fn blocking_session(&self, token: Option<String>) -> Result<WsBlockingSession, reifydb_type::Error> {
145 WsBlockingSession::new(self.inner.clone(), token)
146 }
147
148 pub fn callback_session(&self, token: Option<String>) -> Result<WsCallbackSession, reifydb_type::Error> {
150 WsCallbackSession::new(self.inner.clone(), token)
151 }
152
153 pub fn channel_session(
155 &self,
156 token: Option<String>,
157 ) -> Result<(WsChannelSession, mpsc::Receiver<ResponseMessage>), reifydb_type::Error> {
158 WsChannelSession::new(self.inner.clone(), token)
159 }
160
161 pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
163 self.inner.command_tx.send(InternalMessage::Close)?;
164
165 if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
167 if let Some(handle) = handle_guard.take() {
168 let _ = handle.join();
169 }
170 }
171 Ok(())
172 }
173}
174
175impl Drop for WsClient {
176 fn drop(&mut self) {
177 let _ = self.inner.command_tx.send(InternalMessage::Close);
178 }
179}
180
181pub struct WebSocketClient {
183 pub(crate) stream: TcpStream,
184 read_buffer: Vec<u8>,
185 pub(crate) is_connected: bool,
186}
187
188impl WebSocketClient {
189 pub fn connect(addr: SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
191 let stream = TcpStream::connect(addr)?;
193 stream.set_nonblocking(true)?;
194
195 let mut client = WebSocketClient {
196 stream,
197 read_buffer: Vec::with_capacity(4096),
198 is_connected: false,
199 };
200
201 client.handshake()?;
203
204 Ok(client)
205 }
206
207 fn handshake(&mut self) -> Result<(), Box<dyn std::error::Error>> {
209 let key = generate_websocket_key();
211
212 let request = format!(
214 "GET / HTTP/1.1\r\n\
215Host: localhost\r\n\
216Upgrade: websocket\r\n\
217Connection: Upgrade\r\n\
218Sec-WebSocket-Key: {}\r\n\
219Sec-WebSocket-Version: 13\r\n\
220\r\n",
221 key
222 );
223
224 self.stream.write_all(request.as_bytes())?;
226 self.stream.flush()?;
227
228 let mut response = Vec::new();
230 let mut buffer = [0u8; 1024];
231 let start = std::time::Instant::now();
232 let timeout = std::time::Duration::from_secs(5);
233
234 loop {
235 match self.stream.read(&mut buffer) {
236 Ok(0) => return Err("Connection closed during handshake".into()),
237 Ok(n) => {
238 response.extend_from_slice(&buffer[..n]);
239
240 if let Some(end_pos) = find_header_end(&response) {
243 response.truncate(end_pos);
244 break;
245 }
246 }
247 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
248 if start.elapsed() > timeout {
250 return Err("Handshake timeout".into());
251 }
252 std::thread::sleep(std::time::Duration::from_millis(10));
253 continue;
254 }
255 Err(e) => return Err(e.into()),
256 }
257 }
258
259 let response_str = String::from_utf8_lossy(&response);
261 if !response_str.contains("HTTP/1.1 101") {
262 return Err(format!("Invalid handshake response: {}", response_str).into());
263 }
264
265 let expected_accept = calculate_accept_key(&key);
267 let response_lower = response_str.to_lowercase();
268 let accept_pattern = format!("sec-websocket-accept: {}", expected_accept).to_lowercase();
269 if !response_lower.contains(&accept_pattern) {
270 return Err(format!(
271 "Invalid Sec-WebSocket-Accept. Expected: {}, Response: {}",
272 expected_accept, response_str
273 )
274 .into());
275 }
276
277 self.is_connected = true;
278 Ok(())
279 }
280
281 pub(crate) fn send_request(&mut self, request: &Request) -> Result<(), Box<dyn std::error::Error>> {
283 if !self.is_connected {
284 return Err("Not connected".into());
285 }
286
287 let json = serde_json::to_string(request)?;
289 let payload = json.as_bytes();
290
291 let frame = build_ws_frame(0x01, payload, true);
293
294 self.stream.write_all(&frame)?;
296 self.stream.flush()?;
297
298 Ok(())
299 }
300
301 pub fn receive(&mut self) -> Result<Option<Response>, Box<dyn std::error::Error>> {
303 if !self.is_connected {
304 return Err("Not connected".into());
305 }
306
307 let mut buf = vec![0u8; 4096];
309 match self.stream.read(&mut buf) {
310 Ok(0) => {
311 self.is_connected = false;
312 return Err("Connection closed".into());
313 }
314 Ok(n) => {
315 self.read_buffer.extend_from_slice(&buf[..n]);
316 }
317 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
318 return Ok(None);
320 }
321 Err(e) => return Err(e.into()),
322 }
323
324 if let Some((opcode, payload)) = parse_ws_frame(&self.read_buffer)? {
326 let frame_size = calculate_frame_size(&payload, false);
328 self.read_buffer.drain(..frame_size);
329
330 match opcode {
331 0x01 | 0x02 => {
332 let response: Response = serde_json::from_slice(&payload)?;
334 return Ok(Some(response));
335 }
336 0x08 => {
337 self.is_connected = false;
339 return Err("Connection closed by server".into());
340 }
341 0x09 => {
342 let pong = build_ws_frame(0x0A, &payload, true);
344 self.stream.write_all(&pong)?;
345 self.stream.flush()?;
346 }
347 0x0A => {
348 }
350 _ => {
351 return Err(format!("Unknown opcode: {}", opcode).into());
353 }
354 }
355 }
356
357 Ok(None)
358 }
359
360 pub fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
362 if self.is_connected {
363 let close_frame = build_ws_frame(0x08, &[], true);
365 self.stream.write_all(&close_frame)?;
366 self.stream.flush()?;
367 self.is_connected = false;
368 }
369 Ok(())
370 }
371
372 pub fn is_connected(&self) -> bool {
374 self.is_connected
375 }
376}
377
378impl Drop for WebSocketClient {
379 fn drop(&mut self) {
380 let _ = self.close();
381 }
382}