Skip to main content

heliosdb_proxy/
http_gateway.rs

1//! HTTP SQL gateway — `@neondatabase/serverless`-compatible.
2//!
3//! When `[http_gateway] enabled = true`, the proxy exposes a `POST /sql`
4//! endpoint that runs one SQL statement over the backend PG-wire client and
5//! returns a Neon-style JSON result (`{ command, rowCount, rows, fields }`).
6//! This lets edge/serverless runtimes (Cloudflare Workers, Vercel Edge) that
7//! cannot hold a TCP socket talk to vanilla PostgreSQL or HeliosDB-Nano over
8//! plain HTTP.
9//!
10//! Parameterised queries (`$1`,`$2`) are supported via a JSON `params` array.
11//! `Neon-Array-Mode: true` returns each row as an array instead of an object.
12//! A WebSocket session/transaction mode is the planned follow-on; this is the
13//! one-shot HTTP path the serverless driver uses for non-transactional queries.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::backend::client::QueryResult;
23use crate::backend::types::TextValue;
24use crate::backend::{
25    tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode,
26};
27use crate::config::HttpGatewayConfig;
28use crate::{ProxyError, Result};
29
30pub struct HttpGateway {
31    config: HttpGatewayConfig,
32}
33
34impl HttpGateway {
35    pub fn new(config: HttpGatewayConfig) -> Self {
36        Self { config }
37    }
38
39    pub async fn run(self) -> Result<()> {
40        let listener = TcpListener::bind(&self.config.listen_address)
41            .await
42            .map_err(|e| {
43                ProxyError::Network(format!(
44                    "HTTP gateway bind {}: {}",
45                    self.config.listen_address, e
46                ))
47            })?;
48        tracing::info!(addr = %self.config.listen_address, "HTTP SQL gateway listening");
49        let cfg = Arc::new(self.config);
50        loop {
51            let (stream, peer) = match listener.accept().await {
52                Ok(x) => x,
53                Err(e) => {
54                    tracing::warn!("HTTP gateway accept error: {}", e);
55                    continue;
56                }
57            };
58            let cfg = cfg.clone();
59            tokio::spawn(async move {
60                if let Err(e) = Self::handle(stream, cfg).await {
61                    tracing::debug!(%peer, "HTTP gateway error: {}", e);
62                }
63            });
64        }
65    }
66
67    async fn handle(mut stream: tokio::net::TcpStream, cfg: Arc<HttpGatewayConfig>) -> Result<()> {
68        use tokio::io::AsyncBufReadExt;
69        let (reader, mut writer) = stream.split();
70        let mut reader = BufReader::new(reader);
71        let mut line = String::new();
72        let mut content_length = 0usize;
73        let mut method = String::new();
74        let mut path = String::new();
75        let mut authorized = cfg.auth_token.is_none();
76        let mut array_mode = false;
77        let mut first = true;
78        loop {
79            line.clear();
80            let n = reader
81                .read_line(&mut line)
82                .await
83                .map_err(|e| ProxyError::Network(format!("HTTP gw read: {}", e)))?;
84            if n == 0 || line == "\r\n" {
85                break;
86            }
87            if first {
88                let mut parts = line.split_whitespace();
89                method = parts.next().unwrap_or("").to_string();
90                path = parts.next().unwrap_or("").to_string();
91                first = false;
92                continue;
93            }
94            let lower = line.to_ascii_lowercase();
95            if lower.starts_with("content-length:") {
96                content_length = line
97                    .split(':')
98                    .nth(1)
99                    .and_then(|v| v.trim().parse().ok())
100                    .unwrap_or(0);
101            } else if lower.starts_with("neon-array-mode:") {
102                array_mode = line
103                    .split(':')
104                    .nth(1)
105                    .map(|v| v.trim().eq_ignore_ascii_case("true"))
106                    .unwrap_or(false);
107            } else if lower.starts_with("authorization:") {
108                if let Some(tok) = cfg.auth_token.as_ref() {
109                    let v = line.split_once(':').map(|x| x.1).unwrap_or("").trim();
110                    authorized = v == format!("Bearer {}", tok);
111                }
112            }
113        }
114
115        // Liveness probe.
116        if method == "GET" && (path == "/health" || path == "/") {
117            return Self::respond(&mut writer, 200, &json!({"status":"ok"})).await;
118        }
119        if !authorized {
120            return Self::respond(&mut writer, 401, &json!({"error":"unauthorized"})).await;
121        }
122        if method != "POST" {
123            return Self::respond(&mut writer, 405, &json!({"error":"use POST /sql"})).await;
124        }
125
126        let mut body_buf = vec![0u8; content_length];
127        if content_length > 0 {
128            reader
129                .read_exact(&mut body_buf)
130                .await
131                .map_err(|e| ProxyError::Network(format!("HTTP gw body: {}", e)))?;
132        }
133        let req: Value = match serde_json::from_slice(&body_buf) {
134            Ok(v) => v,
135            Err(e) => {
136                return Self::respond(
137                    &mut writer,
138                    400,
139                    &json!({"error": format!("invalid JSON: {}", e)}),
140                )
141                .await
142            }
143        };
144        let sql = req
145            .get("query")
146            .and_then(|q| q.as_str())
147            .unwrap_or("")
148            .trim();
149        if sql.is_empty() {
150            return Self::respond(&mut writer, 400, &json!({"error":"missing 'query'"})).await;
151        }
152        let params = parse_params(req.get("params"));
153
154        match Self::run_sql(&cfg, sql, &params).await {
155            Ok(qr) => {
156                let body = neon_result(&qr, array_mode);
157                Self::respond(&mut writer, 200, &body).await
158            }
159            Err(e) => Self::respond(&mut writer, 400, &json!({ "error": e })).await,
160        }
161    }
162
163    async fn run_sql(
164        cfg: &HttpGatewayConfig,
165        sql: &str,
166        params: &[ParamValue],
167    ) -> std::result::Result<QueryResult, String> {
168        let bcfg = BackendConfig {
169            host: cfg.backend_host.clone(),
170            port: cfg.backend_port,
171            user: cfg.backend_user.clone(),
172            password: cfg.backend_password.clone(),
173            database: cfg.backend_database.clone(),
174            application_name: Some("heliosproxy-http".to_string()),
175            tls_mode: TlsMode::Disable,
176            connect_timeout: Duration::from_secs(5),
177            query_timeout: Duration::from_secs(30),
178            tls_config: default_client_config(),
179        };
180        let mut client = BackendClient::connect(&bcfg)
181            .await
182            .map_err(|e| format!("backend connect: {}", e))?;
183        let res = if params.is_empty() {
184            client.simple_query(sql).await
185        } else {
186            client.query_with_params(sql, params).await
187        };
188        client.close().await;
189        res.map_err(|e| format!("{}", e))
190    }
191
192    async fn respond(
193        writer: &mut tokio::net::tcp::WriteHalf<'_>,
194        status: u16,
195        body: &Value,
196    ) -> Result<()> {
197        let payload = serde_json::to_vec(body).unwrap_or_default();
198        let status_text = match status {
199            200 => "OK",
200            400 => "Bad Request",
201            401 => "Unauthorized",
202            405 => "Method Not Allowed",
203            _ => "Error",
204        };
205        let head = format!(
206            "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
207            status, status_text, payload.len()
208        );
209        writer
210            .write_all(head.as_bytes())
211            .await
212            .map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
213        writer
214            .write_all(&payload)
215            .await
216            .map_err(|e| ProxyError::Network(format!("write: {}", e)))?;
217        Ok(())
218    }
219}
220
221/// Map a JSON `params` array to text-format ParamValues.
222fn parse_params(v: Option<&Value>) -> Vec<ParamValue> {
223    match v.and_then(|v| v.as_array()) {
224        None => Vec::new(),
225        Some(arr) => arr
226            .iter()
227            .map(|p| match p {
228                Value::Null => ParamValue::Null,
229                Value::Bool(b) => ParamValue::Bool(*b),
230                Value::Number(n) => {
231                    if let Some(i) = n.as_i64() {
232                        ParamValue::Int(i)
233                    } else {
234                        ParamValue::Float(n.as_f64().unwrap_or(0.0))
235                    }
236                }
237                Value::String(s) => ParamValue::Text(s.clone()),
238                other => ParamValue::Text(other.to_string()),
239            })
240            .collect(),
241    }
242}
243
244/// Build the Neon-serverless-style result body.
245fn neon_result(qr: &QueryResult, array_mode: bool) -> Value {
246    let command = qr
247        .command_tag
248        .split_whitespace()
249        .next()
250        .unwrap_or("")
251        .to_string();
252    let fields: Vec<Value> = qr
253        .columns
254        .iter()
255        .map(|c| json!({ "name": c.name, "dataTypeID": c.type_oid }))
256        .collect();
257    let rows: Vec<Value> = qr
258        .rows
259        .iter()
260        .map(|row| {
261            if array_mode {
262                Value::Array(row.iter().map(cell_to_json).collect())
263            } else {
264                let mut obj = serde_json::Map::new();
265                for (i, c) in qr.columns.iter().enumerate() {
266                    let v = row.get(i).map(cell_to_json).unwrap_or(Value::Null);
267                    obj.insert(c.name.clone(), v);
268                }
269                Value::Object(obj)
270            }
271        })
272        .collect();
273    // rowCount mirrors libpq: affected rows for writes, else row count.
274    let row_count = qr.rows_affected().unwrap_or(qr.rows.len() as u64);
275    json!({
276        "command": command,
277        "rowCount": row_count,
278        "rows": rows,
279        "fields": fields,
280        "rowAsArray": array_mode,
281    })
282}
283
284fn cell_to_json(v: &TextValue) -> Value {
285    match v {
286        TextValue::Null => Value::Null,
287        TextValue::Text(s) => Value::String(s.clone()),
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::backend::client::ColumnMeta;
295
296    fn qr() -> QueryResult {
297        QueryResult {
298            columns: vec![
299                ColumnMeta {
300                    name: "id".into(),
301                    type_oid: 23,
302                },
303                ColumnMeta {
304                    name: "name".into(),
305                    type_oid: 25,
306                },
307            ],
308            rows: vec![
309                vec![TextValue::Text("1".into()), TextValue::Text("alice".into())],
310                vec![TextValue::Text("2".into()), TextValue::Null],
311            ],
312            command_tag: "SELECT 2".into(),
313        }
314    }
315
316    #[test]
317    fn neon_object_mode() {
318        let v = neon_result(&qr(), false);
319        assert_eq!(v["command"], "SELECT");
320        assert_eq!(v["rowCount"], 2);
321        assert_eq!(v["rows"][0]["id"], "1");
322        assert_eq!(v["rows"][0]["name"], "alice");
323        assert_eq!(v["rows"][1]["name"], Value::Null);
324        assert_eq!(v["fields"][0]["name"], "id");
325        assert_eq!(v["fields"][0]["dataTypeID"], 23);
326    }
327
328    #[test]
329    fn neon_array_mode() {
330        let v = neon_result(&qr(), true);
331        assert_eq!(v["rowAsArray"], true);
332        assert_eq!(v["rows"][0][0], "1");
333        assert_eq!(v["rows"][0][1], "alice");
334    }
335
336    #[test]
337    fn params_mapping() {
338        let p = parse_params(Some(&json!([1, "x", true, null, 2.5])));
339        assert!(matches!(p[0], ParamValue::Int(1)));
340        assert!(matches!(p[1], ParamValue::Text(ref s) if s == "x"));
341        assert!(matches!(p[2], ParamValue::Bool(true)));
342        assert!(matches!(p[3], ParamValue::Null));
343        assert!(matches!(p[4], ParamValue::Float(_)));
344    }
345}