use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tracing::{error, info, trace, warn};
use thiserror::Error;
use bytes::Bytes;
use futures::io::Error as IoError;
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{body::Incoming as IncomingBody, Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio::signal;
use tokio_util::sync::CancellationToken;
use url::Url;
use dialoguer::Confirm;
use crate::api::identity::v3::auth::os_federation::{
identity_provider::protocol::websso::get as fed_idp_sso_get, websso::get as fed_sso_get,
};
use crate::api::RestEndpoint;
use crate::auth::authtoken::{AuthToken, AuthTokenError};
use crate::config;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum WebSsoError {
#[error("WebSSO callback didn't return a token")]
CallbackNoToken,
#[error("WebSSO authentication failed")]
CallbackFailed,
#[error("error using the dialoguer: {}", source)]
Dialoguer {
#[from]
source: dialoguer::Error,
},
#[error("error preparing auth request: {}", source)]
FederationIdpSsoAuth {
#[from]
source: fed_idp_sso_get::RequestBuilderError,
},
#[error("error preparing auth request: {}", source)]
FederationSsoBuilder {
#[from]
source: fed_sso_get::RequestBuilderError,
},
#[error("http server error: {}", source)]
Http {
#[from]
source: http::Error,
},
#[error("hyper (http server) error: {}", source)]
Hyper {
#[from]
source: hyper::Error,
},
#[error("`IO` error: {}", source)]
IO {
#[from]
source: IoError,
},
#[error("`Join` error: {}", source)]
Join {
#[from]
source: tokio::task::JoinError,
},
#[error("Auth data is missing")]
MissingAuthData,
#[error("Federation protocol information is missing")]
MissingProtocol,
#[error("internal error: poisoned lock: {}", context)]
PoisonedLock {
context: String,
},
}
pub fn get_auth_url(config: &config::CloudConfig) -> Result<Cow<'static, str>, WebSsoError> {
if let Some(auth) = &config.auth {
if let Some(identity_provider) = &auth.identity_provider {
let mut ep = fed_idp_sso_get::RequestBuilder::default();
ep.idp_id(identity_provider);
if let Some(protocol) = &auth.protocol {
ep.protocol_id(protocol);
} else {
return Err(WebSsoError::MissingProtocol);
}
return Ok(ep.build()?.endpoint());
} else {
let mut ep = fed_sso_get::RequestBuilder::default();
if let Some(protocol) = &auth.protocol {
ep.protocol_id(protocol);
} else {
return Err(WebSsoError::MissingProtocol);
}
return Ok(ep.build()?.endpoint());
}
}
Err(WebSsoError::MissingAuthData)
}
pub async fn get_token_auth(url: &mut Url) -> Result<AuthToken, AuthTokenError> {
let token = get_token(url, None).await?;
Ok(AuthToken {
token: token.clone(),
auth_info: None,
})
}
async fn get_token(url: &mut Url, socket_addr: Option<SocketAddr>) -> Result<String, WebSsoError> {
url.set_query(Some("origin=http://localhost:8050/callback"));
let confirmation = Confirm::new()
.with_prompt(format!(
"A default browser is going to be opened at `{}`. Do you want to continue?",
url.as_str()
))
.interact()?;
if confirmation {
info!("Opening browser at {:?}", url.as_str());
let addr = socket_addr.unwrap_or(SocketAddr::from(([127, 0, 0, 1], 8050)));
let cancel_token = CancellationToken::new();
let state: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
tokio::spawn({
let cancel_token = cancel_token.clone();
async move {
if let Ok(()) = signal::ctrl_c().await {
info!("received Ctrl-C, shutting down");
cancel_token.cancel();
}
}
});
let websso_handle = tokio::spawn({
let cancel_token = cancel_token.clone();
let state = state.clone();
async move { websso_callback_server(addr, state, cancel_token).await }
});
open::that(url.as_str())?;
let _res = websso_handle.await?;
let guard = state.lock().map_err(|_| WebSsoError::PoisonedLock {
context: "locking WebSSO authentication state".to_string(),
})?;
guard.clone().ok_or(WebSsoError::CallbackNoToken)
} else {
Err(WebSsoError::CallbackFailed)
}
}
async fn websso_callback_server(
addr: SocketAddr,
state: Arc<Mutex<Option<String>>>,
cancel_token: CancellationToken,
) -> Result<(), WebSsoError> {
let listener = TcpListener::bind(addr).await?;
info!("Starting webserver to receive SSO callback");
let webserver_timeout = Duration::from_secs(120);
loop {
let state_clone = state.clone();
tokio::select! {
Ok((stream, _addr)) = listener.accept() => {
let io = TokioIo::new(stream);
let cancel_token_srv = cancel_token.clone();
let cancel_token_conn = cancel_token.clone();
let service = service_fn(move |req| {
let state_clone = state_clone.clone();
let cancel_token = cancel_token_srv.clone();
handle_request(req, state_clone, cancel_token)
});
tokio::task::spawn(async move {
let cancel_token = cancel_token_conn.clone();
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
error!("Failed to serve connection: {:?}", err);
cancel_token.cancel();
}
});
},
_ = cancel_token.cancelled() => {
info!("Stopping webserver");
break;
},
_ = tokio::time::sleep(webserver_timeout) => {
warn!("Timeout of {} sec waiting for authentication expired. Shutting down", webserver_timeout.as_secs());
cancel_token.cancel();
}
}
}
Ok(())
}
async fn handle_request(
req: Request<IncomingBody>,
state: Arc<Mutex<Option<String>>>,
cancel_token: CancellationToken,
) -> Result<Response<BoxBody<Bytes, Infallible>>, WebSsoError> {
match (req.method(), req.uri().path()) {
(&Method::POST, "/callback") => {
let b = req.collect().await?.to_bytes();
trace!("Body is {:?}", b);
let params = form_urlencoded::parse(b.as_ref())
.into_owned()
.collect::<HashMap<String, String>>();
trace!("Params = {:?}", params);
let mut data = state.lock().map_err(|_| WebSsoError::PoisonedLock {
context: "locking WebSSO authentication state".to_string(),
})?;
if let Some(token) = params.get("token") {
*data = Some(token.clone());
}
cancel_token.cancel();
Ok(Response::builder()
.body(Full::new(Bytes::from(include_str!("../../static/callback.html"))).boxed())?)
}
_ => {
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Empty::<Bytes>::new().boxed())?)
}
}
}
#[cfg(test)]
mod tests {
use reserve_port::ReservedSocketAddr;
use std::sync::{Arc, Mutex};
use tokio::signal;
use tokio_util::sync::CancellationToken;
use super::websso_callback_server;
#[tokio::test]
async fn test_callback() {
let addr = ReservedSocketAddr::reserve_random_socket_addr()
.expect("port available")
.socket_addr();
let cancel_token = CancellationToken::new();
tokio::spawn({
let cancel_token = cancel_token.clone();
async move {
if let Ok(()) = signal::ctrl_c().await {
cancel_token.cancel();
}
}
});
let state = Arc::new(Mutex::new(None));
let websso_handle = tokio::spawn({
let cancel_token = cancel_token.clone();
let state = state.clone();
async move { websso_callback_server(addr, state, cancel_token).await }
});
let params = [("token", "foo_bar_baz")];
let client = reqwest::Client::new();
client
.post(format!("http://localhost:{}/callback", addr.port()))
.form(¶ms)
.send()
.await
.unwrap();
websso_handle.await.unwrap().unwrap();
assert_eq!(*state.lock().unwrap(), Some(params[0].1.into()));
}
#[tokio::test]
async fn test_callback_no_token() {
let addr = ReservedSocketAddr::reserve_random_socket_addr()
.expect("port available")
.socket_addr();
let cancel_token = CancellationToken::new();
tokio::spawn({
let cancel_token = cancel_token.clone();
async move {
if let Ok(()) = signal::ctrl_c().await {
cancel_token.cancel();
}
}
});
let state = Arc::new(Mutex::new(None));
let websso_handle = tokio::spawn({
let cancel_token = cancel_token.clone();
let state = state.clone();
async move { websso_callback_server(addr, state, cancel_token).await }
});
let client = reqwest::Client::new();
client
.post(format!("http://localhost:{}/callback", addr.port()))
.send()
.await
.unwrap();
websso_handle.await.unwrap().unwrap();
assert_eq!(*state.lock().unwrap(), None);
}
}