Skip to main content

forge_runtime/gateway/
response.rs

1use axum::Json;
2use axum::http::StatusCode;
3use axum::response::{IntoResponse, Response};
4use serde::{Deserialize, Serialize};
5
6/// RPC response for function calls.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct RpcResponse {
9    /// Whether the call succeeded.
10    pub success: bool,
11    /// Result data (if successful).
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub data: Option<serde_json::Value>,
14    /// Error information (if failed).
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub error: Option<RpcError>,
17    /// Request ID for tracing.
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub request_id: Option<String>,
20}
21
22impl RpcResponse {
23    /// Create a successful response.
24    pub fn success(data: serde_json::Value) -> Self {
25        Self {
26            success: true,
27            data: Some(data),
28            error: None,
29            request_id: None,
30        }
31    }
32
33    /// Create an error response.
34    pub fn error(error: RpcError) -> Self {
35        Self {
36            success: false,
37            data: None,
38            error: Some(error),
39            request_id: None,
40        }
41    }
42
43    /// Add request ID to the response.
44    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
45        self.request_id = Some(request_id.into());
46        self
47    }
48}
49
50impl IntoResponse for RpcResponse {
51    fn into_response(self) -> Response {
52        let status = if self.success {
53            StatusCode::OK
54        } else {
55            self.error
56                .as_ref()
57                .map(|e| e.status_code())
58                .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
59        };
60
61        (status, Json(self)).into_response()
62    }
63}
64
65/// RPC error information.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RpcError {
68    /// Error code.
69    pub code: String,
70    /// Human-readable error message.
71    pub message: String,
72    /// Additional error details.
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub details: Option<serde_json::Value>,
75}
76
77impl RpcError {
78    /// Create a new error.
79    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
80        Self {
81            code: code.into(),
82            message: message.into(),
83            details: None,
84        }
85    }
86
87    /// Create an error with details.
88    pub fn with_details(
89        code: impl Into<String>,
90        message: impl Into<String>,
91        details: serde_json::Value,
92    ) -> Self {
93        Self {
94            code: code.into(),
95            message: message.into(),
96            details: Some(details),
97        }
98    }
99
100    /// Get HTTP status code for this error.
101    pub fn status_code(&self) -> StatusCode {
102        match self.code.as_str() {
103            "NOT_FOUND" => StatusCode::NOT_FOUND,
104            "UNAUTHORIZED" => StatusCode::UNAUTHORIZED,
105            "FORBIDDEN" => StatusCode::FORBIDDEN,
106            "VALIDATION_ERROR" => StatusCode::BAD_REQUEST,
107            "INVALID_ARGUMENT" => StatusCode::BAD_REQUEST,
108            "TIMEOUT" => StatusCode::GATEWAY_TIMEOUT,
109            "RATE_LIMITED" => StatusCode::TOO_MANY_REQUESTS,
110            "JOB_CANCELLED" => StatusCode::CONFLICT,
111            _ => StatusCode::INTERNAL_SERVER_ERROR,
112        }
113    }
114
115    /// Create a not found error.
116    pub fn not_found(message: impl Into<String>) -> Self {
117        Self::new("NOT_FOUND", message)
118    }
119
120    /// Create an unauthorized error.
121    pub fn unauthorized(message: impl Into<String>) -> Self {
122        Self::new("UNAUTHORIZED", message)
123    }
124
125    /// Create a forbidden error.
126    pub fn forbidden(message: impl Into<String>) -> Self {
127        Self::new("FORBIDDEN", message)
128    }
129
130    /// Create a validation error.
131    pub fn validation(message: impl Into<String>) -> Self {
132        Self::new("VALIDATION_ERROR", message)
133    }
134
135    /// Create an internal error.
136    pub fn internal(message: impl Into<String>) -> Self {
137        Self::new("INTERNAL_ERROR", message)
138    }
139}
140
141impl From<forge_core::error::ForgeError> for RpcError {
142    fn from(err: forge_core::error::ForgeError) -> Self {
143        match err {
144            forge_core::error::ForgeError::NotFound(msg) => Self::not_found(msg),
145            forge_core::error::ForgeError::Unauthorized(msg) => Self::unauthorized(msg),
146            forge_core::error::ForgeError::Forbidden(msg) => Self::forbidden(msg),
147            forge_core::error::ForgeError::Validation(msg) => Self::validation(msg),
148            forge_core::error::ForgeError::InvalidArgument(msg) => {
149                Self::new("INVALID_ARGUMENT", msg)
150            }
151            forge_core::error::ForgeError::Timeout(msg) => Self::new("TIMEOUT", msg),
152            forge_core::error::ForgeError::JobCancelled(msg) => Self::new("JOB_CANCELLED", msg),
153            forge_core::error::ForgeError::Deserialization(msg) => {
154                tracing::warn!(error = %msg, "Deserialization error in RPC handler");
155                Self::new("INVALID_ARGUMENT", "Invalid input format")
156            }
157            ref e @ forge_core::error::ForgeError::Database(_)
158            | ref e @ forge_core::error::ForgeError::Sql(_) => {
159                tracing::error!(error = %e, "Database error in RPC handler");
160                Self::internal("Internal server error")
161            }
162            ref e @ (forge_core::error::ForgeError::Internal(_)
163            | forge_core::error::ForgeError::Serialization(_)
164            | forge_core::error::ForgeError::Function(_)
165            | forge_core::error::ForgeError::Config(_)
166            | forge_core::error::ForgeError::Io(_)
167            | forge_core::error::ForgeError::Cluster(_)
168            | forge_core::error::ForgeError::InvalidState(_)
169            | forge_core::error::ForgeError::WorkflowSuspended) => {
170                tracing::error!(error = %e, "Internal error in RPC handler");
171                Self::internal("Internal server error")
172            }
173            forge_core::error::ForgeError::Job(msg) => {
174                tracing::error!(error = %msg, "Job error");
175                Self::internal("Internal server error")
176            }
177            forge_core::error::ForgeError::RateLimitExceeded { retry_after, .. } => {
178                Self::with_details(
179                    "RATE_LIMITED",
180                    "Rate limit exceeded",
181                    serde_json::json!({
182                        "retry_after_secs": retry_after.as_secs(),
183                    }),
184                )
185            }
186        }
187    }
188}
189
190#[cfg(test)]
191#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_success_response() {
197        let resp = RpcResponse::success(serde_json::json!({"id": 1}));
198        assert!(resp.success);
199        assert!(resp.data.is_some());
200        assert!(resp.error.is_none());
201    }
202
203    #[test]
204    fn test_error_response() {
205        let resp = RpcResponse::error(RpcError::not_found("User not found"));
206        assert!(!resp.success);
207        assert!(resp.data.is_none());
208        assert!(resp.error.is_some());
209        assert_eq!(resp.error.as_ref().unwrap().code, "NOT_FOUND");
210    }
211
212    #[test]
213    fn test_error_status_codes() {
214        assert_eq!(RpcError::not_found("").status_code(), StatusCode::NOT_FOUND);
215        assert_eq!(
216            RpcError::unauthorized("").status_code(),
217            StatusCode::UNAUTHORIZED
218        );
219        assert_eq!(RpcError::forbidden("").status_code(), StatusCode::FORBIDDEN);
220        assert_eq!(
221            RpcError::validation("").status_code(),
222            StatusCode::BAD_REQUEST
223        );
224        assert_eq!(
225            RpcError::internal("").status_code(),
226            StatusCode::INTERNAL_SERVER_ERROR
227        );
228    }
229
230    #[test]
231    fn test_with_request_id() {
232        let resp = RpcResponse::success(serde_json::json!(null)).with_request_id("req-123");
233        assert_eq!(resp.request_id, Some("req-123".to_string()));
234    }
235
236    // --- ForgeError -> RpcError conversion (HTTP boundary contract) ---
237
238    #[test]
239    fn forge_not_found_maps_to_not_found_404() {
240        let rpc: RpcError = forge_core::ForgeError::NotFound("user 42".into()).into();
241        assert_eq!(rpc.code, "NOT_FOUND");
242        assert_eq!(rpc.message, "user 42");
243        assert_eq!(rpc.status_code(), StatusCode::NOT_FOUND);
244    }
245
246    #[test]
247    fn forge_unauthorized_maps_to_401() {
248        let rpc: RpcError = forge_core::ForgeError::Unauthorized("expired".into()).into();
249        assert_eq!(rpc.code, "UNAUTHORIZED");
250        assert_eq!(rpc.status_code(), StatusCode::UNAUTHORIZED);
251    }
252
253    #[test]
254    fn forge_forbidden_maps_to_403() {
255        let rpc: RpcError = forge_core::ForgeError::Forbidden("admin only".into()).into();
256        assert_eq!(rpc.code, "FORBIDDEN");
257        assert_eq!(rpc.status_code(), StatusCode::FORBIDDEN);
258    }
259
260    #[test]
261    fn forge_validation_maps_to_400() {
262        let rpc: RpcError = forge_core::ForgeError::Validation("email required".into()).into();
263        assert_eq!(rpc.code, "VALIDATION_ERROR");
264        assert_eq!(rpc.status_code(), StatusCode::BAD_REQUEST);
265    }
266
267    #[test]
268    fn forge_invalid_argument_maps_to_400() {
269        let rpc: RpcError = forge_core::ForgeError::InvalidArgument("negative id".into()).into();
270        assert_eq!(rpc.code, "INVALID_ARGUMENT");
271        assert_eq!(rpc.status_code(), StatusCode::BAD_REQUEST);
272    }
273
274    #[test]
275    fn forge_timeout_maps_to_504() {
276        let rpc: RpcError = forge_core::ForgeError::Timeout("5s".into()).into();
277        assert_eq!(rpc.code, "TIMEOUT");
278        assert_eq!(rpc.status_code(), StatusCode::GATEWAY_TIMEOUT);
279    }
280
281    #[test]
282    fn forge_job_cancelled_maps_to_409() {
283        let rpc: RpcError = forge_core::ForgeError::JobCancelled("user request".into()).into();
284        assert_eq!(rpc.code, "JOB_CANCELLED");
285        assert_eq!(rpc.status_code(), StatusCode::CONFLICT);
286    }
287
288    #[test]
289    fn forge_rate_limit_maps_to_429_with_details() {
290        let rpc: RpcError = forge_core::ForgeError::RateLimitExceeded {
291            retry_after: std::time::Duration::from_secs(60),
292            limit: 100,
293            remaining: 0,
294        }
295        .into();
296        assert_eq!(rpc.code, "RATE_LIMITED");
297        assert_eq!(rpc.status_code(), StatusCode::TOO_MANY_REQUESTS);
298        assert!(rpc.details.is_some());
299        assert_eq!(rpc.details.unwrap()["retry_after_secs"], 60);
300    }
301
302    #[test]
303    fn forge_deserialization_hides_internal_details() {
304        let rpc: RpcError =
305            forge_core::ForgeError::Deserialization("missing field `id`".into()).into();
306        assert_eq!(rpc.code, "INVALID_ARGUMENT");
307        // Must NOT leak internal error details to clients
308        assert_eq!(rpc.message, "Invalid input format");
309        assert_eq!(rpc.status_code(), StatusCode::BAD_REQUEST);
310    }
311
312    #[test]
313    fn forge_database_error_hides_internals() {
314        let rpc: RpcError =
315            forge_core::ForgeError::Database("relation foo does not exist".into()).into();
316        assert_eq!(rpc.code, "INTERNAL_ERROR");
317        assert_eq!(rpc.message, "Internal server error");
318        assert_eq!(rpc.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
319    }
320
321    #[test]
322    fn forge_internal_variants_all_map_to_500() {
323        let internals: Vec<forge_core::ForgeError> = vec![
324            forge_core::ForgeError::Internal("oops".into()),
325            forge_core::ForgeError::Serialization("bad".into()),
326            forge_core::ForgeError::Function("handler".into()),
327            forge_core::ForgeError::Config("bad toml".into()),
328            forge_core::ForgeError::Cluster("split".into()),
329            forge_core::ForgeError::InvalidState("done".into()),
330            forge_core::ForgeError::Job("failed".into()),
331            forge_core::ForgeError::WorkflowSuspended,
332        ];
333
334        for err in internals {
335            let rpc: RpcError = err.into();
336            assert_eq!(
337                rpc.status_code(),
338                StatusCode::INTERNAL_SERVER_ERROR,
339                "Expected 500 for code: {}",
340                rpc.code
341            );
342            // Must never leak internal details
343            assert_eq!(rpc.message, "Internal server error");
344        }
345    }
346
347    #[test]
348    fn rpc_response_serialization_round_trip() {
349        let resp = RpcResponse::success(serde_json::json!({"users": [1, 2, 3]}))
350            .with_request_id("req-abc");
351        let json = serde_json::to_string(&resp).unwrap();
352        let deserialized: RpcResponse = serde_json::from_str(&json).unwrap();
353        assert!(deserialized.success);
354        assert_eq!(deserialized.request_id, Some("req-abc".to_string()));
355        assert_eq!(deserialized.data.unwrap()["users"][0], 1);
356    }
357
358    #[test]
359    fn rpc_error_with_details_serialization() {
360        let err = RpcError::with_details(
361            "CUSTOM_ERROR",
362            "something broke",
363            serde_json::json!({"field": "email"}),
364        );
365        let json = serde_json::to_string(&err).unwrap();
366        let deserialized: RpcError = serde_json::from_str(&json).unwrap();
367        assert_eq!(deserialized.code, "CUSTOM_ERROR");
368        assert_eq!(deserialized.details.unwrap()["field"], "email");
369    }
370}