thruster/context/
basic_context.rs

1use bytes::Bytes;
2use serde::Serialize;
3use serde_json::to_vec;
4use std::collections::HashMap;
5use std::str;
6
7use crate::core::context::Context;
8use crate::core::request::Request;
9use crate::core::response::Response;
10
11use crate::middleware::cookies::{Cookie, CookieOptions, HasCookies, SameSite};
12use crate::middleware::query_params::HasQueryParams;
13
14pub fn generate_context<S>(request: Request, _state: &S, _path: &str) -> BasicContext {
15    let mut ctx = BasicContext::new();
16    ctx.params = request.params().clone();
17    ctx.request = request;
18
19    ctx
20}
21
22#[derive(Default)]
23pub struct BasicContext {
24    response: Response,
25    pub cookies: Vec<Cookie>,
26    pub params: Option<HashMap<String, String>>,
27    pub query_params: Option<HashMap<String, String>>,
28    pub request: Request,
29    pub status: u32,
30    pub headers: HashMap<String, String>,
31}
32
33impl Clone for BasicContext {
34    fn clone(&self) -> Self {
35        warn!("You should not be calling this method -- it just returns a default context.");
36        BasicContext::new()
37    }
38}
39
40impl BasicContext {
41    pub fn new() -> BasicContext {
42        let mut ctx = BasicContext {
43            response: Response::new(),
44            cookies: Vec::new(),
45            params: None,
46            query_params: None,
47            request: Request::new(),
48            headers: HashMap::new(),
49            status: 200,
50        };
51
52        ctx.set("Server", "Thruster");
53
54        ctx
55    }
56
57    ///
58    /// Set the body as a string
59    ///
60    pub fn body(&mut self, body_string: &str) {
61        self.response
62            .body_bytes_from_vec(body_string.as_bytes().to_vec());
63    }
64
65    ///
66    /// Set Generic Serialize as body and sets header Content-Type to application/json
67    ///
68    pub fn json<T: Serialize>(&mut self, body: T) {
69        self.set("Content-Type", "application/json");
70        self.response.body_bytes_from_vec(to_vec(&body).unwrap());
71    }
72
73    ///
74    /// Set the response status code
75    ///
76    pub fn set_status(&mut self, code: u32) -> &mut BasicContext {
77        self.status = code;
78        self
79    }
80
81    pub fn body_string(&self) -> String {
82        str::from_utf8(&self.response.response)
83            .unwrap_or("")
84            .to_owned()
85    }
86
87    ///
88    /// Set the response `Content-Type`. A shortcode for
89    ///
90    /// ```ignore
91    /// ctx.set("Content-Type", "some-val");
92    /// ```
93    ///
94    pub fn content_type(&mut self, c_type: &str) {
95        self.set("Content-Type", c_type);
96    }
97
98    ///
99    /// Set up a redirect, will default to 302, but can be changed after
100    /// the fact.
101    ///
102    /// ```ignore
103    /// ctx.set("Location", "/some-path");
104    /// ctx.status(302);
105    /// ```
106    ///
107    pub fn redirect(&mut self, destination: &str) {
108        self.status(302);
109
110        self.set("Location", destination);
111    }
112
113    ///
114    /// Sets a cookie on the response
115    ///
116    pub fn cookie(&mut self, name: &str, value: &str, options: &CookieOptions) {
117        let cookie_value = match self.headers.get("Set-Cookie") {
118            Some(val) => format!("{}, {}", val, self.cookify_options(name, value, &options)),
119            None => self.cookify_options(name, value, &options),
120        };
121
122        self.set("Set-Cookie", &cookie_value);
123    }
124
125    fn cookify_options(&self, name: &str, value: &str, options: &CookieOptions) -> String {
126        let mut pieces = vec![format!("Path={}", options.path)];
127
128        if options.expires > 0 {
129            pieces.push(format!("Expires={}", options.expires));
130        }
131
132        if options.max_age > 0 {
133            pieces.push(format!("Max-Age={}", options.max_age));
134        }
135
136        if !options.domain.is_empty() {
137            pieces.push(format!("Domain={}", options.domain));
138        }
139
140        if options.secure {
141            pieces.push("Secure".to_owned());
142        }
143
144        if options.http_only {
145            pieces.push("HttpOnly".to_owned());
146        }
147
148        if let Some(ref same_site) = options.same_site {
149            match same_site {
150                SameSite::Strict => pieces.push("SameSite=Strict".to_owned()),
151                SameSite::Lax => pieces.push("SameSite=Lax".to_owned()),
152            };
153        }
154
155        format!("{}={}; {}", name, value, pieces.join("; "))
156    }
157}
158
159impl Context for BasicContext {
160    type Response = Response;
161
162    fn get_response(mut self) -> Self::Response {
163        self.response.status_code(self.status, "");
164
165        for (key, value) in self.headers {
166            self.response.header(&key, &value);
167        }
168
169        self.response
170    }
171
172    fn set_body(&mut self, body: Vec<u8>) {
173        self.response.body_bytes_from_vec(body);
174    }
175
176    fn set_body_bytes(&mut self, body_bytes: Bytes) {
177        self.response.body_bytes(&body_bytes);
178    }
179
180    fn route(&self) -> &str {
181        self.request.path()
182    }
183
184    fn set(&mut self, key: &str, value: &str) {
185        self.headers.insert(key.to_owned(), value.to_owned());
186    }
187
188    fn remove(&mut self, key: &str) {
189        self.headers.remove(key);
190    }
191
192    fn status(&mut self, code: u16) {
193        self.set_status(code as u32);
194    }
195}
196
197impl HasQueryParams for BasicContext {
198    fn set_query_params(&mut self, query_params: HashMap<String, String>) {
199        self.query_params = Some(query_params);
200    }
201}
202
203impl HasCookies for BasicContext {
204    fn set_cookies(&mut self, cookies: Vec<Cookie>) {
205        self.cookies = cookies;
206    }
207
208    fn get_cookies(&self) -> Vec<String> {
209        self.request
210            .headers()
211            .get("cookie")
212            .cloned()
213            .unwrap_or_else(std::vec::Vec::new)
214    }
215
216    fn get_header(&self, key: &str) -> Vec<String> {
217        match self.request.headers().get(key) {
218            Some(v) => v.clone(),
219            None => vec![],
220        }
221    }
222}