Skip to main content

heliosdb_proxy/
http_gateway.rs

1//! HTTP SQL gateway — `@neondatabase/serverless`-compatible.
2//!
3//! When `[http_gateway] enabled = true`, the proxy exposes a `POST /sql`
4//! endpoint that runs one SQL statement over the backend PG-wire client and
5//! returns a Neon-style JSON result (`{ command, rowCount, rows, fields }`).
6//! This lets edge/serverless runtimes (Cloudflare Workers, Vercel Edge) that
7//! cannot hold a TCP socket talk to vanilla PostgreSQL or HeliosDB-Nano over
8//! plain HTTP.
9//!
10//! Parameterised queries (`$1`,`$2`) are supported via a JSON `params` array.
11//! `Neon-Array-Mode: true` returns each row as an array instead of an object.
12//! A WebSocket session/transaction mode is the planned follow-on; this is the
13//! one-shot HTTP path the serverless driver uses for non-transactional queries.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::backend::client::QueryResult;
23use crate::backend::types::TextValue;
24use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode};
25use crate::config::HttpGatewayConfig;
26use crate::{ProxyError, Result};
27
28pub struct HttpGateway {
29    config: HttpGatewayConfig,
30}
31
32impl HttpGateway {
33    pub fn new(config: HttpGatewayConfig) -> Self {
34        Self { config }
35    }
36
37    pub async fn run(self) -> Result<()> {
38        let listener = TcpListener::bind(&self.config.listen_address).await.map_err(|e| {
39            ProxyError::Network(format!("HTTP gateway bind {}: {}", self.config.listen_address, e))
40        })?;
41        tracing::info!(addr = %self.config.listen_address, "HTTP SQL gateway listening");
42        let cfg = Arc::new(self.config);
43        loop {
44            let (stream, peer) = match listener.accept().await {
45                Ok(x) => x,
46                Err(e) => {
47                    tracing::warn!("HTTP gateway accept error: {}", e);
48                    continue;
49                }
50            };
51            let cfg = cfg.clone();
52            tokio::spawn(async move {
53                if let Err(e) = Self::handle(stream, cfg).await {
54                    tracing::debug!(%peer, "HTTP gateway error: {}", e);
55                }
56            });
57        }
58    }
59
60    async fn handle(mut stream: tokio::net::TcpStream, cfg: Arc<HttpGatewayConfig>) -> Result<()> {
61        use tokio::io::AsyncBufReadExt;
62        let (reader, mut writer) = stream.split();
63        let mut reader = BufReader::new(reader);
64        let mut line = String::new();
65        let mut content_length = 0usize;
66        let mut method = String::new();
67        let mut path = String::new();
68        let mut authorized = cfg.auth_token.is_none();
69        let mut array_mode = false;
70        let mut first = true;
71        loop {
72            line.clear();
73            let n = reader
74                .read_line(&mut line)
75                .await
76                .map_err(|e| ProxyError::Network(format!("HTTP gw read: {}", e)))?;
77            if n == 0 || line == "\r\n" {
78                break;
79            }
80            if first {
81                let mut parts = line.split_whitespace();
82                method = parts.next().unwrap_or("").to_string();
83                path = parts.next().unwrap_or("").to_string();
84                first = false;
85                continue;
86            }
87            let lower = line.to_ascii_lowercase();
88            if lower.starts_with("content-length:") {
89                content_length = line.split(':').nth(1).and_then(|v| v.trim().parse().ok()).unwrap_or(0);
90            } else if lower.starts_with("neon-array-mode:") {
91                array_mode = line.split(':').nth(1).map(|v| v.trim().eq_ignore_ascii_case("true")).unwrap_or(false);
92            } else if lower.starts_with("authorization:") {
93                if let Some(tok) = cfg.auth_token.as_ref() {
94                    let v = line.splitn(2, ':').nth(1).unwrap_or("").trim();
95                    authorized = v == format!("Bearer {}", tok);
96                }
97            }
98        }
99
100        // Liveness probe.
101        if method == "GET" && (path == "/health" || path == "/") {
102            return Self::respond(&mut writer, 200, &json!({"status":"ok"})).await;
103        }
104        if !authorized {
105            return Self::respond(&mut writer, 401, &json!({"error":"unauthorized"})).await;
106        }
107        if method != "POST" {
108            return Self::respond(&mut writer, 405, &json!({"error":"use POST /sql"})).await;
109        }
110
111        let mut body_buf = vec![0u8; content_length];
112        if content_length > 0 {
113            reader
114                .read_exact(&mut body_buf)
115                .await
116                .map_err(|e| ProxyError::Network(format!("HTTP gw body: {}", e)))?;
117        }
118        let req: Value = match serde_json::from_slice(&body_buf) {
119            Ok(v) => v,
120            Err(e) => return Self::respond(&mut writer, 400, &json!({"error": format!("invalid JSON: {}", e)})).await,
121        };
122        let sql = req.get("query").and_then(|q| q.as_str()).unwrap_or("").trim();
123        if sql.is_empty() {
124            return Self::respond(&mut writer, 400, &json!({"error":"missing 'query'"})).await;
125        }
126        let params = parse_params(req.get("params"));
127
128        match Self::run_sql(&cfg, sql, &params).await {
129            Ok(qr) => {
130                let body = neon_result(&qr, array_mode);
131                Self::respond(&mut writer, 200, &body).await
132            }
133            Err(e) => Self::respond(&mut writer, 400, &json!({ "error": e })).await,
134        }
135    }
136
137    async fn run_sql(
138        cfg: &HttpGatewayConfig,
139        sql: &str,
140        params: &[ParamValue],
141    ) -> std::result::Result<QueryResult, String> {
142        let bcfg = BackendConfig {
143            host: cfg.backend_host.clone(),
144            port: cfg.backend_port,
145            user: cfg.backend_user.clone(),
146            password: cfg.backend_password.clone(),
147            database: cfg.backend_database.clone(),
148            application_name: Some("heliosproxy-http".to_string()),
149            tls_mode: TlsMode::Disable,
150            connect_timeout: Duration::from_secs(5),
151            query_timeout: Duration::from_secs(30),
152            tls_config: default_client_config(),
153        };
154        let mut client = BackendClient::connect(&bcfg)
155            .await
156            .map_err(|e| format!("backend connect: {}", e))?;
157        let res = if params.is_empty() {
158            client.simple_query(sql).await
159        } else {
160            client.query_with_params(sql, params).await
161        };
162        client.close().await;
163        res.map_err(|e| format!("{}", e))
164    }
165
166    async fn respond(
167        writer: &mut tokio::net::tcp::WriteHalf<'_>,
168        status: u16,
169        body: &Value,
170    ) -> Result<()> {
171        let payload = serde_json::to_vec(body).unwrap_or_default();
172        let status_text = match status {
173            200 => "OK",
174            400 => "Bad Request",
175            401 => "Unauthorized",
176            405 => "Method Not Allowed",
177            _ => "Error",
178        };
179        let head = format!(
180            "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
181            status, status_text, payload.len()
182        );
183        writer.write_all(head.as_bytes()).await.map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
184        writer.write_all(&payload).await.map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
185        Ok(())
186    }
187}
188
189/// Map a JSON `params` array to text-format ParamValues.
190fn parse_params(v: Option<&Value>) -> Vec<ParamValue> {
191    match v.and_then(|v| v.as_array()) {
192        None => Vec::new(),
193        Some(arr) => arr
194            .iter()
195            .map(|p| match p {
196                Value::Null => ParamValue::Null,
197                Value::Bool(b) => ParamValue::Bool(*b),
198                Value::Number(n) => {
199                    if let Some(i) = n.as_i64() {
200                        ParamValue::Int(i)
201                    } else {
202                        ParamValue::Float(n.as_f64().unwrap_or(0.0))
203                    }
204                }
205                Value::String(s) => ParamValue::Text(s.clone()),
206                other => ParamValue::Text(other.to_string()),
207            })
208            .collect(),
209    }
210}
211
212/// Build the Neon-serverless-style result body.
213fn neon_result(qr: &QueryResult, array_mode: bool) -> Value {
214    let command = qr
215        .command_tag
216        .split_whitespace()
217        .next()
218        .unwrap_or("")
219        .to_string();
220    let fields: Vec<Value> = qr
221        .columns
222        .iter()
223        .map(|c| json!({ "name": c.name, "dataTypeID": c.type_oid }))
224        .collect();
225    let rows: Vec<Value> = qr
226        .rows
227        .iter()
228        .map(|row| {
229            if array_mode {
230                Value::Array(row.iter().map(cell_to_json).collect())
231            } else {
232                let mut obj = serde_json::Map::new();
233                for (i, c) in qr.columns.iter().enumerate() {
234                    let v = row.get(i).map(cell_to_json).unwrap_or(Value::Null);
235                    obj.insert(c.name.clone(), v);
236                }
237                Value::Object(obj)
238            }
239        })
240        .collect();
241    // rowCount mirrors libpq: affected rows for writes, else row count.
242    let row_count = qr.rows_affected().unwrap_or(qr.rows.len() as u64);
243    json!({
244        "command": command,
245        "rowCount": row_count,
246        "rows": rows,
247        "fields": fields,
248        "rowAsArray": array_mode,
249    })
250}
251
252fn cell_to_json(v: &TextValue) -> Value {
253    match v {
254        TextValue::Null => Value::Null,
255        TextValue::Text(s) => Value::String(s.clone()),
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::backend::client::ColumnMeta;
263
264    fn qr() -> QueryResult {
265        QueryResult {
266            columns: vec![
267                ColumnMeta { name: "id".into(), type_oid: 23 },
268                ColumnMeta { name: "name".into(), type_oid: 25 },
269            ],
270            rows: vec![
271                vec![TextValue::Text("1".into()), TextValue::Text("alice".into())],
272                vec![TextValue::Text("2".into()), TextValue::Null],
273            ],
274            command_tag: "SELECT 2".into(),
275        }
276    }
277
278    #[test]
279    fn neon_object_mode() {
280        let v = neon_result(&qr(), false);
281        assert_eq!(v["command"], "SELECT");
282        assert_eq!(v["rowCount"], 2);
283        assert_eq!(v["rows"][0]["id"], "1");
284        assert_eq!(v["rows"][0]["name"], "alice");
285        assert_eq!(v["rows"][1]["name"], Value::Null);
286        assert_eq!(v["fields"][0]["name"], "id");
287        assert_eq!(v["fields"][0]["dataTypeID"], 23);
288    }
289
290    #[test]
291    fn neon_array_mode() {
292        let v = neon_result(&qr(), true);
293        assert_eq!(v["rowAsArray"], true);
294        assert_eq!(v["rows"][0][0], "1");
295        assert_eq!(v["rows"][0][1], "alice");
296    }
297
298    #[test]
299    fn params_mapping() {
300        let p = parse_params(Some(&json!([1, "x", true, null, 2.5])));
301        assert!(matches!(p[0], ParamValue::Int(1)));
302        assert!(matches!(p[1], ParamValue::Text(ref s) if s == "x"));
303        assert!(matches!(p[2], ParamValue::Bool(true)));
304        assert!(matches!(p[3], ParamValue::Null));
305        assert!(matches!(p[4], ParamValue::Float(_)));
306    }
307}