1#![doc = include_str!("../README.md")]
2#[macro_use]
3extern crate log;
4
5use std::borrow::Borrow;
6use std::convert::Infallible;
7use std::future::Future;
8use std::net::{IpAddr, SocketAddr};
9use std::sync::Arc;
10
11use base64::decode;
12use hmac::Hmac;
13use hyper::{Body, Request, Response, Server, StatusCode};
14use hyper::http::HeaderValue;
15use hyper::server::conn::AddrStream;
16use hyper::service::{make_service_fn, service_fn};
17use jwt::{Header, Token, VerifyWithKey};
18use serde::{Deserialize, Serialize};
19use sha2::digest::KeyInit;
20use sha2::Sha512;
21use tokio::sync::oneshot::Receiver;
22
23use crate::cookies::get_auth_cookie;
24use crate::errors::AuthProxyError;
25use crate::redis_session::RedisSessionStore;
26
27mod cookies;
28pub mod redis_session;
29pub mod errors;
30
31static DOUBLE_QUOTES: &str = "\"";
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct ProxyConfig {
35 pub jwt_key: String,
36 pub credentials_key: String,
37 pub back_uri: String,
38 pub redis_uri: String,
39 pub address: SocketAddr,
40}
41
42impl ProxyConfig {
43 pub fn from_address(address_str: &str) -> ProxyConfig {
44 let Self { jwt_key: key, credentials_key, back_uri, redis_uri, address: _ } = ProxyConfig::default();
45 Self { jwt_key: key, credentials_key, back_uri, redis_uri, address: address_str.parse().unwrap() }
46 }
47}
48
49impl Default for ProxyConfig {
50 fn default() -> Self {
51 ProxyConfig {
52 jwt_key: "testsecretpourlestests".to_string(),
53 credentials_key: "credentials_key".to_string(),
54 back_uri: "http://127.0.0.1:5000".to_string(),
55 redis_uri: "redis://redis".to_string(),
56 address: "127.0.0.1:3000".parse().unwrap(),
57 }
58 }
59}
60
61#[derive(Serialize, Deserialize)]
62pub struct SessionToken {
63 pub sub: String,
64 pub sid: String,
65 pub iat: i64,
66 pub exp: i64,
67}
68
69fn decode_token(token_str_from_header: String, key: Hmac<Sha512>) -> Result<SessionToken, AuthProxyError> {
70 let token_bytes = &decode(token_str_from_header)?;
71 let token_str = String::from_utf8_lossy(token_bytes).to_string();
72 let stripped = match token_str.starts_with("\"") {
73 true => token_str.strip_prefix(DOUBLE_QUOTES).unwrap().strip_suffix(DOUBLE_QUOTES).unwrap(),
74 false => token_str.borrow()
75 };
76 let token_checked: Token<Header, SessionToken, _> = VerifyWithKey::verify_with_key(stripped, &key)?;
77 Ok(SessionToken {
78 exp: token_checked.claims().exp,
79 iat: token_checked.claims().iat,
80 sid: token_checked.claims().sid.clone(),
81 sub: token_checked.claims().sub.clone(),
82 })
83}
84
85fn set_basic_auth(req: &mut Request<Body>, credentials: &str) {
86 let mut header_value: String = "Basic ".to_owned() + credentials;
87 if header_value.ends_with("\n") {
88 header_value.pop();
89 }
90 req.headers_mut().insert("Authorization", HeaderValue::from_str(header_value.as_str()).unwrap());
91}
92
93async fn handle(client_ip: IpAddr, mut req: Request<Body>, store: Arc<RedisSessionStore>,
96 config: Arc<ProxyConfig>, decode_credentials: fn(&str, &str) -> Result<String, AuthProxyError>) -> Result<Response<Body>, Infallible> {
97 match get_auth_cookie(&req) {
98 Ok(auth_cookie) => {
99 let key: Hmac<Sha512> = Hmac::new_from_slice(config.jwt_key.as_bytes()).unwrap();
100 match decode_token(auth_cookie.value().to_string(), key) {
101 Ok(session_token) => {
102 match store.get(session_token.sid.as_str()).await {
103 Ok(Some(session)) => {
104 match decode_credentials(session.credentials.as_str(), config.credentials_key.as_str()) {
105 Ok(credentials) => {
106 set_basic_auth(&mut req, credentials.as_str());
107 match hyper_reverse_proxy::call(client_ip, config.back_uri.as_str(), req).await {
108 Ok(response) => { Ok(response) }
109 Err(_error) => {
110 Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap())
111 }
112 }
113 }
114 Err(e) => {
115 debug!("credentials decode error {} for sid={}", e, session_token.sid);
116 Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap())
117 }
118 }
119 }
120 Ok(None) => {
121 debug!("no session {} found", session_token.sid);
122 Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Body::empty()).unwrap())
123 }
124 Err(e) => {
125 debug!("err getting session from redis: {}", e);
126 Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap())
127 }
128 }
129 }
130 Err(e) => {
131 debug!("cannot decode jwt token: {}", e);
132 Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Body::empty()).unwrap())
133 }
134 }
135 }
136 Err(e) => {
137 debug!("cannot find auth cookie: {}", e);
138 Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Body::empty()).unwrap())
139 }
140 }
141}
142
143fn identity_fn_credentials(credentials: &str, _key_str: &str) -> Result<String, AuthProxyError> {
144 Ok(String::from(credentials))
145}
146
147pub async fn run_service(config: ProxyConfig, rx: Receiver<()>) -> impl Future<Output=Result<(), hyper::Error>> {
150 run_service_with_decoder(config, rx, identity_fn_credentials).await
151}
152
153pub async fn run_service_with_decoder(config: ProxyConfig, rx: Receiver<()>, decode_credentials: fn(&str, &str) -> Result<String, AuthProxyError>) -> impl Future<Output=Result<(), hyper::Error>> {
160 let cloned_config = config.clone();
161 let shared_config = Arc::new(config);
162 let shared_store = Arc::new(RedisSessionStore::new(shared_config.redis_uri.to_owned()).unwrap());
163 let make_svc = make_service_fn(move |conn: &AddrStream| {
164 let remote_addr = conn.remote_addr().ip();
165 let config_capture = shared_config.clone();
166 let store_capture = shared_store.clone();
167 async move {
168 Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req, store_capture.clone(), config_capture.clone(), decode_credentials)))
169 }
170 });
171 Server::bind(&cloned_config.address).serve(make_svc).with_graceful_shutdown(async { rx.await.ok(); })
172}
173
174#[cfg(test)]
175mod test {
176 use crate::{ProxyConfig};
177
178 #[test]
179 fn build_from_uri() {
180 let config = ProxyConfig::from_address("127.0.0.1:12345");
181 assert_eq!(config.address, "127.0.0.1:12345".parse().unwrap())
182 }
183}