use axum::http::uri;
use base64::Engine;
use blueprint_core::debug;
use prost::Message;
use std::collections::BTreeMap;
use tracing::instrument;
use crate::{
Error,
api_tokens::{CUSTOM_ENGINE, GeneratedApiToken},
db::{RocksDb, cf},
types::{KeyType, ServiceId},
};
#[derive(prost::Message, Clone)]
pub struct ApiTokenModel {
#[prost(uint64)]
id: u64,
#[prost(string)]
token: String,
#[prost(uint64)]
service_id: u64,
#[prost(uint64)]
sub_service_id: u64,
#[prost(uint64)]
pub expires_at: u64,
#[prost(bool)]
pub is_enabled: bool,
#[prost(bytes)]
pub additional_headers: Vec<u8>,
}
#[derive(prost::Message, Clone, serde::Serialize, serde::Deserialize)]
pub struct TlsProfile {
#[prost(bool)]
pub tls_enabled: bool,
#[prost(bool)]
pub require_client_mtls: bool,
#[prost(bytes)]
pub encrypted_server_cert: Vec<u8>,
#[prost(bytes)]
pub encrypted_server_key: Vec<u8>,
#[prost(bytes)]
pub encrypted_client_ca_bundle: Vec<u8>,
#[prost(bytes)]
pub encrypted_upstream_ca_bundle: Vec<u8>,
#[prost(bytes)]
pub encrypted_upstream_client_cert: Vec<u8>,
#[prost(bytes)]
pub encrypted_upstream_client_key: Vec<u8>,
#[prost(uint32)]
pub client_cert_ttl_hours: u32,
#[prost(string, optional)]
pub sni: Option<String>,
#[prost(string, optional)]
pub subject_alt_name_template: Option<String>,
#[prost(string, repeated)]
pub allowed_dns_names: Vec<String>,
}
impl TlsProfile {
pub fn new(tls_enabled: bool, require_client_mtls: bool, client_cert_ttl_hours: u32) -> Self {
Self {
tls_enabled,
require_client_mtls,
encrypted_server_cert: Vec::new(),
encrypted_server_key: Vec::new(),
encrypted_client_ca_bundle: Vec::new(),
encrypted_upstream_ca_bundle: Vec::new(),
encrypted_upstream_client_cert: Vec::new(),
encrypted_upstream_client_key: Vec::new(),
client_cert_ttl_hours,
sni: None,
subject_alt_name_template: None,
allowed_dns_names: Vec::new(),
}
}
}
#[derive(prost::Message, Clone)]
pub struct ServiceModel {
#[prost(string)]
pub api_key_prefix: String,
#[prost(message, repeated)]
pub owners: Vec<ServiceOwnerModel>,
#[prost(string)]
pub upstream_url: String,
#[prost(message, optional)]
pub tls_profile: Option<TlsProfile>,
}
#[derive(prost::Message, Clone, PartialEq, Eq)]
pub struct ServiceOwnerModel {
#[prost(enumeration = "KeyType")]
pub key_type: i32,
#[prost(bytes)]
pub key_bytes: Vec<u8>,
}
#[derive(prost::Message, Clone, PartialEq, Eq)]
pub struct TlsCertMetadata {
#[prost(uint64)]
pub service_id: u64,
#[prost(string)]
pub cert_id: String,
#[prost(string)]
pub certificate_pem: String,
#[prost(string)]
pub serial: String,
#[prost(uint64)]
pub expires_at: u64,
#[prost(bool)]
pub is_revoked: bool,
#[prost(string)]
pub usage: String,
#[prost(string)]
pub common_name: String,
#[prost(string, repeated)]
pub subject_alt_names: Vec<String>,
#[prost(uint64)]
pub issued_at: u64,
#[prost(uint64)]
pub issued_by_api_key_id: u64,
#[prost(string, optional)]
pub tenant_id: Option<String>,
}
impl ApiTokenModel {
#[instrument(skip(db), err)]
pub fn find_token_id(id: u64, db: &RocksDb) -> Result<Option<Self>, crate::Error> {
let cf = db
.cf_handle(cf::TOKENS_OPTS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TOKENS_OPTS_CF))?;
let token_opts_bytes = db.get_pinned_cf(&cf, id.to_be_bytes())?;
token_opts_bytes
.map(|bytes| ApiTokenModel::decode(bytes.as_ref()))
.transpose()
.map_err(Into::into)
}
#[instrument(skip(self), ret)]
pub fn is(&self, plaintext: &str) -> bool {
use tiny_keccak::Hasher;
let mut hasher = tiny_keccak::Keccak::v256();
hasher.update(plaintext.as_bytes());
let mut output = [0u8; 32];
hasher.finalize(&mut output);
let token_hash = CUSTOM_ENGINE.encode(output);
debug!(
%plaintext,
%self.token,
%token_hash,
token_match = self.token == token_hash,
"Checking token match",
);
self.token == token_hash
}
pub fn save(&mut self, db: &RocksDb) -> Result<u64, crate::Error> {
let tokens_cf = db
.cf_handle(cf::TOKENS_OPTS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TOKENS_OPTS_CF))?;
if self.id != 0 {
let token_bytes = self.encode_to_vec();
db.put_cf(&tokens_cf, self.id.to_be_bytes(), token_bytes)?;
Ok(self.id)
} else {
self.create(db)
}
}
fn create(&mut self, db: &RocksDb) -> Result<u64, crate::Error> {
let tokens_cf = db
.cf_handle(cf::TOKENS_OPTS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TOKENS_OPTS_CF))?;
let seq_cf = db
.cf_handle(cf::SEQ_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::SEQ_CF))?;
let txn = db.transaction();
let mut retry_count = 0;
let max_retries = 10;
loop {
let result = txn.merge_cf(&seq_cf, b"tokens", 1u64.to_be_bytes());
match result {
Ok(()) => break,
Err(e) if e.kind() == rocksdb::ErrorKind::Busy => {
retry_count += 1;
if retry_count >= max_retries {
return Err(crate::Error::RocksDB(e));
}
}
Err(e) if e.kind() == rocksdb::ErrorKind::TryAgain => {
retry_count += 1;
if retry_count >= max_retries {
return Err(crate::Error::RocksDB(e));
}
}
Err(e) => return Err(crate::Error::RocksDB(e)),
}
}
let next_id = txn
.get_cf(&seq_cf, b"tokens")?
.map(|v| {
let mut id = [0u8; 8];
id.copy_from_slice(&v);
u64::from_be_bytes(id)
})
.unwrap_or(0u64);
self.id = next_id;
let tokens_bytes = self.encode_to_vec();
txn.put_cf(&tokens_cf, next_id.to_be_bytes(), tokens_bytes)?;
txn.commit()?;
Ok(next_id)
}
pub fn expires_at(&self) -> Option<u64> {
if self.expires_at == 0 {
None
} else {
Some(self.expires_at)
}
}
#[instrument(skip(self), ret)]
pub fn is_expired(&self) -> bool {
if self.expires_at == 0 {
return false;
}
let now = std::time::SystemTime::now();
let since_epoch = now
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let now = since_epoch.as_secs();
self.expires_at < now
}
pub fn service_id(&self) -> ServiceId {
ServiceId::new(self.service_id).with_subservice(self.sub_service_id)
}
pub fn get_additional_headers(&self) -> BTreeMap<String, String> {
if self.additional_headers.is_empty() {
BTreeMap::new()
} else {
serde_json::from_slice(&self.additional_headers).unwrap_or_default()
}
}
pub fn set_additional_headers(&mut self, headers: &BTreeMap<String, String>) {
self.additional_headers = serde_json::to_vec(headers).unwrap_or_default();
}
}
impl From<&GeneratedApiToken> for ApiTokenModel {
fn from(token: &GeneratedApiToken) -> Self {
let mut model = Self {
id: 0,
token: token.token.clone(),
service_id: token.service_id.0,
sub_service_id: token.service_id.1,
expires_at: token.expires_at().unwrap_or(0),
is_enabled: true,
additional_headers: Vec::new(),
};
model.set_additional_headers(token.additional_headers());
model
}
}
impl ServiceModel {
pub fn new(api_key_prefix: impl Into<String>, upstream_url: impl Into<String>) -> Self {
Self {
api_key_prefix: api_key_prefix.into(),
owners: Vec::new(),
upstream_url: upstream_url.into(),
tls_profile: None,
}
}
pub fn find_by_id(id: ServiceId, db: &RocksDb) -> Result<Option<Self>, crate::Error> {
let cf = db
.cf_handle(cf::SERVICES_USER_KEYS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
let service_bytes = db.get_pinned_cf(&cf, id.to_be_bytes())?;
service_bytes
.map(|bytes| ServiceModel::decode(bytes.as_ref()))
.transpose()
.map_err(Into::into)
}
pub fn save(&self, id: ServiceId, db: &RocksDb) -> Result<(), crate::Error> {
let cf = db
.cf_handle(cf::SERVICES_USER_KEYS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
let service_bytes = self.encode_to_vec();
db.put_cf(&cf, id.to_be_bytes(), service_bytes)?;
Ok(())
}
pub fn save_if_absent(&self, id: ServiceId, db: &RocksDb) -> Result<bool, crate::Error> {
let cf = db
.cf_handle(cf::SERVICES_USER_KEYS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
let key = id.to_be_bytes();
let txn = db.transaction();
if txn.get_cf(&cf, key)?.is_some() {
return Ok(false);
}
txn.put_cf(&cf, key, self.encode_to_vec())?;
txn.commit()?;
Ok(true)
}
pub fn delete(id: ServiceId, db: &RocksDb) -> Result<(), crate::Error> {
let cf = db
.cf_handle(cf::SERVICES_USER_KEYS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
db.delete_cf(&cf, id.to_be_bytes())?;
Ok(())
}
pub fn api_key_prefix(&self) -> &str {
&self.api_key_prefix
}
pub fn is_owner(&self, key_type: KeyType, key_bytes: &[u8]) -> bool {
self.owners
.iter()
.any(|owner| owner.key_type == key_type as i32 && owner.key_bytes == key_bytes)
}
pub fn add_owner(&mut self, key_type: KeyType, key_bytes: Vec<u8>) {
let owner = ServiceOwnerModel {
key_type: key_type as i32,
key_bytes,
};
self.owners.push(owner);
}
pub fn remove_owner(&mut self, key_type: KeyType, key_bytes: &[u8]) {
self.owners
.retain(|owner| !(owner.key_type == key_type as i32 && owner.key_bytes == key_bytes));
}
pub fn owners(&self) -> &[ServiceOwnerModel] {
&self.owners
}
pub fn upstream_url(&self) -> Result<uri::Uri, Error> {
self.upstream_url.parse::<uri::Uri>().map_err(Into::into)
}
pub fn tls_profile(&self) -> Option<&TlsProfile> {
self.tls_profile.as_ref()
}
pub fn set_tls_profile(&mut self, tls_profile: TlsProfile) {
self.tls_profile = Some(tls_profile);
}
pub fn is_tls_enabled(&self) -> bool {
self.tls_profile
.as_ref()
.map(|p| p.tls_enabled)
.unwrap_or(false)
}
pub fn requires_client_mtls(&self) -> bool {
self.tls_profile
.as_ref()
.map(|p| p.require_client_mtls)
.unwrap_or(false)
}
}
pub struct TlsCertMetadataConfig {
pub service_id: u64,
pub cert_id: String,
pub certificate_pem: String,
pub serial: String,
pub expires_at: u64,
pub usage: String,
pub common_name: String,
pub issued_by_api_key_id: u64,
}
impl TlsCertMetadata {
pub fn new(config: TlsCertMetadataConfig) -> Self {
Self {
service_id: config.service_id,
cert_id: config.cert_id,
certificate_pem: config.certificate_pem,
serial: config.serial,
expires_at: config.expires_at,
is_revoked: false,
usage: config.usage,
common_name: config.common_name,
subject_alt_names: Vec::new(),
issued_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
issued_by_api_key_id: config.issued_by_api_key_id,
tenant_id: None,
}
}
pub fn db_key(&self) -> Vec<u8> {
let mut key = Vec::with_capacity(16 + self.cert_id.len());
key.extend_from_slice(&self.service_id.to_be_bytes());
key.extend_from_slice(self.cert_id.as_bytes());
key
}
pub fn save(&self, db: &RocksDb) -> Result<(), crate::Error> {
let cf = db.cf_handle(crate::db::cf::TLS_CERT_METADATA_CF).ok_or(
crate::Error::UnknownColumnFamily(crate::db::cf::TLS_CERT_METADATA_CF),
)?;
let key = self.db_key();
let metadata_bytes = self.encode_to_vec();
db.put_cf(&cf, key, metadata_bytes)?;
Ok(())
}
pub fn find_by_service_and_cert_id(
service_id: u64,
cert_id: &str,
db: &RocksDb,
) -> Result<Option<Self>, crate::Error> {
let cf = db.cf_handle(crate::db::cf::TLS_CERT_METADATA_CF).ok_or(
crate::Error::UnknownColumnFamily(crate::db::cf::TLS_CERT_METADATA_CF),
)?;
let mut key = Vec::with_capacity(16 + cert_id.len());
key.extend_from_slice(&service_id.to_be_bytes());
key.extend_from_slice(cert_id.as_bytes());
let metadata_bytes = db.get_pinned_cf(&cf, key)?;
metadata_bytes
.map(|bytes| TlsCertMetadata::decode(bytes.as_ref()))
.transpose()
.map_err(Into::into)
}
pub fn is_expired(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.expires_at < now
}
pub fn revoke(&mut self) {
self.is_revoked = true;
}
pub fn add_subject_alt_name(&mut self, san: String) {
self.subject_alt_names.push(san);
}
pub fn set_tenant_id(&mut self, tenant_id: String) {
self.tenant_id = Some(tenant_id);
}
}
pub mod tls_assets {
use super::*;
use crate::db::cf;
pub fn store_tls_asset(
db: &RocksDb,
service_id: u64,
asset_type: &str,
encrypted_data: &[u8],
) -> Result<(), crate::Error> {
let cf = db
.cf_handle(cf::TLS_ASSETS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ASSETS_CF))?;
let key = format!("{service_id}:{asset_type}");
db.put_cf(&cf, key.as_bytes(), encrypted_data)?;
Ok(())
}
pub fn get_tls_asset(
db: &RocksDb,
service_id: u64,
asset_type: &str,
) -> Result<Option<Vec<u8>>, crate::Error> {
let cf = db
.cf_handle(cf::TLS_ASSETS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ASSETS_CF))?;
let key = format!("{service_id}:{asset_type}");
let asset_bytes = db.get_pinned_cf(&cf, key.as_bytes())?;
Ok(asset_bytes.map(|bytes| bytes.to_vec()))
}
pub fn delete_tls_asset(
db: &RocksDb,
service_id: u64,
asset_type: &str,
) -> Result<(), crate::Error> {
let cf = db
.cf_handle(cf::TLS_ASSETS_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ASSETS_CF))?;
let key = format!("{service_id}:{asset_type}");
db.delete_cf(&cf, key.as_bytes())?;
Ok(())
}
pub fn log_certificate_issuance(
db: &RocksDb,
service_id: u64,
cert_id: &str,
api_key_id: u64,
tenant_id: Option<&str>,
) -> Result<(), crate::Error> {
let cf = db
.cf_handle(cf::TLS_ISSUANCE_LOG_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ISSUANCE_LOG_CF))?;
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut log_entry = Vec::new();
log_entry.extend_from_slice(×tamp.to_be_bytes());
log_entry.extend_from_slice(&service_id.to_be_bytes());
log_entry.extend_from_slice(cert_id.as_bytes());
log_entry.push(0u8); log_entry.extend_from_slice(&api_key_id.to_be_bytes());
if let Some(tenant_id) = tenant_id {
log_entry.extend_from_slice(tenant_id.as_bytes());
}
let log_key = format!("{timestamp}:{cert_id}");
db.put_cf(&cf, log_key.as_bytes(), log_entry)?;
Ok(())
}
pub fn get_certificate_issuance_log(
db: &RocksDb,
service_id: u64,
) -> Result<Vec<TlsCertMetadata>, crate::Error> {
let cf = db
.cf_handle(cf::TLS_CERT_METADATA_CF)
.ok_or(crate::Error::UnknownColumnFamily(cf::TLS_CERT_METADATA_CF))?;
let mut certificates = Vec::new();
let prefix = service_id.to_be_bytes();
let iter = db.prefix_iterator_cf(&cf, prefix);
for item in iter {
let (_key, value) = item?;
let metadata = TlsCertMetadata::decode(&*value)?;
if metadata.service_id == service_id {
certificates.push(metadata);
}
}
Ok(certificates)
}
}
#[cfg(test)]
mod tests {
use crate::{api_tokens::ApiTokenGenerator, types::ServiceId};
use super::*;
#[test]
fn token_generator() {
let mut rng = blueprint_std::BlueprintRng::new();
let tmp_dir = tempfile::tempdir().unwrap();
let db = RocksDb::open(tmp_dir.path(), &Default::default()).unwrap();
let service_id = ServiceId::new(1);
let generator = ApiTokenGenerator::new();
let token = generator.generate_token(service_id, &mut rng);
let mut token = ApiTokenModel::from(&token);
let id = token.save(&db).unwrap();
assert_eq!(id, 1);
let found_token = ApiTokenModel::find_token_id(id, &db).unwrap();
assert!(found_token.is_some());
let found_token = found_token.unwrap();
assert_eq!(found_token.id, id);
assert_eq!(found_token.token, token.token);
assert_eq!(found_token.expires_at, token.expires_at);
assert_eq!(found_token.is_enabled, token.is_enabled);
}
#[test]
fn token_with_headers() {
use std::collections::BTreeMap;
let mut rng = blueprint_std::BlueprintRng::new();
let tmp_dir = tempfile::tempdir().unwrap();
let db = RocksDb::open(tmp_dir.path(), &Default::default()).unwrap();
let service_id = ServiceId::new(1);
let generator = ApiTokenGenerator::new();
let mut headers = BTreeMap::new();
headers.insert("X-Tenant-Id".to_string(), "tenant123".to_string());
headers.insert("X-User-Type".to_string(), "premium".to_string());
let token = generator.generate_token_with_expiration_and_headers(
service_id,
0,
headers.clone(),
&mut rng,
);
let mut token_model = ApiTokenModel::from(&token);
let id = token_model.save(&db).unwrap();
let found_token = ApiTokenModel::find_token_id(id, &db).unwrap().unwrap();
let found_headers = found_token.get_additional_headers();
assert_eq!(found_headers, headers);
}
#[test]
fn test_additional_headers_methods() {
use std::collections::BTreeMap;
let mut token_model = ApiTokenModel {
id: 0,
token: "test".to_string(),
service_id: 1,
sub_service_id: 0,
expires_at: 0,
is_enabled: true,
additional_headers: Vec::new(),
};
assert!(token_model.get_additional_headers().is_empty());
let mut headers = BTreeMap::new();
headers.insert("X-Test".to_string(), "value".to_string());
token_model.set_additional_headers(&headers);
let retrieved = token_model.get_additional_headers();
assert_eq!(retrieved, headers);
}
}