#[cfg(feature = "async-std")]
use async_std::{
io::{BufReadExt, BufReader, WriteExt},
net::TcpListener,
};
use oauth2::{
url::Url, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RequestTokenError,
Scope, TokenResponse,
};
#[cfg(feature = "tokio")]
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::TcpListener,
};
use super::{Client, Error, Result};
#[derive(Debug, Default)]
pub struct AuthorizationCodeGrant {
pub scopes: Vec<Scope>,
pub pkce: Option<(PkceCodeChallenge, PkceCodeVerifier)>,
}
impl AuthorizationCodeGrant {
pub fn new() -> Self {
Self::default()
}
pub fn with_scope<T>(mut self, scope: T) -> Self
where
T: ToString,
{
self.scopes.push(Scope::new(scope.to_string()));
self
}
pub fn with_pkce(mut self) -> Self {
self.pkce = Some(PkceCodeChallenge::new_random_sha256());
self
}
pub fn get_redirect_url(&self, client: &Client) -> (Url, CsrfToken) {
let mut redirect = client
.authorize_url(CsrfToken::new_random)
.add_scopes(self.scopes.clone());
if let Some((pkce_challenge, _)) = &self.pkce {
redirect = redirect.set_pkce_challenge(pkce_challenge.clone());
}
redirect.url()
}
pub async fn wait_for_redirection(
self,
client: &Client,
csrf_state: CsrfToken,
) -> Result<(String, Option<String>)> {
let (mut stream, _) =
TcpListener::bind((client.redirect_host.as_str(), client.redirect_port))
.await
.map_err(|err| {
Error::BindRedirectServerError(
client.redirect_host.clone(),
client.redirect_port,
err,
)
})?
.accept()
.await
.map_err(Error::AcceptRedirectServerError)?;
let code = {
let mut reader = BufReader::new(&mut stream);
let mut request_line = String::new();
reader.read_line(&mut request_line).await?;
let redirect_url = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| Error::MissingRedirectUrlError(request_line.clone()))?;
let redirect_url = format!("http://localhost{redirect_url}");
let redirect_url = Url::parse(&redirect_url)
.map_err(|err| Error::ParseRedirectUrlError(err, redirect_url.clone()))?;
let (_, state) = redirect_url
.query_pairs()
.find(|(key, _)| key == "state")
.ok_or_else(|| Error::FindStateInRedirectUrlError(redirect_url.clone()))?;
let state = CsrfToken::new(state.into_owned());
if state.secret() != csrf_state.secret() {
return Err(Error::InvalidStateError(
state.secret().to_owned(),
csrf_state.secret().to_owned(),
));
}
let (_, code) = redirect_url
.query_pairs()
.find(|(key, _)| key == "code")
.ok_or_else(|| Error::FindCodeInRedirectUrlError(redirect_url.clone()))?;
AuthorizationCode::new(code.into_owned())
};
let res = "Authentication successful!";
let res = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
res.len(),
res
);
stream.write_all(res.as_bytes()).await?;
let mut res = client.exchange_code(code);
if let Some((_, pkce_verifier)) = self.pkce {
res = res.set_pkce_verifier(pkce_verifier);
}
let res = res
.request_async(&Client::send_oauth2_request)
.await
.map_err(|err| match err {
RequestTokenError::Request(req) => Error::ExchangeCodeError(req.to_string()),
RequestTokenError::ServerResponse(res) => Error::ExchangeCodeError(res.to_string()),
RequestTokenError::Parse(err, _) => Error::ExchangeCodeError(err.to_string()),
RequestTokenError::Other(err) => Error::ExchangeCodeError(err),
})?;
let access_token = res.access_token().secret().to_owned();
let refresh_token = res.refresh_token().map(|t| t.secret().clone());
Ok((access_token, refresh_token))
}
}