1use std::{
5 collections::HashMap,
6 io::{BufRead, BufReader, Read, Write},
7 net::{SocketAddr, TcpStream, ToSocketAddrs},
8 sync::{Arc, Mutex, mpsc},
9 thread::{self, JoinHandle},
10 time::Duration,
11};
12
13use serde_json;
14
15use crate::{
16 CommandRequest, CommandResponse, ErrResponse, QueryRequest, QueryResponse,
17 http::{message::HttpInternalMessage, worker::http_worker_thread},
18};
19
20#[derive(Clone)]
22pub struct HttpClient {
23 inner: Arc<HttpClientInner>,
24}
25
26pub(crate) struct HttpClientInner {
27 pub(crate) command_tx: mpsc::Sender<HttpInternalMessage>,
28 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
29}
30
31#[derive(Clone)]
33pub(crate) struct HttpClientConfig {
34 pub(crate) host: String,
35 pub(crate) port: u16,
36 pub(crate) _timeout: Duration,
37}
38
39impl Drop for HttpClient {
40 fn drop(&mut self) {
41 let _ = self.inner.command_tx.send(HttpInternalMessage::Close);
42 }
43}
44
45impl HttpClient {
46 pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
48 let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
50
51 let host = socket_addr.ip().to_string();
52 let port = socket_addr.port();
53
54 let config = HttpClientConfig {
55 host,
56 port,
57 _timeout: Duration::from_secs(30),
58 };
59
60 Self::with_config(config)
61 }
62
63 pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
65 let url = if url.starts_with("http://") {
66 &url[7..] } else if url.starts_with("https://") {
68 return Err("HTTPS is not yet supported".into());
69 } else {
70 url
71 };
72
73 let (host, port) = if url.starts_with('[') {
75 if let Some(end_bracket) = url.find(']') {
77 let host = &url[1..end_bracket];
78 let port_str = &url[end_bracket + 1..];
79 let port = if port_str.starts_with(':') {
80 port_str[1..].parse()?
81 } else {
82 80
83 };
84 (host.to_string(), port)
85 } else {
86 return Err("Invalid IPv6 address format".into());
87 }
88 } else if url.starts_with("::") || url.contains("::") {
89 if let Some(port_idx) = url.rfind(':') {
92 if url[port_idx + 1..].chars().all(|c| c.is_ascii_digit()) {
95 let host = &url[..port_idx];
96 let port: u16 = url[port_idx + 1..].parse()?;
97 (host.to_string(), port)
98 } else {
99 (url.to_string(), 80)
101 }
102 } else {
103 (url.to_string(), 80)
104 }
105 } else {
106 if let Some(colon_idx) = url.find(':') {
108 let host = &url[..colon_idx];
109 let port: u16 = url[colon_idx + 1..].parse()?;
110 (host.to_string(), port)
111 } else {
112 (url.to_string(), 80)
113 }
114 };
115
116 Self::new((host.as_str(), port))
117 }
118
119 fn with_config(config: HttpClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
121 let (command_tx, command_rx) = mpsc::channel();
122
123 let test_config = config.clone();
125 test_config.test_connection()?;
126
127 let worker_config = config.clone();
129 let worker_handle = thread::spawn(move || {
130 http_worker_thread(worker_config, command_rx);
131 });
132
133 Ok(Self {
134 inner: Arc::new(HttpClientInner {
135 command_tx,
136 worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
137 }),
138 })
139 }
140
141 pub(crate) fn command_tx(&self) -> &mpsc::Sender<HttpInternalMessage> {
143 &self.inner.command_tx
144 }
145
146 pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
148 self.inner.command_tx.send(HttpInternalMessage::Close)?;
149
150 if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
152 if let Some(handle) = handle_guard.take() {
153 let _ = handle.join();
154 }
155 }
156 Ok(())
157 }
158
159 pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
161 Ok(())
163 }
164
165 pub fn blocking_session(
167 &self,
168 token: Option<String>,
169 ) -> Result<crate::http::HttpBlockingSession, reifydb_type::Error> {
170 crate::http::HttpBlockingSession::from_client(self.clone(), token)
171 }
172
173 pub fn callback_session(
175 &self,
176 token: Option<String>,
177 ) -> Result<crate::http::HttpCallbackSession, reifydb_type::Error> {
178 crate::http::HttpCallbackSession::from_client(self.clone(), token)
179 }
180
181 pub fn channel_session(
183 &self,
184 token: Option<String>,
185 ) -> Result<
186 (crate::http::HttpChannelSession, mpsc::Receiver<crate::http::HttpResponseMessage>),
187 reifydb_type::Error,
188 > {
189 crate::http::HttpChannelSession::from_client(self.clone(), token)
190 }
191}
192
193impl HttpClientConfig {
194 pub fn send_command(&self, request: &CommandRequest) -> Result<CommandResponse, reifydb_type::Error> {
196 let json_body = serde_json::to_string(request).map_err(|e| {
197 reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
198 "Failed to serialize request: {}",
199 e
200 )))
201 })?;
202 let response_body = self.send_request("/v1/command", &json_body).map_err(|e| {
203 reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
204 })?;
205
206 match serde_json::from_str::<CommandResponse>(&response_body) {
208 Ok(response) => Ok(response),
209 Err(_) => {
210 match serde_json::from_str::<ErrResponse>(&response_body) {
212 Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
213 Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
214 format!("Failed to parse response: {}", response_body),
215 ))),
216 }
217 }
218 }
219 }
220
221 pub fn send_query(&self, request: &QueryRequest) -> Result<QueryResponse, reifydb_type::Error> {
223 let json_body = serde_json::to_string(request).map_err(|e| {
224 reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
225 "Failed to serialize request: {}",
226 e
227 )))
228 })?;
229 let response_body = self.send_request("/v1/query", &json_body).map_err(|e| {
230 reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
231 })?;
232
233 match serde_json::from_str::<QueryResponse>(&response_body) {
235 Ok(response) => Ok(response),
236 Err(_) => {
237 match serde_json::from_str::<ErrResponse>(&response_body) {
239 Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
240 Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
241 format!("Failed to parse response: {}", response_body),
242 ))),
243 }
244 }
245 }
246 }
247
248 fn send_request(&self, path: &str, body: &str) -> Result<String, Box<dyn std::error::Error>> {
250 let addr_str = if self.host.contains(':') {
253 format!("[{}]:{}", self.host, self.port)
254 } else {
255 format!("{}:{}", self.host, self.port)
256 };
257 let addr: SocketAddr = addr_str.parse()?;
258
259 let mut stream = TcpStream::connect(addr)?;
261
262 let body_bytes = body.as_bytes();
264
265 let header = format!(
267 "POST {} HTTP/1.1\r\n\
268 Host: {}\r\n\
269 Content-Type: application/json\r\n\
270 Content-Length: {}\r\n\
271 Connection: close\r\n\
272 \r\n",
273 path,
274 self.host,
275 body_bytes.len()
276 );
277
278 stream.write_all(header.as_bytes())?;
280 stream.write_all(body_bytes)?;
281 stream.flush()?;
282
283 self.parse_http_response_buffered(stream)
285 }
286
287 fn parse_http_response_buffered(&self, stream: TcpStream) -> Result<String, Box<dyn std::error::Error>> {
289 let mut reader = BufReader::new(stream);
290 let mut line = String::new();
291
292 reader.read_line(&mut line)?;
294 let status_line = line.trim_end();
295 let status_parts: Vec<&str> = status_line.split_whitespace().collect();
296
297 if status_parts.len() < 3 {
298 return Err("Invalid HTTP status line".into());
299 }
300
301 let status_code: u16 = status_parts[1].parse()?;
302 if status_code < 200 || status_code >= 300 {
303 return Err(
304 format!("HTTP error {}: {}", status_code, status_parts.get(2).unwrap_or(&"")).into()
305 );
306 }
307
308 let mut headers = HashMap::new();
310 let mut content_length: Option<usize> = None;
311 let mut is_chunked = false;
312
313 loop {
314 line.clear();
315 reader.read_line(&mut line)?;
316
317 if line == "\r\n" || line == "\n" {
318 break; }
320
321 if let Some(colon_pos) = line.find(':') {
322 let key = line[..colon_pos].trim().to_lowercase();
323 let value = line[colon_pos + 1..].trim().to_string();
324
325 if key == "content-length" {
326 content_length = value.parse().ok();
327 } else if key == "transfer-encoding" && value.contains("chunked") {
328 is_chunked = true;
329 }
330
331 headers.insert(key, value);
332 }
333 }
334
335 let body = if is_chunked {
337 self.read_chunked_body(&mut reader)?
338 } else if let Some(length) = content_length {
339 let mut body = vec![0u8; length];
341 reader.read_exact(&mut body)?;
342 String::from_utf8(body)?
343 } else {
344 let mut body = String::new();
346 reader.read_to_string(&mut body)?;
347 body
348 };
349
350 Ok(body)
351 }
352
353 fn read_chunked_body(&self, reader: &mut BufReader<TcpStream>) -> Result<String, Box<dyn std::error::Error>> {
355 let mut result = Vec::new();
356 let mut line = String::new();
357
358 loop {
359 line.clear();
361 reader.read_line(&mut line)?;
362
363 let size_str = line.trim().split(';').next().unwrap_or("0");
366 let chunk_size = usize::from_str_radix(size_str, 16)?;
367
368 if chunk_size == 0 {
369 loop {
371 line.clear();
372 reader.read_line(&mut line)?;
373 if line == "\r\n" || line == "\n" {
374 break;
375 }
376 }
377 break;
378 }
379
380 let mut chunk = vec![0u8; chunk_size];
382 reader.read_exact(&mut chunk)?;
383 result.extend_from_slice(&chunk);
384
385 line.clear();
387 reader.read_line(&mut line)?;
388 }
389
390 String::from_utf8(result).map_err(|e| e.into())
391 }
392
393 pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
395 let addr_str = if self.host.contains(':') {
397 format!("[{}]:{}", self.host, self.port)
398 } else {
399 format!("{}:{}", self.host, self.port)
400 };
401 let addr: SocketAddr = addr_str.parse()?;
402 let _stream = TcpStream::connect(addr)?;
403 Ok(())
404 }
405}