use serde::{Deserialize, Serialize};
use url::Url;
use crate::error::Error;
use crate::pkce;
use crate::types::{Ppnum, PpnumId};
const DEFAULT_AUTH_URL: &str = "https://accounts.ppoppo.com/oauth/authorize";
const DEFAULT_TOKEN_URL: &str = "https://accounts.ppoppo.com/oauth/token";
const DEFAULT_USERINFO_URL: &str = "https://accounts.ppoppo.com/oauth/userinfo";
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct OAuthConfig {
pub(crate) client_id: String,
pub(crate) auth_url: Url,
pub(crate) token_url: Url,
pub(crate) userinfo_url: Url,
pub(crate) redirect_uri: Url,
pub(crate) scopes: Vec<String>,
}
impl OAuthConfig {
#[must_use]
#[allow(clippy::expect_used)] pub fn new(client_id: impl Into<String>, redirect_uri: Url) -> Self {
Self {
client_id: client_id.into(),
redirect_uri,
auth_url: DEFAULT_AUTH_URL.parse().expect("valid default URL"),
token_url: DEFAULT_TOKEN_URL.parse().expect("valid default URL"),
userinfo_url: DEFAULT_USERINFO_URL.parse().expect("valid default URL"),
scopes: vec!["profile".into()],
}
}
#[must_use]
pub fn with_auth_url(mut self, url: Url) -> Self {
self.auth_url = url;
self
}
#[must_use]
pub fn with_token_url(mut self, url: Url) -> Self {
self.token_url = url;
self
}
#[must_use]
pub fn with_userinfo_url(mut self, url: Url) -> Self {
self.userinfo_url = url;
self
}
#[must_use]
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
#[must_use]
pub fn client_id(&self) -> &str {
&self.client_id
}
#[must_use]
pub fn auth_url(&self) -> &Url {
&self.auth_url
}
#[must_use]
pub fn token_url(&self) -> &Url {
&self.token_url
}
#[must_use]
pub fn userinfo_url(&self) -> &Url {
&self.userinfo_url
}
#[must_use]
pub fn redirect_uri(&self) -> &Url {
&self.redirect_uri
}
#[must_use]
pub fn scopes(&self) -> &[String] {
&self.scopes
}
}
pub struct AuthClient {
config: OAuthConfig,
http: reqwest::Client,
}
#[non_exhaustive]
pub struct AuthorizationRequest {
pub url: String,
pub state: String,
pub code_verifier: String,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(default)]
pub expires_in: Option<u64>,
#[serde(default)]
pub refresh_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct UserInfo {
pub sub: PpnumId,
#[serde(default)]
pub email: Option<String>,
pub ppnum: Ppnum,
#[serde(default)]
pub email_verified: Option<bool>,
#[serde(default, with = "time::serde::rfc3339::option")]
pub created_at: Option<time::OffsetDateTime>,
}
impl UserInfo {
#[must_use]
pub fn new(sub: PpnumId, ppnum: Ppnum) -> Self {
Self {
sub,
ppnum,
email: None,
email_verified: None,
created_at: None,
}
}
#[must_use]
pub fn with_email(mut self, email: impl Into<String>) -> Self {
self.email = Some(email.into());
self
}
#[must_use]
pub fn with_email_verified(mut self, verified: bool) -> Self {
self.email_verified = Some(verified);
self
}
}
impl AuthClient {
pub fn try_new(config: OAuthConfig) -> Result<Self, Error> {
let builder = reqwest::Client::builder();
#[cfg(not(target_arch = "wasm32"))]
let builder = builder
.timeout(std::time::Duration::from_secs(10))
.connect_timeout(std::time::Duration::from_secs(5));
Ok(Self {
config,
http: builder.build()?,
})
}
#[must_use]
pub fn with_http_client(config: OAuthConfig, client: reqwest::Client) -> Self {
Self {
config,
http: client,
}
}
#[must_use]
pub fn authorization_url(&self) -> AuthorizationRequest {
let state = pkce::generate_state();
let code_verifier = pkce::generate_code_verifier();
let code_challenge = pkce::generate_code_challenge(&code_verifier);
let scope = self.config.scopes.join(" ");
let mut url = self.config.auth_url.clone();
url.query_pairs_mut()
.append_pair("response_type", "code")
.append_pair("client_id", &self.config.client_id)
.append_pair("redirect_uri", self.config.redirect_uri.as_str())
.append_pair("state", &state)
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256")
.append_pair("scope", &scope);
AuthorizationRequest {
url: url.into(),
state,
code_verifier,
}
}
pub async fn exchange_code(
&self,
code: &str,
code_verifier: &str,
) -> Result<TokenResponse, Error> {
let params = [
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", self.config.redirect_uri.as_str()),
("client_id", self.config.client_id.as_str()),
("code_verifier", code_verifier),
];
self.send_and_deserialize(
self.http.post(self.config.token_url.clone()).form(¶ms),
"token exchange",
)
.await
}
pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, Error> {
let params = [
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", self.config.client_id.as_str()),
];
self.send_and_deserialize(
self.http.post(self.config.token_url.clone()).form(¶ms),
"token refresh",
)
.await
}
pub async fn get_user_info(&self, access_token: &str) -> Result<UserInfo, Error> {
self.send_and_deserialize(
self.http
.get(self.config.userinfo_url.clone())
.bearer_auth(access_token),
"userinfo request",
)
.await
}
async fn send_and_deserialize<T: serde::de::DeserializeOwned>(
&self,
request: reqwest::RequestBuilder,
operation: &'static str,
) -> Result<T, Error> {
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(Error::OAuth {
operation,
status: Some(status),
detail: body,
});
}
response.json::<T>().await.map_err(|e| Error::OAuth {
operation,
status: None,
detail: format!("response deserialization failed: {e}"),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn test_config() -> OAuthConfig {
OAuthConfig::new(
"test-client",
"https://example.com/callback".parse().unwrap(),
)
}
#[test]
fn test_authorization_url_contains_pkce() {
let client = AuthClient::try_new(test_config()).unwrap();
let req = client.authorization_url();
assert!(req.url.contains("code_challenge="));
assert!(req.url.contains("code_challenge_method=S256"));
assert!(req.url.contains("state="));
assert!(req.url.contains("response_type=code"));
assert!(req.url.contains("client_id=test-client"));
assert!(!req.code_verifier.is_empty());
assert!(!req.state.is_empty());
}
#[test]
fn test_authorization_url_unique_per_call() {
let client = AuthClient::try_new(test_config()).unwrap();
let req1 = client.authorization_url();
let req2 = client.authorization_url();
assert_ne!(req1.state, req2.state);
assert_ne!(req1.code_verifier, req2.code_verifier);
}
#[test]
fn test_config_constructor() {
let config = OAuthConfig::new("my-app", "https://my-app.com/callback".parse().unwrap());
assert_eq!(config.client_id(), "my-app");
assert_eq!(
config.redirect_uri().as_str(),
"https://my-app.com/callback"
);
assert_eq!(
config.auth_url().as_str(),
"https://accounts.ppoppo.com/oauth/authorize"
);
}
#[test]
fn test_config_with_overrides() {
let config = OAuthConfig::new("my-app", "https://my-app.com/callback".parse().unwrap())
.with_auth_url("https://custom.example.com/authorize".parse().unwrap())
.with_scopes(vec!["profile".into(), "email".into()]);
assert_eq!(
config.auth_url().as_str(),
"https://custom.example.com/authorize"
);
assert_eq!(config.scopes(), &["profile", "email"]);
}
}