use std::fmt::{Debug, Formatter};
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::ops::Range;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use log::debug;
use oauth2::{
AuthorizationCode, CsrfToken, ErrorResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
RevocableToken, Scope, TokenIntrospectionResponse, TokenResponse, TokenType,
};
use tokio::runtime::Handle;
use url::Url;
pub use crate::builder::CliOAuthBuilder;
pub use crate::error::{AuthError, ConfigError, ServerError};
use crate::server::launch;
use crate::ConfigError::CannotBindAddress;
mod builder;
mod error;
mod server;
pub(crate) type PortRange = Range<u16>;
pub type ConfigResult<T> = Result<T, ConfigError>;
type AuthorizationResultHolder = Arc<Mutex<Option<AuthorizationResult>>>;
#[derive(Debug)]
pub struct CliOAuth {
address: SocketAddr,
timeout: u64,
scopes: Vec<Scope>,
auth_context: Option<AuthContext>,
auth_result: Option<AuthorizationResult>,
}
impl CliOAuth {
pub fn builder() -> CliOAuthBuilder {
CliOAuthBuilder::new()
}
pub fn redirect_url(&self) -> RedirectUrl {
let url = format!("http://{}", self.address);
RedirectUrl::from_url(Url::parse(&url).unwrap())
}
#[cfg(not(tarpaulin_include))]
pub async fn authorize<TE, TR, TT, TIR, RT, TRE>(
&mut self,
oauth_client: &oauth2::Client<TE, TR, TT, TIR, RT, TRE>,
) -> Result<(), ServerError>
where
TE: ErrorResponse + 'static,
TR: TokenResponse<TT>,
TT: TokenType,
TIR: TokenIntrospectionResponse<TT>,
RT: RevocableToken,
TRE: ErrorResponse + 'static,
{
let scopes: Vec<Scope> = self.scopes.to_vec();
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, state) = oauth_client
.authorize_url(CsrfToken::new_random)
.add_scopes(scopes)
.set_pkce_challenge(pkce_challenge)
.url();
let handle = Handle::try_current()?;
let server = handle.spawn(launch(self.address, Duration::from_secs(self.timeout)));
debug!("🔑 authorization URL: {}", auth_url);
open::that(auth_url.as_str())?;
let result = server.await?;
match result {
Ok(auth_result) => {
self.auth_result = Some(auth_result.clone());
let auth_ctx = AuthContext {
auth_code: AuthorizationCode::new(auth_result.auth_code.clone()),
state,
pkce_verifier,
};
self.auth_context = Some(auth_ctx);
Ok(())
}
Err(e) => Err(e),
}
}
pub fn validate(&mut self) -> Result<AuthContext, AuthError> {
let expected_state = self
.auth_result
.take()
.ok_or(AuthError::InvalidAuthState)?
.state;
match self.auth_context.take() {
Some(auth_ctx) if auth_ctx.state.secret() == &expected_state => Ok(auth_ctx),
Some(_) => Err(AuthError::CsrfMismatch),
None => Err(AuthError::InvalidAuthState),
}
}
}
#[derive(Debug)]
pub struct AuthContext {
pub auth_code: AuthorizationCode,
pub state: CsrfToken,
pub pkce_verifier: PkceCodeVerifier,
}
#[derive(Clone)]
struct AuthorizationResult {
pub auth_code: String,
pub state: String,
}
impl Debug for AuthorizationResult {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"auth code={}*****, state={}*****",
self.auth_code.chars().take(3).collect::<String>(),
self.state.chars().take(3).collect::<String>(),
))
}
}
const PORT_MIN: u16 = 1024;
const DEFAULT_PORT_MIN: u16 = 3456;
const DEFAULT_PORT_MAX: u16 = DEFAULT_PORT_MIN + 10;
const DEFAULT_TIMEOUT: u64 = 60;
fn find_available_port(ip_addr: IpAddr, port_range: PortRange) -> ConfigResult<SocketAddr> {
for port in port_range.clone() {
let socket_addr = SocketAddr::new(ip_addr, port);
if is_address_available(socket_addr) {
return Ok(socket_addr);
}
}
Err(CannotBindAddress {
addr: ip_addr,
port_range,
})
}
fn is_address_available(socket_addr: SocketAddr) -> bool {
TcpListener::bind(socket_addr).is_ok()
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
use std::sync::atomic::AtomicU16;
use std::sync::atomic::Ordering::AcqRel;
use rstest::{fixture, rstest};
use crate::{find_available_port, is_address_available, PortRange};
pub(crate) static LOCALHOST: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
pub(crate) static PORT_GENERATOR: AtomicU16 = AtomicU16::new(8000);
pub(crate) fn next_ports(count: u16) -> (u16, u16) {
let start = PORT_GENERATOR.fetch_add(count, AcqRel);
let end = start + count - 1;
(start, end)
}
pub(crate) fn port_range(count: u16) -> PortRange {
let (start, end) = next_ports(count);
start..end
}
#[fixture]
fn one_port() -> PortRange {
port_range(1)
}
#[fixture]
fn two_ports() -> PortRange {
port_range(2)
}
#[fixture]
fn three_ports() -> PortRange {
port_range(3)
}
#[rstest]
fn find_available_port_with_open_port(three_ports: PortRange) {
let res = find_available_port(LOCALHOST, three_ports.clone());
match res {
Ok(addr) => assert!(three_ports.contains(&addr.port())),
Err(e) => panic!("error finding available port: {:?}", e),
}
}
#[rstest]
fn find_available_port_with_no_open_port(two_ports: PortRange) {
let _s1 = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.start)).unwrap();
let _s2 = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.end)).unwrap();
let res = find_available_port(LOCALHOST, two_ports);
res.expect_err("ports should not be available");
}
#[rstest]
fn check_address_is_available_when_port_is_open(two_ports: PortRange) {
let _sock = TcpListener::bind(SocketAddr::new(LOCALHOST, two_ports.end))
.expect("control port {open_port} is already open");
let address = SocketAddr::new(LOCALHOST, two_ports.start);
assert!(is_address_available(address));
}
#[rstest]
fn check_address_is_not_available_when_port_is_used(one_port: PortRange) {
let _socket = TcpListener::bind(SocketAddr::new(LOCALHOST, one_port.end)).expect(
"port is already \
open",
);
let address = SocketAddr::new(LOCALHOST, one_port.start);
assert!(!is_address_available(address));
}
mod cli_oauth {
use crate::{AuthContext, AuthError, AuthorizationResult, CliOAuth};
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier};
use rstest::{fixture, rstest};
#[fixture]
fn auth() -> CliOAuth {
CliOAuth {
address: ([127, 0, 0, 1], 8080).into(),
timeout: 30,
scopes: vec![],
auth_context: None,
auth_result: None,
}
}
#[fixture]
fn auth_context() -> AuthContext {
AuthContext {
state: CsrfToken::new(String::from("state")),
auth_code: AuthorizationCode::new(String::from("code")),
pkce_verifier: PkceCodeVerifier::new(String::from("pkce")),
}
}
#[fixture]
fn auth_result() -> AuthorizationResult {
AuthorizationResult {
auth_code: String::from("code"),
state: String::from("state"),
}
}
#[rstest]
fn redirect_url_valid(auth: CliOAuth) {
let url = auth.redirect_url();
assert_eq!("http://127.0.0.1:8080/", url.as_str());
}
#[rstest]
fn validate_with_no_context(mut auth: CliOAuth, auth_result: AuthorizationResult) {
auth.auth_result = Some(auth_result);
assert!(auth.validate().is_err());
}
#[rstest]
fn validate_with_no_result(mut auth: CliOAuth, auth_context: AuthContext) {
auth.auth_context = Some(auth_context);
assert!(auth.validate().is_err());
}
#[rstest]
fn validate_state_mismatch(
mut auth: CliOAuth,
mut auth_result: AuthorizationResult,
auth_context: AuthContext,
) {
auth_result.state = String::from("other_state");
auth.auth_result = Some(auth_result);
auth.auth_context = Some(auth_context);
match auth.validate() {
Err(AuthError::CsrfMismatch) => (),
Err(e) => panic!("CsrfMismatch error should be raised, but was {:?}", e),
Ok(_) => panic!("Validation should fail"),
};
}
}
}