hyper_auth_proxy/
lib.rs

1#![doc = include_str!("../README.md")]
2#[macro_use]
3extern crate log;
4
5use std::borrow::Borrow;
6use std::convert::Infallible;
7use std::future::Future;
8use std::net::{IpAddr, SocketAddr};
9use std::sync::Arc;
10
11use base64::decode;
12use hmac::Hmac;
13use hyper::{Body, Request, Response, Server, StatusCode};
14use hyper::http::HeaderValue;
15use hyper::server::conn::AddrStream;
16use hyper::service::{make_service_fn, service_fn};
17use jwt::{Header, Token, VerifyWithKey};
18use serde::{Deserialize, Serialize};
19use sha2::digest::KeyInit;
20use sha2::Sha512;
21use tokio::sync::oneshot::Receiver;
22
23use crate::cookies::get_auth_cookie;
24use crate::errors::AuthProxyError;
25use crate::redis_session::RedisSessionStore;
26
27mod cookies;
28pub mod redis_session;
29pub mod errors;
30
31static DOUBLE_QUOTES: &str = "\"";
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct ProxyConfig {
35    pub jwt_key: String,
36    pub credentials_key: String,
37    pub back_uri: String,
38    pub redis_uri: String,
39    pub address: SocketAddr,
40}
41
42impl ProxyConfig {
43    pub fn from_address(address_str: &str) -> ProxyConfig {
44        let Self { jwt_key: key, credentials_key, back_uri, redis_uri, address: _ } = ProxyConfig::default();
45        Self { jwt_key: key, credentials_key, back_uri, redis_uri, address: address_str.parse().unwrap() }
46    }
47}
48
49impl Default for ProxyConfig {
50    fn default() -> Self {
51        ProxyConfig {
52            jwt_key: "testsecretpourlestests".to_string(),
53            credentials_key: "credentials_key".to_string(),
54            back_uri: "http://127.0.0.1:5000".to_string(),
55            redis_uri: "redis://redis".to_string(),
56            address: "127.0.0.1:3000".parse().unwrap(),
57        }
58    }
59}
60
61#[derive(Serialize, Deserialize)]
62pub struct SessionToken {
63    pub sub: String,
64    pub sid: String,
65    pub iat: i64,
66    pub exp: i64,
67}
68
69fn decode_token(token_str_from_header: String, key: Hmac<Sha512>) -> Result<SessionToken, AuthProxyError> {
70    let token_bytes = &decode(token_str_from_header)?;
71    let token_str = String::from_utf8_lossy(token_bytes).to_string();
72    let stripped = match token_str.starts_with("\"") {
73        true => token_str.strip_prefix(DOUBLE_QUOTES).unwrap().strip_suffix(DOUBLE_QUOTES).unwrap(),
74        false => token_str.borrow()
75    };
76    let token_checked: Token<Header, SessionToken, _> = VerifyWithKey::verify_with_key(stripped, &key)?;
77    Ok(SessionToken {
78        exp: token_checked.claims().exp,
79        iat: token_checked.claims().iat,
80        sid: token_checked.claims().sid.clone(),
81        sub: token_checked.claims().sub.clone(),
82    })
83}
84
85fn set_basic_auth(req: &mut Request<Body>, credentials: &str) {
86    let mut header_value: String = "Basic ".to_owned() + credentials;
87    if header_value.ends_with("\n") {
88        header_value.pop();
89    }
90    req.headers_mut().insert("Authorization", HeaderValue::from_str(header_value.as_str()).unwrap());
91}
92
93/// main handler processor for each request
94/// all the workflow is defined here
95async fn handle(client_ip: IpAddr, mut req: Request<Body>, store: Arc<RedisSessionStore>,
96                config: Arc<ProxyConfig>, decode_credentials: fn(&str, &str) -> Result<String, AuthProxyError>) -> Result<Response<Body>, Infallible> {
97    match get_auth_cookie(&req) {
98        Ok(auth_cookie) => {
99            let key: Hmac<Sha512> = Hmac::new_from_slice(config.jwt_key.as_bytes()).unwrap();
100            match decode_token(auth_cookie.value().to_string(), key) {
101                Ok(session_token) => {
102                    match store.get(session_token.sid.as_str()).await {
103                        Ok(Some(session)) => {
104                            match decode_credentials(session.credentials.as_str(), config.credentials_key.as_str()) {
105                                Ok(credentials) => {
106                                    set_basic_auth(&mut req, credentials.as_str());
107                                    match hyper_reverse_proxy::call(client_ip, config.back_uri.as_str(), req).await {
108                                        Ok(response) => { Ok(response) }
109                                        Err(_error) => {
110                                            Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap())
111                                        }
112                                    }
113                                }
114                                Err(e) => {
115                                    debug!("credentials decode error {} for sid={}", e, session_token.sid);
116                                    Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap())
117                                }
118                            }
119                        }
120                        Ok(None) => {
121                            debug!("no session {} found", session_token.sid);
122                            Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Body::empty()).unwrap())
123                        }
124                        Err(e) => {
125                            debug!("err getting session from redis: {}", e);
126                            Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap())
127                        }
128                    }
129                }
130                Err(e) => {
131                    debug!("cannot decode jwt token: {}", e);
132                    Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Body::empty()).unwrap())
133                }
134            }
135        }
136        Err(e) => {
137            debug!("cannot find auth cookie: {}", e);
138            Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Body::empty()).unwrap())
139        }
140    }
141}
142
143fn identity_fn_credentials(credentials: &str, _key_str: &str) -> Result<String, AuthProxyError> {
144    Ok(String::from(credentials))
145}
146
147/// Runs the proxy without credential decoder. The string in Redis credential field is used
148///  as `Authorization` header
149pub async fn run_service(config: ProxyConfig, rx: Receiver<()>) -> impl Future<Output=Result<(), hyper::Error>> {
150    run_service_with_decoder(config, rx, identity_fn_credentials).await
151}
152
153/// Runs the proxy with a credential decoder function. It should be with the signature :
154/// ```rust,no_run
155/// use hyper_auth_proxy::errors::AuthProxyError;
156/// type F = fn(&str, &str) -> Result<String, AuthProxyError>;
157/// ```
158///
159pub async fn run_service_with_decoder(config: ProxyConfig, rx: Receiver<()>, decode_credentials: fn(&str, &str) -> Result<String, AuthProxyError>) -> impl Future<Output=Result<(), hyper::Error>> {
160    let cloned_config = config.clone();
161    let shared_config = Arc::new(config);
162    let shared_store = Arc::new(RedisSessionStore::new(shared_config.redis_uri.to_owned()).unwrap());
163    let make_svc = make_service_fn(move |conn: &AddrStream| {
164        let remote_addr = conn.remote_addr().ip();
165        let config_capture = shared_config.clone();
166        let store_capture = shared_store.clone();
167        async move {
168            Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req, store_capture.clone(), config_capture.clone(), decode_credentials)))
169        }
170    });
171    Server::bind(&cloned_config.address).serve(make_svc).with_graceful_shutdown(async { rx.await.ok(); })
172}
173
174#[cfg(test)]
175mod test {
176    use crate::{ProxyConfig};
177
178    #[test]
179    fn build_from_uri() {
180        let config = ProxyConfig::from_address("127.0.0.1:12345");
181        assert_eq!(config.address, "127.0.0.1:12345".parse().unwrap())
182    }
183}