1use futures_util::future::BoxFuture;
2use http::{HeaderValue, Request, Response};
3use secstr::SecStr;
4use std::{
5 sync::Arc,
6 task::{Context, Poll},
7};
8use tower_cookies::{
9 cookie::{Expiration, SameSite},
10 CookieManager, Cookies,
11};
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::{guard::GuardService, Error, Token};
16
17#[derive(Clone)]
18pub(crate) struct Config {
19 pub(crate) secret: SecStr,
20 pub(crate) cookie_name: String,
21 pub(crate) expires: Expiration,
22 pub(crate) header_name: String,
23 pub(crate) hsts: bool,
24 pub(crate) http_only: bool,
25 pub(crate) prefix: bool,
26 pub(crate) preload: bool,
27 pub(crate) same_site: SameSite,
28 pub(crate) secure: bool,
29}
30
31impl Config {
32 pub(crate) fn cookie_name(&self) -> String {
33 if self.prefix {
34 format!("__HOST-{}", self.cookie_name)
35 } else {
36 self.cookie_name.clone()
37 }
38 }
39}
40
41#[derive(Clone)]
47pub struct Surf {
48 pub(crate) config: Config,
49}
50
51impl Surf {
52 pub fn new(secret: impl Into<String>) -> Self {
54 Self {
55 config: Config {
56 secret: SecStr::from(secret.into()),
57 cookie_name: "csrf_token".into(),
58 expires: Expiration::Session,
59 header_name: "X-CSRF-Token".into(),
60 hsts: true,
61 http_only: true,
62 prefix: true,
63 preload: false,
64 same_site: SameSite::Strict,
65 secure: true,
66 },
67 }
68 }
69
70 pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
73 self.config.cookie_name = cookie_name.into();
74
75 self
76 }
77
78 pub fn expires(mut self, expires: Expiration) -> Self {
80 self.config.expires = expires;
81
82 self
83 }
84
85 pub fn header_name(mut self, header_name: impl Into<String>) -> Self {
88 self.config.header_name = header_name.into();
89
90 self
91 }
92
93 pub fn hsts(mut self, hsts: bool) -> Self {
97 self.config.hsts = hsts;
98
99 self
100 }
101
102 pub fn http_only(mut self, http_only: bool) -> Self {
107 self.config.http_only = http_only;
108
109 self
110 }
111
112 pub fn prefix(mut self, prefix: bool) -> Self {
117 self.config.prefix = prefix;
118
119 self
120 }
121
122 pub fn preload(mut self, preload: bool) -> Self {
126 self.config.preload = preload;
127
128 self
129 }
130
131 pub fn same_site(mut self, same_site: SameSite) -> Self {
135 self.config.same_site = same_site;
136
137 self
138 }
139
140 pub fn secure(mut self, secure: bool) -> Self {
145 self.config.secure = secure;
146
147 self
148 }
149}
150
151impl<S> Layer<S> for Surf {
152 type Service = CookieManager<SurfService<GuardService<S>>>;
153
154 fn layer(&self, inner: S) -> Self::Service {
155 CookieManager::new(SurfService {
156 config: Arc::new(self.config.clone()),
157 inner: GuardService::new(inner),
158 })
159 }
160}
161
162#[derive(Clone)]
163pub struct SurfService<S> {
164 config: Arc<Config>,
165 inner: S,
166}
167
168impl<S, Q, R> Service<Request<Q>> for SurfService<S>
169where
170 S: Service<Request<Q>, Response = Response<R>> + Send + 'static,
171 S::Future: Send + 'static,
172 Q: Send + 'static,
173 R: Default + Send,
174{
175 type Response = S::Response;
176 type Error = S::Error;
177 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
178
179 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180 self.inner.poll_ready(cx)
181 }
182
183 fn call(&mut self, mut request: Request<Q>) -> Self::Future {
184 let cookies = match request
185 .extensions()
186 .get::<Cookies>()
187 .ok_or(Error::ExtensionNotFound("Cookies".into()))
188 {
189 Ok(cookies) => cookies,
190 Err(err) => return Box::pin(async move { Error::make_layer_error(err) }),
191 };
192
193 let token = Token {
194 config: self.config.clone(),
195 cookies: cookies.clone(),
196 };
197
198 if cookies.get(&self.config.cookie_name()).is_none() {
199 if let Err(err) = token.create() {
200 return Box::pin(async move { Error::make_layer_error(err) });
201 };
202 }
203
204 request.extensions_mut().insert(self.config.clone());
205 request.extensions_mut().insert(token);
206
207 let config = self.config.clone();
208
209 if config.hsts {
210 let future = self.inner.call(request);
211
212 Box::pin(async move {
213 let mut response = future.await?;
214
215 let mut value = "max-age=31536000; includeSubDomains".to_owned();
216
217 if config.preload {
218 value.push_str("; preload");
219 }
220
221 let value = match HeaderValue::from_str(&value) {
222 Ok(value) => value,
223 Err(err) => return Error::make_layer_error(err),
224 };
225
226 response
227 .headers_mut()
228 .insert("Strict-Transport-Security", value);
229
230 Ok(response)
231 })
232 } else {
233 Box::pin(self.inner.call(request))
234 }
235 }
236}