1use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::backend::client::QueryResult;
23use crate::backend::types::TextValue;
24use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode};
25use crate::config::HttpGatewayConfig;
26use crate::{ProxyError, Result};
27
28pub struct HttpGateway {
29 config: HttpGatewayConfig,
30}
31
32impl HttpGateway {
33 pub fn new(config: HttpGatewayConfig) -> Self {
34 Self { config }
35 }
36
37 pub async fn run(self) -> Result<()> {
38 let listener = TcpListener::bind(&self.config.listen_address).await.map_err(|e| {
39 ProxyError::Network(format!("HTTP gateway bind {}: {}", self.config.listen_address, e))
40 })?;
41 tracing::info!(addr = %self.config.listen_address, "HTTP SQL gateway listening");
42 let cfg = Arc::new(self.config);
43 loop {
44 let (stream, peer) = match listener.accept().await {
45 Ok(x) => x,
46 Err(e) => {
47 tracing::warn!("HTTP gateway accept error: {}", e);
48 continue;
49 }
50 };
51 let cfg = cfg.clone();
52 tokio::spawn(async move {
53 if let Err(e) = Self::handle(stream, cfg).await {
54 tracing::debug!(%peer, "HTTP gateway error: {}", e);
55 }
56 });
57 }
58 }
59
60 async fn handle(mut stream: tokio::net::TcpStream, cfg: Arc<HttpGatewayConfig>) -> Result<()> {
61 use tokio::io::AsyncBufReadExt;
62 let (reader, mut writer) = stream.split();
63 let mut reader = BufReader::new(reader);
64 let mut line = String::new();
65 let mut content_length = 0usize;
66 let mut method = String::new();
67 let mut path = String::new();
68 let mut authorized = cfg.auth_token.is_none();
69 let mut array_mode = false;
70 let mut first = true;
71 loop {
72 line.clear();
73 let n = reader
74 .read_line(&mut line)
75 .await
76 .map_err(|e| ProxyError::Network(format!("HTTP gw read: {}", e)))?;
77 if n == 0 || line == "\r\n" {
78 break;
79 }
80 if first {
81 let mut parts = line.split_whitespace();
82 method = parts.next().unwrap_or("").to_string();
83 path = parts.next().unwrap_or("").to_string();
84 first = false;
85 continue;
86 }
87 let lower = line.to_ascii_lowercase();
88 if lower.starts_with("content-length:") {
89 content_length = line.split(':').nth(1).and_then(|v| v.trim().parse().ok()).unwrap_or(0);
90 } else if lower.starts_with("neon-array-mode:") {
91 array_mode = line.split(':').nth(1).map(|v| v.trim().eq_ignore_ascii_case("true")).unwrap_or(false);
92 } else if lower.starts_with("authorization:") {
93 if let Some(tok) = cfg.auth_token.as_ref() {
94 let v = line.splitn(2, ':').nth(1).unwrap_or("").trim();
95 authorized = v == format!("Bearer {}", tok);
96 }
97 }
98 }
99
100 if method == "GET" && (path == "/health" || path == "/") {
102 return Self::respond(&mut writer, 200, &json!({"status":"ok"})).await;
103 }
104 if !authorized {
105 return Self::respond(&mut writer, 401, &json!({"error":"unauthorized"})).await;
106 }
107 if method != "POST" {
108 return Self::respond(&mut writer, 405, &json!({"error":"use POST /sql"})).await;
109 }
110
111 let mut body_buf = vec![0u8; content_length];
112 if content_length > 0 {
113 reader
114 .read_exact(&mut body_buf)
115 .await
116 .map_err(|e| ProxyError::Network(format!("HTTP gw body: {}", e)))?;
117 }
118 let req: Value = match serde_json::from_slice(&body_buf) {
119 Ok(v) => v,
120 Err(e) => return Self::respond(&mut writer, 400, &json!({"error": format!("invalid JSON: {}", e)})).await,
121 };
122 let sql = req.get("query").and_then(|q| q.as_str()).unwrap_or("").trim();
123 if sql.is_empty() {
124 return Self::respond(&mut writer, 400, &json!({"error":"missing 'query'"})).await;
125 }
126 let params = parse_params(req.get("params"));
127
128 match Self::run_sql(&cfg, sql, ¶ms).await {
129 Ok(qr) => {
130 let body = neon_result(&qr, array_mode);
131 Self::respond(&mut writer, 200, &body).await
132 }
133 Err(e) => Self::respond(&mut writer, 400, &json!({ "error": e })).await,
134 }
135 }
136
137 async fn run_sql(
138 cfg: &HttpGatewayConfig,
139 sql: &str,
140 params: &[ParamValue],
141 ) -> std::result::Result<QueryResult, String> {
142 let bcfg = BackendConfig {
143 host: cfg.backend_host.clone(),
144 port: cfg.backend_port,
145 user: cfg.backend_user.clone(),
146 password: cfg.backend_password.clone(),
147 database: cfg.backend_database.clone(),
148 application_name: Some("heliosproxy-http".to_string()),
149 tls_mode: TlsMode::Disable,
150 connect_timeout: Duration::from_secs(5),
151 query_timeout: Duration::from_secs(30),
152 tls_config: default_client_config(),
153 };
154 let mut client = BackendClient::connect(&bcfg)
155 .await
156 .map_err(|e| format!("backend connect: {}", e))?;
157 let res = if params.is_empty() {
158 client.simple_query(sql).await
159 } else {
160 client.query_with_params(sql, params).await
161 };
162 client.close().await;
163 res.map_err(|e| format!("{}", e))
164 }
165
166 async fn respond(
167 writer: &mut tokio::net::tcp::WriteHalf<'_>,
168 status: u16,
169 body: &Value,
170 ) -> Result<()> {
171 let payload = serde_json::to_vec(body).unwrap_or_default();
172 let status_text = match status {
173 200 => "OK",
174 400 => "Bad Request",
175 401 => "Unauthorized",
176 405 => "Method Not Allowed",
177 _ => "Error",
178 };
179 let head = format!(
180 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
181 status, status_text, payload.len()
182 );
183 writer.write_all(head.as_bytes()).await.map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
184 writer.write_all(&payload).await.map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
185 Ok(())
186 }
187}
188
189fn parse_params(v: Option<&Value>) -> Vec<ParamValue> {
191 match v.and_then(|v| v.as_array()) {
192 None => Vec::new(),
193 Some(arr) => arr
194 .iter()
195 .map(|p| match p {
196 Value::Null => ParamValue::Null,
197 Value::Bool(b) => ParamValue::Bool(*b),
198 Value::Number(n) => {
199 if let Some(i) = n.as_i64() {
200 ParamValue::Int(i)
201 } else {
202 ParamValue::Float(n.as_f64().unwrap_or(0.0))
203 }
204 }
205 Value::String(s) => ParamValue::Text(s.clone()),
206 other => ParamValue::Text(other.to_string()),
207 })
208 .collect(),
209 }
210}
211
212fn neon_result(qr: &QueryResult, array_mode: bool) -> Value {
214 let command = qr
215 .command_tag
216 .split_whitespace()
217 .next()
218 .unwrap_or("")
219 .to_string();
220 let fields: Vec<Value> = qr
221 .columns
222 .iter()
223 .map(|c| json!({ "name": c.name, "dataTypeID": c.type_oid }))
224 .collect();
225 let rows: Vec<Value> = qr
226 .rows
227 .iter()
228 .map(|row| {
229 if array_mode {
230 Value::Array(row.iter().map(cell_to_json).collect())
231 } else {
232 let mut obj = serde_json::Map::new();
233 for (i, c) in qr.columns.iter().enumerate() {
234 let v = row.get(i).map(cell_to_json).unwrap_or(Value::Null);
235 obj.insert(c.name.clone(), v);
236 }
237 Value::Object(obj)
238 }
239 })
240 .collect();
241 let row_count = qr.rows_affected().unwrap_or(qr.rows.len() as u64);
243 json!({
244 "command": command,
245 "rowCount": row_count,
246 "rows": rows,
247 "fields": fields,
248 "rowAsArray": array_mode,
249 })
250}
251
252fn cell_to_json(v: &TextValue) -> Value {
253 match v {
254 TextValue::Null => Value::Null,
255 TextValue::Text(s) => Value::String(s.clone()),
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::backend::client::ColumnMeta;
263
264 fn qr() -> QueryResult {
265 QueryResult {
266 columns: vec![
267 ColumnMeta { name: "id".into(), type_oid: 23 },
268 ColumnMeta { name: "name".into(), type_oid: 25 },
269 ],
270 rows: vec![
271 vec![TextValue::Text("1".into()), TextValue::Text("alice".into())],
272 vec![TextValue::Text("2".into()), TextValue::Null],
273 ],
274 command_tag: "SELECT 2".into(),
275 }
276 }
277
278 #[test]
279 fn neon_object_mode() {
280 let v = neon_result(&qr(), false);
281 assert_eq!(v["command"], "SELECT");
282 assert_eq!(v["rowCount"], 2);
283 assert_eq!(v["rows"][0]["id"], "1");
284 assert_eq!(v["rows"][0]["name"], "alice");
285 assert_eq!(v["rows"][1]["name"], Value::Null);
286 assert_eq!(v["fields"][0]["name"], "id");
287 assert_eq!(v["fields"][0]["dataTypeID"], 23);
288 }
289
290 #[test]
291 fn neon_array_mode() {
292 let v = neon_result(&qr(), true);
293 assert_eq!(v["rowAsArray"], true);
294 assert_eq!(v["rows"][0][0], "1");
295 assert_eq!(v["rows"][0][1], "alice");
296 }
297
298 #[test]
299 fn params_mapping() {
300 let p = parse_params(Some(&json!([1, "x", true, null, 2.5])));
301 assert!(matches!(p[0], ParamValue::Int(1)));
302 assert!(matches!(p[1], ParamValue::Text(ref s) if s == "x"));
303 assert!(matches!(p[2], ParamValue::Bool(true)));
304 assert!(matches!(p[3], ParamValue::Null));
305 assert!(matches!(p[4], ParamValue::Float(_)));
306 }
307}