use bytes::BytesMut;
use deadpool_postgres::{Config, Pool, Runtime};
use serde_json::Value;
use tokio_postgres::{NoTls, Row};
use tokio_postgres::types::{FromSql, IsNull, ToSql, Type};
use uuid::Uuid;
use crate::models::{Account, AccountRole, AppState, DatabaseAccess, FieldMetadata, Permission, QueryResult};
struct RawBytes(Vec<u8>);
impl<'a> FromSql<'a> for RawBytes {
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
Ok(RawBytes(raw.to_vec()))
}
fn accepts(_ty: &Type) -> bool {
true
}
}
fn numeric_binary_to_string(raw: &[u8]) -> Option<String> {
if raw.len() < 8 {
return None;
}
let ndigits = i16::from_be_bytes([raw[0], raw[1]]) as usize;
let weight = i16::from_be_bytes([raw[2], raw[3]]);
let sign = u16::from_be_bytes([raw[4], raw[5]]);
let dscale = u16::from_be_bytes([raw[6], raw[7]]) as usize;
if raw.len() != 8 + ndigits * 2 {
return None;
}
if sign == 0xC000 {
return Some("NaN".to_string());
}
let is_negative = sign == 0x4000;
let mut digits = Vec::with_capacity(ndigits);
for i in 0..ndigits {
let offset = 8 + i * 2;
digits.push(i16::from_be_bytes([raw[offset], raw[offset + 1]]));
}
let mut result = String::new();
if is_negative {
result.push('-');
}
let int_groups = (weight + 1).max(0) as usize;
if ndigits == 0 {
result.push('0');
} else {
for i in 0..int_groups {
let d = if i < ndigits { digits[i] } else { 0 };
if i == 0 {
result.push_str(&d.to_string());
} else {
result.push_str(&format!("{:04}", d));
}
}
if int_groups == 0 {
result.push('0');
}
}
if dscale > 0 {
result.push('.');
let mut dec = String::new();
for i in int_groups..ndigits {
dec.push_str(&format!("{:04}", digits[i]));
}
while dec.len() < dscale {
dec.push('0');
}
if dec.len() > dscale {
dec.truncate(dscale);
}
result.push_str(&dec);
}
Some(result)
}
pub async fn get_or_create_pool(
state: &AppState,
pool_key: &str,
account: &Account,
db_access: &DatabaseAccess,
) -> Result<Pool, String> {
if let Some(pool) = state.connections.get(pool_key) {
return Ok(pool.clone());
}
let instances = state.instances.read().await;
let instance = instances.get(&account.instance_id)
.ok_or("Instance not found")?;
let mut config = Config::new();
config.host = Some(instance.host.clone());
config.port = Some(instance.port);
config.user = Some(db_access.username.clone());
config.password = Some(db_access.password.clone());
config.dbname = Some(db_access.database.clone());
let pool = config.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| e.to_string())?;
state.connections.insert(pool_key.to_string(), pool.clone());
Ok(pool)
}
pub fn validate_query_permissions(
query: &str,
permissions: &[Permission],
role: &AccountRole,
) -> Result<(), String> {
let query_upper = query.to_uppercase();
if *role == AccountRole::Superuser {
return Ok(());
}
let required_permission = if query_upper.trim().starts_with("SELECT") || query_upper.trim().starts_with("WITH") {
Permission::Select
} else if query_upper.trim().starts_with("INSERT") {
Permission::Insert
} else if query_upper.trim().starts_with("UPDATE") {
Permission::Update
} else if query_upper.trim().starts_with("DELETE") {
Permission::Delete
} else if query_upper.trim().starts_with("CREATE") {
Permission::Create
} else if query_upper.trim().starts_with("DROP") {
Permission::Drop
} else if query_upper.trim().starts_with("TRUNCATE") {
Permission::Truncate
} else {
return Err("Unsupported query type".to_string());
};
if permissions.contains(&Permission::All) || permissions.contains(&required_permission) {
Ok(())
} else {
Err(format!("Missing permission: {:?}", required_permission))
}
}
pub fn rows_to_json(rows: &[Row]) -> Vec<Value> {
rows.iter().map(|row| {
let mut obj = serde_json::Map::new();
for (i, column) in row.columns().iter().enumerate() {
let value = match column.type_().name() {
"bool" => row.try_get::<_, bool>(i).ok().map(Value::Bool),
"int2" | "int4" => row.try_get::<_, i32>(i).ok().map(|v| Value::Number(v.into())),
"int8" => row.try_get::<_, i64>(i).ok().map(|v| Value::Number(v.into())),
"float4" => row.try_get::<_, f32>(i).ok().map(|v| Value::Number(serde_json::Number::from_f64(v as f64).unwrap_or_else(|| 0.into()))),
"float8" => row.try_get::<_, f64>(i).ok().map(|v| Value::Number(serde_json::Number::from_f64(v).unwrap_or_else(|| 0.into()))),
"text" | "varchar" | "char" | "bpchar" => row.try_get::<_, String>(i).ok().map(Value::String),
"json" | "jsonb" => row.try_get::<_, Value>(i).ok(),
"timestamp" | "timestamptz" => row.try_get::<_, chrono::NaiveDateTime>(i).ok().map(|dt| Value::String(dt.format("%Y-%m-%d %H:%M:%S").to_string())),
"date" => row.try_get::<_, chrono::NaiveDate>(i).ok().map(|d| Value::String(d.format("%Y-%m-%d").to_string())),
"uuid" => row.try_get::<_, Uuid>(i).ok().map(|u| Value::String(u.to_string())),
"numeric" | "decimal" => {
row.try_get::<_, RawBytes>(i)
.ok()
.and_then(|raw| numeric_binary_to_string(&raw.0))
.map(Value::String)
},
"time" | "timetz" => row.try_get::<_, chrono::NaiveTime>(i).ok().map(|t| Value::String(t.format("%H:%M:%S").to_string())),
_ => row.try_get::<_, String>(i).ok().map(Value::String),
};
obj.insert(column.name().to_string(), value.unwrap_or(Value::Null));
}
Value::Object(obj)
}).collect()
}
#[derive(Debug)]
enum PgParam {
Null,
Bool(bool),
I64(i64),
F64(f64),
String(String),
}
impl ToSql for PgParam {
fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
match self {
PgParam::Null => Ok(IsNull::Yes),
PgParam::Bool(b) => b.to_sql(_ty, out),
PgParam::I64(n) => n.to_sql(_ty, out),
PgParam::F64(n) => n.to_sql(_ty, out),
PgParam::String(s) => s.to_sql(_ty, out),
}
}
fn accepts(_ty: &Type) -> bool {
true
}
fn to_sql_checked(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.to_sql(ty, out)
}
}
fn json_to_pg_param(value: &Value) -> PgParam {
match value {
Value::Null => PgParam::Null,
Value::Bool(b) => PgParam::Bool(*b),
Value::Number(n) => {
n.as_i64()
.map(PgParam::I64)
.or_else(|| n.as_f64().map(PgParam::F64))
.unwrap_or(PgParam::F64(0.0))
}
Value::String(s) => PgParam::String(s.clone()),
Value::Array(_) => PgParam::String(value.to_string()),
Value::Object(_) => PgParam::String(value.to_string()),
}
}
pub async fn execute_query_with_pool(
pool: Pool,
query: String,
params: Vec<Value>,
) -> Result<QueryResult, String> {
let client = pool.get().await.map_err(|e| e.to_string())?;
let pg_params: Vec<PgParam> = params.iter().map(json_to_pg_param).collect();
let param_refs: Vec<&(dyn ToSql + Sync)> = pg_params.iter()
.map(|p| p as &(dyn ToSql + Sync))
.collect();
let rows = if param_refs.is_empty() {
client.query(&query, &[]).await.map_err(|e| e.to_string())?
} else {
client.query(&query, ¶m_refs).await.map_err(|e| e.to_string())?
};
let fields = if !rows.is_empty() {
rows[0].columns().iter().map(|col| {
FieldMetadata {
name: col.name().to_string(),
data_type: col.type_().name().to_string(),
nullable: true,
max_length: None,
}
}).collect()
} else {
vec![]
};
let json_rows = rows_to_json(&rows);
Ok(QueryResult {
rows: json_rows,
fields,
query_plan: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_validate_query_permissions_select() {
let permissions = vec![Permission::Select];
let role = AccountRole::Developer;
let result = validate_query_permissions("SELECT * FROM users", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_insert_denied() {
let permissions = vec![Permission::Select];
let role = AccountRole::Developer;
let result = validate_query_permissions("INSERT INTO users VALUES (1)", &permissions, &role);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Missing permission"));
}
#[test]
fn test_validate_query_permissions_owner_bypass() {
let permissions = vec![];
let role = AccountRole::Superuser;
let result = validate_query_permissions("DROP TABLE users", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_all_permission() {
let permissions = vec![Permission::All];
let role = AccountRole::Developer;
let result = validate_query_permissions("DELETE FROM users", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_update() {
let permissions = vec![Permission::Update, Permission::Select];
let role = AccountRole::Developer;
let result = validate_query_permissions("UPDATE users SET name = 'test'", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_delete() {
let permissions = vec![Permission::Delete];
let role = AccountRole::Developer;
let result = validate_query_permissions("DELETE FROM users WHERE id = 1", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_create() {
let permissions = vec![Permission::Create];
let role = AccountRole::Developer;
let result = validate_query_permissions("CREATE TABLE test (id INT)", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_drop() {
let permissions = vec![Permission::Drop];
let role = AccountRole::Developer;
let result = validate_query_permissions("DROP TABLE test", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_truncate() {
let permissions = vec![Permission::Truncate];
let role = AccountRole::Developer;
let result = validate_query_permissions("TRUNCATE TABLE test", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_unsupported() {
let permissions = vec![Permission::Select];
let role = AccountRole::Developer;
let result = validate_query_permissions("ALTER TABLE users ADD COLUMN x INT", &permissions, &role);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Unsupported"));
}
#[test]
fn test_validate_query_permissions_case_insensitive() {
let permissions = vec![Permission::Select];
let role = AccountRole::Developer;
let result = validate_query_permissions("select * from users", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_validate_query_permissions_with_cte() {
let permissions = vec![Permission::Select];
let role = AccountRole::Developer;
let result = validate_query_permissions("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", &permissions, &role);
assert!(result.is_ok());
}
#[test]
fn test_json_to_pg_param_null() {
assert!(matches!(json_to_pg_param(&Value::Null), PgParam::Null));
}
#[test]
fn test_json_to_pg_param_bool() {
assert!(matches!(json_to_pg_param(&Value::Bool(true)), PgParam::Bool(true)));
}
#[test]
fn test_json_to_pg_param_number() {
assert!(matches!(json_to_pg_param(&json!(42)), PgParam::I64(42)));
assert!(matches!(json_to_pg_param(&json!(3.14)), PgParam::F64(_)));
}
#[test]
fn test_json_to_pg_param_string() {
assert!(matches!(json_to_pg_param(&json!("hello")), PgParam::String(ref s) if s == "hello"));
}
#[test]
fn test_numeric_binary_to_string_zero() {
let raw = vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let result = numeric_binary_to_string(&raw);
assert_eq!(result, Some("0".to_string()));
}
#[test]
fn test_numeric_binary_to_string_nan() {
let raw = vec![0x00, 0x00, 0x00, 0x00, 0xC0, 0x00, 0x00, 0x00];
let result = numeric_binary_to_string(&raw);
assert_eq!(result, Some("NaN".to_string()));
}
#[test]
fn test_numeric_binary_to_string_invalid_short() {
let raw = vec![0x00, 0x00, 0x00];
let result = numeric_binary_to_string(&raw);
assert_eq!(result, None);
}
}