1use std::{
5 io::{BufRead, BufReader, Read, Write},
6 net::{SocketAddr, TcpStream, ToSocketAddrs},
7 sync::{Arc, Mutex, mpsc},
8 thread::{self, JoinHandle},
9 time::Duration,
10};
11
12use serde_json;
13
14use crate::{
15 CommandRequest, CommandResponse, ErrResponse, QueryRequest, QueryResponse,
16 http::{message::HttpInternalMessage, worker::http_worker_thread},
17};
18
19#[derive(Debug, serde::Deserialize)]
21struct HttpErrorResponse {
22 code: String,
23 error: String,
24 #[serde(default)]
26 diagnostic: Option<reifydb_type::diagnostic::Diagnostic>,
27}
28
29#[derive(Clone)]
31pub struct HttpClient {
32 inner: Arc<HttpClientInner>,
33}
34
35pub(crate) struct HttpClientInner {
36 pub(crate) command_tx: mpsc::Sender<HttpInternalMessage>,
37 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
38}
39
40#[derive(Clone)]
42pub(crate) struct HttpClientConfig {
43 pub(crate) host: String,
44 pub(crate) port: u16,
45 pub(crate) _timeout: Duration,
46}
47
48impl Drop for HttpClient {
49 fn drop(&mut self) {
50 let _ = self.inner.command_tx.send(HttpInternalMessage::Close);
51 }
52}
53
54impl HttpClient {
55 pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
57 let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
59
60 let host = socket_addr.ip().to_string();
61 let port = socket_addr.port();
62
63 let config = HttpClientConfig {
64 host,
65 port,
66 _timeout: Duration::from_secs(30),
67 };
68
69 Self::with_config(config)
70 }
71
72 pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
74 let url = if url.starts_with("http://") {
75 &url[7..] } else if url.starts_with("https://") {
77 return Err("HTTPS is not yet supported".into());
78 } else {
79 url
80 };
81
82 let (host, port) = if url.starts_with('[') {
84 if let Some(end_bracket) = url.find(']') {
86 let host = &url[1..end_bracket];
87 let port_str = &url[end_bracket + 1..];
88 let port = if port_str.starts_with(':') {
89 port_str[1..].parse()?
90 } else {
91 80
92 };
93 (host.to_string(), port)
94 } else {
95 return Err("Invalid IPv6 address format".into());
96 }
97 } else if url.starts_with("::") || url.contains("::") {
98 if let Some(port_idx) = url.rfind(':') {
101 if url[port_idx + 1..].chars().all(|c| c.is_ascii_digit()) {
104 let host = &url[..port_idx];
105 let port: u16 = url[port_idx + 1..].parse()?;
106 (host.to_string(), port)
107 } else {
108 (url.to_string(), 80)
110 }
111 } else {
112 (url.to_string(), 80)
113 }
114 } else {
115 if let Some(colon_idx) = url.find(':') {
117 let host = &url[..colon_idx];
118 let port: u16 = url[colon_idx + 1..].parse()?;
119 (host.to_string(), port)
120 } else {
121 (url.to_string(), 80)
122 }
123 };
124
125 Self::new((host.as_str(), port))
126 }
127
128 fn with_config(config: HttpClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
130 let (command_tx, command_rx) = mpsc::channel();
131
132 let test_config = config.clone();
134 test_config.test_connection()?;
135
136 let worker_config = config.clone();
138 let worker_handle = thread::spawn(move || {
139 http_worker_thread(worker_config, command_rx);
140 });
141
142 Ok(Self {
143 inner: Arc::new(HttpClientInner {
144 command_tx,
145 worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
146 }),
147 })
148 }
149
150 pub(crate) fn command_tx(&self) -> &mpsc::Sender<HttpInternalMessage> {
152 &self.inner.command_tx
153 }
154
155 pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
157 self.inner.command_tx.send(HttpInternalMessage::Close)?;
158
159 if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
161 if let Some(handle) = handle_guard.take() {
162 let _ = handle.join();
163 }
164 }
165 Ok(())
166 }
167
168 pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
170 Ok(())
172 }
173
174 pub fn blocking_session(
176 &self,
177 token: Option<String>,
178 ) -> Result<crate::http::HttpBlockingSession, reifydb_type::Error> {
179 crate::http::HttpBlockingSession::from_client(self.clone(), token)
180 }
181
182 pub fn callback_session(
184 &self,
185 token: Option<String>,
186 ) -> Result<crate::http::HttpCallbackSession, reifydb_type::Error> {
187 crate::http::HttpCallbackSession::from_client(self.clone(), token)
188 }
189
190 pub fn channel_session(
192 &self,
193 token: Option<String>,
194 ) -> Result<
195 (crate::http::HttpChannelSession, mpsc::Receiver<crate::http::HttpResponseMessage>),
196 reifydb_type::Error,
197 > {
198 crate::http::HttpChannelSession::from_client(self.clone(), token)
199 }
200}
201
202impl HttpClientConfig {
203 pub fn send_command(&self, request: &CommandRequest) -> Result<CommandResponse, reifydb_type::Error> {
205 let json_body = serde_json::to_string(request).map_err(|e| {
206 reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
207 "Failed to serialize request: {}",
208 e
209 )))
210 })?;
211 let response_body = self.send_request("/v1/command", &json_body).map_err(|e| {
212 reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
213 })?;
214
215 match serde_json::from_str::<CommandResponse>(&response_body) {
217 Ok(response) => Ok(response),
218 Err(_) => {
219 if let Ok(http_err) = serde_json::from_str::<HttpErrorResponse>(&response_body) {
221 let diagnostic = http_err.diagnostic.unwrap_or_else(|| {
223 reifydb_type::diagnostic::Diagnostic {
224 code: http_err.code,
225 message: http_err.error,
226 ..Default::default()
227 }
228 });
229 return Err(reifydb_type::Error(diagnostic));
230 }
231 match serde_json::from_str::<ErrResponse>(&response_body) {
233 Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
234 Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
235 format!("Failed to parse response: {}", response_body),
236 ))),
237 }
238 }
239 }
240 }
241
242 pub fn send_query(&self, request: &QueryRequest) -> Result<QueryResponse, reifydb_type::Error> {
244 let json_body = serde_json::to_string(request).map_err(|e| {
245 reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
246 "Failed to serialize request: {}",
247 e
248 )))
249 })?;
250 let response_body = self.send_request("/v1/query", &json_body).map_err(|e| {
251 reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
252 })?;
253
254 match serde_json::from_str::<QueryResponse>(&response_body) {
256 Ok(response) => Ok(response),
257 Err(_) => {
258 if let Ok(http_err) = serde_json::from_str::<HttpErrorResponse>(&response_body) {
260 let diagnostic = http_err.diagnostic.unwrap_or_else(|| {
262 reifydb_type::diagnostic::Diagnostic {
263 code: http_err.code,
264 message: http_err.error,
265 ..Default::default()
266 }
267 });
268 return Err(reifydb_type::Error(diagnostic));
269 }
270 match serde_json::from_str::<ErrResponse>(&response_body) {
272 Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
273 Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
274 format!("Failed to parse response: {}", response_body),
275 ))),
276 }
277 }
278 }
279 }
280
281 fn send_request(&self, path: &str, body: &str) -> Result<String, Box<dyn std::error::Error>> {
283 let addr_str = if self.host.contains(':') {
286 format!("[{}]:{}", self.host, self.port)
287 } else {
288 format!("{}:{}", self.host, self.port)
289 };
290 let addr: SocketAddr = addr_str.parse()?;
291
292 let mut stream = TcpStream::connect(addr)?;
294
295 let body_bytes = body.as_bytes();
297
298 let header = format!(
300 "POST {} HTTP/1.1\r\n\
301 Host: {}\r\n\
302 Content-Type: application/json\r\n\
303 Content-Length: {}\r\n\
304 Authorization: Bearer mysecrettoken\r\n\
305 Connection: close\r\n\
306 \r\n",
307 path,
308 self.host,
309 body_bytes.len()
310 );
311
312 stream.write_all(header.as_bytes())?;
314 stream.write_all(body_bytes)?;
315 stream.flush()?;
316
317 self.parse_http_response_buffered(stream)
319 }
320
321 fn parse_http_response_buffered(&self, stream: TcpStream) -> Result<String, Box<dyn std::error::Error>> {
323 let mut reader = BufReader::new(stream);
324 let mut line = String::new();
325
326 reader.read_line(&mut line)?;
328 let status_line = line.trim_end();
329 let status_parts: Vec<&str> = status_line.split_whitespace().collect();
330
331 if status_parts.len() < 3 {
332 return Err("Invalid HTTP status line".into());
333 }
334
335 let mut content_length: Option<usize> = None;
337 let mut is_chunked = false;
338
339 loop {
340 line.clear();
341 reader.read_line(&mut line)?;
342
343 if line == "\r\n" || line == "\n" {
344 break; }
346
347 if let Some(colon_pos) = line.find(':') {
348 let key = line[..colon_pos].trim().to_lowercase();
349 let value = line[colon_pos + 1..].trim();
350
351 if key == "content-length" {
352 content_length = value.parse().ok();
353 } else if key == "transfer-encoding" && value.contains("chunked") {
354 is_chunked = true;
355 }
356 }
357 }
358
359 let body = if is_chunked {
361 self.read_chunked_body(&mut reader)?
362 } else if let Some(length) = content_length {
363 let mut body = vec![0u8; length];
365 reader.read_exact(&mut body)?;
366 String::from_utf8(body)?
367 } else {
368 let mut body = String::new();
370 reader.read_to_string(&mut body)?;
371 body
372 };
373
374 Ok(body)
375 }
376
377 fn read_chunked_body(&self, reader: &mut BufReader<TcpStream>) -> Result<String, Box<dyn std::error::Error>> {
379 let mut result = Vec::new();
380 let mut line = String::new();
381
382 loop {
383 line.clear();
385 reader.read_line(&mut line)?;
386
387 let size_str = line.trim().split(';').next().unwrap_or("0");
390 let chunk_size = usize::from_str_radix(size_str, 16)?;
391
392 if chunk_size == 0 {
393 loop {
395 line.clear();
396 reader.read_line(&mut line)?;
397 if line == "\r\n" || line == "\n" {
398 break;
399 }
400 }
401 break;
402 }
403
404 let mut chunk = vec![0u8; chunk_size];
406 reader.read_exact(&mut chunk)?;
407 result.extend_from_slice(&chunk);
408
409 line.clear();
411 reader.read_line(&mut line)?;
412 }
413
414 String::from_utf8(result).map_err(|e| e.into())
415 }
416
417 pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
419 let addr_str = if self.host.contains(':') {
421 format!("[{}]:{}", self.host, self.port)
422 } else {
423 format!("{}:{}", self.host, self.port)
424 };
425 let addr: SocketAddr = addr_str.parse()?;
426 let _stream = TcpStream::connect(addr)?;
427 Ok(())
428 }
429}