gateway_runtime/
forward.rs1use crate::codec::Codec;
23use crate::errors::GatewayError;
24use crate::{BoxBody, GatewayRequest};
25use futures::{Stream, StreamExt};
26use http::{Response, StatusCode};
27use http_body::Frame;
28use http_body_util::{BodyExt, Full, StreamBody};
29use prost::Message;
30
31pub fn forward_response_message<T: Message + serde::Serialize, C: Codec>(
44 codec: &C,
45 msg: &T,
46 req: &GatewayRequest,
47) -> Result<Response<BoxBody>, GatewayError> {
48 let accept = req
49 .headers()
50 .get(http::header::ACCEPT)
51 .and_then(|h| h.to_str().ok());
52 let content_type = codec.encoder_content_type(accept);
53 let body_bytes = codec.encode(msg, Some(&content_type))?;
54 let body = BodyExt::boxed_unsync(Full::new(body_bytes).map_err(|never| match never {}));
55
56 Response::builder()
57 .status(StatusCode::OK)
58 .header("Content-Type", content_type)
59 .body(body)
60 .map_err(GatewayError::Http)
61}
62
63pub async fn forward_response_stream<S, T, C>(
79 codec: &C,
80 stream: S,
81 req: &GatewayRequest,
82) -> Result<Response<BoxBody>, GatewayError>
83where
84 S: Stream<Item = Result<T, tonic::Status>> + Send + 'static,
85 T: Message + serde::Serialize,
86 C: Codec + Clone + Send + Sync + 'static,
87{
88 let accept = req
89 .headers()
90 .get(http::header::ACCEPT)
91 .and_then(|h| h.to_str().ok());
92 let content_type = codec.encoder_content_type(accept);
93 let codec = codec.clone();
94 let content_type_clone = content_type.clone();
95
96 let stream = stream.map(move |result| match result {
97 Ok(msg) => match codec.encode(&msg, Some(&content_type_clone)) {
98 Ok(bytes) => Ok(Frame::data(bytes)),
99 Err(e) => Err(e),
100 },
101 Err(e) => Err(GatewayError::Upstream(e)),
102 });
103
104 let body = BodyExt::boxed_unsync(StreamBody::new(stream));
105
106 Response::builder()
107 .status(StatusCode::OK)
108 .header("Content-Type", content_type)
109 .body(body)
110 .map_err(GatewayError::Http)
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[derive(Clone)]
119 struct MockCodec;
120 impl Codec for MockCodec {
121 fn encode<T: Message + serde::Serialize>(
122 &self,
123 _item: &T,
124 _buf: Option<&str>,
125 ) -> Result<crate::bytes::Bytes, GatewayError> {
126 Ok(crate::bytes::Bytes::from_static(b"ok"))
127 }
128 fn decode<T: Message + Default + serde::de::DeserializeOwned>(
129 &self,
130 _buf: &[u8],
131 _content_type: Option<&str>,
132 ) -> Result<T, GatewayError> {
133 unimplemented!()
134 }
135 fn encoder_content_type(&self, _accept: Option<&str>) -> String {
136 "text/plain".to_string()
137 }
138 }
139
140 #[derive(serde::Serialize, prost::Message)]
141 struct Dummy {
142 #[prost(string, tag = "1")]
143 foo: String,
144 }
145
146 #[test]
147 fn test_forward_response_message() {
148 let codec = MockCodec;
149 let msg = Dummy::default();
150 let req = http::Request::builder().body(Vec::new()).unwrap();
151
152 let resp = forward_response_message(&codec, &msg, &req).unwrap();
153 assert_eq!(resp.status(), StatusCode::OK);
154 assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
155 }
157
158 #[tokio::test]
159 async fn test_forward_response_stream() {
160 let codec = MockCodec;
161 let stream = futures::stream::iter(vec![Ok(Dummy::default())]);
162 let req = http::Request::builder().body(Vec::new()).unwrap();
163
164 let resp = forward_response_stream(&codec, stream, &req).await.unwrap();
165 assert_eq!(resp.status(), StatusCode::OK);
166 assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
167 }
168
169 #[test]
170 fn test_forward_response_message_accept() {
171 let codec = MockCodec; let req = http::Request::builder()
176 .header("accept", "application/json")
177 .body(Vec::new())
178 .unwrap();
179 let resp = forward_response_message(&codec, &Dummy::default(), &req).unwrap();
180 assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
181 }
182
183 #[tokio::test]
184 async fn test_forward_response_stream_error() {
185 let codec = MockCodec;
186 let stream = futures::stream::iter(vec![Err::<Dummy, tonic::Status>(
187 tonic::Status::internal("fail"),
188 )]);
189 let req = http::Request::builder().body(Vec::new()).unwrap();
190
191 let resp = forward_response_stream(&codec, stream, &req).await.unwrap();
199 assert_eq!(resp.status(), StatusCode::OK);
200
201 let body = resp.into_body();
203 let collected = body.collect().await;
204 assert!(collected.is_err());
207 }
208}