Skip to main content

forge_runtime/gateway/
response.rs

1use axum::Json;
2use axum::http::StatusCode;
3use axum::response::{IntoResponse, Response};
4use serde::{Deserialize, Serialize};
5
6/// RPC response for function calls.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct RpcResponse {
9    /// Whether the call succeeded.
10    pub success: bool,
11    /// Result data (if successful).
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub data: Option<serde_json::Value>,
14    /// Error information (if failed).
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub error: Option<RpcError>,
17    /// Request ID for tracing.
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub request_id: Option<String>,
20}
21
22impl RpcResponse {
23    /// Create a successful response.
24    pub fn success(data: serde_json::Value) -> Self {
25        Self {
26            success: true,
27            data: Some(data),
28            error: None,
29            request_id: None,
30        }
31    }
32
33    /// Create an error response.
34    pub fn error(error: RpcError) -> Self {
35        Self {
36            success: false,
37            data: None,
38            error: Some(error),
39            request_id: None,
40        }
41    }
42
43    /// Add request ID to the response.
44    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
45        self.request_id = Some(request_id.into());
46        self
47    }
48}
49
50impl IntoResponse for RpcResponse {
51    fn into_response(self) -> Response {
52        let status = if self.success {
53            StatusCode::OK
54        } else {
55            self.error
56                .as_ref()
57                .map(|e| e.status_code())
58                .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
59        };
60
61        (status, Json(self)).into_response()
62    }
63}
64
65/// RPC error information.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RpcError {
68    /// Error code.
69    pub code: String,
70    /// Human-readable error message.
71    pub message: String,
72    /// Additional error details.
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub details: Option<serde_json::Value>,
75}
76
77impl RpcError {
78    /// Create a new error.
79    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
80        Self {
81            code: code.into(),
82            message: message.into(),
83            details: None,
84        }
85    }
86
87    /// Create an error with details.
88    pub fn with_details(
89        code: impl Into<String>,
90        message: impl Into<String>,
91        details: serde_json::Value,
92    ) -> Self {
93        Self {
94            code: code.into(),
95            message: message.into(),
96            details: Some(details),
97        }
98    }
99
100    /// Get HTTP status code for this error.
101    pub fn status_code(&self) -> StatusCode {
102        match self.code.as_str() {
103            "NOT_FOUND" => StatusCode::NOT_FOUND,
104            "UNAUTHORIZED" => StatusCode::UNAUTHORIZED,
105            "FORBIDDEN" => StatusCode::FORBIDDEN,
106            "VALIDATION_ERROR" => StatusCode::BAD_REQUEST,
107            "INVALID_ARGUMENT" => StatusCode::BAD_REQUEST,
108            "TIMEOUT" => StatusCode::GATEWAY_TIMEOUT,
109            "RATE_LIMITED" => StatusCode::TOO_MANY_REQUESTS,
110            "JOB_CANCELLED" => StatusCode::CONFLICT,
111            _ => StatusCode::INTERNAL_SERVER_ERROR,
112        }
113    }
114
115    /// Create a not found error.
116    pub fn not_found(message: impl Into<String>) -> Self {
117        Self::new("NOT_FOUND", message)
118    }
119
120    /// Create an unauthorized error.
121    pub fn unauthorized(message: impl Into<String>) -> Self {
122        Self::new("UNAUTHORIZED", message)
123    }
124
125    /// Create a forbidden error.
126    pub fn forbidden(message: impl Into<String>) -> Self {
127        Self::new("FORBIDDEN", message)
128    }
129
130    /// Create a validation error.
131    pub fn validation(message: impl Into<String>) -> Self {
132        Self::new("VALIDATION_ERROR", message)
133    }
134
135    /// Create an internal error.
136    pub fn internal(message: impl Into<String>) -> Self {
137        Self::new("INTERNAL_ERROR", message)
138    }
139}
140
141impl From<forge_core::error::ForgeError> for RpcError {
142    fn from(err: forge_core::error::ForgeError) -> Self {
143        match err {
144            forge_core::error::ForgeError::NotFound(msg) => Self::not_found(msg),
145            forge_core::error::ForgeError::Unauthorized(msg) => Self::unauthorized(msg),
146            forge_core::error::ForgeError::Forbidden(msg) => Self::forbidden(msg),
147            forge_core::error::ForgeError::Validation(msg) => Self::validation(msg),
148            forge_core::error::ForgeError::InvalidArgument(msg) => {
149                Self::new("INVALID_ARGUMENT", msg)
150            }
151            forge_core::error::ForgeError::Timeout(msg) => Self::new("TIMEOUT", msg),
152            forge_core::error::ForgeError::JobCancelled(msg) => Self::new("JOB_CANCELLED", msg),
153            forge_core::error::ForgeError::Database(_)
154            | forge_core::error::ForgeError::Sql(_)
155            | forge_core::error::ForgeError::Internal(_)
156            | forge_core::error::ForgeError::Serialization(_)
157            | forge_core::error::ForgeError::Deserialization(_)
158            | forge_core::error::ForgeError::Function(_)
159            | forge_core::error::ForgeError::Config(_)
160            | forge_core::error::ForgeError::Io(_)
161            | forge_core::error::ForgeError::Cluster(_)
162            | forge_core::error::ForgeError::InvalidState(_)
163            | forge_core::error::ForgeError::WorkflowSuspended => {
164                Self::internal("Internal server error")
165            }
166            forge_core::error::ForgeError::Job(msg) => Self::internal(msg),
167            forge_core::error::ForgeError::RateLimitExceeded { retry_after, .. } => {
168                Self::with_details(
169                    "RATE_LIMITED",
170                    "Rate limit exceeded",
171                    serde_json::json!({
172                        "retry_after_secs": retry_after.as_secs(),
173                    }),
174                )
175            }
176        }
177    }
178}
179
180#[cfg(test)]
181#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_success_response() {
187        let resp = RpcResponse::success(serde_json::json!({"id": 1}));
188        assert!(resp.success);
189        assert!(resp.data.is_some());
190        assert!(resp.error.is_none());
191    }
192
193    #[test]
194    fn test_error_response() {
195        let resp = RpcResponse::error(RpcError::not_found("User not found"));
196        assert!(!resp.success);
197        assert!(resp.data.is_none());
198        assert!(resp.error.is_some());
199        assert_eq!(resp.error.as_ref().unwrap().code, "NOT_FOUND");
200    }
201
202    #[test]
203    fn test_error_status_codes() {
204        assert_eq!(RpcError::not_found("").status_code(), StatusCode::NOT_FOUND);
205        assert_eq!(
206            RpcError::unauthorized("").status_code(),
207            StatusCode::UNAUTHORIZED
208        );
209        assert_eq!(RpcError::forbidden("").status_code(), StatusCode::FORBIDDEN);
210        assert_eq!(
211            RpcError::validation("").status_code(),
212            StatusCode::BAD_REQUEST
213        );
214        assert_eq!(
215            RpcError::internal("").status_code(),
216            StatusCode::INTERNAL_SERVER_ERROR
217        );
218    }
219
220    #[test]
221    fn test_with_request_id() {
222        let resp = RpcResponse::success(serde_json::json!(null)).with_request_id("req-123");
223        assert_eq!(resp.request_id, Some("req-123".to_string()));
224    }
225}