use crate::db;
use crate::error::AppError;
use chrono::NaiveDate;
use duckdb::arrow::datatypes::DataType;
use duckdb::Connection;
use serde_json::{json, Value};
use std::sync::Mutex;
pub const DEFAULT_LIMIT: usize = 1000;
pub const MAX_RESULT_LIMIT: usize = 100_000;
pub fn bbox_filter(min_lat: f64, min_lon: f64, max_lat: f64, max_lon: f64) -> String {
format!(
"ST_Intersects(geom, ST_MakeEnvelope({}, {}, {}, {}))",
min_lon, min_lat, max_lon, max_lat
)
}
pub fn radius_filter(lat: f64, lon: f64, radius_m: f64) -> String {
let deg_offset = radius_m / 111_000.0 * 1.5;
let bbox = bbox_filter(
lat - deg_offset,
lon - deg_offset,
lat + deg_offset,
lon + deg_offset,
);
format!(
"{} AND ST_Distance_Sphere(geom, ST_Point({}, {})) <= {}",
bbox, lon, lat, radius_m
)
}
pub fn distance_select(lat: f64, lon: f64) -> String {
format!(
"ST_Distance_Sphere(geom, ST_Point({}, {})) / 1000.0 AS _distance_km",
lon, lat
)
}
pub fn within_filter_geojson(geojson_str: &str) -> String {
format!(
"ST_Contains(ST_GeomFromGeoJSON('{}'), geom)",
geojson_str.replace('\'', "''")
)
}
#[allow(clippy::too_many_arguments)]
pub fn query(
db: &Mutex<Connection>,
spatial_where: Option<&str>,
where_clauses: &[(String, String)],
select_cols: Option<&[String]>,
group_by: Option<&str>,
agg: Option<&str>,
limit: usize,
extra_select: Option<&str>,
order_by: Option<&str>,
) -> Result<Vec<Value>, AppError> {
if let (Some(gb), Some(a)) = (group_by, agg) {
db::validate_column_name(gb)?;
return query_aggregate(db, spatial_where, where_clauses, gb, a, limit);
}
let table = if spatial_where.is_some() {
"raw_data"
} else {
"data"
};
let select = match select_cols {
Some(cols) => {
for c in cols {
db::validate_column_name(c)?;
}
let mut s = cols.join(", ");
if let Some(extra) = extra_select {
s.push_str(", ");
s.push_str(extra);
}
s
}
None => {
let base = if spatial_where.is_some() {
"* EXCLUDE (geom)"
} else {
"*"
};
let mut s = String::from(base);
if let Some(extra) = extra_select {
s.push_str(", ");
s.push_str(extra);
}
s
}
};
let mut sql = format!("SELECT {} FROM {}", select, table);
let mut conditions = Vec::new();
if let Some(sw) = spatial_where {
conditions.push(sw.to_string());
}
for (col, val) in where_clauses {
db::validate_column_name(col)?;
conditions.push(format!(
"CAST(\"{}\" AS VARCHAR) = '{}'",
col,
escape_sql_value(val)
));
}
if !conditions.is_empty() {
sql.push_str(" WHERE ");
sql.push_str(&conditions.join(" AND "));
}
if let Some(ob) = order_by {
sql.push_str(" ORDER BY ");
sql.push_str(ob);
}
sql.push_str(&format!(" LIMIT {}", limit));
execute_query_to_json(db, &sql)
}
fn query_aggregate(
db: &Mutex<Connection>,
spatial_where: Option<&str>,
where_clauses: &[(String, String)],
group_by: &str,
agg: &str,
limit: usize,
) -> Result<Vec<Value>, AppError> {
let agg_expr = if agg == "count" {
"COUNT(*) AS count".to_string()
} else if let Some(col) = agg.strip_prefix("sum:") {
db::validate_column_name(col)?;
format!("SUM(\"{}\") AS sum_{}", col, col)
} else if let Some(col) = agg.strip_prefix("avg:") {
db::validate_column_name(col)?;
format!("AVG(\"{}\") AS avg_{}", col, col)
} else if let Some(col) = agg.strip_prefix("min:") {
db::validate_column_name(col)?;
format!("MIN(\"{}\") AS min_{}", col, col)
} else if let Some(col) = agg.strip_prefix("max:") {
db::validate_column_name(col)?;
format!("MAX(\"{}\") AS max_{}", col, col)
} else {
"COUNT(*) AS count".to_string()
};
let table2 = if spatial_where.is_some() {
"raw_data"
} else {
"data"
};
let mut sql = format!("SELECT \"{}\", {} FROM {}", group_by, agg_expr, table2);
let mut conditions = Vec::new();
if let Some(sw) = spatial_where {
conditions.push(sw.to_string());
}
for (col, val) in where_clauses {
db::validate_column_name(col)?;
conditions.push(format!(
"CAST(\"{}\" AS VARCHAR) = '{}'",
col,
escape_sql_value(val)
));
}
if !conditions.is_empty() {
sql.push_str(" WHERE ");
sql.push_str(&conditions.join(" AND "));
}
sql.push_str(&format!(" GROUP BY \"{}\"", group_by));
sql.push_str(&format!(" LIMIT {}", limit));
execute_query_to_json(db, &sql)
}
pub fn query_points_in_bbox(
db: &Mutex<Connection>,
lat_col: &str,
lon_col: &str,
min_lat: f64,
min_lon: f64,
max_lat: f64,
max_lon: f64,
) -> Result<Vec<(f64, f64)>, AppError> {
db::validate_column_name(lat_col)?;
db::validate_column_name(lon_col)?;
let filter = bbox_filter(min_lat, min_lon, max_lat, max_lon);
let sql = format!(
"SELECT \"{}\", \"{}\" FROM raw_data WHERE {}",
lat_col, lon_col, filter
);
let conn = db::lock_db(db)?;
let mut stmt = conn
.prepare(&sql)
.map_err(|e| AppError::Internal(anyhow::anyhow!("SQL prepare error: {} — {}", e, sql)))?;
let mut rows = stmt
.query([])
.map_err(|e| AppError::Internal(anyhow::anyhow!("SQL query error: {} — {}", e, sql)))?;
let mut points = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| AppError::Internal(anyhow::anyhow!("Row fetch error: {}", e)))?
{
let lat: f64 = row
.get(0)
.map_err(|e| AppError::Internal(anyhow::anyhow!("lat read error: {}", e)))?;
let lon: f64 = row
.get(1)
.map_err(|e| AppError::Internal(anyhow::anyhow!("lon read error: {}", e)))?;
points.push((lat, lon));
}
Ok(points)
}
pub fn query_rows_in_bbox(
db: &Mutex<Connection>,
min_lat: f64,
min_lon: f64,
max_lat: f64,
max_lon: f64,
limit: usize,
) -> Result<Vec<Value>, AppError> {
let filter = bbox_filter(min_lat, min_lon, max_lat, max_lon);
let sql = format!(
"SELECT * EXCLUDE (geom) FROM raw_data WHERE {} LIMIT {}",
filter, limit
);
execute_query_to_json(db, &sql)
}
pub fn execute_query_to_json(db: &Mutex<Connection>, sql: &str) -> Result<Vec<Value>, AppError> {
let conn = db::lock_db(db)?;
let mut stmt = conn
.prepare(sql)
.map_err(|e| AppError::Internal(anyhow::anyhow!("SQL prepare error: {} — {}", e, sql)))?;
let mut rows = stmt
.query([])
.map_err(|e| AppError::Internal(anyhow::anyhow!("SQL query error: {} — {}", e, sql)))?;
let stmt_ref = rows
.as_ref()
.ok_or_else(|| AppError::Internal(anyhow::anyhow!("No statement backing query results")))?;
let col_count = stmt_ref.column_count();
let col_names: Vec<String> = (0..col_count)
.map(|i| {
stmt_ref
.column_name(i)
.map_or("?".to_string(), |v| v.to_string())
})
.collect();
let col_types: Vec<DataType> = (0..col_count).map(|i| stmt_ref.column_type(i)).collect();
let mut rows_out = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| AppError::Internal(anyhow::anyhow!("Row fetch error: {}", e)))?
{
let mut obj = serde_json::Map::with_capacity(col_count);
for (i, name) in col_names.iter().enumerate() {
let val = row_value_to_json(row, i, &col_types[i]);
obj.insert(name.clone(), val);
}
rows_out.push(Value::Object(obj));
}
Ok(rows_out)
}
fn row_value_to_json(row: &duckdb::Row, idx: usize, col_type: &DataType) -> Value {
match col_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32 => match row.get::<_, i64>(idx) {
Ok(v) => json!(v),
Err(_) => Value::Null,
},
DataType::UInt64 => match row.get::<_, u64>(idx) {
Ok(v) if v <= i64::MAX as u64 => json!(v),
Ok(v) => json!(v.to_string()),
Err(_) => Value::Null,
},
DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _) => match row.get::<_, f64>(idx) {
Ok(v) if v.is_nan() || v.is_infinite() => Value::Null,
Ok(v) => json!(v),
Err(_) => Value::Null,
},
DataType::Boolean => match row.get::<_, bool>(idx) {
Ok(v) => json!(v),
Err(_) => Value::Null,
},
DataType::Date32 => match row.get::<_, i32>(idx) {
Ok(days) => {
let date = NaiveDate::from_num_days_from_ce_opt(days + 719_163)
.map(|d| d.format("%Y-%m-%d").to_string());
match date {
Some(s) => json!(s),
None => json!(days),
}
}
Err(_) => Value::Null,
},
DataType::Date64 | DataType::Timestamp(_, _) => match row.get::<_, String>(idx) {
Ok(v) => json!(v),
Err(_) => match row.get::<_, i64>(idx) {
Ok(v) => json!(v),
Err(_) => Value::Null,
},
},
_ => match row.get::<_, String>(idx) {
Ok(v) => json!(v),
Err(_) => Value::Null,
},
}
}
fn escape_sql_value(s: &str) -> String {
s.replace('\'', "''")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn escape_doubles_single_quotes() {
assert_eq!(escape_sql_value("o'brien"), "o''brien");
assert_eq!(escape_sql_value("'; DROP --"), "''; DROP --");
assert_eq!(escape_sql_value("plain"), "plain");
}
#[test]
fn bbox_filter_uses_lon_lat_order() {
let f = bbox_filter(35.0, -84.0, 37.0, -81.0);
assert!(f.contains("ST_MakeEnvelope(-84, 35, -81, 37)"));
assert!(f.starts_with("ST_Intersects(geom,"));
}
#[test]
fn distance_select_aliases_km() {
let s = distance_select(36.5, -82.5);
assert!(s.contains("ST_Distance_Sphere(geom, ST_Point(-82.5, 36.5))"));
assert!(s.ends_with("AS _distance_km"));
}
#[test]
fn within_filter_escapes_geojson() {
let f = within_filter_geojson("{\"x\":\"o'brien\"}");
assert!(f.contains("o''brien"));
assert!(f.starts_with("ST_Contains(ST_GeomFromGeoJSON("));
}
}