use crate::error::Result;
use crate::middleware::{BoxedMiddleware, Next};
use crate::response::{Response, ResponseBuilder};
use crate::Context;
use async_trait::async_trait;
use sha2::{Digest, Sha256};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ApiKeyIdentity {
pub id: String,
pub scopes: Vec<String>,
}
#[async_trait]
pub trait ApiKeyStore: Send + Sync {
async fn validate(&self, presented_key: &str) -> Option<ApiKeyIdentity>;
}
#[derive(Default)]
pub struct StaticKeys {
entries: Vec<([u8; 32], ApiKeyIdentity)>,
}
impl StaticKeys {
pub fn new() -> Self {
Self::default()
}
pub fn insert(mut self, key: impl AsRef<[u8]>, id: impl Into<String>) -> Self {
self.entries.push((
hash_key(key.as_ref()),
ApiKeyIdentity {
id: id.into(),
scopes: Vec::new(),
},
));
self
}
pub fn with_scopes(
mut self,
key: impl AsRef<[u8]>,
id: impl Into<String>,
scopes: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.entries.push((
hash_key(key.as_ref()),
ApiKeyIdentity {
id: id.into(),
scopes: scopes.into_iter().map(Into::into).collect(),
},
));
self
}
}
#[async_trait]
impl ApiKeyStore for StaticKeys {
async fn validate(&self, presented_key: &str) -> Option<ApiKeyIdentity> {
let presented = hash_key(presented_key.as_bytes());
let mut found: Option<&ApiKeyIdentity> = None;
for (digest, identity) in &self.entries {
if ct_eq(&presented, digest) {
found = Some(identity);
}
}
found.cloned()
}
}
#[derive(Debug, Clone)]
enum KeySource {
Header(String),
Query(String),
}
pub struct ApiKey<S: ApiKeyStore> {
store: Arc<S>,
source: KeySource,
optional: bool,
}
impl<S: ApiKeyStore + 'static> ApiKey<S> {
pub fn new(store: S) -> Self {
Self {
store: Arc::new(store),
source: KeySource::Header("x-api-key".to_string()),
optional: false,
}
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.source = KeySource::Header(name.into());
self
}
pub fn from_query(mut self, name: impl Into<String>) -> Self {
self.source = KeySource::Query(name.into());
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
pub fn build(self) -> BoxedMiddleware {
let cfg = Arc::new(self);
Arc::new(move |ctx: Context, next: Next| {
let cfg = cfg.clone();
Box::pin(async move {
let presented = match &cfg.source {
KeySource::Header(name) => ctx.req.header(name),
KeySource::Query(name) => ctx.req.query(name),
};
match presented {
Some(key) => match cfg.store.validate(&key).await {
Some(identity) => {
let principal = crate::auth::Principal {
id: Some(identity.id.clone()),
scopes: identity.scopes.clone(),
};
ctx.set_api_key(identity).await;
ctx.set_principal(principal).await;
next(ctx).await
}
None if cfg.optional => next(ctx).await,
None => Ok(unauthorized()),
},
None if cfg.optional => next(ctx).await,
None => Ok(unauthorized()),
}
}) as Pin<Box<dyn Future<Output = Result<Response>> + Send>>
})
}
}
fn hash_key(key: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(key);
hasher.finalize().into()
}
fn ct_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn unauthorized() -> Response {
ResponseBuilder::new()
.status(401)
.text("Unauthorized")
.build()
.unwrap_or_else(|_| crate::response::helpers::text("Unauthorized").unwrap())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn valid_key_resolves_to_identity_with_scopes() {
let store = StaticKeys::new()
.insert("key-abc", "service-a")
.with_scopes("key-def", "service-b", ["read", "write"]);
let a = store.validate("key-abc").await.unwrap();
assert_eq!(a.id, "service-a");
assert!(a.scopes.is_empty());
let b = store.validate("key-def").await.unwrap();
assert_eq!(b.id, "service-b");
assert_eq!(b.scopes, vec!["read".to_string(), "write".to_string()]);
}
#[tokio::test]
async fn unknown_key_is_rejected() {
let store = StaticKeys::new().insert("key-abc", "service-a");
assert!(store.validate("nope").await.is_none());
assert!(StaticKeys::new().validate("anything").await.is_none());
}
#[test]
fn ct_eq_matches_only_identical_digests() {
let a = hash_key(b"key-abc");
let same = hash_key(b"key-abc");
let other = hash_key(b"key-xyz");
assert!(ct_eq(&a, &same));
assert!(!ct_eq(&a, &other));
}
}