openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
//! Plugin contracts for OpenAuth extensions.

use std::future::Future;
use std::pin::Pin;

mod db;
mod endpoint;
mod error;
mod hooks;
mod init;
mod password;
mod rate_limit;
mod schema;

pub use db::{
    PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
    PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
    PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration,
};
pub use endpoint::PluginEndpoint;
pub use error::PluginErrorCode;
pub use hooks::{
    PluginAfterHook, PluginAfterHookAction, PluginAfterHookFuture, PluginAfterHookHandler,
    PluginAsyncAfterHook, PluginAsyncAfterHookHandler, PluginAsyncBeforeHook,
    PluginAsyncBeforeHookHandler, PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture,
    PluginBeforeHookHandler, PluginEndpointHooks, PluginHookMatcher,
};
pub use init::{PluginInitHandler, PluginInitOutput};
pub use password::{
    PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
    PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
};
pub use rate_limit::PluginRateLimitRule;
pub use schema::PluginSchemaContribution;

use crate::api::AsyncAuthEndpoint;
use crate::context::AuthContext;
use crate::error::OpenAuthError;
use http::{Request, Response};
use openauth_oauth::oauth2::SocialOAuthProvider;
use serde_json::Value;
use std::fmt;
use std::sync::Arc;

pub type PluginBody = Vec<u8>;
pub type PluginRequest = Request<PluginBody>;
pub type PluginResponse = Response<PluginBody>;
pub type PluginMiddlewareFuture<'a> =
    Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, OpenAuthError>> + Send + 'a>>;
pub type PluginOnRequest = Arc<
    dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError> + Send + Sync,
>;
pub type PluginOnResponse = Arc<
    dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, OpenAuthError>
        + Send
        + Sync,
>;
pub type PluginMiddlewareHandler = Arc<
    dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
        + Send
        + Sync,
>;
pub type PluginAsyncMiddlewareHandler = Arc<
    dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
>;

#[derive(Clone)]
pub struct AuthPlugin {
    pub id: String,
    pub version: Option<String>,
    pub options: Option<Value>,
    pub endpoints: Vec<AsyncAuthEndpoint>,
    pub middlewares: Vec<PluginMiddleware>,
    pub async_middlewares: Vec<PluginAsyncMiddleware>,
    pub on_request: Option<PluginOnRequest>,
    pub on_response: Option<PluginOnResponse>,
    pub init: Option<PluginInitHandler>,
    pub schema: Vec<PluginSchemaContribution>,
    pub rate_limit: Vec<PluginRateLimitRule>,
    pub hooks: PluginEndpointHooks,
    pub error_codes: Vec<PluginErrorCode>,
    pub database_hooks: Vec<PluginDatabaseHook>,
    pub migrations: Vec<PluginMigration>,
    pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
    pub password_validators: Vec<PluginPasswordValidator>,
}

impl AuthPlugin {
    pub fn new(id: impl Into<String>) -> Self {
        Self {
            id: id.into(),
            version: None,
            options: None,
            endpoints: Vec::new(),
            middlewares: Vec::new(),
            async_middlewares: Vec::new(),
            on_request: None,
            on_response: None,
            init: None,
            schema: Vec::new(),
            rate_limit: Vec::new(),
            hooks: PluginEndpointHooks::default(),
            error_codes: Vec::new(),
            database_hooks: Vec::new(),
            migrations: Vec::new(),
            social_providers: Vec::new(),
            password_validators: Vec::new(),
        }
    }

    pub fn with_version(mut self, version: impl Into<String>) -> Self {
        self.version = Some(version.into());
        self
    }

    pub fn with_options(mut self, options: Value) -> Self {
        self.options = Some(options);
        self
    }

    pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
        self.endpoints.push(endpoint);
        self
    }

    pub fn with_init<F>(mut self, init: F) -> Self
    where
        F: Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync + 'static,
    {
        self.init = Some(Arc::new(init));
        self
    }

    pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
        self.schema.push(contribution);
        self
    }

    pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
        self.rate_limit.push(rule);
        self
    }

    pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
    where
        F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, OpenAuthError>
            + Send
            + Sync
            + 'static,
    {
        self.hooks.before.push(PluginBeforeHook {
            matcher: PluginHookMatcher::path(path),
            handler: Arc::new(hook),
        });
        self
    }

    pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
    where
        F: Fn(
                &AuthContext,
                &PluginRequest,
                PluginResponse,
            ) -> Result<PluginAfterHookAction, OpenAuthError>
            + Send
            + Sync
            + 'static,
    {
        self.hooks.after.push(PluginAfterHook {
            matcher: PluginHookMatcher::path(path),
            handler: Arc::new(hook),
        });
        self
    }

    pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
    where
        F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
            + Send
            + Sync
            + 'static,
    {
        self.hooks.async_before.push(PluginAsyncBeforeHook {
            matcher: PluginHookMatcher::path(path),
            handler: Arc::new(hook),
        });
        self
    }

    pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
    where
        F: for<'a> Fn(
                &'a AuthContext,
                &'a PluginRequest,
                PluginResponse,
            ) -> PluginAfterHookFuture<'a>
            + Send
            + Sync
            + 'static,
    {
        self.hooks.async_after.push(PluginAsyncAfterHook {
            matcher: PluginHookMatcher::path(path),
            handler: Arc::new(hook),
        });
        self
    }

    pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
        self.error_codes.push(error_code);
        self
    }

    pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
        self.database_hooks.push(hook);
        self
    }

    pub fn with_migration(mut self, migration: PluginMigration) -> Self {
        self.migrations.push(migration);
        self
    }

    pub fn with_social_provider(
        mut self,
        provider: impl Into<Arc<dyn SocialOAuthProvider>>,
    ) -> Self {
        self.social_providers.push(provider.into());
        self
    }

    pub fn with_password_validator<F>(mut self, validator: F) -> Self
    where
        F: for<'a> Fn(
                &'a AuthContext,
                PluginPasswordValidationInput,
            ) -> PluginPasswordValidatorFuture<'a>
            + Send
            + Sync
            + 'static,
    {
        self.password_validators.push(PluginPasswordValidator {
            handler: Arc::new(validator),
        });
        self
    }

    pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
    where
        F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
            + Send
            + Sync
            + 'static,
    {
        self.middlewares.push(PluginMiddleware {
            path: path.into(),
            handler: Arc::new(middleware),
        });
        self
    }

    pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
    where
        F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
            + Send
            + Sync
            + 'static,
    {
        self.async_middlewares.push(PluginAsyncMiddleware {
            path: path.into(),
            handler: Arc::new(middleware),
        });
        self
    }

    pub fn with_on_request<F>(mut self, hook: F) -> Self
    where
        F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError>
            + Send
            + Sync
            + 'static,
    {
        self.on_request = Some(Arc::new(hook));
        self
    }

    pub fn with_on_response<F>(mut self, hook: F) -> Self
    where
        F: Fn(
                &AuthContext,
                &PluginRequest,
                PluginResponse,
            ) -> Result<PluginResponse, OpenAuthError>
            + Send
            + Sync
            + 'static,
    {
        self.on_response = Some(Arc::new(hook));
        self
    }
}

impl fmt::Debug for AuthPlugin {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter
            .debug_struct("AuthPlugin")
            .field("id", &self.id)
            .field("version", &self.version)
            .field("options", &self.options)
            .field("endpoints", &self.endpoints.len())
            .field("middlewares", &self.middlewares)
            .field("async_middlewares", &self.async_middlewares)
            .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
            .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
            .field("init", &self.init.as_ref().map(|_| "<init>"))
            .field("schema", &self.schema)
            .field("rate_limit", &self.rate_limit)
            .field("hooks", &self.hooks)
            .field("error_codes", &self.error_codes)
            .field("database_hooks", &self.database_hooks)
            .field("migrations", &self.migrations)
            .field(
                "social_providers",
                &self
                    .social_providers
                    .iter()
                    .map(|provider| provider.id())
                    .collect::<Vec<_>>(),
            )
            .field("password_validators", &self.password_validators)
            .finish()
    }
}

#[derive(Clone)]
pub struct PluginMiddleware {
    pub path: String,
    pub handler: PluginMiddlewareHandler,
}

impl fmt::Debug for PluginMiddleware {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter
            .debug_struct("PluginMiddleware")
            .field("path", &self.path)
            .field("handler", &"<middleware>")
            .finish()
    }
}

#[derive(Clone)]
pub struct PluginAsyncMiddleware {
    pub path: String,
    pub handler: PluginAsyncMiddlewareHandler,
}

impl fmt::Debug for PluginAsyncMiddleware {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter
            .debug_struct("PluginAsyncMiddleware")
            .field("path", &self.path)
            .field("handler", &"<async middleware>")
            .finish()
    }
}

pub enum PluginRequestAction {
    Continue(PluginRequest),
    Respond(PluginResponse),
}