rok-core 0.6.0

Core primitives for the rok ecosystem — errors, crypto, i18n, config, DI, and more
Documentation
//! RFC 9457 — Problem Details for HTTP APIs.
//!
//! # Example
//!
//! ```rust
//! use rok_core::Problem;
//!
//! let p = Problem::not_found("User 42 does not exist.");
//! assert_eq!(p.status, 404);
//!
//! let custom = Problem::new("https://example.com/errors/quota-exceeded", "Quota Exceeded", 429)
//!     .detail("You have exceeded your daily upload quota.")
//!     .extend("limit", 100)
//!     .extend("remaining", 0);
//! assert_eq!(custom.status, 429);
//! ```

use std::fmt;

use serde::{Serialize, Serializer};
use serde_json::{Map, Value};

// ── Core struct ───────────────────────────────────────────────────────────────

/// RFC 9457 problem details response.
#[derive(Debug, Clone)]
pub struct Problem {
    pub type_uri: String,
    pub title: String,
    pub status: u16,
    pub detail: Option<String>,
    pub instance: Option<String>,
    pub extensions: Map<String, Value>,
}

impl Serialize for Problem {
    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
        use serde::ser::SerializeMap;

        let extras = self.extensions.len()
            + self.detail.is_some() as usize
            + self.instance.is_some() as usize;
        let mut map = s.serialize_map(Some(3 + extras))?;

        map.serialize_entry("type", &self.type_uri)?;
        map.serialize_entry("title", &self.title)?;
        map.serialize_entry("status", &self.status)?;

        if let Some(d) = &self.detail {
            map.serialize_entry("detail", d)?;
        }
        if let Some(i) = &self.instance {
            map.serialize_entry("instance", i)?;
        }
        for (k, v) in &self.extensions {
            map.serialize_entry(k, v)?;
        }

        map.end()
    }
}

impl fmt::Display for Problem {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{} {}{}", self.status, self.title, self.type_uri)?;
        if let Some(d) = &self.detail {
            write!(f, ": {d}")?;
        }
        Ok(())
    }
}

// ── Builder ───────────────────────────────────────────────────────────────────

impl Problem {
    const BASE: &'static str = "https://docs.rok.rs/errors";

    /// Create a problem with a fully-qualified `type_uri`.
    pub fn new(type_uri: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
        Self {
            type_uri: type_uri.into(),
            title: title.into(),
            status,
            detail: None,
            instance: None,
            extensions: Map::new(),
        }
    }

    /// Human-readable explanation of the specific occurrence.
    pub fn detail(mut self, detail: impl Into<String>) -> Self {
        self.detail = Some(detail.into());
        self
    }

    /// URI reference that identifies the specific occurrence.
    pub fn instance(mut self, instance: impl Into<String>) -> Self {
        self.instance = Some(instance.into());
        self
    }

    /// Attach a custom extension member (serialisable value).
    pub fn extend(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
        self.extensions.insert(key.into(), value.into());
        self
    }

    /// Override the title (fluent).
    pub fn title(mut self, title: impl Into<String>) -> Self {
        self.title = title.into();
        self
    }

    // ── Predefined constructors ───────────────────────────────────────────────

    /// 404 Not Found.
    pub fn not_found(detail: impl Into<String>) -> Self {
        Self::new(format!("{}/not-found", Self::BASE), "Not Found", 404).detail(detail)
    }

    /// 400 Bad Request.
    pub fn bad_request(detail: impl Into<String>) -> Self {
        Self::new(format!("{}/bad-request", Self::BASE), "Bad Request", 400).detail(detail)
    }

    /// 401 Unauthorized.
    pub fn unauthorized(detail: impl Into<String>) -> Self {
        Self::new(format!("{}/unauthorized", Self::BASE), "Unauthorized", 401).detail(detail)
    }

    /// 403 Forbidden.
    pub fn forbidden(detail: impl Into<String>) -> Self {
        Self::new(format!("{}/forbidden", Self::BASE), "Forbidden", 403).detail(detail)
    }

    /// 409 Conflict.
    pub fn conflict(detail: impl Into<String>) -> Self {
        Self::new(format!("{}/conflict", Self::BASE), "Conflict", 409).detail(detail)
    }

    /// 422 Unprocessable Entity.
    pub fn unprocessable(detail: impl Into<String>) -> Self {
        Self::new(
            format!("{}/unprocessable-entity", Self::BASE),
            "Unprocessable Entity",
            422,
        )
        .detail(detail)
    }

    /// 429 Too Many Requests.
    pub fn too_many_requests(detail: impl Into<String>) -> Self {
        Self::new(
            format!("{}/too-many-requests", Self::BASE),
            "Too Many Requests",
            429,
        )
        .detail(detail)
    }

    /// 500 Internal Server Error.
    pub fn internal(detail: impl Into<String>) -> Self {
        Self::new(
            format!("{}/internal-server-error", Self::BASE),
            "Internal Server Error",
            500,
        )
        .detail(detail)
    }

    /// 503 Service Unavailable.
    pub fn service_unavailable(detail: impl Into<String>) -> Self {
        Self::new(
            format!("{}/service-unavailable", Self::BASE),
            "Service Unavailable",
            503,
        )
        .detail(detail)
    }

    // ── Validation integration ────────────────────────────────────────────────

    /// Build an unprocessable-entity problem from a map of field → errors.
    pub fn from_validation(errors: std::collections::HashMap<String, Vec<String>>) -> Self {
        let field_errors: Value = errors
            .into_iter()
            .map(|(k, v)| (k, Value::Array(v.into_iter().map(Value::String).collect())))
            .collect::<Map<_, _>>()
            .into();

        Self::unprocessable("One or more fields failed validation.").extend("errors", field_errors)
    }

    // ── Serialisation helpers ─────────────────────────────────────────────────

    /// Serialize to a JSON byte vec.
    ///
    /// On serialization failure, logs the error and returns `b"{}"` as a
    /// minimal RFC 9457-compliant fallback.
    pub fn to_json_bytes(&self) -> Vec<u8> {
        match serde_json::to_vec(self) {
            Ok(bytes) => bytes,
            Err(_e) => {
                #[cfg(feature = "app")]
                tracing::error!(error = %_e, "Problem serialization failed");
                b"{}"[..].to_vec()
            }
        }
    }
}

// ── axum IntoResponse ─────────────────────────────────────────────────────────

#[cfg(feature = "axum")]
mod axum_impl {
    use super::Problem;
    use axum::{
        body::Body,
        response::{IntoResponse, Response},
    };
    use http::{header, StatusCode};

    impl IntoResponse for Problem {
        fn into_response(self) -> Response {
            let status =
                StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
            let body = self.to_json_bytes();

            Response::builder()
                .status(status)
                .header(header::CONTENT_TYPE, "application/problem+json")
                .body(Body::from(body))
                .unwrap_or_else(|_| {
                    Response::builder()
                        .status(StatusCode::INTERNAL_SERVER_ERROR)
                        .body(Body::empty())
                        .unwrap()
                })
        }
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    fn base() -> &'static str {
        "https://docs.rok.rs/errors"
    }

    #[test]
    fn not_found_shape() {
        let p = Problem::not_found("User 42 does not exist.");
        assert_eq!(p.status, 404);
        assert_eq!(p.title, "Not Found");
        assert_eq!(p.type_uri, format!("{}/not-found", base()));
        assert_eq!(p.detail.as_deref(), Some("User 42 does not exist."));
    }

    #[test]
    fn bad_request_shape() {
        let p = Problem::bad_request("Missing required field `email`.");
        assert_eq!(p.status, 400);
        assert_eq!(p.title, "Bad Request");
    }

    #[test]
    fn unauthorized_shape() {
        let p = Problem::unauthorized("No valid credentials were supplied.");
        assert_eq!(p.status, 401);
        assert_eq!(p.title, "Unauthorized");
    }

    #[test]
    fn forbidden_shape() {
        let p = Problem::forbidden("You may not delete another user's posts.");
        assert_eq!(p.status, 403);
        assert_eq!(p.title, "Forbidden");
    }

    #[test]
    fn conflict_shape() {
        let p = Problem::conflict("Email address already registered.");
        assert_eq!(p.status, 409);
        assert_eq!(p.title, "Conflict");
    }

    #[test]
    fn unprocessable_shape() {
        let p = Problem::unprocessable("Validation failed.");
        assert_eq!(p.status, 422);
        assert_eq!(p.title, "Unprocessable Entity");
    }

    #[test]
    fn too_many_requests_shape() {
        let p = Problem::too_many_requests("Slow down!");
        assert_eq!(p.status, 429);
        assert_eq!(p.title, "Too Many Requests");
    }

    #[test]
    fn internal_shape() {
        let p = Problem::internal("An unexpected error occurred.");
        assert_eq!(p.status, 500);
        assert_eq!(p.title, "Internal Server Error");
    }

    #[test]
    fn service_unavailable_shape() {
        let p = Problem::service_unavailable("The service is temporarily offline.");
        assert_eq!(p.status, 503);
        assert_eq!(p.title, "Service Unavailable");
    }

    #[test]
    fn custom_problem_with_extensions() {
        let p = Problem::new(
            "https://example.com/errors/quota-exceeded",
            "Quota Exceeded",
            429,
        )
        .detail("You have exceeded your daily upload quota.")
        .extend("limit", 100u64)
        .extend("remaining", 0u64);

        assert_eq!(p.status, 429);
        assert_eq!(p.extensions["limit"], serde_json::json!(100u64));
        assert_eq!(p.extensions["remaining"], serde_json::json!(0u64));
    }

    #[test]
    fn instance_field_roundtrip() {
        let p = Problem::not_found("Order not found.").instance("/orders/99");
        assert_eq!(p.instance.as_deref(), Some("/orders/99"));
    }

    #[test]
    fn title_override() {
        let p = Problem::not_found("Custom detail.").title("Resource Not Found");
        assert_eq!(p.title, "Resource Not Found");
    }

    #[test]
    fn serialize_mandatory_fields() {
        let p = Problem::not_found("test");
        let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
        assert!(v.get("type").is_some());
        assert!(v.get("title").is_some());
        assert!(v.get("status").is_some());
    }

    #[test]
    fn serialize_omits_none_fields() {
        let p = Problem::not_found("test");
        let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
        assert!(v.get("instance").is_none());
    }

    #[test]
    fn serialize_includes_optional_fields() {
        let p = Problem::not_found("test").instance("/foo/1");
        let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
        assert_eq!(v["instance"], "/foo/1");
    }

    #[test]
    fn display_format() {
        let p = Problem::not_found("User not found.");
        let s = p.to_string();
        assert!(s.contains("404"));
        assert!(s.contains("Not Found"));
        assert!(s.contains("User not found."));
    }

    #[test]
    fn from_validation_errors() {
        let mut errors = std::collections::HashMap::new();
        errors.insert("email".to_string(), vec!["is required".to_string()]);
        errors.insert("name".to_string(), vec!["is too short".to_string()]);

        let p = Problem::from_validation(errors);
        assert_eq!(p.status, 422);
        let errors_val = p.extensions.get("errors").expect("errors extension");
        assert!(errors_val.get("email").is_some());
        assert!(errors_val.get("name").is_some());
    }
}