use std::fmt::Display;
use std::str::FromStr;
use base64::Engine;
use chrono::Utc;
use cronback_api_model::admin::CreateAPIkeyRequest;
use lib::database::models::api_keys;
use lib::database::DatabaseError;
use lib::prelude::ValidShardedId;
use lib::types::ProjectId;
use sha2::{Digest, Sha512};
use thiserror::Error;
use tracing::error;
use uuid::Uuid;
use crate::auth_store::AuthStore;
use crate::errors::ApiError;
pub static API_KEY_PREFIX: &str = "sk_";
#[derive(Error, Debug)]
pub enum AuthError {
#[error("database error: {0}")]
Database(#[from] DatabaseError),
#[error("auth failed: {0}")]
AuthFailed(String),
#[error("internal error: {0}")]
Internal(String),
}
impl From<AuthError> for ApiError {
fn from(value: AuthError) -> Self {
match value {
| AuthError::Database(e) => {
error!("{}", e);
ApiError::ServiceUnavailable
}
| AuthError::Internal(e) => {
error!("{}", e);
ApiError::ServiceUnavailable
}
| AuthError::AuthFailed(_) => ApiError::Unauthorized,
}
}
}
pub struct Authenticator {
store: Box<dyn AuthStore + Send + Sync>,
}
impl Authenticator {
pub fn new(store: Box<dyn AuthStore + Send + Sync>) -> Self {
Self { store }
}
pub async fn gen_key(
&self,
req: CreateAPIkeyRequest,
project: &ValidShardedId<ProjectId>,
) -> Result<SecretApiKey, AuthError> {
let key = SecretApiKey::generate();
let hashed = key.hash(HashVersion::default());
let model = api_keys::Model {
key_id: hashed.key_id,
hash: hashed.hash,
hash_version: hashed.hash_version.to_string(),
project_id: project.clone(),
name: req.key_name,
created_at: Utc::now(),
metadata: api_keys::Metadata {
creator_user_id: req.metadata.creator_user_id,
},
};
self.store.save_key(model).await?;
Ok(key)
}
pub async fn authenticate(
&self,
user_provided_secret: &SecretApiKey,
) -> Result<ValidShardedId<ProjectId>, AuthError> {
let key_model =
self.store.get_key(user_provided_secret.key_id()).await?;
let Some(key_model) = key_model else {
return Err(AuthError::AuthFailed(
"key_id doesn't exist".to_string(),
));
};
let hash_version = key_model.hash_version;
let hash_version: HashVersion = hash_version.parse().map_err(|_| {
AuthError::Internal(format!("Unknown version: {hash_version}"))
})?;
let user_provided_hash = user_provided_secret.hash(hash_version);
let stored_hash = key_model.hash;
if user_provided_hash.hash != stored_hash {
return Err(AuthError::AuthFailed(
"Mismatched secret key".to_string(),
));
}
Ok(key_model.project_id)
}
pub async fn revoke_key(
&self,
key_id: &str,
project: &ValidShardedId<ProjectId>,
) -> Result<bool, AuthError> {
let res = self.store.delete_key(key_id, project).await?;
Ok(res)
}
pub async fn list_keys(
&self,
project: &ValidShardedId<ProjectId>,
) -> Result<Vec<api_keys::Model>, AuthError> {
let res = self.store.list_keys(project).await?;
Ok(res)
}
}
#[derive(Default)]
enum HashVersion {
#[default]
V1,
}
impl Display for HashVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
| HashVersion::V1 => write!(f, "v1"),
}
}
}
impl FromStr for HashVersion {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
| "v1" => Ok(HashVersion::V1),
| v => Err(format!("Invalid version: {v}")),
}
}
}
struct HashedApiKey {
pub key_id: String,
pub hash: String,
pub hash_version: HashVersion,
}
#[cfg_attr(test, derive(Clone))]
pub struct SecretApiKey {
key_id: String,
plain_secret: String,
}
impl FromStr for SecretApiKey {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let Some(s) = s.strip_prefix(API_KEY_PREFIX) else {
return Err(format!(
"API key doesn't start with {}",
API_KEY_PREFIX
));
};
match s.split_once('_') {
| Some((id, secret)) => {
Ok(Self {
key_id: id.to_string(),
plain_secret: secret.to_string(),
})
}
| None => Err("Malformed API key".to_string()),
}
}
}
impl SecretApiKey {
fn generate() -> Self {
Self {
key_id: Uuid::new_v4().simple().to_string(),
plain_secret: Uuid::new_v4().simple().to_string(),
}
}
fn hash(&self, version: HashVersion) -> HashedApiKey {
match version {
| HashVersion::V1 => {
let hash =
Sha512::digest(&self.plain_secret).as_slice().to_vec();
HashedApiKey {
key_id: self.key_id.clone(),
hash: base64::engine::general_purpose::STANDARD
.encode(hash),
hash_version: HashVersion::V1,
}
}
}
}
pub fn key_id(&self) -> &String {
&self.key_id
}
pub fn unsafe_to_string(&self) -> String {
format!("{}{}_{}", API_KEY_PREFIX, self.key_id, self.plain_secret)
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use cronback_api_model::admin::CreateAPIkeyRequest;
use lib::database::Database;
use lib::types::ProjectId;
use super::SecretApiKey;
use crate::auth::{AuthError, Authenticator};
use crate::auth_store::SqlAuthStore;
#[test]
fn test_api_key() {
let api_key = SecretApiKey {
key_id: "key1".to_string(),
plain_secret: "supersecure".to_string(),
};
let serialized = api_key.unsafe_to_string();
assert_eq!(serialized, "sk_key1_supersecure");
let parsed_api_key: SecretApiKey = serialized.parse().unwrap();
assert_eq!(api_key.key_id, parsed_api_key.key_id);
assert_eq!(api_key.plain_secret, parsed_api_key.plain_secret);
}
fn build_create_key_request(name: &str) -> CreateAPIkeyRequest {
CreateAPIkeyRequest {
key_name: name.to_string(),
metadata: cronback_api_model::admin::APIKeyMetaData {
creator_user_id: None,
},
}
}
#[tokio::test]
async fn test_sql_auth_store() -> anyhow::Result<()> {
let db = Database::in_memory().await?;
let store = SqlAuthStore::new(db);
let prj1 = ProjectId::generate();
let prj2 = ProjectId::generate();
let authenticator = Authenticator::new(Box::new(store));
let key1 = authenticator
.gen_key(build_create_key_request("key1"), &prj1)
.await?;
let key2 = authenticator
.gen_key(build_create_key_request("key2"), &prj2)
.await?;
let key3 = authenticator
.gen_key(build_create_key_request("key3"), &prj1)
.await?;
let key4 = authenticator
.gen_key(build_create_key_request("key4"), &prj2)
.await?;
assert_eq!(prj1, authenticator.authenticate(&key1).await?);
assert_eq!(prj2, authenticator.authenticate(&key2).await?);
assert_eq!(prj1, authenticator.authenticate(&key3).await?);
assert_eq!(prj2, authenticator.authenticate(&key4).await?);
let key5 = SecretApiKey::from_str("sk_notfound_secret4").unwrap();
assert!(matches!(
authenticator.authenticate(&key5).await,
Err(AuthError::AuthFailed(_))
));
let key5 = SecretApiKey::from_str("sk_key1_wrongsecret").unwrap();
assert!(matches!(
authenticator.authenticate(&key5).await,
Err(AuthError::AuthFailed(_))
));
assert!(authenticator.revoke_key(key1.key_id(), &prj1).await?);
assert!(matches!(
authenticator.authenticate(&key1).await,
Err(AuthError::AuthFailed(_))
));
assert_eq!(
authenticator
.list_keys(&prj2)
.await?
.into_iter()
.map(|k| k.name)
.collect::<Vec<_>>(),
vec!["key2".to_string(), "key4".to_string()]
);
Ok(())
}
}