1use std::{
5 io::{BufRead, BufReader, Read, Write},
6 net::{SocketAddr, TcpStream, ToSocketAddrs},
7 sync::{Arc, Mutex, mpsc},
8 thread::{self, JoinHandle},
9 time::Duration,
10};
11
12use serde_json;
13
14use crate::{
15 CommandRequest, CommandResponse, ErrResponse, QueryRequest, QueryResponse,
16 http::{message::HttpInternalMessage, worker::http_worker_thread},
17};
18
19#[derive(Clone)]
21pub struct HttpClient {
22 inner: Arc<HttpClientInner>,
23}
24
25pub(crate) struct HttpClientInner {
26 pub(crate) command_tx: mpsc::Sender<HttpInternalMessage>,
27 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
28}
29
30#[derive(Clone)]
32pub(crate) struct HttpClientConfig {
33 pub(crate) host: String,
34 pub(crate) port: u16,
35 pub(crate) _timeout: Duration,
36}
37
38impl Drop for HttpClient {
39 fn drop(&mut self) {
40 let _ = self.inner.command_tx.send(HttpInternalMessage::Close);
41 }
42}
43
44impl HttpClient {
45 pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn std::error::Error>> {
47 let socket_addr = addr.to_socket_addrs()?.next().ok_or("Failed to resolve address")?;
49
50 let host = socket_addr.ip().to_string();
51 let port = socket_addr.port();
52
53 let config = HttpClientConfig {
54 host,
55 port,
56 _timeout: Duration::from_secs(30),
57 };
58
59 Self::with_config(config)
60 }
61
62 pub fn from_url(url: &str) -> Result<Self, Box<dyn std::error::Error>> {
64 let url = if url.starts_with("http://") {
65 &url[7..] } else if url.starts_with("https://") {
67 return Err("HTTPS is not yet supported".into());
68 } else {
69 url
70 };
71
72 let (host, port) = if url.starts_with('[') {
74 if let Some(end_bracket) = url.find(']') {
76 let host = &url[1..end_bracket];
77 let port_str = &url[end_bracket + 1..];
78 let port = if port_str.starts_with(':') {
79 port_str[1..].parse()?
80 } else {
81 80
82 };
83 (host.to_string(), port)
84 } else {
85 return Err("Invalid IPv6 address format".into());
86 }
87 } else if url.starts_with("::") || url.contains("::") {
88 if let Some(port_idx) = url.rfind(':') {
91 if url[port_idx + 1..].chars().all(|c| c.is_ascii_digit()) {
94 let host = &url[..port_idx];
95 let port: u16 = url[port_idx + 1..].parse()?;
96 (host.to_string(), port)
97 } else {
98 (url.to_string(), 80)
100 }
101 } else {
102 (url.to_string(), 80)
103 }
104 } else {
105 if let Some(colon_idx) = url.find(':') {
107 let host = &url[..colon_idx];
108 let port: u16 = url[colon_idx + 1..].parse()?;
109 (host.to_string(), port)
110 } else {
111 (url.to_string(), 80)
112 }
113 };
114
115 Self::new((host.as_str(), port))
116 }
117
118 fn with_config(config: HttpClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
120 let (command_tx, command_rx) = mpsc::channel();
121
122 let test_config = config.clone();
124 test_config.test_connection()?;
125
126 let worker_config = config.clone();
128 let worker_handle = thread::spawn(move || {
129 http_worker_thread(worker_config, command_rx);
130 });
131
132 Ok(Self {
133 inner: Arc::new(HttpClientInner {
134 command_tx,
135 worker_handle: Arc::new(Mutex::new(Some(worker_handle))),
136 }),
137 })
138 }
139
140 pub(crate) fn command_tx(&self) -> &mpsc::Sender<HttpInternalMessage> {
142 &self.inner.command_tx
143 }
144
145 pub fn close(self) -> Result<(), Box<dyn std::error::Error>> {
147 self.inner.command_tx.send(HttpInternalMessage::Close)?;
148
149 if let Ok(mut handle_guard) = self.inner.worker_handle.lock() {
151 if let Some(handle) = handle_guard.take() {
152 let _ = handle.join();
153 }
154 }
155 Ok(())
156 }
157
158 pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
160 Ok(())
162 }
163
164 pub fn blocking_session(
166 &self,
167 token: Option<String>,
168 ) -> Result<crate::http::HttpBlockingSession, reifydb_type::Error> {
169 crate::http::HttpBlockingSession::from_client(self.clone(), token)
170 }
171
172 pub fn callback_session(
174 &self,
175 token: Option<String>,
176 ) -> Result<crate::http::HttpCallbackSession, reifydb_type::Error> {
177 crate::http::HttpCallbackSession::from_client(self.clone(), token)
178 }
179
180 pub fn channel_session(
182 &self,
183 token: Option<String>,
184 ) -> Result<
185 (crate::http::HttpChannelSession, mpsc::Receiver<crate::http::HttpResponseMessage>),
186 reifydb_type::Error,
187 > {
188 crate::http::HttpChannelSession::from_client(self.clone(), token)
189 }
190}
191
192impl HttpClientConfig {
193 pub fn send_command(&self, request: &CommandRequest) -> Result<CommandResponse, reifydb_type::Error> {
195 let json_body = serde_json::to_string(request).map_err(|e| {
196 reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
197 "Failed to serialize request: {}",
198 e
199 )))
200 })?;
201 let response_body = self.send_request("/v1/command", &json_body).map_err(|e| {
202 reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
203 })?;
204
205 match serde_json::from_str::<CommandResponse>(&response_body) {
207 Ok(response) => Ok(response),
208 Err(_) => {
209 match serde_json::from_str::<ErrResponse>(&response_body) {
211 Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
212 Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
213 format!("Failed to parse response: {}", response_body),
214 ))),
215 }
216 }
217 }
218 }
219
220 pub fn send_query(&self, request: &QueryRequest) -> Result<QueryResponse, reifydb_type::Error> {
222 let json_body = serde_json::to_string(request).map_err(|e| {
223 reifydb_type::Error(reifydb_type::diagnostic::internal(format!(
224 "Failed to serialize request: {}",
225 e
226 )))
227 })?;
228 let response_body = self.send_request("/v1/query", &json_body).map_err(|e| {
229 reifydb_type::Error(reifydb_type::diagnostic::internal(format!("Request failed: {}", e)))
230 })?;
231
232 match serde_json::from_str::<QueryResponse>(&response_body) {
234 Ok(response) => Ok(response),
235 Err(_) => {
236 match serde_json::from_str::<ErrResponse>(&response_body) {
238 Ok(err_response) => Err(reifydb_type::Error(err_response.diagnostic)),
239 Err(_) => Err(reifydb_type::Error(reifydb_type::diagnostic::internal(
240 format!("Failed to parse response: {}", response_body),
241 ))),
242 }
243 }
244 }
245 }
246
247 fn send_request(&self, path: &str, body: &str) -> Result<String, Box<dyn std::error::Error>> {
249 let addr_str = if self.host.contains(':') {
252 format!("[{}]:{}", self.host, self.port)
253 } else {
254 format!("{}:{}", self.host, self.port)
255 };
256 let addr: SocketAddr = addr_str.parse()?;
257
258 let mut stream = TcpStream::connect(addr)?;
260
261 let body_bytes = body.as_bytes();
263
264 let header = format!(
266 "POST {} HTTP/1.1\r\n\
267 Host: {}\r\n\
268 Content-Type: application/json\r\n\
269 Content-Length: {}\r\n\
270 Authorization: Bearer mysecrettoken\r\n\
271 Connection: close\r\n\
272 \r\n",
273 path,
274 self.host,
275 body_bytes.len()
276 );
277
278 stream.write_all(header.as_bytes())?;
280 stream.write_all(body_bytes)?;
281 stream.flush()?;
282
283 self.parse_http_response_buffered(stream)
285 }
286
287 fn parse_http_response_buffered(&self, stream: TcpStream) -> Result<String, Box<dyn std::error::Error>> {
289 let mut reader = BufReader::new(stream);
290 let mut line = String::new();
291
292 reader.read_line(&mut line)?;
294 let status_line = line.trim_end();
295 let status_parts: Vec<&str> = status_line.split_whitespace().collect();
296
297 if status_parts.len() < 3 {
298 return Err("Invalid HTTP status line".into());
299 }
300
301 let mut content_length: Option<usize> = None;
303 let mut is_chunked = false;
304
305 loop {
306 line.clear();
307 reader.read_line(&mut line)?;
308
309 if line == "\r\n" || line == "\n" {
310 break; }
312
313 if let Some(colon_pos) = line.find(':') {
314 let key = line[..colon_pos].trim().to_lowercase();
315 let value = line[colon_pos + 1..].trim();
316
317 if key == "content-length" {
318 content_length = value.parse().ok();
319 } else if key == "transfer-encoding" && value.contains("chunked") {
320 is_chunked = true;
321 }
322 }
323 }
324
325 let body = if is_chunked {
327 self.read_chunked_body(&mut reader)?
328 } else if let Some(length) = content_length {
329 let mut body = vec![0u8; length];
331 reader.read_exact(&mut body)?;
332 String::from_utf8(body)?
333 } else {
334 let mut body = String::new();
336 reader.read_to_string(&mut body)?;
337 body
338 };
339
340 Ok(body)
341 }
342
343 fn read_chunked_body(&self, reader: &mut BufReader<TcpStream>) -> Result<String, Box<dyn std::error::Error>> {
345 let mut result = Vec::new();
346 let mut line = String::new();
347
348 loop {
349 line.clear();
351 reader.read_line(&mut line)?;
352
353 let size_str = line.trim().split(';').next().unwrap_or("0");
356 let chunk_size = usize::from_str_radix(size_str, 16)?;
357
358 if chunk_size == 0 {
359 loop {
361 line.clear();
362 reader.read_line(&mut line)?;
363 if line == "\r\n" || line == "\n" {
364 break;
365 }
366 }
367 break;
368 }
369
370 let mut chunk = vec![0u8; chunk_size];
372 reader.read_exact(&mut chunk)?;
373 result.extend_from_slice(&chunk);
374
375 line.clear();
377 reader.read_line(&mut line)?;
378 }
379
380 String::from_utf8(result).map_err(|e| e.into())
381 }
382
383 pub fn test_connection(&self) -> Result<(), Box<dyn std::error::Error>> {
385 let addr_str = if self.host.contains(':') {
387 format!("[{}]:{}", self.host, self.port)
388 } else {
389 format!("{}:{}", self.host, self.port)
390 };
391 let addr: SocketAddr = addr_str.parse()?;
392 let _stream = TcpStream::connect(addr)?;
393 Ok(())
394 }
395}