use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use hmac::{Hmac, Mac};
use sha2::Sha256;
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct AgentWidgetAuth {
#[serde(default)]
pub allowed_origins: Vec<String>,
#[serde(default)]
pub public_key: Option<String>,
}
#[async_trait]
pub trait WidgetAuthProvider: Send + Sync {
async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth>;
}
#[derive(Debug, Default)]
pub struct PermissiveWidgetAuth;
#[async_trait]
impl WidgetAuthProvider for PermissiveWidgetAuth {
async fn agent_widget_auth(&self, _agent_id: &str) -> Option<AgentWidgetAuth> {
None
}
}
#[derive(Debug, Default)]
pub struct StaticWidgetAuth {
rows: HashMap<String, AgentWidgetAuth>,
}
impl StaticWidgetAuth {
#[must_use]
pub fn new(rows: HashMap<String, AgentWidgetAuth>) -> Self {
Self { rows }
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
let rows: HashMap<String, AgentWidgetAuth> = serde_json::from_str(json)?;
Ok(Self { rows })
}
}
#[async_trait]
impl WidgetAuthProvider for StaticWidgetAuth {
async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth> {
self.rows.get(agent_id).cloned()
}
}
struct CacheEntry {
value: Option<AgentWidgetAuth>,
fetched: Instant,
}
pub struct HttpWidgetAuth {
client: reqwest::Client,
base_url: String,
bearer: Option<String>,
ttl: Duration,
cache: RwLock<HashMap<String, CacheEntry>>,
}
impl HttpWidgetAuth {
#[must_use]
pub fn new(base_url: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.unwrap_or_default();
Self::with_client(base_url, client)
}
#[must_use]
pub fn with_client(base_url: impl Into<String>, client: reqwest::Client) -> Self {
Self {
client,
base_url: base_url.into().trim_end_matches('/').to_string(),
bearer: None,
ttl: Duration::from_secs(60),
cache: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn with_bearer(mut self, token: impl Into<String>) -> Self {
self.bearer = Some(token.into());
self
}
#[must_use]
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
fn cached(&self, agent_id: &str) -> Option<Option<AgentWidgetAuth>> {
let cache = self.cache.read().ok()?;
let entry = cache.get(agent_id)?;
if entry.fetched.elapsed() < self.ttl {
Some(entry.value.clone())
} else {
None
}
}
fn store(&self, agent_id: &str, value: Option<AgentWidgetAuth>) {
if let Ok(mut cache) = self.cache.write() {
cache.insert(
agent_id.to_string(),
CacheEntry {
value,
fetched: Instant::now(),
},
);
}
}
}
#[async_trait]
impl WidgetAuthProvider for HttpWidgetAuth {
async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth> {
if let Some(cached) = self.cached(agent_id) {
return cached;
}
let mut url = match reqwest::Url::parse(&self.base_url) {
Ok(u) => u,
Err(e) => {
tracing::warn!(error = %e, base_url = %self.base_url, "widget-auth: invalid base_url");
return None;
}
};
match url.path_segments_mut() {
Ok(mut segs) => {
segs.push(agent_id);
}
Err(()) => {
tracing::warn!(base_url = %self.base_url, "widget-auth: base_url cannot be a base");
return None;
}
}
let mut req = self.client.get(url);
if let Some(bearer) = &self.bearer {
req = req.bearer_auth(bearer);
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!(error = %e, agent_id, "widget-auth: policy fetch failed");
return None;
}
};
let status = resp.status();
if status.is_success() {
match resp.json::<AgentWidgetAuth>().await {
Ok(policy) => {
let value = Some(policy);
self.store(agent_id, value.clone());
value
}
Err(e) => {
tracing::warn!(error = %e, agent_id, "widget-auth: malformed policy body");
None
}
}
} else if status == reqwest::StatusCode::NOT_FOUND {
self.store(agent_id, None);
None
} else {
tracing::warn!(%status, agent_id, "widget-auth: policy service error");
None
}
}
}
#[must_use]
pub fn origin_allowed(allowed: &[String], origin: &str) -> bool {
allowed
.iter()
.any(|pattern| origin_matches(pattern, origin))
}
fn origin_matches(pattern: &str, origin: &str) -> bool {
if pattern == "*" {
return true;
}
if pattern == origin {
return true;
}
let (Some((p_scheme, p_host)), Some((o_scheme, o_host))) =
(pattern.split_once("://"), origin.split_once("://"))
else {
return false;
};
if p_scheme != o_scheme {
return false;
}
if let Some(suffix) = p_host.strip_prefix("*.") {
return o_host == suffix || o_host.ends_with(&format!(".{suffix}"));
}
false
}
#[must_use]
pub fn verify_auth_context(
public_key: &str,
user_id: &str,
signature_hex: &str,
timestamp: i64,
now_unix: i64,
max_age_secs: i64,
) -> bool {
if (now_unix - timestamp).abs() > max_age_secs {
return false;
}
let Ok(sig) = hex::decode(signature_hex) else {
return false;
};
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(public_key.as_bytes()) else {
return false;
};
mac.update(format!("{user_id}:{timestamp}").as_bytes());
mac.verify_slice(&sig).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn origin_exact_and_wildcard() {
let allow = vec![
"https://app.example.com".to_string(),
"https://*.smoo.ai".to_string(),
];
assert!(origin_allowed(&allow, "https://app.example.com"));
assert!(origin_allowed(&allow, "https://dash.smoo.ai"));
assert!(origin_allowed(&allow, "https://smoo.ai"));
assert!(!origin_allowed(&allow, "https://evil.com"));
assert!(!origin_allowed(&allow, "http://dash.smoo.ai"));
assert!(!origin_allowed(&allow, "https://notsmoo.ai"));
}
#[test]
fn origin_star_allows_all_but_empty_denies() {
assert!(origin_allowed(&["*".to_string()], "https://anything.test"));
assert!(!origin_allowed(&[], "https://anything.test"));
}
fn sign(key: &str, user: &str, ts: i64) -> String {
let mut mac = Hmac::<Sha256>::new_from_slice(key.as_bytes()).unwrap();
mac.update(format!("{user}:{ts}").as_bytes());
hex::encode(mac.finalize().into_bytes())
}
#[test]
fn auth_context_valid_and_invalid() {
let key = "super-secret-public-key";
let now = 1_000_000;
let good = sign(key, "user-123", now);
assert!(verify_auth_context(key, "user-123", &good, now, now, 60));
assert!(verify_auth_context(
key,
"user-123",
&sign(key, "user-123", now - 30),
now - 30,
now,
60
));
assert!(!verify_auth_context(
"other-key",
"user-123",
&good,
now,
now,
60
));
assert!(!verify_auth_context(key, "user-999", &good, now, now, 60));
assert!(!verify_auth_context(
key,
"user-123",
&sign(key, "user-123", now - 600),
now - 600,
now,
60
));
assert!(!verify_auth_context(
key, "user-123", "not-hex", now, now, 60
));
}
#[tokio::test]
async fn static_provider_resolves_known_agents() {
let json =
r#"{ "agent-1": { "allowed_origins": ["https://*.smoo.ai"], "public_key": "k" } }"#;
let p = StaticWidgetAuth::from_json(json).unwrap();
let a = p.agent_widget_auth("agent-1").await.unwrap();
assert_eq!(a.allowed_origins, vec!["https://*.smoo.ai".to_string()]);
assert_eq!(a.public_key.as_deref(), Some("k"));
assert!(p.agent_widget_auth("unknown").await.is_none());
}
#[tokio::test]
async fn permissive_provider_returns_none() {
assert!(PermissiveWidgetAuth
.agent_widget_auth("anything")
.await
.is_none());
}
#[tokio::test]
async fn http_provider_fetches_then_serves_from_cache() {
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/agent-9"))
.and(header("authorization", "Bearer m2m-token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"allowed_origins": ["https://app.smoo.ai"],
"public_key": "secret"
})))
.expect(1) .mount(&server)
.await;
let provider = HttpWidgetAuth::new(server.uri()).with_bearer("m2m-token");
let first = provider.agent_widget_auth("agent-9").await.expect("policy");
assert_eq!(
first.allowed_origins,
vec!["https://app.smoo.ai".to_string()]
);
assert_eq!(first.public_key.as_deref(), Some("secret"));
let second = provider.agent_widget_auth("agent-9").await.expect("cached");
assert_eq!(second.public_key.as_deref(), Some("secret"));
}
#[tokio::test]
async fn http_provider_404_is_none_and_cached() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/ghost"))
.respond_with(ResponseTemplate::new(404))
.expect(1) .mount(&server)
.await;
let provider = HttpWidgetAuth::new(server.uri());
assert!(provider.agent_widget_auth("ghost").await.is_none());
assert!(provider.agent_widget_auth("ghost").await.is_none()); }
#[tokio::test]
async fn http_provider_server_error_is_none_and_not_cached() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/flaky"))
.respond_with(ResponseTemplate::new(500))
.expect(2) .mount(&server)
.await;
let provider = HttpWidgetAuth::new(server.uri());
assert!(provider.agent_widget_auth("flaky").await.is_none());
assert!(provider.agent_widget_auth("flaky").await.is_none()); }
}