Skip to main content

gateway_runtime/
forward.rs

1//! # Forward
2//!
3//! ## Purpose
4//! Provides utilities for constructing HTTP responses from gRPC messages. This module
5//! bridges the gap between the gRPC client's output (Protobuf messages or streams) and
6//! the HTTP server's response format (bytes).
7//!
8//! ## Scope
9//! This module exposes functions to:
10//! -   `forward_response_message`: Convert a single gRPC message into a unary HTTP response.
11//! -   `forward_response_stream`: Convert a stream of gRPC messages into a single aggregated HTTP response.
12//!
13//! ## Position in the Architecture
14//! These functions are called by the code generated by `gateway-codegen` at the end of
15//! a request handler. They take the result from the gRPC client, encode it using the
16//! specified `Codec`, and wrap it in an `http::Response`.
17//!
18//! ## Design Constraints
19//! -   **Codec Agnostic**: Functions are generic over `Codec`, allowing them to work with any supported wire format.
20//! -   **Buffered Responses**: Currently aggregates streaming responses into a single byte vector to simplify the return type.
21
22use crate::codec::Codec;
23use crate::errors::GatewayError;
24use crate::{BoxBody, GatewayRequest};
25use futures::{Stream, StreamExt};
26use http::{Response, StatusCode};
27use http_body::Frame;
28use http_body_util::{BodyExt, Full, StreamBody};
29use prost::Message;
30
31/// Constructs an HTTP response from a single gRPC message.
32///
33/// This function encodes the message using the provided codec and builds a standard
34/// HTTP 200 OK response with the appropriate `Content-Type`.
35///
36/// # Parameters
37/// *   `codec`: The codec used to encode the message.
38/// *   `msg`: The gRPC message to send.
39/// *   `req`: The incoming request (used to determine `Accept` header).
40///
41/// # Returns
42/// A `Result` containing the HTTP response with the encoded body, or a `GatewayError` if encoding fails.
43pub fn forward_response_message<T: Message + serde::Serialize, C: Codec>(
44    codec: &C,
45    msg: &T,
46    req: &GatewayRequest,
47) -> Result<Response<BoxBody>, GatewayError> {
48    let accept = req
49        .headers()
50        .get(http::header::ACCEPT)
51        .and_then(|h| h.to_str().ok());
52    let content_type = codec.encoder_content_type(accept);
53    let body_bytes = codec.encode(msg, Some(&content_type))?;
54    let body = BodyExt::boxed_unsync(Full::new(body_bytes).map_err(|never| match never {}));
55
56    Response::builder()
57        .status(StatusCode::OK)
58        .header("Content-Type", content_type)
59        .body(body)
60        .map_err(GatewayError::Http)
61}
62
63/// Constructs an HTTP response from a gRPC stream.
64///
65/// This function collects all items from the stream, encodes them, and aggregates them
66/// into a single byte vector.
67///
68/// **Note**: This implementation buffers the entire stream in memory. This is a simplification
69/// to ensure compatibility with `tower::Service` return types that expect a concrete body type.
70///
71/// # Parameters
72/// *   `codec`: The codec used to encode stream items.
73/// *   `stream`: The incoming gRPC stream.
74/// *   `req`: The incoming request (used to determine `Accept` header).
75///
76/// # Returns
77/// A `Result` containing the HTTP response with the aggregated body, or a `GatewayError` if processing fails.
78pub async fn forward_response_stream<S, T, C>(
79    codec: &C,
80    stream: S,
81    req: &GatewayRequest,
82) -> Result<Response<BoxBody>, GatewayError>
83where
84    S: Stream<Item = Result<T, tonic::Status>> + Send + 'static,
85    T: Message + serde::Serialize,
86    C: Codec + Clone + Send + Sync + 'static,
87{
88    let accept = req
89        .headers()
90        .get(http::header::ACCEPT)
91        .and_then(|h| h.to_str().ok());
92    let content_type = codec.encoder_content_type(accept);
93    let codec = codec.clone();
94    let content_type_clone = content_type.clone();
95
96    let stream = stream.map(move |result| match result {
97        Ok(msg) => match codec.encode(&msg, Some(&content_type_clone)) {
98            Ok(bytes) => Ok(Frame::data(bytes)),
99            Err(e) => Err(e),
100        },
101        Err(e) => Err(GatewayError::Upstream(e)),
102    });
103
104    let body = BodyExt::boxed_unsync(StreamBody::new(stream));
105
106    Response::builder()
107        .status(StatusCode::OK)
108        .header("Content-Type", content_type)
109        .body(body)
110        .map_err(GatewayError::Http)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    // Mock Codec
118    #[derive(Clone)]
119    struct MockCodec;
120    impl Codec for MockCodec {
121        fn encode<T: Message + serde::Serialize>(
122            &self,
123            _item: &T,
124            _buf: Option<&str>,
125        ) -> Result<crate::bytes::Bytes, GatewayError> {
126            Ok(crate::bytes::Bytes::from_static(b"ok"))
127        }
128        fn decode<T: Message + Default + serde::de::DeserializeOwned>(
129            &self,
130            _buf: &[u8],
131            _content_type: Option<&str>,
132        ) -> Result<T, GatewayError> {
133            unimplemented!()
134        }
135        fn encoder_content_type(&self, _accept: Option<&str>) -> String {
136            "text/plain".to_string()
137        }
138    }
139
140    #[derive(serde::Serialize, prost::Message)]
141    struct Dummy {
142        #[prost(string, tag = "1")]
143        foo: String,
144    }
145
146    #[test]
147    fn test_forward_response_message() {
148        let codec = MockCodec;
149        let msg = Dummy::default();
150        let req = http::Request::builder().body(Vec::new()).unwrap();
151
152        let resp = forward_response_message(&codec, &msg, &req).unwrap();
153        assert_eq!(resp.status(), StatusCode::OK);
154        assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
155        // body check skipped
156    }
157
158    #[tokio::test]
159    async fn test_forward_response_stream() {
160        let codec = MockCodec;
161        let stream = futures::stream::iter(vec![Ok(Dummy::default())]);
162        let req = http::Request::builder().body(Vec::new()).unwrap();
163
164        let resp = forward_response_stream(&codec, stream, &req).await.unwrap();
165        assert_eq!(resp.status(), StatusCode::OK);
166        assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
167    }
168
169    #[test]
170    fn test_forward_response_message_accept() {
171        let codec = MockCodec; // returns text/plain regardless, but we check arg passed.
172                               // MockCodec ignores accept.
173                               // But function logic: req -> accept header -> codec.encoder_content_type(accept).
174                               // So it passes it.
175        let req = http::Request::builder()
176            .header("accept", "application/json")
177            .body(Vec::new())
178            .unwrap();
179        let resp = forward_response_message(&codec, &Dummy::default(), &req).unwrap();
180        assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
181    }
182
183    #[tokio::test]
184    async fn test_forward_response_stream_error() {
185        let codec = MockCodec;
186        let stream = futures::stream::iter(vec![Err::<Dummy, tonic::Status>(
187            tonic::Status::internal("fail"),
188        )]);
189        let req = http::Request::builder().body(Vec::new()).unwrap();
190
191        // forward_response_stream maps stream items to Result<Frame, Error>.
192        // StreamBody iterates. If item is error, body stream yields error.
193        // It does NOT return Err from function immediately unless setup fails.
194        // The function returns Result<Response...>.
195        // The body is StreamBody.
196        // The stream inside yields Result.
197
198        let resp = forward_response_stream(&codec, stream, &req).await.unwrap();
199        assert_eq!(resp.status(), StatusCode::OK);
200
201        // To verify body error, we must collect body.
202        let body = resp.into_body();
203        let collected = body.collect().await;
204        // Should be error?
205        // stream yields Err(GatewayError::Upstream(status)).
206        assert!(collected.is_err());
207    }
208}