use std::time::Duration;
use crate::biome::OAuthUserSessionStore;
use crate::error::InternalError;
use crate::oauth::OAuthClient;
use crate::rest_api::auth::{AuthorizationHeader, BearerToken};
use super::{Identity, IdentityProvider};
const DEFAULT_REAUTHENTICATION_INTERVAL: Duration = Duration::from_secs(3600);
#[derive(Clone)]
pub struct OAuthUserIdentityProvider {
oauth_client: OAuthClient,
oauth_user_session_store: Box<dyn OAuthUserSessionStore>,
reauthentication_interval: Duration,
}
impl OAuthUserIdentityProvider {
pub fn new(
oauth_client: OAuthClient,
oauth_user_session_store: Box<dyn OAuthUserSessionStore>,
reauthentication_interval: Option<Duration>,
) -> Self {
Self {
oauth_client,
oauth_user_session_store,
reauthentication_interval: reauthentication_interval
.unwrap_or(DEFAULT_REAUTHENTICATION_INTERVAL),
}
}
}
impl IdentityProvider for OAuthUserIdentityProvider {
fn get_identity(
&self,
authorization: &AuthorizationHeader,
) -> Result<Option<Identity>, InternalError> {
let token = match authorization {
AuthorizationHeader::Bearer(BearerToken::OAuth2(token)) => token,
_ => return Ok(None),
};
let session = match self
.oauth_user_session_store
.get_session(token)
.map_err(|err| InternalError::from_source(err.into()))?
{
Some(session) => session,
None => return Ok(None),
};
let user_id = session.user().user_id().to_string();
let time_since_authenticated = session
.last_authenticated()
.elapsed()
.map_err(|err| InternalError::from_source(err.into()))?;
if time_since_authenticated >= self.reauthentication_interval {
match self.oauth_client.get_subject(session.oauth_access_token()) {
Ok(Some(_)) => {
let updated_session = session.into_update_builder().build();
self.oauth_user_session_store
.update_session(updated_session)
.map_err(|err| InternalError::from_source(err.into()))?;
Ok(Some(Identity::User(user_id)))
}
Ok(None) => {
match session.oauth_refresh_token() {
Some(refresh_token) => {
match self
.oauth_client
.exchange_refresh_token(refresh_token.to_string())
{
Ok(access_token) => {
let updated_session = session
.into_update_builder()
.with_oauth_access_token(access_token.clone())
.build();
self.oauth_user_session_store
.update_session(updated_session)
.map_err(|err| InternalError::from_source(err.into()))?;
match self.oauth_client.get_subject(&access_token)? {
Some(_) => Ok(Some(Identity::User(user_id))),
None => Err(InternalError::with_message(
"failed to authenticate user with new access token"
.into(),
)),
}
}
Err(err) => {
debug!("Failed to exchange refresh token: {}", err);
self.oauth_user_session_store
.remove_session(token)
.map_err(|err| InternalError::from_source(err.into()))?;
Ok(None)
}
}
}
None => {
self.oauth_user_session_store
.remove_session(token)
.map_err(|err| InternalError::from_source(err.into()))?;
Ok(None)
}
}
}
Err(err) => {
self.oauth_user_session_store
.remove_session(token)
.map_err(|err| InternalError::from_source(err.into()))?;
Err(err)
}
}
} else {
Ok(Some(Identity::User(user_id)))
}
}
fn clone_box(&self) -> Box<dyn IdentityProvider> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc::channel;
use std::thread::JoinHandle;
use actix::System;
use actix_web::{dev::Server, web, App, HttpResponse, HttpServer};
use futures::Future;
use crate::biome::oauth::store::InsertableOAuthUserSessionBuilder;
use crate::biome::MemoryOAuthUserSessionStore;
use crate::oauth::{
store::MemoryInflightOAuthRequestStore, OAuthClientBuilder, SubjectProvider,
};
use crate::oauth::{Profile, ProfileProvider};
const TOKEN_ENDPOINT: &str = "/token";
const REFRESH_TOKEN: &str = "refresh_token";
const NEW_OAUTH_ACCESS_TOKEN: &str = "new_oauth_access_token";
#[test]
fn get_identity_cached() {
let session_store = Box::new(MemoryOAuthUserSessionStore::new());
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.build()
.expect("Failed to build session");
session_store
.add_session(session)
.expect("Failed to add session");
let user_id = session_store
.get_session(splinter_access_token)
.expect("Failed to get inserted session")
.expect("Inserted session not found")
.user()
.user_id()
.to_string();
let identity_provider =
OAuthUserIdentityProvider::new(always_err_client(), session_store, None);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2(splinter_access_token.into()));
let identity = identity_provider
.get_identity(&authorization_header)
.expect("Failed to get identity")
.expect("Identity not found");
assert_eq!(identity, Identity::User(user_id));
}
#[test]
fn get_identity_no_session() {
let identity_provider = OAuthUserIdentityProvider::new(
always_some_client(),
Box::new(MemoryOAuthUserSessionStore::new()),
None,
);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2("splinter_access_token".into()));
assert!(identity_provider
.get_identity(&authorization_header)
.expect("Failed to get identity")
.is_none());
}
#[test]
fn get_identity_reauthentication_successful() {
let session_store = Box::new(MemoryOAuthUserSessionStore::new());
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.build()
.expect("Failed to build session");
session_store
.add_session(session)
.expect("Failed to add session");
let original_session = session_store
.get_session(splinter_access_token)
.expect("Failed to get inserted session")
.expect("Inserted session not found");
let identity_provider = OAuthUserIdentityProvider::new(
always_some_client(),
session_store.clone(),
Some(Duration::from_secs(0)),
);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2(splinter_access_token.into()));
let identity = identity_provider
.get_identity(&authorization_header)
.expect("Failed to get identity")
.expect("Identity not found");
assert_eq!(
identity,
Identity::User(original_session.user().user_id().into())
);
let new_session = session_store
.get_session(splinter_access_token)
.expect("Failed to get updated session")
.expect("Updated session not found");
assert!(new_session.last_authenticated() > original_session.last_authenticated());
}
#[test]
fn get_identity_reauthentication_unauthorized() {
let session_store = Box::new(MemoryOAuthUserSessionStore::new());
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.build()
.expect("Failed to build session");
session_store
.add_session(session)
.expect("Failed to add session");
let identity_provider = OAuthUserIdentityProvider::new(
always_none_client(),
session_store.clone(),
Some(Duration::from_secs(0)),
);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2(splinter_access_token.into()));
assert!(identity_provider
.get_identity(&authorization_header)
.expect("Failed to get identity")
.is_none());
assert!(session_store
.get_session(splinter_access_token)
.expect("Failed to get session")
.is_none());
}
#[test]
fn get_identity_reauthentication_failed() {
let session_store = Box::new(MemoryOAuthUserSessionStore::new());
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.build()
.expect("Failed to build session");
session_store
.add_session(session)
.expect("Failed to add session");
let identity_provider = OAuthUserIdentityProvider::new(
always_err_client(),
session_store.clone(),
Some(Duration::from_secs(0)),
);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2(splinter_access_token.into()));
assert!(identity_provider
.get_identity(&authorization_header)
.is_err());
assert!(session_store
.get_session(splinter_access_token)
.expect("Failed to get session")
.is_none());
}
#[test]
fn get_identity_refresh_successful() {
let (shutdown_handle, address) = run_mock_oauth_server("get_identity_refresh_successful");
let session_store = Box::new(MemoryOAuthUserSessionStore::new());
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.with_oauth_refresh_token(Some(REFRESH_TOKEN.into()))
.build()
.expect("Failed to build session");
session_store
.add_session(session)
.expect("Failed to add session");
let original_session = session_store
.get_session(splinter_access_token)
.expect("Failed to get inserted session")
.expect("Inserted session not found");
let client = OAuthClientBuilder::new()
.with_client_id("client_id".into())
.with_client_secret("client_secret".into())
.with_auth_url("http://test.com/auth".into())
.with_redirect_url("http://test.com/redirect".into())
.with_token_url(format!("{}{}", address, TOKEN_ENDPOINT))
.with_subject_provider(Box::new(RefreshedTokenSubjectProvider))
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_profile_provider(Box::new(RefreshedTokenProfileProvider))
.build()
.expect("Failed to build OAuth client");
let identity_provider = OAuthUserIdentityProvider::new(
client,
session_store.clone(),
Some(Duration::from_secs(0)),
);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2(splinter_access_token.into()));
let identity = identity_provider
.get_identity(&authorization_header)
.expect("Failed to get identity")
.expect("Identity not found");
assert_eq!(
identity,
Identity::User(original_session.user().user_id().into())
);
let new_session = session_store
.get_session(splinter_access_token)
.expect("Failed to get updated session")
.expect("Updated session not found");
assert!(new_session.last_authenticated() > original_session.last_authenticated());
assert_eq!(new_session.oauth_access_token(), NEW_OAUTH_ACCESS_TOKEN);
shutdown_handle.shutdown();
}
#[test]
fn get_identity_refresh_failed() {
let (shutdown_handle, address) = run_mock_oauth_server("get_identity_refresh_successful");
let session_store = Box::new(MemoryOAuthUserSessionStore::new());
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.with_oauth_refresh_token(Some("unknown_refresh_token".into()))
.build()
.expect("Failed to build session");
session_store
.add_session(session)
.expect("Failed to add session");
let client = OAuthClientBuilder::new()
.with_client_id("client_id".into())
.with_client_secret("client_secret".into())
.with_auth_url("http://test.com/auth".into())
.with_redirect_url("http://test.com/redirect".into())
.with_token_url(format!("{}{}", address, TOKEN_ENDPOINT))
.with_subject_provider(Box::new(RefreshedTokenSubjectProvider))
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_profile_provider(Box::new(RefreshedTokenProfileProvider))
.build()
.expect("Failed to build OAuth client");
let identity_provider = OAuthUserIdentityProvider::new(
client,
session_store.clone(),
Some(Duration::from_secs(0)),
);
let authorization_header =
AuthorizationHeader::Bearer(BearerToken::OAuth2(splinter_access_token.into()));
assert!(identity_provider
.get_identity(&authorization_header)
.expect("Failed to get identity")
.is_none());
assert!(session_store
.get_session(splinter_access_token)
.expect("Failed to get session")
.is_none());
shutdown_handle.shutdown();
}
fn always_some_client() -> OAuthClient {
OAuthClientBuilder::new()
.with_client_id("client_id".into())
.with_client_secret("client_secret".into())
.with_auth_url("http://test.com/auth".into())
.with_redirect_url("http://test.com/redirect".into())
.with_token_url("http://test.com/token".into())
.with_subject_provider(Box::new(AlwaysSomeSubjectProvider))
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_profile_provider(Box::new(AlwaysSomeProfileProvider))
.build()
.expect("Failed to build OAuth client")
}
#[derive(Clone)]
struct AlwaysSomeSubjectProvider;
impl SubjectProvider for AlwaysSomeSubjectProvider {
fn get_subject(&self, _access_token: &str) -> Result<Option<String>, InternalError> {
Ok(Some("subject".into()))
}
fn clone_box(&self) -> Box<dyn SubjectProvider> {
Box::new(self.clone())
}
}
#[derive(Clone)]
struct AlwaysSomeProfileProvider;
impl ProfileProvider for AlwaysSomeProfileProvider {
fn get_profile(&self, _access_token: &str) -> Result<Option<Profile>, InternalError> {
let profile = Profile {
subject: "subject".to_string(),
name: None,
given_name: None,
family_name: None,
email: None,
picture: None,
};
Ok(Some(profile))
}
fn clone_box(&self) -> Box<dyn ProfileProvider> {
Box::new(self.clone())
}
}
fn always_none_client() -> OAuthClient {
OAuthClientBuilder::new()
.with_client_id("client_id".into())
.with_client_secret("client_secret".into())
.with_auth_url("http://test.com/auth".into())
.with_redirect_url("http://test.com/redirect".into())
.with_token_url("http://test.com/token".into())
.with_subject_provider(Box::new(AlwaysNoneSubjectProvider))
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_profile_provider(Box::new(AlwaysNoneProfileProvider))
.build()
.expect("Failed to build OAuth client")
}
#[derive(Clone)]
struct AlwaysNoneSubjectProvider;
impl SubjectProvider for AlwaysNoneSubjectProvider {
fn get_subject(&self, _access_token: &str) -> Result<Option<String>, InternalError> {
Ok(None)
}
fn clone_box(&self) -> Box<dyn SubjectProvider> {
Box::new(self.clone())
}
}
#[derive(Clone)]
struct AlwaysNoneProfileProvider;
impl ProfileProvider for AlwaysNoneProfileProvider {
fn get_profile(&self, _access_token: &str) -> Result<Option<Profile>, InternalError> {
Ok(None)
}
fn clone_box(&self) -> Box<dyn ProfileProvider> {
Box::new(self.clone())
}
}
fn always_err_client() -> OAuthClient {
OAuthClientBuilder::new()
.with_client_id("client_id".into())
.with_client_secret("client_secret".into())
.with_auth_url("http://test.com/auth".into())
.with_redirect_url("http://test.com/redirect".into())
.with_token_url("http://test.com/token".into())
.with_subject_provider(Box::new(AlwaysErrSubjectProvider))
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_profile_provider(Box::new(AlwaysErrProfileProvider))
.build()
.expect("Failed to build OAuth client")
}
#[derive(Clone)]
struct AlwaysErrSubjectProvider;
impl SubjectProvider for AlwaysErrSubjectProvider {
fn get_subject(&self, _access_token: &str) -> Result<Option<String>, InternalError> {
Err(InternalError::with_message("error".into()))
}
fn clone_box(&self) -> Box<dyn SubjectProvider> {
Box::new(self.clone())
}
}
#[derive(Clone)]
struct AlwaysErrProfileProvider;
impl ProfileProvider for AlwaysErrProfileProvider {
fn get_profile(&self, _access_token: &str) -> Result<Option<Profile>, InternalError> {
Err(InternalError::with_message("error".into()))
}
fn clone_box(&self) -> Box<dyn ProfileProvider> {
Box::new(self.clone())
}
}
#[derive(Clone)]
struct RefreshedTokenSubjectProvider;
impl SubjectProvider for RefreshedTokenSubjectProvider {
fn get_subject(&self, access_token: &str) -> Result<Option<String>, InternalError> {
if access_token == NEW_OAUTH_ACCESS_TOKEN {
Ok(Some("subject".into()))
} else {
Ok(None)
}
}
fn clone_box(&self) -> Box<dyn SubjectProvider> {
Box::new(self.clone())
}
}
#[derive(Clone)]
struct RefreshedTokenProfileProvider;
impl ProfileProvider for RefreshedTokenProfileProvider {
fn get_profile(&self, access_token: &str) -> Result<Option<Profile>, InternalError> {
if access_token == NEW_OAUTH_ACCESS_TOKEN {
let profile = Profile {
subject: "subject".to_string(),
name: None,
given_name: None,
family_name: None,
email: None,
picture: None,
};
Ok(Some(profile))
} else {
Ok(None)
}
}
fn clone_box(&self) -> Box<dyn ProfileProvider> {
Box::new(self.clone())
}
}
fn run_mock_oauth_server(test_name: &str) -> (OAuthServerShutdownHandle, String) {
let (tx, rx) = channel();
let instance_name = format!("OAuth-Server-{}", test_name);
let join_handle = std::thread::Builder::new()
.name(instance_name.clone())
.spawn(move || {
let sys = System::new(instance_name);
let server = HttpServer::new(|| {
App::new().service(web::resource(TOKEN_ENDPOINT).to(token_endpoint))
})
.bind("127.0.0.1:0")
.expect("Failed to bind OAuth server");
let address = format!("http://127.0.0.1:{}", server.addrs()[0].port());
let server = server.disable_signals().system_exit().start();
tx.send((server, address)).expect("Failed to send server");
sys.run().expect("OAuth server runtime failed");
})
.expect("Failed to spawn OAuth server thread");
let (server, address) = rx.recv().expect("Failed to receive server");
(OAuthServerShutdownHandle(server, join_handle), address)
}
fn token_endpoint(form: web::Form<TokenRequestForm>) -> HttpResponse {
assert_eq!(&form.grant_type, "refresh_token");
if &form.refresh_token == REFRESH_TOKEN {
HttpResponse::Ok()
.content_type("application/json")
.json(json!({
"token_type": "bearer",
"access_token": NEW_OAUTH_ACCESS_TOKEN,
}))
} else {
HttpResponse::Unauthorized().finish()
}
}
#[derive(Deserialize)]
struct TokenRequestForm {
grant_type: String,
refresh_token: String,
}
struct OAuthServerShutdownHandle(Server, JoinHandle<()>);
impl OAuthServerShutdownHandle {
pub fn shutdown(self) {
self.0
.stop(false)
.wait()
.expect("Failed to stop OAuth server");
self.1.join().expect("OAuth server thread failed");
}
}
}