1use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::backend::client::QueryResult;
23use crate::backend::types::TextValue;
24use crate::backend::{
25 tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode,
26};
27use crate::config::HttpGatewayConfig;
28use crate::{ProxyError, Result};
29
30pub struct HttpGateway {
31 config: HttpGatewayConfig,
32}
33
34impl HttpGateway {
35 pub fn new(config: HttpGatewayConfig) -> Self {
36 Self { config }
37 }
38
39 pub async fn run(self) -> Result<()> {
40 let listener = TcpListener::bind(&self.config.listen_address)
41 .await
42 .map_err(|e| {
43 ProxyError::Network(format!(
44 "HTTP gateway bind {}: {}",
45 self.config.listen_address, e
46 ))
47 })?;
48 tracing::info!(addr = %self.config.listen_address, "HTTP SQL gateway listening");
49 let cfg = Arc::new(self.config);
50 loop {
51 let (stream, peer) = match listener.accept().await {
52 Ok(x) => x,
53 Err(e) => {
54 tracing::warn!("HTTP gateway accept error: {}", e);
55 continue;
56 }
57 };
58 let cfg = cfg.clone();
59 tokio::spawn(async move {
60 if let Err(e) = Self::handle(stream, cfg).await {
61 tracing::debug!(%peer, "HTTP gateway error: {}", e);
62 }
63 });
64 }
65 }
66
67 async fn handle(mut stream: tokio::net::TcpStream, cfg: Arc<HttpGatewayConfig>) -> Result<()> {
68 use tokio::io::AsyncBufReadExt;
69 let (reader, mut writer) = stream.split();
70 let mut reader = BufReader::new(reader);
71 let mut line = String::new();
72 let mut content_length = 0usize;
73 let mut method = String::new();
74 let mut path = String::new();
75 let mut authorized = cfg.auth_token.is_none();
76 let mut array_mode = false;
77 let mut first = true;
78 loop {
79 line.clear();
80 let n = reader
81 .read_line(&mut line)
82 .await
83 .map_err(|e| ProxyError::Network(format!("HTTP gw read: {}", e)))?;
84 if n == 0 || line == "\r\n" {
85 break;
86 }
87 if first {
88 let mut parts = line.split_whitespace();
89 method = parts.next().unwrap_or("").to_string();
90 path = parts.next().unwrap_or("").to_string();
91 first = false;
92 continue;
93 }
94 let lower = line.to_ascii_lowercase();
95 if lower.starts_with("content-length:") {
96 content_length = line
97 .split(':')
98 .nth(1)
99 .and_then(|v| v.trim().parse().ok())
100 .unwrap_or(0);
101 } else if lower.starts_with("neon-array-mode:") {
102 array_mode = line
103 .split(':')
104 .nth(1)
105 .map(|v| v.trim().eq_ignore_ascii_case("true"))
106 .unwrap_or(false);
107 } else if lower.starts_with("authorization:") {
108 if let Some(tok) = cfg.auth_token.as_ref() {
109 let v = line.split_once(':').map(|x| x.1).unwrap_or("").trim();
110 authorized = v == format!("Bearer {}", tok);
111 }
112 }
113 }
114
115 if method == "GET" && (path == "/health" || path == "/") {
117 return Self::respond(&mut writer, 200, &json!({"status":"ok"})).await;
118 }
119 if !authorized {
120 return Self::respond(&mut writer, 401, &json!({"error":"unauthorized"})).await;
121 }
122 if method != "POST" {
123 return Self::respond(&mut writer, 405, &json!({"error":"use POST /sql"})).await;
124 }
125
126 let mut body_buf = vec![0u8; content_length];
127 if content_length > 0 {
128 reader
129 .read_exact(&mut body_buf)
130 .await
131 .map_err(|e| ProxyError::Network(format!("HTTP gw body: {}", e)))?;
132 }
133 let req: Value = match serde_json::from_slice(&body_buf) {
134 Ok(v) => v,
135 Err(e) => {
136 return Self::respond(
137 &mut writer,
138 400,
139 &json!({"error": format!("invalid JSON: {}", e)}),
140 )
141 .await
142 }
143 };
144 let sql = req
145 .get("query")
146 .and_then(|q| q.as_str())
147 .unwrap_or("")
148 .trim();
149 if sql.is_empty() {
150 return Self::respond(&mut writer, 400, &json!({"error":"missing 'query'"})).await;
151 }
152 let params = parse_params(req.get("params"));
153
154 match Self::run_sql(&cfg, sql, ¶ms).await {
155 Ok(qr) => {
156 let body = neon_result(&qr, array_mode);
157 Self::respond(&mut writer, 200, &body).await
158 }
159 Err(e) => Self::respond(&mut writer, 400, &json!({ "error": e })).await,
160 }
161 }
162
163 async fn run_sql(
164 cfg: &HttpGatewayConfig,
165 sql: &str,
166 params: &[ParamValue],
167 ) -> std::result::Result<QueryResult, String> {
168 let bcfg = BackendConfig {
169 host: cfg.backend_host.clone(),
170 port: cfg.backend_port,
171 user: cfg.backend_user.clone(),
172 password: cfg.backend_password.clone(),
173 database: cfg.backend_database.clone(),
174 application_name: Some("heliosproxy-http".to_string()),
175 tls_mode: TlsMode::Disable,
176 connect_timeout: Duration::from_secs(5),
177 query_timeout: Duration::from_secs(30),
178 tls_config: default_client_config(),
179 };
180 let mut client = BackendClient::connect(&bcfg)
181 .await
182 .map_err(|e| format!("backend connect: {}", e))?;
183 let res = if params.is_empty() {
184 client.simple_query(sql).await
185 } else {
186 client.query_with_params(sql, params).await
187 };
188 client.close().await;
189 res.map_err(|e| format!("{}", e))
190 }
191
192 async fn respond(
193 writer: &mut tokio::net::tcp::WriteHalf<'_>,
194 status: u16,
195 body: &Value,
196 ) -> Result<()> {
197 let payload = serde_json::to_vec(body).unwrap_or_default();
198 let status_text = match status {
199 200 => "OK",
200 400 => "Bad Request",
201 401 => "Unauthorized",
202 405 => "Method Not Allowed",
203 _ => "Error",
204 };
205 let head = format!(
206 "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
207 status, status_text, payload.len()
208 );
209 writer
210 .write_all(head.as_bytes())
211 .await
212 .map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
213 writer
214 .write_all(&payload)
215 .await
216 .map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
217 Ok(())
218 }
219}
220
221fn parse_params(v: Option<&Value>) -> Vec<ParamValue> {
223 match v.and_then(|v| v.as_array()) {
224 None => Vec::new(),
225 Some(arr) => arr
226 .iter()
227 .map(|p| match p {
228 Value::Null => ParamValue::Null,
229 Value::Bool(b) => ParamValue::Bool(*b),
230 Value::Number(n) => {
231 if let Some(i) = n.as_i64() {
232 ParamValue::Int(i)
233 } else {
234 ParamValue::Float(n.as_f64().unwrap_or(0.0))
235 }
236 }
237 Value::String(s) => ParamValue::Text(s.clone()),
238 other => ParamValue::Text(other.to_string()),
239 })
240 .collect(),
241 }
242}
243
244fn neon_result(qr: &QueryResult, array_mode: bool) -> Value {
246 let command = qr
247 .command_tag
248 .split_whitespace()
249 .next()
250 .unwrap_or("")
251 .to_string();
252 let fields: Vec<Value> = qr
253 .columns
254 .iter()
255 .map(|c| json!({ "name": c.name, "dataTypeID": c.type_oid }))
256 .collect();
257 let rows: Vec<Value> = qr
258 .rows
259 .iter()
260 .map(|row| {
261 if array_mode {
262 Value::Array(row.iter().map(cell_to_json).collect())
263 } else {
264 let mut obj = serde_json::Map::new();
265 for (i, c) in qr.columns.iter().enumerate() {
266 let v = row.get(i).map(cell_to_json).unwrap_or(Value::Null);
267 obj.insert(c.name.clone(), v);
268 }
269 Value::Object(obj)
270 }
271 })
272 .collect();
273 let row_count = qr.rows_affected().unwrap_or(qr.rows.len() as u64);
275 json!({
276 "command": command,
277 "rowCount": row_count,
278 "rows": rows,
279 "fields": fields,
280 "rowAsArray": array_mode,
281 })
282}
283
284fn cell_to_json(v: &TextValue) -> Value {
285 match v {
286 TextValue::Null => Value::Null,
287 TextValue::Text(s) => Value::String(s.clone()),
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::backend::client::ColumnMeta;
295
296 fn qr() -> QueryResult {
297 QueryResult {
298 columns: vec![
299 ColumnMeta {
300 name: "id".into(),
301 type_oid: 23,
302 },
303 ColumnMeta {
304 name: "name".into(),
305 type_oid: 25,
306 },
307 ],
308 rows: vec![
309 vec![TextValue::Text("1".into()), TextValue::Text("alice".into())],
310 vec![TextValue::Text("2".into()), TextValue::Null],
311 ],
312 command_tag: "SELECT 2".into(),
313 }
314 }
315
316 #[test]
317 fn neon_object_mode() {
318 let v = neon_result(&qr(), false);
319 assert_eq!(v["command"], "SELECT");
320 assert_eq!(v["rowCount"], 2);
321 assert_eq!(v["rows"][0]["id"], "1");
322 assert_eq!(v["rows"][0]["name"], "alice");
323 assert_eq!(v["rows"][1]["name"], Value::Null);
324 assert_eq!(v["fields"][0]["name"], "id");
325 assert_eq!(v["fields"][0]["dataTypeID"], 23);
326 }
327
328 #[test]
329 fn neon_array_mode() {
330 let v = neon_result(&qr(), true);
331 assert_eq!(v["rowAsArray"], true);
332 assert_eq!(v["rows"][0][0], "1");
333 assert_eq!(v["rows"][0][1], "alice");
334 }
335
336 #[test]
337 fn params_mapping() {
338 let p = parse_params(Some(&json!([1, "x", true, null, 2.5])));
339 assert!(matches!(p[0], ParamValue::Int(1)));
340 assert!(matches!(p[1], ParamValue::Text(ref s) if s == "x"));
341 assert!(matches!(p[2], ParamValue::Bool(true)));
342 assert!(matches!(p[3], ParamValue::Null));
343 assert!(matches!(p[4], ParamValue::Float(_)));
344 }
345}