reifydb_client/http/
client.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT
3
4use std::{
5	collections::HashMap,
6	io::{BufRead, BufReader, Read, Write},
7	net::{SocketAddr, TcpStream, ToSocketAddrs},
8	sync::{Arc, Mutex, mpsc},
9	thread::{self, JoinHandle},
10	time::Duration,
11};
12
13use serde_json;
14
15use crate::{
16	CommandRequest, CommandResponse, ErrResponse, QueryRequest, QueryResponse,
17	http::{message::HttpInternalMessage, worker::http_worker_thread},
18};
19
20/// HTTP client implementation with worker thread
21#[derive(Clone)]
22pub struct HttpClient {
23	inner: Arc<HttpClientInner>,
24}
25
26pub(crate) struct HttpClientInner {
27	pub(crate) command_tx: mpsc::Sender<HttpInternalMessage>,
28	worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
29}
30
31/// HTTP client configuration for the worker thread
32#[derive(Clone)]
33pub(crate) struct HttpClientConfig {
34	pub(crate) host: String,
35	pub(crate) port: u16,
36	pub(crate) _timeout: Duration,
37}
38
39impl Drop for HttpClient {
40	fn drop(&mut self) {
41		let _ = self.inner.command_tx.send(HttpInternalMessage::Close);
42	}
43}
44
45impl HttpClient {
46	/// Create a new HTTP client from a socket address
47	pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
48		// Resolve the address to get the first valid SocketAddr
49		let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
50
51		let host = socket_addr.ip().to_string();
52		let port = socket_addr.port();
53
54		let config = HttpClientConfig {
55			host,
56			port,
57			_timeout: Duration::from_secs(30),
58		};
59
60		Self::with_config(config)
61	}
62
63	/// Create HTTP client from URL (e.g., "http://localhost:8080")
64	pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
65		let url = if url.starts_with("http://") {
66			&url[7..] // Remove "http://"
67		} else if url.starts_with("https://") {
68			return Err("HTTPS is not yet supported".into());
69		} else {
70			url
71		};
72
73		// Parse host and port, handling IPv6 addresses
74		let (host, port) = if url.starts_with('[') {
75			// IPv6 address with brackets: [::1]:8080
76			if let Some(end_bracket) = url.find(']') {
77				let host = &url[1..end_bracket];
78				let port_str = &url[end_bracket + 1..];
79				let port = if port_str.starts_with(':') {
80					port_str[1..].parse()?
81				} else {
82					80
83				};
84				(host.to_string(), port)
85			} else {
86				return Err("Invalid IPv6 address format".into());
87			}
88		} else if url.starts_with("::") || url.contains("::") {
89			// IPv6 address without brackets: ::1:8080
90			// Find the last colon that's likely the port separator
91			if let Some(port_idx) = url.rfind(':') {
92				// Check if what follows the last colon is a
93				// port number
94				if url[port_idx + 1..].chars().all(|c| c.is_ascii_digit()) {
95					let host = &url[..port_idx];
96					let port: u16 = url[port_idx + 1..].parse()?;
97					(host.to_string(), port)
98				} else {
99					// No port specified, use default
100					(url.to_string(), 80)
101				}
102			} else {
103				(url.to_string(), 80)
104			}
105		} else {
106			// Regular hostname or IPv4 address
107			if let Some(colon_idx) = url.find(':') {
108				let host = &url[..colon_idx];
109				let port: u16 = url[colon_idx + 1..].parse()?;
110				(host.to_string(), port)
111			} else {
112				(url.to_string(), 80)
113			}
114		};
115
116		Self::new((host.as_str(), port))
117	}
118
119	/// Create HTTP client with specific configuration
120	fn with_config(config: HttpClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
121		let (command_tx, command_rx) = mpsc::channel();
122
123		// Test connection first
124		let test_config = config.clone();
125		test_config.test_connection()?;
126
127		// Start the background worker thread
128		let worker_config = config.clone();
129		let worker_handle = thread::spawn(move || {
130			http_worker_thread(worker_config, command_rx);
131		});
132
133		Ok(Self {
134			inner: Arc::new(HttpClientInner {
135				command_tx,
136				worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
137			}),
138		})
139	}
140
141	/// Get the command sender for internal use
142	pub(crate) fn command_tx(&self) -> &mpsc::Sender<HttpInternalMessage> {
143		&self.inner.command_tx
144	}
145
146	/// Close the client connection
147	pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
148		self.inner.command_tx.send(HttpInternalMessage::Close)?;
149
150		// Wait for worker thread to finish
151		if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
152			if let Some(handle) = handle_guard.take() {
153				let _ = handle.join();
154			}
155		}
156		Ok(())
157	}
158
159	/// Test connection to the server
160	pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
161		// The connection was already tested during creation
162		Ok(())
163	}
164
165	/// Create a blocking session
166	pub fn blocking_session(
167		&self,
168		token: Option<String>,
169	) -> Result<crate::http::HttpBlockingSession, reifydb_type::Error> {
170		crate::http::HttpBlockingSession::from_client(self.clone(), token)
171	}
172
173	/// Create a callback session
174	pub fn callback_session(
175		&self,
176		token: Option<String>,
177	) -> Result<crate::http::HttpCallbackSession, reifydb_type::Error> {
178		crate::http::HttpCallbackSession::from_client(self.clone(), token)
179	}
180
181	/// Create a channel session
182	pub fn channel_session(
183		&self,
184		token: Option<String>,
185	) -> Result<
186		(crate::http::HttpChannelSession, mpsc::Receiver<crate::http::HttpResponseMessage>),
187		reifydb_type::Error,
188	> {
189		crate::http::HttpChannelSession::from_client(self.clone(), token)
190	}
191}
192
193impl HttpClientConfig {
194	/// Send a command request
195	pub fn send_command(&self, request: &CommandRequest) -> Result<CommandResponse, reifydb_type::Error> {
196		let json_body = serde_json::to_string(request).map_err(|e| {
197			reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
198				"Failed to serialize request: {}",
199				e
200			)))
201		})?;
202		let response_body = self.send_request("/v1/command", &json_body).map_err(|e| {
203			reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
204		})?;
205
206		// Try to parse as CommandResponse first, then as error
207		match serde_json::from_str::<CommandResponse>(&response_body) {
208			Ok(response) => Ok(response),
209			Err(_) => {
210				// Try parsing as error response
211				match serde_json::from_str::<ErrResponse>(&response_body) {
212					Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
213					Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
214						format!("Failed to parse response: {}", response_body),
215					))),
216				}
217			}
218		}
219	}
220
221	/// Send a query request
222	pub fn send_query(&self, request: &QueryRequest) -> Result<QueryResponse, reifydb_type::Error> {
223		let json_body = serde_json::to_string(request).map_err(|e| {
224			reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
225				"Failed to serialize request: {}",
226				e
227			)))
228		})?;
229		let response_body = self.send_request("/v1/query", &json_body).map_err(|e| {
230			reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
231		})?;
232
233		// Try to parse as QueryResponse first, then as error
234		match serde_json::from_str::<QueryResponse>(&response_body) {
235			Ok(response) => Ok(response),
236			Err(_) => {
237				// Try parsing as error response
238				match serde_json::from_str::<ErrResponse>(&response_body) {
239					Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
240					Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
241						format!("Failed to parse response: {}", response_body),
242					))),
243				}
244			}
245		}
246	}
247
248	/// Send HTTP request and return response body
249	fn send_request(&self, path: &str, body: &str) -> Result<String, Box<dyn std::error::Error>> {
250		// Parse socket address
251		// Check if host is an IPv6 address by looking for colons
252		let addr_str = if self.host.contains(':') {
253			format!("[{}]:{}", self.host, self.port)
254		} else {
255			format!("{}:{}", self.host, self.port)
256		};
257		let addr: SocketAddr = addr_str.parse()?;
258
259		// Create TCP connection
260		let mut stream = TcpStream::connect(addr)?;
261
262		// Convert body to bytes first to get accurate Content-Length
263		let body_bytes = body.as_bytes();
264
265		// Build HTTP request header
266		let header = format!(
267			"POST {} HTTP/1.1\r\n\
268			Host: {}\r\n\
269			Content-Type: application/json\r\n\
270			Content-Length: {}\r\n\
271			Connection: close\r\n\
272			\r\n",
273			path,
274			self.host,
275			body_bytes.len()
276		);
277
278		// Send request header and body
279		stream.write_all(header.as_bytes())?;
280		stream.write_all(body_bytes)?;
281		stream.flush()?;
282
283		// Parse HTTP response using buffered reader
284		self.parse_http_response_buffered(stream)
285	}
286
287	/// Parse HTTP response using buffered reading for large responses
288	fn parse_http_response_buffered(&self, stream: TcpStream) -> Result<String, Box<dyn std::error::Error>> {
289		let mut reader = BufReader::new(stream);
290		let mut line = String::new();
291
292		// Read status line
293		reader.read_line(&mut line)?;
294		let status_line = line.trim_end();
295		let status_parts: Vec<&str> = status_line.split_whitespace().collect();
296
297		if status_parts.len() < 3 {
298			return Err("Invalid HTTP status line".into());
299		}
300
301		let status_code: u16 = status_parts[1].parse()?;
302		if status_code < 200 || status_code >= 300 {
303			return Err(
304				format!("HTTP error {}: {}", status_code, status_parts.get(2).unwrap_or(&"")).into()
305			);
306		}
307
308		// Read headers
309		let mut headers = HashMap::new();
310		let mut content_length: Option<usize> = None;
311		let mut is_chunked = false;
312
313		loop {
314			line.clear();
315			reader.read_line(&mut line)?;
316
317			if line == "\r\n" || line == "\n" {
318				break; // End of headers
319			}
320
321			if let Some(colon_pos) = line.find(':') {
322				let key = line[..colon_pos].trim().to_lowercase();
323				let value = line[colon_pos + 1..].trim().to_string();
324
325				if key == "content-length" {
326					content_length = value.parse().ok();
327				} else if key == "transfer-encoding" && value.contains("chunked") {
328					is_chunked = true;
329				}
330
331				headers.insert(key, value);
332			}
333		}
334
335		// Read body based on transfer method
336		let body = if is_chunked {
337			self.read_chunked_body(&mut reader)?
338		} else if let Some(length) = content_length {
339			// Read exact content length
340			let mut body = vec![0u8; length];
341			reader.read_exact(&mut body)?;
342			String::from_utf8(body)?
343		} else {
344			// Read until EOF (Connection: close)
345			let mut body = String::new();
346			reader.read_to_string(&mut body)?;
347			body
348		};
349
350		Ok(body)
351	}
352
353	/// Read chunked HTTP response body
354	fn read_chunked_body(&self, reader: &mut BufReader<TcpStream>) -> Result<String, Box<dyn std::error::Error>> {
355		let mut result = Vec::new();
356		let mut line = String::new();
357
358		loop {
359			// Read chunk size line
360			line.clear();
361			reader.read_line(&mut line)?;
362
363			// Parse chunk size (hexadecimal), ignoring any chunk
364			// extensions after ';'
365			let size_str = line.trim().split(';').next().unwrap_or("0");
366			let chunk_size = usize::from_str_radix(size_str, 16)?;
367
368			if chunk_size == 0 {
369				// Last chunk - read trailing headers if any
370				loop {
371					line.clear();
372					reader.read_line(&mut line)?;
373					if line == "\r\n" || line == "\n" {
374						break;
375					}
376				}
377				break;
378			}
379
380			// Read exact chunk data
381			let mut chunk = vec![0u8; chunk_size];
382			reader.read_exact(&mut chunk)?;
383			result.extend_from_slice(&chunk);
384
385			// Read trailing CRLF after chunk data
386			line.clear();
387			reader.read_line(&mut line)?;
388		}
389
390		String::from_utf8(result).map_err(|e| e.into())
391	}
392
393	/// Test connection to the server
394	pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
395		// Check if host is an IPv6 address by looking for colons
396		let addr_str = if self.host.contains(':') {
397			format!("[{}]:{}", self.host, self.port)
398		} else {
399			format!("{}:{}", self.host, self.port)
400		};
401		let addr: SocketAddr = addr_str.parse()?;
402		let _stream = TcpStream::connect(addr)?;
403		Ok(())
404	}
405}