use fraiseql_error::{FraiseQLError, Result};
use serde::{Deserialize, Serialize};
use crate::{types::db_types::DatabaseType, utils::to_snake_case};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub enum OrderByFieldType {
#[default]
Text,
Integer,
Numeric,
Boolean,
DateTime,
Date,
Time,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OrderByClause {
pub field: String,
pub direction: OrderDirection,
#[serde(default)]
pub field_type: OrderByFieldType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub native_column: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum OrderDirection {
Asc,
Desc,
}
impl OrderDirection {
#[must_use]
pub const fn as_sql(self) -> &'static str {
match self {
Self::Asc => "ASC",
Self::Desc => "DESC",
}
}
}
impl OrderByClause {
#[must_use]
pub fn new(field: String, direction: OrderDirection) -> Self {
Self {
field,
direction,
field_type: OrderByFieldType::default(),
native_column: None,
}
}
#[must_use]
pub fn storage_key(&self) -> String {
to_snake_case(&self.field)
}
pub fn validate_field_name(field: &str) -> Result<()> {
let mut chars = field.chars();
let first_ok = chars.next().is_some_and(|c| c.is_ascii_alphabetic() || c == '_');
let rest_ok = chars.all(|c| c.is_ascii_alphanumeric() || c == '_');
if first_ok && rest_ok {
Ok(())
} else {
Err(FraiseQLError::Validation {
message: format!(
"orderBy field name '{field}' contains invalid characters; \
only [_A-Za-z][_0-9A-Za-z]* is allowed"
),
path: None,
})
}
}
pub fn from_graphql_json(value: &serde_json::Value) -> Result<Vec<Self>> {
if let Some(obj) = value.as_object() {
obj.iter()
.map(|(field, dir_val)| {
let dir_str = dir_val.as_str().ok_or_else(|| FraiseQLError::Validation {
message: format!("orderBy direction for '{field}' must be a string"),
path: None,
})?;
let direction = match dir_str.to_ascii_uppercase().as_str() {
"ASC" => OrderDirection::Asc,
"DESC" => OrderDirection::Desc,
_ => {
return Err(FraiseQLError::Validation {
message: format!(
"orderBy direction '{dir_str}' must be ASC or DESC"
),
path: None,
});
},
};
Self::validate_field_name(field)?;
Ok(Self::new(field.clone(), direction))
})
.collect()
} else if let Some(arr) = value.as_array() {
arr.iter()
.map(|item| {
let obj = item.as_object().ok_or_else(|| FraiseQLError::Validation {
message: "orderBy array items must be objects".to_string(),
path: None,
})?;
let field = obj
.get("field")
.and_then(|v| v.as_str())
.ok_or_else(|| FraiseQLError::Validation {
message: "orderBy item missing 'field' string".to_string(),
path: None,
})?
.to_string();
let dir_str = obj.get("direction").and_then(|v| v.as_str()).unwrap_or("ASC");
let direction = match dir_str.to_ascii_uppercase().as_str() {
"ASC" => OrderDirection::Asc,
"DESC" => OrderDirection::Desc,
_ => {
return Err(FraiseQLError::Validation {
message: format!(
"orderBy direction '{dir_str}' must be ASC or DESC"
),
path: None,
});
},
};
Self::validate_field_name(&field)?;
Ok(Self::new(field, direction))
})
.collect()
} else {
Err(FraiseQLError::Validation {
message: "orderBy must be an object or array".to_string(),
path: None,
})
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SqlProjectionHint {
pub database: DatabaseType,
pub projection_template: String,
pub estimated_reduction_percent: u32,
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_storage_key_camel_to_snake() {
let clause = OrderByClause::new("createdAt".into(), OrderDirection::Asc);
assert_eq!(clause.storage_key(), "created_at");
}
#[test]
fn test_storage_key_multi_word() {
let clause = OrderByClause::new("firstName".into(), OrderDirection::Desc);
assert_eq!(clause.storage_key(), "first_name");
}
#[test]
fn test_storage_key_already_snake() {
let clause = OrderByClause::new("id".into(), OrderDirection::Asc);
assert_eq!(clause.storage_key(), "id");
}
#[test]
fn test_storage_key_long_camel() {
let clause = OrderByClause::new("updatedAtTimestamp".into(), OrderDirection::Asc);
assert_eq!(clause.storage_key(), "updated_at_timestamp");
}
#[test]
fn test_order_direction_as_sql() {
assert_eq!(OrderDirection::Asc.as_sql(), "ASC");
assert_eq!(OrderDirection::Desc.as_sql(), "DESC");
}
#[test]
fn test_validate_field_name_accepts_valid() {
assert!(OrderByClause::validate_field_name("id").is_ok());
assert!(OrderByClause::validate_field_name("createdAt").is_ok());
assert!(OrderByClause::validate_field_name("_private").is_ok());
assert!(OrderByClause::validate_field_name("field123").is_ok());
}
#[test]
fn test_validate_field_name_rejects_injection() {
assert!(OrderByClause::validate_field_name("'; DROP TABLE users; --").is_err());
assert!(OrderByClause::validate_field_name("field name").is_err());
assert!(OrderByClause::validate_field_name("123start").is_err());
assert!(OrderByClause::validate_field_name("").is_err());
}
}