use reqwest::blocking::Client;
use crate::error::{InternalError, InvalidStateError};
use crate::oauth::OpenIdProfileProvider;
use crate::oauth::{
builder::OAuthClientBuilder, error::OAuthClientBuildError, store::InflightOAuthRequestStore,
OAuthClient, OpenIdSubjectProvider,
};
const AZURE_SCOPE: &str = "offline_access";
const DEFAULT_SCOPES: &[&str] = &["openid", "profile", "email"];
const GOOGLE_AUTH_PARAMS: &[(&str, &str)] = &[("access_type", "offline"), ("prompt", "consent")];
const GOOGLE_DISCOVERY_URL: &str = "https://accounts.google.com/.well-known/openid-configuration";
pub struct OpenIdOAuthClientBuilder {
openid_discovery_url: Option<String>,
inner: OAuthClientBuilder,
}
impl OpenIdOAuthClientBuilder {
pub fn new() -> Self {
Self {
openid_discovery_url: None,
inner: OAuthClientBuilder::default(),
}
}
pub fn new_azure() -> Self {
Self {
openid_discovery_url: None,
inner: OAuthClientBuilder::default().with_scopes(vec![AZURE_SCOPE.into()]),
}
}
pub fn new_google() -> Self {
Self {
openid_discovery_url: Some(GOOGLE_DISCOVERY_URL.into()),
inner: OAuthClientBuilder::default().with_extra_auth_params(
GOOGLE_AUTH_PARAMS
.iter()
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect(),
),
}
}
pub fn with_client_id(self, client_id: String) -> Self {
Self {
openid_discovery_url: self.openid_discovery_url,
inner: self.inner.with_client_id(client_id),
}
}
pub fn with_client_secret(self, client_secret: String) -> Self {
Self {
openid_discovery_url: self.openid_discovery_url,
inner: self.inner.with_client_secret(client_secret),
}
}
pub fn with_extra_auth_params(self, extra_auth_params: Vec<(String, String)>) -> Self {
Self {
openid_discovery_url: self.openid_discovery_url,
inner: self.inner.with_extra_auth_params(extra_auth_params),
}
}
pub fn with_scopes(self, scopes: Vec<String>) -> Self {
Self {
openid_discovery_url: self.openid_discovery_url,
inner: self.inner.with_scopes(scopes),
}
}
pub fn with_inflight_request_store(
self,
inflight_request_store: Box<dyn InflightOAuthRequestStore>,
) -> Self {
Self {
openid_discovery_url: self.openid_discovery_url,
inner: self
.inner
.with_inflight_request_store(inflight_request_store),
}
}
pub fn with_redirect_url(self, redirect_url: String) -> Self {
Self {
openid_discovery_url: self.openid_discovery_url,
inner: self.inner.with_redirect_url(redirect_url),
}
}
pub fn with_discovery_url(mut self, discovery_url: String) -> Self {
self.openid_discovery_url = Some(discovery_url);
self
}
pub fn build(self) -> Result<OAuthClient, OAuthClientBuildError> {
let discovery_url = self.openid_discovery_url.ok_or_else(|| {
InvalidStateError::with_message(
"An OpenID discovery URL is required to successfully build an OAuthClient".into(),
)
})?;
let response = Client::new().get(&discovery_url).send().map_err(|err| {
InternalError::from_source_with_message(
Box::new(err),
"Unable to retrieve OpenID discovery document".into(),
)
})?;
let discovery_document_response =
response
.json::<DiscoveryDocumentResponse>()
.map_err(|err| {
InternalError::from_source_with_message(
Box::new(err),
"Unable to deserialize OpenID discovery document".into(),
)
})?;
let userinfo_endpoint = discovery_document_response.userinfo_endpoint;
let inner = self
.inner
.with_auth_url(discovery_document_response.authorization_endpoint)
.with_token_url(discovery_document_response.token_endpoint)
.with_scopes(DEFAULT_SCOPES.iter().map(ToString::to_string).collect())
.with_subject_provider(Box::new(OpenIdSubjectProvider::new(
userinfo_endpoint.clone(),
)))
.with_profile_provider(Box::new(OpenIdProfileProvider::new(userinfo_endpoint)));
inner.build()
}
}
impl Default for OpenIdOAuthClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Deserialize)]
struct DiscoveryDocumentResponse {
authorization_endpoint: String,
token_endpoint: String,
userinfo_endpoint: String,
}
#[cfg(test)]
#[cfg(all(feature = "actix", feature = "actix-web", feature = "futures"))]
mod tests {
use super::*;
use std::collections::HashSet;
use std::sync::mpsc::channel;
use std::thread::JoinHandle;
use actix::System;
use actix_web::{dev::Server, web, App, HttpResponse, HttpServer};
use futures::Future;
use url::Url;
use crate::oauth::store::MemoryInflightOAuthRequestStore;
const CLIENT_ID: &str = "client_id";
const CLIENT_SECRET: &str = "client_secret";
const EXTRA_AUTH_PARAM_KEY: &str = "key";
const EXTRA_AUTH_PARAM_VAL: &str = "val";
const EXTRA_SCOPE: &str = "scope";
const REDIRECT_URL: &str = "http://oauth/callback";
const DISCOVERY_DOCUMENT_ENDPOINT: &str = "/.well-known/openid-configuration";
const AUTHORIZATION_ENDPOINT: &str = "http://oauth/auth";
const TOKEN_ENDPOINT: &str = "http://oauth/token";
const USERINFO_ENDPOINT: &str = "http://oauth/userinfo";
#[test]
fn basic_client() {
let (shutdown_handle, address) = run_mock_openid_server("basic_client");
let extra_auth_params = vec![(
EXTRA_AUTH_PARAM_KEY.to_string(),
EXTRA_AUTH_PARAM_VAL.to_string(),
)];
let extra_scopes = vec![EXTRA_SCOPE.to_string()];
let discovery_url = format!("{}{}", address, DISCOVERY_DOCUMENT_ENDPOINT);
let builder = OpenIdOAuthClientBuilder::new()
.with_client_id(CLIENT_ID.into())
.with_client_secret(CLIENT_SECRET.into())
.with_extra_auth_params(extra_auth_params.clone())
.with_scopes(vec![EXTRA_SCOPE.into()])
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_redirect_url(REDIRECT_URL.into())
.with_discovery_url(discovery_url.clone());
assert_eq!(builder.inner.client_id, Some(CLIENT_ID.into()));
assert_eq!(builder.inner.client_secret, Some(CLIENT_SECRET.into()));
assert_eq!(builder.inner.extra_auth_params, extra_auth_params);
assert_eq!(builder.inner.scopes, extra_scopes);
assert!(builder.inner.inflight_request_store.is_some());
assert_eq!(builder.inner.redirect_url, Some(REDIRECT_URL.into()));
assert_eq!(builder.openid_discovery_url, Some(discovery_url));
assert!(builder.inner.auth_url.is_none());
assert!(builder.inner.token_url.is_none());
assert!(builder.inner.subject_provider.is_none());
assert!(builder.inner.profile_provider.is_none());
let client = builder
.build()
.expect("Failed to build OpenID OAuth client");
assert_eq!(client.extra_auth_params, extra_auth_params);
assert_eq!(
client
.scopes
.iter()
.map(|scope| scope.as_str())
.collect::<HashSet<_>>(),
DEFAULT_SCOPES
.iter()
.cloned()
.chain(std::iter::once(EXTRA_SCOPE.into()))
.collect::<HashSet<_>>(),
);
let expected_auth_url =
Url::parse(AUTHORIZATION_ENDPOINT).expect("Failed to parse expected auth URL");
let generated_auth_url = Url::parse(
&client
.get_authorization_url("client_redirect_url".into())
.expect("Failed to generate auth URL"),
)
.expect("Failed to parse generated auth URL");
assert_eq!(expected_auth_url.origin(), generated_auth_url.origin());
shutdown_handle.shutdown();
}
#[test]
fn azure_client() {
let (shutdown_handle, address) = run_mock_openid_server("azure_client");
let discovery_url = format!("{}{}", address, DISCOVERY_DOCUMENT_ENDPOINT);
let client = OpenIdOAuthClientBuilder::new_azure()
.with_client_id(CLIENT_ID.into())
.with_client_secret(CLIENT_SECRET.into())
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_redirect_url(REDIRECT_URL.into())
.with_discovery_url(discovery_url.clone())
.build()
.expect("Failed to build Azure client");
assert_eq!(
client
.scopes
.iter()
.map(|scope| scope.as_str())
.collect::<HashSet<_>>(),
DEFAULT_SCOPES
.iter()
.cloned()
.chain(std::iter::once(AZURE_SCOPE.into()))
.collect::<HashSet<_>>(),
);
shutdown_handle.shutdown();
}
#[test]
fn google_client() {
let builder = OpenIdOAuthClientBuilder::new_google()
.with_client_id(CLIENT_ID.into())
.with_client_secret(CLIENT_SECRET.into())
.with_inflight_request_store(Box::new(MemoryInflightOAuthRequestStore::new()))
.with_redirect_url(REDIRECT_URL.into());
assert_eq!(
builder.openid_discovery_url.as_deref(),
Some(GOOGLE_DISCOVERY_URL)
);
let client = builder.build().expect("Failed to build Google client");
assert_eq!(
client
.extra_auth_params
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<HashSet<_>>(),
GOOGLE_AUTH_PARAMS.iter().cloned().collect::<HashSet<_>>(),
);
}
fn run_mock_openid_server(test_name: &str) -> (OpenIDServerShutdownHandle, String) {
let (tx, rx) = channel();
let instance_name = format!("OpenID-Server-{}", test_name);
let join_handle = std::thread::Builder::new()
.name(instance_name.clone())
.spawn(move || {
let sys = System::new(instance_name);
let server = HttpServer::new(|| {
App::new().service(
web::resource(DISCOVERY_DOCUMENT_ENDPOINT).to(discovery_document_endpoint),
)
})
.bind("127.0.0.1:0")
.expect("Failed to bind OpenID server");
let address = format!("http://127.0.0.1:{}", server.addrs()[0].port());
let server = server.disable_signals().system_exit().start();
tx.send((server, address)).expect("Failed to send server");
sys.run().expect("OpenID server runtime failed");
})
.expect("Failed to spawn OpenID server thread");
let (server, address) = rx.recv().expect("Failed to receive server");
(OpenIDServerShutdownHandle(server, join_handle), address)
}
fn discovery_document_endpoint() -> HttpResponse {
HttpResponse::Ok()
.content_type("application/json")
.json(json!({
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
"token_endpoint": TOKEN_ENDPOINT,
"userinfo_endpoint": USERINFO_ENDPOINT,
}))
}
struct OpenIDServerShutdownHandle(Server, JoinHandle<()>);
impl OpenIDServerShutdownHandle {
pub fn shutdown(self) {
self.0
.stop(false)
.wait()
.expect("Failed to stop OpenID server");
self.1.join().expect("OpenID server thread failed");
}
}
}