skyhook 0.5.51

Application server for Ordinary
Documentation
use std::sync::Arc;

use aes_gcm::{
    Aes256Gcm, Key,
    aead::{Aead, AeadCore, OsRng},
};
use axum::body::Bytes;
use axum::extract::{Form, State};
use axum::http::StatusCode;
use axum::http::header::CONTENT_TYPE;
use axum::response::IntoResponse;
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use bytes::{BufMut, BytesMut};
use ordinary_auth::AuthClient;
use ordinary_config::ClientPasswordHash;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use x25519_dalek::{EphemeralSecret, PublicKey};

pub async fn start(
    State(state): State<Arc<crate::server::OrdinaryAppServerState>>,
    body: Bytes,
) -> impl IntoResponse {
    let span = tracing::info_span!("auth", flv = %"wasm");

    span.in_scope(|| match state.auth.registration_start(body, None, None) {
        Ok(v) => (StatusCode::OK, v),
        Err(e) => {
            tracing::error!("{e}");
            (StatusCode::INTERNAL_SERVER_ERROR, Bytes::new())
        }
    })
}

pub async fn finish(
    State(state): State<Arc<crate::server::OrdinaryAppServerState>>,
    body: Bytes,
) -> impl IntoResponse {
    let span = tracing::info_span!("auth", flv = %"wasm");

    span.in_scope(|| match state.auth.registration_finish(body, None) {
        Ok((v, account_bytes, _invite_claims)) => {
            if let Ok(account) = std::str::from_utf8(&account_bytes) {
                let mut builder =
                    flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
                let mut vec_builder = builder.start_vector();
                vec_builder.push(account);

                // empty claims
                vec_builder.push(());

                vec_builder.end_vector();

                for action in &state.registration_actions {
                    if let Some(action) = state.actions.get(*action as usize) {
                        let span = tracing::info_span!("action");

                        span.in_scope(|| {
                            if let Err(err) = action.call(builder.view(), &state.actions) {
                                tracing::error!(%err);
                            }
                        });
                    }
                }
            }

            (StatusCode::OK, v)
        }
        Err(e) => {
            tracing::error!("{e}");
            (StatusCode::INTERNAL_SERVER_ERROR, Bytes::new())
        }
    })
}

#[derive(Deserialize, Debug)]
pub struct RegisterForm {
    account: String,
    password: String,
    invite_token: Option<String>,
}

#[allow(clippy::too_many_lines)]
/// for <noscript> users
pub async fn form(
    State(state): State<Arc<crate::server::OrdinaryAppServerState>>,
    Form(register_form): Form<RegisterForm>,
) -> impl IntoResponse {
    let span = tracing::info_span!("auth", flv = %"noscript");

    span.in_scope(|| {
        let account = register_form.account.as_bytes();

        if account.len() > 255 {
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                [(CONTENT_TYPE, "text/plain".to_string())],
                Bytes::copy_from_slice(b"encoded account cannot be more than 255 bytes!"),
            );
        }

        let password = register_form.password.as_bytes();
        let app_name = state.config.domain.as_bytes();

        let mut input = app_name.to_vec();
        input.extend_from_slice(account);
        input.extend_from_slice(password);

        let mut password = vec![];

        if let Some(auth) = &state.config.auth {
            match &auth.client_hash {
                ClientPasswordHash::Sha256 => {
                    let mut hasher = Sha256::new();
                    hasher.update(&input);
                    password = hasher.finalize().to_vec();
                }
            }
        }

        let mut invite_token = None;

        if state.auth.config.invite.is_some()
            && let Some(token) = register_form.invite_token
        {
            match b64.decode(token) {
                Ok(decoded) => {
                    invite_token = Some(Bytes::copy_from_slice(&decoded[..]));
                }
                Err(err) => {
                    tracing::error!(%err);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        [(CONTENT_TYPE, "text/plain".to_string())],
                        Bytes::copy_from_slice(b"error!"),
                    );
                }
            }
        }

        match AuthClient::registration_start_req(account, &password[..], invite_token) {
            Ok((client_state, req)) => match state.auth.registration_start(req, None, None) {
                Ok(server_message) => match AuthClient::registration_finish_req(
                    account,
                    &password[..],
                    &client_state,
                    &server_message[..],
                ) {
                    Ok((private_key, req)) => match state.auth.registration_finish(req, None) {
                        Ok((res, account_bytes, _invite_claims)) => {
                            if let Ok(account) = std::str::from_utf8(&account_bytes) {
                                let mut builder = flexbuffers::Builder::new(
                                    &flexbuffers::BuilderOptions::SHARE_NONE,
                                );
                                let mut vec_builder = builder.start_vector();
                                vec_builder.push(account);

                                // empty claims
                                vec_builder.push(());

                                vec_builder.end_vector();

                                for action in &state.registration_actions {
                                    if let Some(action) = state.actions.get(*action as usize) {
                                        let span = tracing::info_span!("action");

                                        span.in_scope(|| {
                                            if let Err(err) =
                                                action.call(builder.view(), &state.actions)
                                            {
                                                tracing::error!(%err);
                                            }
                                        });
                                    }
                                }
                            }

                            match AuthClient::decrypt_totp_mfa_to_qr_svg(
                                &res,
                                private_key,
                                state.config.domain.clone(),
                                register_form.account,
                            ) {
                                Ok((qr_code, recovery_codes)) => {
                                    if let Some(idx) = state.mfa_totp_template_idx
                                        && let Some(template) = state.templates.get(idx as usize)
                                    {
                                        match std::str::from_utf8(&account_bytes) {
                                            Ok(account) => {
                                                match template.render(
                                                    "/totp".into(),
                                                    None,
                                                    None,
                                                    Some((
                                                        qr_code.clone(),
                                                        account.to_string(),
                                                        recovery_codes,
                                                    )),
                                                    &None,
                                                ) {
                                                    Ok(res) => {
                                                        return (
                                                            StatusCode::OK,
                                                            [(
                                                                CONTENT_TYPE,
                                                                template.config.mime.clone(),
                                                            )],
                                                            res,
                                                        );
                                                    }
                                                    Err(err) => tracing::error!(%err),
                                                }
                                            }
                                            Err(err) => tracing::error!(%err),
                                        }
                                    }

                                    return (
                                        StatusCode::OK,
                                        [(CONTENT_TYPE, "text/html".to_string())],
                                        Bytes::copy_from_slice(qr_code.as_bytes()),
                                    );
                                }
                                Err(err) => tracing::error!(%err),
                            }
                        }
                        Err(err) => tracing::error!(%err),
                    },
                    Err(err) => tracing::error!(%err),
                },
                Err(err) => tracing::error!(%err),
            },
            Err(err) => tracing::error!(%err),
        }

        (
            StatusCode::INTERNAL_SERVER_ERROR,
            [(CONTENT_TYPE, "text/plain".to_string())],
            Bytes::copy_from_slice(b"error!"),
        )
    })
}

#[allow(clippy::too_many_lines)]
pub async fn hash_only(
    State(state): State<Arc<crate::server::OrdinaryAppServerState>>,
    body: Bytes,
) -> impl IntoResponse {
    let span = tracing::info_span!("auth", flv = %"js");

    span.in_scope(|| {
        if let Some(account_len) = body.first() {
            let account_len = *account_len as usize;

            if body.len() < account_len + 1 + 32 + 32 {
                return (StatusCode::INTERNAL_SERVER_ERROR, Bytes::new());
            }

            let account_bytes = &body[1..=account_len];
            let password = &body[account_len + 1..account_len + 1 + 32];
            let public_key: [u8; 32] =
                match body[account_len + 1 + 32..account_len + 1 + 32 + 32].try_into() {
                    Ok(pk) => pk,
                    Err(err) => {
                        tracing::error!(%err);
                        return (StatusCode::EXPECTATION_FAILED, Bytes::new());
                    }
                };

            let mut invite_token = None;

            if state.auth.config.invite.is_some() {
                if body.len() < account_len + 1 + 32 + 32 + 1 {
                    return (StatusCode::INTERNAL_SERVER_ERROR, Bytes::new());
                }

                let token = &body[account_len + 1 + 32 + 32..];
                invite_token = Some(Bytes::copy_from_slice(token));
            }

            match AuthClient::registration_start_req(account_bytes, password, invite_token) {
                Ok((client_state, req)) => match state.auth.registration_start(req, None, None) {
                    Ok(server_message) => match AuthClient::registration_finish_req(
                        account_bytes,
                        password,
                        &client_state,
                        &server_message[..],
                    ) {
                        Ok((private_key, req)) => {
                            match state.auth.registration_finish(req, None) {
                                Ok((res, account_bytes, _invite_claims)) => {
                                    if let Ok(account) = std::str::from_utf8(&account_bytes) {
                                        let mut builder = flexbuffers::Builder::new(
                                            &flexbuffers::BuilderOptions::SHARE_NONE,
                                        );
                                        let mut vec_builder = builder.start_vector();
                                        vec_builder.push(account);

                                        // empty claims
                                        vec_builder.push(());

                                        vec_builder.end_vector();

                                        for action in &state.registration_actions {
                                            if let Some(action) =
                                                state.actions.get(*action as usize)
                                            {
                                                let span = tracing::info_span!("action");

                                                span.in_scope(|| {
                                                    if let Err(err) =
                                                        action.call(builder.view(), &state.actions)
                                                    {
                                                        tracing::error!(%err);
                                                    }
                                                });
                                            }
                                        }

                                        match AuthClient::decrypt_totp_mfa_to_qr_svg(
                                            &res,
                                            private_key,
                                            state.config.domain.clone(),
                                            account.into(),
                                        ) {
                                            Ok((qr_code, recovery_codes)) => {
                                                use aes_gcm::aead::KeyInit;

                                                let public_key = PublicKey::from(public_key);
                                                let ephemeral_secret =
                                                    EphemeralSecret::random_from_rng(OsRng);
                                                let ephemeral_public_key =
                                                    PublicKey::from(&ephemeral_secret);

                                                let shared_secret =
                                                    ephemeral_secret.diffie_hellman(&public_key);

                                                let key = Key::<Aes256Gcm>::from_slice(
                                                    shared_secret.as_bytes(),
                                                );

                                                let cipher = Aes256Gcm::new(key);
                                                let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 96-bits; unique per message

                                                let ciphertext = match cipher.encrypt(
                                                    &nonce,
                                                    format!("{recovery_codes}__{qr_code}").as_ref(),
                                                ) {
                                                    Ok(v) => v,
                                                    Err(err) => {
                                                        tracing::error!(%err);
                                                        return (
                                                            StatusCode::INTERNAL_SERVER_ERROR,
                                                            Bytes::new(),
                                                        );
                                                    }
                                                };

                                                let mut res = BytesMut::new();

                                                res.put(&ephemeral_public_key.as_bytes()[..]);
                                                res.put(&nonce[..]);
                                                res.put(&ciphertext[..]);

                                                return (StatusCode::OK, res.into());
                                            }
                                            Err(err) => tracing::error!(%err),
                                        }
                                    }
                                }
                                Err(err) => tracing::error!(%err),
                            }
                        }
                        Err(err) => tracing::error!(%err),
                    },
                    Err(err) => tracing::error!(%err),
                },
                Err(err) => tracing::error!(%err),
            }
        }

        (StatusCode::INTERNAL_SERVER_ERROR, Bytes::new())
    })
}