cronback-api-srv 0.1.0

API server for cronback backend
Documentation
use std::fmt::Display;
use std::str::FromStr;

use base64::Engine;
use chrono::Utc;
use cronback_api_model::admin::CreateAPIkeyRequest;
use lib::database::models::api_keys;
use lib::database::DatabaseError;
use lib::prelude::ValidShardedId;
use lib::types::ProjectId;
use sha2::{Digest, Sha512};
use thiserror::Error;
use tracing::error;
use uuid::Uuid;

use crate::auth_store::AuthStore;
use crate::errors::ApiError;

pub static API_KEY_PREFIX: &str = "sk_";

#[derive(Error, Debug)]
pub enum AuthError {
    #[error("database error: {0}")]
    Database(#[from] DatabaseError),
    #[error("auth failed: {0}")]
    AuthFailed(String),
    #[error("internal error: {0}")]
    Internal(String),
}

impl From<AuthError> for ApiError {
    fn from(value: AuthError) -> Self {
        match value {
            | AuthError::Database(e) => {
                error!("{}", e);
                ApiError::ServiceUnavailable
            }
            | AuthError::Internal(e) => {
                error!("{}", e);
                ApiError::ServiceUnavailable
            }
            | AuthError::AuthFailed(_) => ApiError::Unauthorized,
        }
    }
}

pub struct Authenticator {
    store: Box<dyn AuthStore + Send + Sync>,
}

impl Authenticator {
    pub fn new(store: Box<dyn AuthStore + Send + Sync>) -> Self {
        Self { store }
    }

    pub async fn gen_key(
        &self,
        req: CreateAPIkeyRequest,
        project: &ValidShardedId<ProjectId>,
    ) -> Result<SecretApiKey, AuthError> {
        let key = SecretApiKey::generate();
        let hashed = key.hash(HashVersion::default());

        let model = api_keys::Model {
            key_id: hashed.key_id,
            hash: hashed.hash,
            hash_version: hashed.hash_version.to_string(),
            project_id: project.clone(),
            name: req.key_name,
            created_at: Utc::now(),
            metadata: api_keys::Metadata {
                creator_user_id: req.metadata.creator_user_id,
            },
        };

        self.store.save_key(model).await?;

        Ok(key)
    }

    pub async fn authenticate(
        &self,
        user_provided_secret: &SecretApiKey,
    ) -> Result<ValidShardedId<ProjectId>, AuthError> {
        let key_model =
            self.store.get_key(user_provided_secret.key_id()).await?;

        let Some(key_model) = key_model else {
            // key_id doesn't exist in the database
            return Err(AuthError::AuthFailed(
                "key_id doesn't exist".to_string(),
            ));
        };

        let hash_version = key_model.hash_version;
        let hash_version: HashVersion = hash_version.parse().map_err(|_| {
            AuthError::Internal(format!("Unknown version: {hash_version}"))
        })?;

        let user_provided_hash = user_provided_secret.hash(hash_version);
        let stored_hash = key_model.hash;

        if user_provided_hash.hash != stored_hash {
            return Err(AuthError::AuthFailed(
                "Mismatched secret key".to_string(),
            ));
        }

        Ok(key_model.project_id)
    }

    pub async fn revoke_key(
        &self,
        key_id: &str,
        project: &ValidShardedId<ProjectId>,
    ) -> Result<bool, AuthError> {
        let res = self.store.delete_key(key_id, project).await?;
        Ok(res)
    }

    pub async fn list_keys(
        &self,
        project: &ValidShardedId<ProjectId>,
    ) -> Result<Vec<api_keys::Model>, AuthError> {
        let res = self.store.list_keys(project).await?;
        Ok(res)
    }
}

#[derive(Default)]
enum HashVersion {
    #[default]
    V1,
}

impl Display for HashVersion {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            | HashVersion::V1 => write!(f, "v1"),
        }
    }
}

impl FromStr for HashVersion {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            | "v1" => Ok(HashVersion::V1),
            | v => Err(format!("Invalid version: {v}")),
        }
    }
}

struct HashedApiKey {
    pub key_id: String,
    pub hash: String,
    pub hash_version: HashVersion,
}

// To avoid leaking the plaintext key anywhere, this struct doesn't allow you
// to unwrap the inner plaintext key and doesn't implement Debug/Display
#[cfg_attr(test, derive(Clone))]
pub struct SecretApiKey {
    key_id: String,
    plain_secret: String,
}

impl FromStr for SecretApiKey {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let Some(s) = s.strip_prefix(API_KEY_PREFIX) else {
            return Err(format!(
                "API key doesn't start with {}",
                API_KEY_PREFIX
            ));
        };

        match s.split_once('_') {
            | Some((id, secret)) => {
                Ok(Self {
                    key_id: id.to_string(),
                    plain_secret: secret.to_string(),
                })
            }
            | None => Err("Malformed API key".to_string()),
        }
    }
}

impl SecretApiKey {
    fn generate() -> Self {
        Self {
            key_id: Uuid::new_v4().simple().to_string(),
            plain_secret: Uuid::new_v4().simple().to_string(),
        }
    }

    fn hash(&self, version: HashVersion) -> HashedApiKey {
        match version {
            | HashVersion::V1 => {
                let hash =
                    Sha512::digest(&self.plain_secret).as_slice().to_vec();
                HashedApiKey {
                    key_id: self.key_id.clone(),
                    hash: base64::engine::general_purpose::STANDARD
                        .encode(hash),
                    hash_version: HashVersion::V1,
                }
            }
        }
    }

    pub fn key_id(&self) -> &String {
        &self.key_id
    }

    pub fn unsafe_to_string(&self) -> String {
        format!("{}{}_{}", API_KEY_PREFIX, self.key_id, self.plain_secret)
    }
}

#[cfg(test)]
mod tests {
    use std::str::FromStr;

    use cronback_api_model::admin::CreateAPIkeyRequest;
    use lib::database::Database;
    use lib::types::ProjectId;

    use super::SecretApiKey;
    use crate::auth::{AuthError, Authenticator};
    use crate::auth_store::SqlAuthStore;

    #[test]
    fn test_api_key() {
        let api_key = SecretApiKey {
            key_id: "key1".to_string(),
            plain_secret: "supersecure".to_string(),
        };

        let serialized = api_key.unsafe_to_string();

        assert_eq!(serialized, "sk_key1_supersecure");

        let parsed_api_key: SecretApiKey = serialized.parse().unwrap();

        assert_eq!(api_key.key_id, parsed_api_key.key_id);
        assert_eq!(api_key.plain_secret, parsed_api_key.plain_secret);
    }

    fn build_create_key_request(name: &str) -> CreateAPIkeyRequest {
        CreateAPIkeyRequest {
            key_name: name.to_string(),
            metadata: cronback_api_model::admin::APIKeyMetaData {
                creator_user_id: None,
            },
        }
    }

    #[tokio::test]
    async fn test_sql_auth_store() -> anyhow::Result<()> {
        let db = Database::in_memory().await?;
        let store = SqlAuthStore::new(db);

        let prj1 = ProjectId::generate();
        let prj2 = ProjectId::generate();

        let authenticator = Authenticator::new(Box::new(store));

        let key1 = authenticator
            .gen_key(build_create_key_request("key1"), &prj1)
            .await?;
        let key2 = authenticator
            .gen_key(build_create_key_request("key2"), &prj2)
            .await?;
        let key3 = authenticator
            .gen_key(build_create_key_request("key3"), &prj1)
            .await?;
        let key4 = authenticator
            .gen_key(build_create_key_request("key4"), &prj2)
            .await?;

        // Test authenticate
        assert_eq!(prj1, authenticator.authenticate(&key1).await?);
        assert_eq!(prj2, authenticator.authenticate(&key2).await?);
        assert_eq!(prj1, authenticator.authenticate(&key3).await?);
        assert_eq!(prj2, authenticator.authenticate(&key4).await?);

        // Unknown key id
        let key5 = SecretApiKey::from_str("sk_notfound_secret4").unwrap();
        assert!(matches!(
            authenticator.authenticate(&key5).await,
            Err(AuthError::AuthFailed(_))
        ));

        // Wrong secret
        let key5 = SecretApiKey::from_str("sk_key1_wrongsecret").unwrap();
        assert!(matches!(
            authenticator.authenticate(&key5).await,
            Err(AuthError::AuthFailed(_))
        ));

        // Test delete key
        assert!(authenticator.revoke_key(key1.key_id(), &prj1).await?);
        assert!(matches!(
            authenticator.authenticate(&key1).await,
            Err(AuthError::AuthFailed(_))
        ));

        // Test List keys
        assert_eq!(
            authenticator
                .list_keys(&prj2)
                .await?
                .into_iter()
                .map(|k| k.name)
                .collect::<Vec<_>>(),
            vec!["key2".to_string(), "key4".to_string()]
        );

        Ok(())
    }
}