use http::StatusCode;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("Invalid path: {0}")]
InvalidPath(String),
#[error("Invalid query parameter: {0}")]
InvalidQueryParam(String),
#[error("Invalid header: {0}")]
InvalidHeader(&'static str),
#[error("Invalid request body: {0}")]
InvalidBody(String),
#[error("Unsupported HTTP method: {0}")]
UnsupportedMethod(String),
#[error("Unacceptable schema: {0}")]
UnacceptableSchema(String),
#[error("Unknown column: {0}")]
UnknownColumn(String),
#[error("Invalid range: {0}")]
InvalidRange(String),
#[error("Invalid media type: {0}")]
InvalidMediaType(String),
#[error("Missing required parameter: {0}")]
MissingParameter(String),
#[error("Ambiguous request: {0}")]
AmbiguousRequest(String),
#[error("Invalid JWT: {0}")]
InvalidJwt(String),
#[error("JWT expired")]
JwtExpired,
#[error("Missing authentication")]
MissingAuth,
#[error("Insufficient permissions: {0}")]
InsufficientPermissions(String),
#[error("Resource not found: {0}")]
NotFound(String),
#[error("Table not found: {0}")]
TableNotFound(String),
#[error("Function not found: {0}")]
FunctionNotFound(String),
#[error("Column not found: {0}")]
ColumnNotFound(String),
#[error("Relationship not found: {0}")]
RelationshipNotFound(String),
#[error("Schema cache not loaded")]
SchemaCacheNotLoaded,
#[error("Schema cache load failed: {0}")]
SchemaCacheLoadFailed(String),
#[error("Database error: {0}")]
Database(#[from] DatabaseError),
#[error("Connection pool error: {0}")]
ConnectionPool(String),
#[error("Internal error: {0}")]
Internal(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Invalid plan: {0}")]
InvalidPlan(String),
#[error("Embedding error: {0}")]
EmbeddingError(String),
}
impl Error {
pub fn status_code(&self) -> StatusCode {
match self {
Self::InvalidPath(_)
| Self::InvalidQueryParam(_)
| Self::InvalidHeader(_)
| Self::InvalidBody(_)
| Self::InvalidRange(_)
| Self::InvalidMediaType(_)
| Self::MissingParameter(_)
| Self::AmbiguousRequest(_)
| Self::UnknownColumn(_)
| Self::InvalidPlan(_)
| Self::EmbeddingError(_) => StatusCode::BAD_REQUEST,
Self::InvalidJwt(_) | Self::JwtExpired | Self::MissingAuth => StatusCode::UNAUTHORIZED,
Self::InsufficientPermissions(_) => StatusCode::FORBIDDEN,
Self::NotFound(_)
| Self::TableNotFound(_)
| Self::FunctionNotFound(_)
| Self::ColumnNotFound(_)
| Self::RelationshipNotFound(_) => StatusCode::NOT_FOUND,
Self::UnsupportedMethod(_) => StatusCode::METHOD_NOT_ALLOWED,
Self::UnacceptableSchema(_) => StatusCode::NOT_ACCEPTABLE,
Self::SchemaCacheNotLoaded
| Self::SchemaCacheLoadFailed(_)
| Self::ConnectionPool(_)
| Self::Internal(_)
| Self::Config(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Database(db_err) => db_err.status_code(),
}
}
pub fn code(&self) -> &'static str {
match self {
Self::InvalidPath(_) => "PGRST100",
Self::InvalidQueryParam(_) => "PGRST101",
Self::InvalidHeader(_) => "PGRST102",
Self::InvalidBody(_) => "PGRST103",
Self::UnsupportedMethod(_) => "PGRST104",
Self::UnacceptableSchema(_) => "PGRST105",
Self::UnknownColumn(_) => "PGRST106",
Self::InvalidRange(_) => "PGRST107",
Self::InvalidMediaType(_) => "PGRST108",
Self::MissingParameter(_) => "PGRST109",
Self::AmbiguousRequest(_) => "PGRST110",
Self::InvalidJwt(_) => "PGRST200",
Self::JwtExpired => "PGRST201",
Self::MissingAuth => "PGRST202",
Self::InsufficientPermissions(_) => "PGRST203",
Self::NotFound(_) => "PGRST300",
Self::TableNotFound(_) => "PGRST301",
Self::FunctionNotFound(_) => "PGRST302",
Self::ColumnNotFound(_) => "PGRST303",
Self::RelationshipNotFound(_) => "PGRST304",
Self::SchemaCacheNotLoaded => "PGRST400",
Self::SchemaCacheLoadFailed(_) => "PGRST401",
Self::Database(e) => e.code(),
Self::ConnectionPool(_) => "PGRST500",
Self::Internal(_) => "PGRST900",
Self::Config(_) => "PGRST901",
Self::InvalidPlan(_) => "PGRST600",
Self::EmbeddingError(_) => "PGRST601",
}
}
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"code": self.code(),
"message": self.to_string(),
"details": self.details(),
"hint": self.hint(),
})
}
fn details(&self) -> Option<String> {
match self {
Self::Database(db_err) => db_err.details.clone(),
_ => None,
}
}
fn hint(&self) -> Option<String> {
match self {
Self::InvalidJwt(_) => Some("Check that the JWT is properly signed and not expired".into()),
Self::MissingAuth => Some("Provide a valid JWT in the Authorization header".into()),
Self::TableNotFound(_) => Some("Check the table name and schema".into()),
Self::UnknownColumn(_) => Some("Check column names against the table schema".into()),
Self::Database(db_err) => db_err.hint.clone(),
_ => None,
}
}
}
#[derive(Error, Debug)]
#[error("Database error [{code}]: {message}")]
pub struct DatabaseError {
pub code: String,
pub message: String,
pub details: Option<String>,
pub hint: Option<String>,
pub constraint: Option<String>,
pub table: Option<String>,
pub column: Option<String>,
}
impl DatabaseError {
pub fn status_code(&self) -> StatusCode {
match self.code.as_str() {
c if c.starts_with("23") => StatusCode::CONFLICT,
c if c.starts_with("42") => StatusCode::BAD_REQUEST,
c if c.starts_with("28") => StatusCode::FORBIDDEN,
c if c.starts_with("40") => StatusCode::CONFLICT,
c if c.starts_with("53") => StatusCode::SERVICE_UNAVAILABLE,
c if c.starts_with("54") => StatusCode::PAYLOAD_TOO_LARGE,
"P0001" => StatusCode::BAD_REQUEST, _ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
pub fn code(&self) -> &'static str {
match self.code.as_str() {
c if c.starts_with("23") => "PGRST503", c if c.starts_with("42") => "PGRST504", c if c.starts_with("28") => "PGRST505", _ => "PGRST500", }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_status_codes() {
assert_eq!(
Error::InvalidQueryParam("test".into()).status_code(),
StatusCode::BAD_REQUEST
);
assert_eq!(Error::MissingAuth.status_code(), StatusCode::UNAUTHORIZED);
assert_eq!(
Error::TableNotFound("users".into()).status_code(),
StatusCode::NOT_FOUND
);
assert_eq!(
Error::UnsupportedMethod("TRACE".into()).status_code(),
StatusCode::METHOD_NOT_ALLOWED
);
}
#[test]
fn test_error_codes() {
assert_eq!(Error::InvalidQueryParam("test".into()).code(), "PGRST101");
assert_eq!(Error::MissingAuth.code(), "PGRST202");
assert_eq!(Error::TableNotFound("users".into()).code(), "PGRST301");
}
#[test]
fn test_database_error_status() {
let constraint_error = DatabaseError {
code: "23505".into(), message: "Duplicate key".into(),
details: None,
hint: None,
constraint: Some("users_pkey".into()),
table: Some("users".into()),
column: None,
};
assert_eq!(constraint_error.status_code(), StatusCode::CONFLICT);
}
#[test]
fn test_error_to_json() {
let error = Error::InvalidQueryParam("bad filter".into());
let json = error.to_json();
assert_eq!(json["code"], "PGRST101");
assert!(json["message"].as_str().unwrap().contains("bad filter"));
}
}