use axum::{
extract::{Path, Query, State},
http::{HeaderMap, StatusCode},
Json,
};
use std::collections::HashMap;
use tracing::{info, warn, debug};
use crate::api::{
models::ApiError,
server::AppState,
rest_executor::RestExecutor,
};
const RESERVED_KEYS: &[&str] = &["select", "order", "limit", "offset", "apikey"];
fn value_to_json(val: &crate::Value) -> serde_json::Value {
serde_json::Value::from(val)
}
fn collect_filters(params: &HashMap<String, String>) -> Vec<(String, String)> {
params.iter()
.filter(|(k, _)| !RESERVED_KEYS.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
fn extract_user_from_headers(headers: &HeaderMap, state: &AppState) -> Option<String> {
let bridge = state.auth_bridge.as_ref()?;
let auth_header = headers.get("authorization")?.to_str().ok()?;
let token = auth_header.strip_prefix("Bearer ")?;
bridge.get_user(token).ok().map(|u| u.id)
}
fn is_service_role(headers: &HeaderMap, state: &AppState) -> bool {
let bridge = match &state.auth_bridge {
Some(b) => b,
None => return false,
};
let key = headers
.get("apikey")
.or_else(|| headers.get("x-api-key"))
.and_then(|v| v.to_str().ok());
if let Some(key) = key {
if let Ok(user) = bridge.get_user(key) {
return user.role == "service_role";
}
}
false
}
pub async fn rest_select(
State(state): State<AppState>,
Path(table): Path<String>,
Query(params): Query<HashMap<String, String>>,
headers: HeaderMap,
) -> Result<Json<Vec<serde_json::Value>>, ApiError> {
info!(table = %table, "REST SELECT");
let executor = RestExecutor::new(state.db.clone());
let select = params.get("select").map(|s| s.as_str()).unwrap_or("*");
let order = params.get("order").map(|s| s.as_str());
let limit: Option<usize> = params.get("limit").and_then(|s| s.parse().ok());
let offset: Option<usize> = params.get("offset").and_then(|s| s.parse().ok());
let filters = collect_filters(¶ms);
let bypass_rls = is_service_role(&headers, &state);
let user_id = if bypass_rls {
None
} else {
extract_user_from_headers(&headers, &state)
};
debug!(table = %table, ?user_id, bypass_rls, "RLS context");
let (tuples, columns) = if user_id.is_some() && !bypass_rls {
executor
.select_with_rls(&table, select, &filters, order, limit, offset, user_id.as_deref())
.map_err(|e| {
warn!(table = %table, error = %e, "REST SELECT (RLS) failed");
ApiError::from(e)
})?
} else {
executor
.select(&table, select, &filters, order, limit, offset)
.map_err(|e| {
warn!(table = %table, error = %e, "REST SELECT failed");
ApiError::from(e)
})?
};
let rows: Vec<serde_json::Value> = tuples.iter().map(|tuple| {
let mut obj = serde_json::Map::new();
for (i, col) in columns.iter().enumerate() {
if let Some(val) = tuple.values.get(i) {
obj.insert(col.clone(), value_to_json(val));
}
}
serde_json::Value::Object(obj)
}).collect();
Ok(Json(rows))
}
pub async fn rest_insert(
State(state): State<AppState>,
Path(table): Path<String>,
Json(body): Json<serde_json::Value>,
) -> Result<(StatusCode, Json<serde_json::Value>), ApiError> {
info!(table = %table, "REST INSERT");
let executor = RestExecutor::new(state.db.clone());
let rows: Vec<serde_json::Value> = match body.clone() {
serde_json::Value::Array(arr) => arr,
obj @ serde_json::Value::Object(_) => vec![obj],
_ => return Err(ApiError::bad_request(
"Request body must be a JSON object or array of objects"
)),
};
let (affected, _, _) = executor.insert(&table, &rows).map_err(|e| {
warn!(table = %table, error = %e, "REST INSERT failed");
ApiError::from(e)
})?;
if affected > 0 {
if let Some(notifier) = &state.change_notifier {
for row in &rows {
notifier.notify(&table, "INSERT", Some(row.clone()), None);
}
}
}
let response = serde_json::json!({
"message": format!("{affected} row(s) inserted"),
"count": affected,
});
Ok((StatusCode::CREATED, Json(response)))
}
pub async fn rest_update(
State(state): State<AppState>,
Path(table): Path<String>,
Query(params): Query<HashMap<String, String>>,
headers: HeaderMap,
Json(body): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, ApiError> {
info!(table = %table, "REST UPDATE");
let executor = RestExecutor::new(state.db.clone());
let filters = collect_filters(¶ms);
let bypass_rls = is_service_role(&headers, &state);
let user_id = if bypass_rls {
None
} else {
extract_user_from_headers(&headers, &state)
};
let affected = if user_id.is_some() && !bypass_rls {
executor.update_with_rls(&table, &body, &filters, user_id.as_deref()).map_err(|e| {
warn!(table = %table, error = %e, "REST UPDATE (RLS) failed");
ApiError::from(e)
})?
} else {
executor.update(&table, &body, &filters).map_err(|e| {
warn!(table = %table, error = %e, "REST UPDATE failed");
ApiError::from(e)
})?
};
if affected > 0 {
if let Some(notifier) = &state.change_notifier {
notifier.notify(&table, "UPDATE", Some(body), None);
}
}
Ok(Json(serde_json::json!({
"message": format!("{affected} row(s) updated"),
"count": affected,
})))
}
pub async fn rest_delete(
State(state): State<AppState>,
Path(table): Path<String>,
Query(params): Query<HashMap<String, String>>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, ApiError> {
info!(table = %table, "REST DELETE");
let executor = RestExecutor::new(state.db.clone());
let filters = collect_filters(¶ms);
let bypass_rls = is_service_role(&headers, &state);
let user_id = if bypass_rls {
None
} else {
extract_user_from_headers(&headers, &state)
};
let affected = if user_id.is_some() && !bypass_rls {
executor.delete_with_rls(&table, &filters, user_id.as_deref()).map_err(|e| {
warn!(table = %table, error = %e, "REST DELETE (RLS) failed");
ApiError::from(e)
})?
} else {
executor.delete(&table, &filters).map_err(|e| {
warn!(table = %table, error = %e, "REST DELETE failed");
ApiError::from(e)
})?
};
if affected > 0 {
if let Some(notifier) = &state.change_notifier {
notifier.notify(&table, "DELETE", None, None);
}
}
Ok(Json(serde_json::json!({
"message": format!("{affected} row(s) deleted"),
"count": affected,
})))
}
#[allow(dead_code)]
pub async fn rest_rpc(
State(_state): State<AppState>,
Path(function): Path<String>,
Json(_body): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, ApiError> {
info!(function = %function, "REST RPC");
Err(ApiError::new(
StatusCode::NOT_IMPLEMENTED,
"NotImplemented",
format!("RPC function '{}' is not yet supported", function),
))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::EmbeddedDatabase;
use crate::compute::QueryRegistry;
use std::sync::Arc;
fn test_state() -> AppState {
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
let query_registry = Arc::new(QueryRegistry::new());
AppState { db, query_registry, auth_bridge: None, oauth_registry: None, change_notifier: None }
}
fn test_state_with_table() -> AppState {
let state = test_state();
state.db.execute("CREATE TABLE users (id INT, name TEXT, age INT)").unwrap();
state.db.execute("INSERT INTO users VALUES (1, 'Alice', 30)").unwrap();
state.db.execute("INSERT INTO users VALUES (2, 'Bob', 25)").unwrap();
state.db.execute("INSERT INTO users VALUES (3, 'Carol', 35)").unwrap();
state
}
fn empty_headers() -> HeaderMap {
HeaderMap::new()
}
#[tokio::test]
async fn test_rest_select_all() {
let state = test_state_with_table();
let params = HashMap::new();
let result = rest_select(
State(state),
Path("users".to_string()),
Query(params),
empty_headers(),
).await;
assert!(result.is_ok());
let rows = result.unwrap().0;
assert_eq!(rows.len(), 3);
}
#[tokio::test]
async fn test_rest_select_with_filter() {
let state = test_state_with_table();
let mut params = HashMap::new();
params.insert("name".to_string(), "eq.Alice".to_string());
let result = rest_select(
State(state),
Path("users".to_string()),
Query(params),
empty_headers(),
).await;
assert!(result.is_ok());
let rows = result.unwrap().0;
assert_eq!(rows.len(), 1);
}
#[tokio::test]
async fn test_rest_select_with_limit() {
let state = test_state_with_table();
let mut params = HashMap::new();
params.insert("limit".to_string(), "2".to_string());
let result = rest_select(
State(state),
Path("users".to_string()),
Query(params),
empty_headers(),
).await;
assert!(result.is_ok());
let rows = result.unwrap().0;
assert_eq!(rows.len(), 2);
}
#[tokio::test]
async fn test_rest_insert_single() {
let state = test_state();
state.db.execute("CREATE TABLE items (id INT, label TEXT)").unwrap();
let body = serde_json::json!({"id": 1, "label": "test"});
let result = rest_insert(
State(state),
Path("items".to_string()),
Json(body),
).await;
assert!(result.is_ok());
let (status, json) = result.unwrap();
assert_eq!(status, StatusCode::CREATED);
assert_eq!(json.0["count"], 1);
}
#[tokio::test]
async fn test_rest_insert_batch() {
let state = test_state();
state.db.execute("CREATE TABLE items (id INT, label TEXT)").unwrap();
let body = serde_json::json!([
{"id": 1, "label": "a"},
{"id": 2, "label": "b"},
]);
let result = rest_insert(
State(state),
Path("items".to_string()),
Json(body),
).await;
assert!(result.is_ok());
let (_, json) = result.unwrap();
assert_eq!(json.0["count"], 2);
}
#[tokio::test]
async fn test_rest_update() {
let state = test_state_with_table();
let mut params = HashMap::new();
params.insert("id".to_string(), "eq.1".to_string());
let body = serde_json::json!({"name": "Alicia"});
let result = rest_update(
State(state),
Path("users".to_string()),
Query(params),
empty_headers(),
Json(body),
).await;
assert!(result.is_ok());
let json = result.unwrap().0;
assert_eq!(json["count"], 1);
}
#[tokio::test]
async fn test_rest_delete() {
let state = test_state_with_table();
let mut params = HashMap::new();
params.insert("id".to_string(), "eq.2".to_string());
let result = rest_delete(
State(state.clone()),
Path("users".to_string()),
Query(params),
empty_headers(),
).await;
assert!(result.is_ok());
let json = result.unwrap().0;
assert_eq!(json["count"], 1);
let _remaining = state.db.query("SELECT * FROM users", &[]);
}
#[tokio::test]
async fn test_rest_select_nonexistent_table() {
let state = test_state();
let params = HashMap::new();
let result = rest_select(
State(state),
Path("nonexistent".to_string()),
Query(params),
empty_headers(),
).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_rest_insert_invalid_body() {
let state = test_state();
state.db.execute("CREATE TABLE t (id INT)").unwrap();
let body = serde_json::json!("not an object");
let result = rest_insert(
State(state),
Path("t".to_string()),
Json(body),
).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_rest_select_rls_no_auth_bridge_returns_all() {
let state = test_state_with_table();
state.db.execute("ALTER TABLE users ADD COLUMN owner_id TEXT").unwrap();
state.db.execute("UPDATE users SET owner_id = 'u1' WHERE id = 1").unwrap();
state.db.execute("UPDATE users SET owner_id = 'u2' WHERE id = 2").unwrap();
state.db.execute("UPDATE users SET owner_id = 'u1' WHERE id = 3").unwrap();
let result = rest_select(
State(state),
Path("users".to_string()),
Query(HashMap::new()),
empty_headers(),
).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().0.len(), 3);
}
}