reifydb_sub_server/protocols/http/
handler.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4use std::io::{Read, Write};
5
6use reifydb_core::interface::{Engine, Identity, Params};
7use reifydb_type::diagnostic::Diagnostic;
8
9use super::{
10	HttpConnectionData, HttpState, ResponseType,
11	command::{CommandHandlerResult, handle_v1_command},
12	query::{QueryHandlerResult, handle_v1_query},
13};
14use crate::{
15	core::{Connection, ConnectionState},
16	protocols::{
17		ProtocolError, ProtocolHandler, ProtocolResult,
18		ws::{CommandRequest, ErrorResponse, QueryRequest},
19	},
20};
21
22#[derive(Clone)]
23pub struct HttpHandler;
24
25impl HttpHandler {
26	pub fn new() -> Self {
27		Self
28	}
29
30	/// Parse HTTP request headers
31	fn parse_request(
32		&self,
33		data: &[u8],
34	) -> Result<(String, String, std::collections::HashMap<String, String>), String> {
35		let request_str = String::from_utf8_lossy(data);
36		let lines: Vec<&str> = request_str.lines().collect();
37
38		if lines.is_empty() {
39			return Err("Empty request".to_string());
40		}
41
42		// Parse request line (GET /path HTTP/1.1)
43		let request_parts: Vec<&str> = lines[0].split_whitespace().collect();
44		if request_parts.len() != 3 {
45			return Err("Invalid request line".to_string());
46		}
47
48		let method = request_parts[0].to_string();
49		let path = request_parts[1].to_string();
50
51		// Parse headers
52		let mut headers = std::collections::HashMap::new();
53		for line in &lines[1..] {
54			if line.is_empty() {
55				break;
56			}
57			if let Some(colon_pos) = line.find(':') {
58				let key = line[..colon_pos].trim().to_lowercase();
59				let value = line[colon_pos + 1..].trim().to_string();
60				headers.insert(key, value);
61			}
62		}
63
64		Ok((method, path, headers))
65	}
66
67	/// Build HTTP response
68	fn build_response(
69		&self,
70		status_code: u16,
71		status_text: &str,
72		body: &str,
73		headers: Option<&std::collections::HashMap<String, String>>,
74	) -> String {
75		let mut response = format!("HTTP/1.1 {} {}\r\n", status_code, status_text);
76
77		// Add default headers - use byte length for Content-Length
78		response.push_str(&format!("Content-Length: {}\r\n", body.as_bytes().len()));
79		response.push_str("Content-Type: application/json\r\n");
80		response.push_str("Connection: close\r\n");
81
82		// Add custom headers if provided
83		if let Some(custom_headers) = headers {
84			for (key, value) in custom_headers {
85				response.push_str(&format!("{}: {}\r\n", key, value));
86			}
87		}
88
89		response.push_str("\r\n");
90		response.push_str(body);
91
92		response
93	}
94
95	/// Handle query execution for HTTP requests
96	fn handle_query(&self, conn: &Connection, query: &str) -> Result<String, String> {
97		match conn.engine().query_as(
98			&Identity::System {
99				id: 1,
100				name: "root".to_string(),
101			},
102			query,
103			Params::None,
104		) {
105			Ok(result) => {
106				let response_body = serde_json::json!({
107				    "success": true,
108				    "data": format!("Query executed successfully, {} frames returned", result.len()),
109				    "results": result.len()
110				});
111				Ok(response_body.to_string())
112			}
113			Err(e) => {
114				let error_body = serde_json::json!({
115				    "success": false,
116				    "error": format!("Query error: {}", e)
117				});
118				Ok(error_body.to_string())
119			}
120		}
121	}
122}
123
124impl ProtocolHandler for HttpHandler {
125	fn name(&self) -> &'static str {
126		"http"
127	}
128
129	fn can_handle(&self, buffer: &[u8]) -> bool {
130		// Check for HTTP request signature
131		if buffer.len() < 16 {
132			return false;
133		}
134
135		let request = String::from_utf8_lossy(buffer);
136		request.starts_with("GET ")
137			|| request.starts_with("POST ")
138			|| request.starts_with("PUT ")
139			|| request.starts_with("DELETE ")
140			|| request.starts_with("HEAD ")
141			|| request.starts_with("OPTIONS ")
142	}
143
144	fn handle_connection(&self, conn: &mut Connection) -> ProtocolResult<()> {
145		// Initialize HTTP state
146		let http_state = HttpState::ReadingRequest(HttpConnectionData::new());
147		conn.set_state(ConnectionState::Http(http_state));
148		Ok(())
149	}
150
151	fn handle_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
152		if let ConnectionState::Http(http_state) = conn.state() {
153			match http_state {
154				HttpState::ReadingRequest(_) => self.handle_request_read(conn),
155				HttpState::Processing(_) => {
156					Ok(()) // No additional reading needed during processing
157				}
158				HttpState::ProcessingQuery {
159					..
160				} => {
161					Ok(()) // No additional reading needed while query is processing
162				}
163				HttpState::WritingResponse(_) => {
164					Ok(()) // No reading during response writing
165				}
166				HttpState::Closed => Ok(()),
167			}
168		} else {
169			Err(ProtocolError::InvalidFrame)
170		}
171	}
172
173	fn handle_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
174		if let ConnectionState::Http(http_state) = conn.state() {
175			match http_state {
176				HttpState::ReadingRequest(_) => {
177					Ok(()) // No writing during request reading
178				}
179				HttpState::Processing(_) => {
180					Ok(()) // No writing during processing
181				}
182				HttpState::ProcessingQuery {
183					..
184				} => {
185					Ok(()) // No writing while query is processing
186				}
187				HttpState::WritingResponse(_) => self.handle_response_write(conn),
188				HttpState::Closed => Ok(()),
189			}
190		} else {
191			Err(ProtocolError::InvalidFrame)
192		}
193	}
194
195	fn should_close(&self, conn: &Connection) -> bool {
196		matches!(
197			conn.state(),
198			crate::core::ConnectionState::Http(HttpState::Closed) | crate::core::ConnectionState::Closed
199		)
200	}
201}
202
203impl HttpHandler {
204	fn handle_request_read(&self, conn: &mut Connection) -> ProtocolResult<()> {
205		// Check if we already found headers and are waiting for body
206		let header_end = if let ConnectionState::Http(HttpState::ReadingRequest(data)) = conn.state() {
207			data.header_end
208		} else {
209			None
210		};
211
212		// If we haven't found headers yet, look for them
213		let header_end = if header_end.is_none() {
214			// First, check any data already in the buffer
215			if !conn.buffer().is_empty() {
216				if let Some(end) = self.find_header_end(conn.buffer()) {
217					// Store the header end position
218					if let ConnectionState::Http(HttpState::ReadingRequest(data)) = conn.state_mut()
219					{
220						data.header_end = Some(end);
221					}
222					Some(end)
223				} else {
224					None
225				}
226			} else {
227				None
228			}
229		} else {
230			header_end
231		};
232
233		// Read more data if needed
234		let mut buf = [0u8; 4096];
235		loop {
236			match conn.stream().read(&mut buf) {
237				Ok(0) => {
238					// Connection closed
239					if conn.buffer().is_empty() {
240						return Err(ProtocolError::ConnectionClosed);
241					}
242					break; // Process what we have
243				}
244				Ok(n) => {
245					// Add data to connection buffer
246					conn.buffer_mut().extend_from_slice(&buf[..n]);
247
248					// If we don't have headers yet, look for them
249					if header_end.is_none() {
250						if let Some(end) = self.find_header_end(conn.buffer()) {
251							// Store the header end position
252							if let ConnectionState::Http(HttpState::ReadingRequest(data)) =
253								conn.state_mut()
254							{
255								data.header_end = Some(end);
256							}
257							// Try to process the complete request
258							return self.process_http_request(conn, end);
259						}
260					}
261				}
262				Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
263					break; // No more data available now
264				}
265				Err(e) => return Err(ProtocolError::Io(e)),
266			}
267		}
268
269		// If we have headers, try to process the request
270		if let Some(end) = header_end {
271			// Try to process - this will check if we have enough
272			// body data
273			self.process_http_request(conn, end)?;
274		} else if let ConnectionState::Http(HttpState::ReadingRequest(data)) = conn.state() {
275			// Check again if we have headers after reading
276			if let Some(end) = data.header_end {
277				self.process_http_request(conn, end)?;
278			}
279		}
280
281		Ok(())
282	}
283
284	fn handle_response_write(&self, conn: &mut Connection) -> ProtocolResult<()> {
285		// Extract response data to avoid borrowing conflicts
286		let (response_data, bytes_written, keep_alive) =
287			if let ConnectionState::Http(HttpState::WritingResponse(data)) = conn.state() {
288				(data.response_buffer.clone(), data.bytes_written, data.keep_alive)
289			} else {
290				return Ok(());
291			};
292
293		let mut total_written = bytes_written;
294
295		loop {
296			if total_written >= response_data.len() {
297				// Response completely written
298				if keep_alive {
299					// Reset to reading state for keep-alive
300					let new_data = HttpConnectionData::new();
301					conn.set_state(ConnectionState::Http(HttpState::ReadingRequest(new_data)));
302				} else {
303					// Close connection
304					conn.set_state(ConnectionState::Http(HttpState::Closed));
305				}
306				break;
307			}
308
309			match conn.stream().write(&response_data[total_written..]) {
310				Ok(0) => return Err(ProtocolError::ConnectionClosed),
311				Ok(n) => {
312					total_written += n;
313					// Update bytes written in state
314					if let ConnectionState::Http(HttpState::WritingResponse(data)) =
315						conn.state_mut()
316					{
317						data.bytes_written = total_written;
318					}
319				}
320				Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
321				Err(e) => return Err(ProtocolError::Io(e)),
322			}
323		}
324		Ok(())
325	}
326
327	/// Find the end of HTTP headers (\r\n\r\n)
328	fn find_header_end(&self, buffer: &[u8]) -> Option<usize> {
329		for i in 0..buffer.len().saturating_sub(3) {
330			if buffer[i] == b'\r'
331				&& buffer[i + 1] == b'\n' && buffer[i + 2] == b'\r'
332				&& buffer[i + 3] == b'\n'
333			{
334				return Some(i);
335			}
336		}
337		None
338	}
339
340	/// Process a complete HTTP request
341	fn process_http_request(&self, conn: &mut Connection, header_end: usize) -> ProtocolResult<()> {
342		// Parse the request
343		let (method, path, headers) = self
344			.parse_request(&conn.buffer()[..header_end])
345			.map_err(|e| ProtocolError::Custom(format!("Parse error: {}", e)))?;
346
347		// Calculate content length for POST requests
348		let content_length: usize = headers.get("content-length").and_then(|v| v.parse().ok()).unwrap_or(0);
349
350		let body_start = header_end + 4; // Skip \r\n\r\n
351		let total_needed = body_start + content_length;
352
353		// Check if we have the complete request (headers + body)
354		if method == "POST" && conn.buffer().len() < total_needed {
355			// We don't have the full body yet, keep waiting
356			return Ok(());
357		}
358
359		// Process the request based on method and path
360		let response_body = match (&method[..], &path[..]) {
361			("GET", "/health") => serde_json::json!({"status": "ok", "service": "reifydb"}).to_string(),
362			("POST", "/query") => {
363				// Body is guaranteed to be complete at this point
364				let body = &conn.buffer()[body_start..body_start + content_length];
365				let body_str = String::from_utf8_lossy(body);
366
367				// Try to parse JSON body for query
368				if let Ok(query_json) = serde_json::from_str::<serde_json::Value>(&body_str) {
369					if let Some(query) = query_json.get("query").and_then(|q| q.as_str()) {
370						self.handle_query(conn, query).map_err(|e| ProtocolError::Custom(e))?
371					} else {
372						serde_json::json!({"error": "Missing 'query' field in request body"})
373							.to_string()
374					}
375				} else {
376					serde_json::json!({"error": "Invalid JSON in request body"}).to_string()
377				}
378			}
379			("POST", "/v1/command") => {
380				// Body is guaranteed to be complete at this point
381				let body = &conn.buffer()[body_start..body_start + content_length];
382				let body_str = String::from_utf8_lossy(body);
383
384				match serde_json::from_str::<CommandRequest>(&body_str) {
385					Ok(cmd_req) => match handle_v1_command(conn, &cmd_req) {
386						CommandHandlerResult::Immediate(Ok(response)) => {
387							serde_json::to_string(&response).map_err(|e| {
388								ProtocolError::Custom(format!(
389									"Serialization error: {}",
390									e
391								))
392							})?
393						}
394						CommandHandlerResult::Immediate(Err(error_response)) => {
395							serde_json::to_string(&error_response).map_err(|e| {
396								ProtocolError::Custom(format!(
397									"Serialization error: {}",
398									e
399								))
400							})?
401						}
402						CommandHandlerResult::Pending => {
403							// Transition to ProcessingQuery state - response will come
404							// later
405							let current_state = if let ConnectionState::Http(
406								HttpState::ReadingRequest(data),
407							) = conn.state()
408							{
409								data.clone()
410							} else {
411								HttpConnectionData::new()
412							};
413
414							conn.set_state(ConnectionState::Http(
415								HttpState::ProcessingQuery {
416									original_request: current_state,
417									response_type: ResponseType::Command,
418								},
419							));
420
421							// Return empty for now - response will be handled in event loop
422							return Ok(());
423						}
424					},
425					Err(e) => {
426						let error_response = ErrorResponse {
427							diagnostic: Diagnostic {
428								code: "INVALID_JSON".to_string(),
429								message: format!("Invalid CommandRequest JSON: {}", e),
430								..Default::default()
431							},
432						};
433						serde_json::to_string(&error_response).map_err(|e| {
434							ProtocolError::Custom(format!("Serialization error: {}", e))
435						})?
436					}
437				}
438			}
439			("POST", "/v1/query") => {
440				// Body is guaranteed to be complete at this point
441				let body = &conn.buffer()[body_start..body_start + content_length];
442				let body_str = String::from_utf8_lossy(body);
443
444				match serde_json::from_str::<QueryRequest>(&body_str) {
445					Ok(query_req) => match handle_v1_query(conn, &query_req) {
446						QueryHandlerResult::Immediate(Ok(response)) => {
447							serde_json::to_string(&response).map_err(|e| {
448								ProtocolError::Custom(format!(
449									"Serialization error: {}",
450									e
451								))
452							})?
453						}
454						QueryHandlerResult::Immediate(Err(error_response)) => {
455							serde_json::to_string(&error_response).map_err(|e| {
456								ProtocolError::Custom(format!(
457									"Serialization error: {}",
458									e
459								))
460							})?
461						}
462						QueryHandlerResult::Pending => {
463							// Transition to ProcessingQuery state - response will come
464							// later
465							let current_state = if let ConnectionState::Http(
466								HttpState::ReadingRequest(data),
467							) = conn.state()
468							{
469								data.clone()
470							} else {
471								HttpConnectionData::new()
472							};
473
474							conn.set_state(ConnectionState::Http(
475								HttpState::ProcessingQuery {
476									original_request: current_state,
477									response_type: ResponseType::Query,
478								},
479							));
480
481							// Return empty for now - response will be handled in event loop
482							return Ok(());
483						}
484					},
485					Err(e) => {
486						let error_response = ErrorResponse {
487							diagnostic: Diagnostic {
488								code: "INVALID_JSON".to_string(),
489								message: format!("Invalid QueryRequest JSON: {}", e),
490								..Default::default()
491							},
492						};
493						serde_json::to_string(&error_response).map_err(|e| {
494							ProtocolError::Custom(format!("Serialization error: {}", e))
495						})?
496					}
497				}
498			}
499			("GET", path) if path.starts_with("/query?") => {
500				// Handle query via GET parameters
501				if let Some(query_start) = path.find("q=") {
502					let query_param = &path[query_start + 2..];
503					let query = urlencoding::decode(query_param)
504						.map_err(|_| ProtocolError::Custom("Invalid URL encoding".to_string()))?
505						.to_string();
506					self.handle_query(conn, &query).map_err(|e| ProtocolError::Custom(e))?
507				} else {
508					serde_json::json!({"error": "Missing 'q' query parameter"}).to_string()
509				}
510			}
511			_ => serde_json::json!({"error": "Not found", "path": path, "method": method}).to_string(),
512		};
513
514		// Build HTTP response
515		let response = if path == "/health"
516			|| (method == "POST" && path == "/query")
517			|| (method == "POST" && path == "/v1/command")
518			|| (method == "POST" && path == "/v1/query")
519			|| path.starts_with("/query?")
520		{
521			self.build_response(200, "OK", &response_body, None)
522		} else {
523			self.build_response(404, "Not Found", &response_body, None)
524		};
525
526		// Clear the processed request from the buffer
527		let bytes_consumed = if method == "POST" {
528			body_start + content_length
529		} else {
530			header_end + 4 // Just headers + \r\n\r\n
531		};
532
533		// Remove processed data from buffer
534		conn.buffer_mut().drain(0..bytes_consumed);
535
536		// Update state to writing response
537		let mut response_data = HttpConnectionData::new();
538		response_data.response_buffer = response.into_bytes();
539		response_data.keep_alive =
540			headers.get("connection").map(|v| v.to_lowercase() == "keep-alive").unwrap_or(false);
541
542		conn.set_state(ConnectionState::Http(HttpState::WritingResponse(response_data)));
543
544		Ok(())
545	}
546}