Skip to main content

apimock_server/
response_handler.rs

1use http_body_util::{BodyExt, Empty, Full};
2use hyper::{
3    body::{Body, Bytes},
4    header::{
5        HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_ORIGIN,
6        CONTENT_LENGTH, ORIGIN, VARY,
7    },
8    http::response::Builder,
9    HeaderMap, StatusCode,
10};
11
12use std::{collections::HashMap, str::FromStr};
13
14use super::{
15    constant::DEFAULT_RESPONSE_HEADERS, response::error_response::internal_server_error_response,
16};
17use crate::types::BoxBody;
18
19#[derive(Clone)]
20pub enum BodyKind {
21    Empty,
22    Text(String),
23    Binary(Vec<u8>),
24}
25
26impl Default for BodyKind {
27    fn default() -> Self {
28        Self::Empty
29    }
30}
31
32#[derive(Default)]
33pub struct ResponseHandler {
34    response_builder: Builder,
35    status: Option<StatusCode>,
36    headers: HashMap<String, Option<String>>,
37    body_kind: BodyKind,
38}
39
40impl ResponseHandler {
41    /// build response
42    pub fn into_response(
43        self,
44        request_headers: &HeaderMap,
45    ) -> Result<hyper::Response<BoxBody>, hyper::http::Error> {
46        // - body + content-length
47        let response = match self.body_kind {
48            BodyKind::Text(s) => self
49                .response_builder
50                .body(Full::new(Bytes::from(s.to_owned())).boxed()),
51            BodyKind::Binary(b) => self
52                .response_builder
53                .body(Full::new(Bytes::from(b)).boxed()),
54            BodyKind::Empty => self.response_builder.body(Empty::new().boxed()),
55        };
56
57        let mut response = match response {
58            Ok(x) => x,
59            Err(err) => {
60                return internal_server_error_response(
61                    &format!("failed to create response: {}", err),
62                    request_headers,
63                )
64            }
65        };
66
67        // - http status code
68        *response.status_mut() = if let Some(status) = self.status {
69            status
70        } else {
71            StatusCode::OK
72        };
73
74        // - content-length
75        let content_length = response.body().size_hint().exact().unwrap_or_default();
76
77        let headers = response.headers_mut();
78
79        headers.insert(CONTENT_LENGTH, HeaderValue::from(content_length));
80
81        // - the other default headers
82        for (header_key, header_value) in default_response_headers(request_headers).iter() {
83            headers.insert(header_key, header_value.to_owned());
84        }
85
86        // - additional custom headers passed from caller
87        for (header_key, header_value) in self.headers {
88            let _ = match HeaderName::from_str(header_key.as_str()) {
89                Ok(header_key) => {
90                    match HeaderValue::from_str(header_value.unwrap_or_default().as_str()) {
91                        Ok(header_value) => {
92                            headers.insert(header_key, header_value);
93                        }
94                        Err(err) => {
95                            log::warn!(
96                                "failed to create header with the header value (header key = {}) ({})",
97                                header_key,
98                                err
99                            );
100                            headers.insert(header_key, HeaderValue::from_static(""));
101                        }
102                    }
103                }
104                Err(err) => log::warn!(
105                    "failed to create header with the header key: {} ({})",
106                    header_key,
107                    err
108                ),
109            };
110        }
111
112        Ok(response)
113    }
114
115    /// set http status code
116    pub fn with_status(mut self, status: &StatusCode) -> Self {
117        self.status = Some(status.to_owned());
118        self
119    }
120
121    /// add custom header
122    pub fn with_header(mut self, key: impl Into<String>, value: Option<impl Into<String>>) -> Self {
123        self.headers.insert(key.into(), value.map(|x| x.into()));
124        self
125    }
126
127    /// add custom headers
128    pub fn with_headers<K, V, I>(mut self, headers: I) -> Self
129    where
130        K: Into<String>,
131        V: Into<String>,
132        I: IntoIterator<Item = (K, Option<V>)>,
133    {
134        for (key, value) in headers {
135            self.headers.insert(key.into(), value.map(|x| x.into()));
136        }
137        self
138    }
139
140    /// add text to body
141    pub fn with_text(mut self, text: impl Into<String>, content_type: Option<&str>) -> Self {
142        let content_type = if let Some(content_type) = content_type {
143            content_type.into()
144        } else {
145            "text/plain; charset=utf-8".to_owned()
146        };
147        self.headers
148            .insert("content-type".into(), Some(content_type));
149
150        self.body_kind = BodyKind::Text(text.into());
151        self
152    }
153
154    /// treat response as json
155    pub fn with_json_body(mut self, body: impl Into<String>) -> Self {
156        self.headers
157            .insert("content-type".into(), Some("application/json".into()));
158        self.body_kind = BodyKind::Text(body.into());
159        self
160    }
161
162    /// treat response as json
163    pub fn with_binary_body(
164        mut self,
165        body: Vec<u8>,
166        content_type: Option<impl Into<String>>,
167    ) -> Self {
168        let content_type = if let Some(content_type) = content_type {
169            content_type.into()
170        } else {
171            "application/octet-stream".to_owned()
172        };
173        self.headers
174            .insert("content-type".into(), Some(content_type));
175
176        self.body_kind = BodyKind::Binary(body);
177
178        self
179    }
180}
181
182/// default response headers key-value pairs
183pub fn default_response_headers(request_headers: &HeaderMap) -> HeaderMap {
184    let mut header_map_src = Vec::with_capacity(DEFAULT_RESPONSE_HEADERS.len() + 1);
185
186    // resource
187    // - the other default headers but access-control-allow-origin, vary
188    header_map_src.extend(
189        DEFAULT_RESPONSE_HEADERS
190            .iter()
191            .map(|(k, v)| (k.to_string(), v.to_string())),
192    );
193
194    // - access-control-allow-origin, vary
195    let origin = if is_likely_authenticated_request(request_headers) {
196        match request_headers.get(ORIGIN) {
197            Some(x) => Some(x.to_owned()),
198            None => None,
199        }
200    } else {
201        None
202    };
203    let (origin, vary) = if let Some(origin) = origin {
204        header_map_src.push((
205            ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
206            "true".to_owned(),
207        ));
208
209        (origin, HeaderValue::from_static("Origin"))
210    } else {
211        (HeaderValue::from_static("*"), HeaderValue::from_static("*"))
212    };
213    header_map_src.push((
214        ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
215        origin.to_str().unwrap_or_default().to_owned(),
216    ));
217    header_map_src.push((
218        VARY.to_string(),
219        vary.to_str().unwrap_or_default().to_owned(),
220    ));
221
222    // header map
223    let ret = header_map_src.iter().fold(HeaderMap::new(),|mut ret,(header_key, header_value)| {
224        match HeaderName::from_str(header_key) {
225            Ok(header_key) => {
226                match HeaderValue::from_str(
227                    header_value.as_str(),
228                ) {
229                    Ok(header_value) => {
230                        ret.insert(header_key, header_value);
231                        ret
232                    },
233                    Err(err) => {
234                        log::warn!(
235                            "only header key set because failed to get header value: {} [key = {}] ({})",
236                            header_value.as_str(),
237                            header_key,
238                            err
239                        );
240                        ret.insert(header_key, HeaderValue::from_static(""));
241                        ret
242                    }
243                }
244            }
245            Err(err) => {
246                log::warn!("failed to set header key: {} ({})", header_key, err);
247                ret
248            }
249    }});
250
251    ret
252}
253
254/// guess if the request is likely related to authentication
255fn is_likely_authenticated_request(request_headers: &HeaderMap) -> bool {
256    request_headers.contains_key("cookie") || request_headers.contains_key("authorization")
257}