use activitystreams_vocabulary::{field_access, impl_default, impl_display};
use chrono::Utc;
use oauth::primitives::scope::Scope as OAuthScope;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use crate::app::oauth::{OAuthGrantType, Scope, ScopeList};
use crate::crypto::{Password, SymmetricKey};
use crate::db::{DateTime, Db, Iri, IriList, Key, OAuthGrantTypeList, Transaction, Uuid, UuidList};
use crate::{Error, Result, impl_sql_list_field, impl_sql_record, util};
mod auth_method;
mod response;
mod secret;
pub use auth_method::*;
pub use response::*;
pub use secret::*;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Deserialize, Serialize, sqlx::Type)]
#[sqlx(type_name = "oauth_client_type", rename_all = "lowercase")]
pub enum OAuthClientType {
Private,
Public,
}
impl OAuthClientType {
pub const PRIVATE: &str = "private";
pub const PUBLIC: &str = "public";
#[inline]
pub const fn new() -> Self {
Self::Private
}
pub const fn as_str(&self) -> &'static str {
match self {
Self::Private => Self::PRIVATE,
Self::Public => Self::PUBLIC,
}
}
}
impl_default!(OAuthClientType);
impl_display!(OAuthClientType, str);
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct OAuthClientRegister {
#[serde(skip_serializing_if = "Option::is_none")]
redirect_uris: Option<Vec<Iri>>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<OAuthScope>,
#[serde(skip_serializing_if = "Option::is_none")]
jwks: Option<jwt::jwk::JwkSet>,
}
impl OAuthClientRegister {
#[inline]
pub const fn new() -> Self {
Self {
redirect_uris: None,
scope: None,
jwks: None,
}
}
}
field_access! {
OAuthClientRegister {
redirect_uris: option_deref { &[Iri], Vec<Iri> },
}
}
field_access! {
OAuthClientRegister {
scope: option_ref { OAuthScope },
jwks: option_ref { jwt::jwk::JwkSet },
}
}
impl_default!(OAuthClientRegister);
impl_display!(OAuthClientRegister, json);
impl TryFrom<OAuthClientRegister> for OAuthClient {
type Error = Error;
fn try_from(val: OAuthClientRegister) -> Result<Self> {
(&val).try_into()
}
}
impl TryFrom<&OAuthClientRegister> for OAuthClient {
type Error = Error;
fn try_from(val: &OAuthClientRegister) -> Result<Self> {
let mut client = Self::new().with_uuid(util::rand_uuid());
if let Some(uris) = val.redirect_uris() {
client.set_redirect_uris(uris)?;
}
if let Some(scope) = val.scope() {
ScopeList::try_from(scope).and_then(|scopes| client.with_scopes(scopes))
} else {
Ok(client)
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize, FromRow)]
#[sqlx(type_name = "oauth_client")]
pub struct OAuthClient {
#[serde(
rename = "client_id",
serialize_with = "util::ser_uuid",
deserialize_with = "util::de_uuid"
)]
uuid: Uuid,
#[serde(serialize_with = "util::ser_uuid", deserialize_with = "util::de_uuid")]
owner_id: Uuid,
client_type: OAuthClientType,
scopes: ScopeList,
#[serde(skip)]
password: Password,
issued_at: DateTime,
redirect_uris: Vec<Iri>,
grant_types: Vec<OAuthGrantType>,
#[serde(
serialize_with = "util::ser_uuid_list",
deserialize_with = "util::de_uuid_list"
)]
key_ids: Vec<Uuid>,
}
impl OAuthClient {
pub fn new() -> Self {
Self {
uuid: Uuid::nil(),
owner_id: Uuid::nil(),
client_type: OAuthClientType::new(),
scopes: ScopeList::new(),
password: Password::new(),
issued_at: Utc::now().into(),
redirect_uris: Vec::new(),
grant_types: [
OAuthGrantType::AuthorizationCode,
OAuthGrantType::RefreshToken,
]
.into(),
key_ids: Vec::new(),
}
}
#[inline]
pub const fn client_id(&self) -> Uuid {
self.uuid
}
pub fn oauth_scope(&self) -> Result<OAuthScope> {
(&self.scopes)
.try_into()
.map_err(|err| Error::http(format!("oauth: client: {err}")))
}
pub fn check_db(&self) -> Result<()> {
if self.password.is_empty() {
Err(Error::db("oauth: client: empty password"))
} else if self.scopes.is_empty() {
Err(Error::db("oauth: client: empty scopes"))
} else {
Ok(())
}
}
pub async fn try_from_register(
db: &Db,
uri: &Iri,
owner_id: Uuid,
client_secret: &ClientSecret,
val: &OAuthClientRegister,
) -> Result<Self> {
let pool = db.pool()?;
let db_key = db.key()?;
let mut dbtx = pool.begin().await?;
let client =
Self::try_from_register_tx(&mut dbtx, &db_key, uri, owner_id, client_secret, val)
.await?;
dbtx.commit()
.await
.map(|_| client)
.map_err(|err| Error::db(format!("oauth: client: {err}")))
}
pub async fn try_from_register_tx(
dbtx: &mut Transaction<'_>,
db_key: &SymmetricKey,
uri: &Iri,
owner_id: Uuid,
client_secret: &ClientSecret,
val: &OAuthClientRegister,
) -> Result<Self> {
let mut client = Self::try_from(val).map(|c| c.with_owner_id(owner_id))?;
let password = Password::derive(
&client.uuid.to_string(),
client_secret.to_string().as_bytes(),
)?;
client.set_password(password.clone());
let client_id = Self::TABLE.id_from_uuid(uri, client.uuid)?;
let client_entry = client.table_entry();
let mut key_ids = Vec::new();
if let Some(jwks) = val.jwks() {
for mut key in jwks.keys.iter().filter_map(|jwk| {
Key::try_from(jwk)
.map(|k| k.with_actor_id(client_id.clone()).with_actor(client_entry))
.map_err(|err| log::warn!("oauth: client: error parsing JWK: {err}"))
.ok()
}) {
let key_uuid = util::rand_uuid();
let key_id = Key::TABLE.id_from_uuid(uri, key_uuid)?;
key.set_id(key_id);
let key_uuid = key
.insert_tx(dbtx, db_key)
.await
.map_err(|err| Error::db(format!("oauth: client: {err}")))?;
key_ids.push(key_uuid);
}
}
client.set_key_ids(key_ids)?;
client.insert_tx(dbtx).await.map(|_| client)
}
pub async fn keys(&self, db: &Db) -> Result<Vec<Key>> {
let pool = db.pool()?;
let db_key = db.key()?;
let mut dbtx = pool.begin().await?;
let client = self.keys_tx(&mut dbtx, &db_key).await?;
dbtx.commit()
.await
.map(|_| client)
.map_err(|err| Error::db(format!("oauth: client: {err}")))
}
pub async fn keys_tx(
&self,
dbtx: &mut Transaction<'_>,
db_key: &SymmetricKey,
) -> Result<Vec<Key>> {
Key::find_by_actor_tx(dbtx, db_key, self.table_entry()).await
}
}
field_access! {
OAuthClient {
uuid: Uuid,
owner_id: Uuid,
client_type: OAuthClientType,
issued_at: DateTime,
}
}
field_access! {
OAuthClient {
password: as_ref { Password },
}
}
impl_sql_record! {
OAuthClient {
owner_id: { "owner_id" Uuid },
client_type: { "client_type" OAuthClientType },
scopes: { "scopes" ScopeList },
password: { "password" Password },
issued_at: { "issued_at" DateTime },
redirect_uris: { "redirect_uris" IriList },
grant_types: { "grant_types" OAuthGrantTypeList },
key_ids: { "key_ids" UuidList },
}
}
impl_sql_list_field! {
OAuthClient {
scope, scopes: { "scopes" Scope },
redirect_uri, redirect_uris: { "redirect_uris" Iri },
grant_type, grant_types: { "grant_types" OAuthGrantType },
key_id, key_ids: { "key_ids" Uuid },
}
}
impl_default!(OAuthClient);
impl_display!(OAuthClient, json);