use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use serde::Deserialize;
use sqlx::SqlitePool;
use tokio::sync::RwLock;
use tracing::{debug, warn};
use crate::db::tenants;
use crate::tenant::TenantContext;
#[async_trait]
pub trait TenantResolver: Send + Sync {
async fn resolve(&self, peer_addr: &SocketAddr) -> TenantContext;
}
struct CacheEntry {
context: TenantContext,
expires_at: Instant,
}
pub struct IpMappingResolver {
pool: SqlitePool,
cache: RwLock<HashMap<IpAddr, CacheEntry>>,
ttl: Duration,
}
impl IpMappingResolver {
pub fn new(pool: SqlitePool) -> Self {
Self {
pool,
cache: RwLock::new(HashMap::new()),
ttl: Duration::from_secs(60),
}
}
pub fn invalidate_cache(&self) {
if let Ok(mut cache) = self.cache.try_write() {
cache.clear();
}
}
}
#[async_trait]
impl TenantResolver for IpMappingResolver {
async fn resolve(&self, peer_addr: &SocketAddr) -> TenantContext {
let ip = peer_addr.ip();
{
let cache = self.cache.read().await;
if let Some(entry) = cache.get(&ip)
&& entry.expires_at > Instant::now() {
return entry.context.clone();
}
}
let ip_str = ip.to_string();
let ctx = match tenants::tenant_by_ip(&self.pool, &ip_str).await {
Ok(Some(tenant)) if tenant.enabled => TenantContext::new(tenant.id),
Ok(_) => TenantContext::anonymous(),
Err(e) => {
warn!("IP mapping lookup failed for {}: {}", ip, e);
TenantContext::anonymous()
}
};
{
let mut cache = self.cache.write().await;
cache.insert(
ip,
CacheEntry {
context: ctx.clone(),
expires_at: Instant::now() + self.ttl,
},
);
}
ctx
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct TailscaleWhoisResponse {
user_profile: Option<TailscaleUserProfile>,
node: Option<TailscaleNode>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct TailscaleUserProfile {
login_name: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct TailscaleNode {
computed_name: Option<String>,
}
pub struct TailscaleResolver {
pool: SqlitePool,
cache: RwLock<HashMap<IpAddr, CacheEntry>>,
ttl: Duration,
http_client: reqwest::Client,
}
impl TailscaleResolver {
pub fn new(pool: SqlitePool) -> Self {
Self {
pool,
cache: RwLock::new(HashMap::new()),
ttl: Duration::from_secs(300),
http_client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl TenantResolver for TailscaleResolver {
async fn resolve(&self, peer_addr: &SocketAddr) -> TenantContext {
let ip = peer_addr.ip();
{
let cache = self.cache.read().await;
if let Some(entry) = cache.get(&ip)
&& entry.expires_at > Instant::now() {
return entry.context.clone();
}
}
let url = format!("http://127.0.0.1/localapi/v0/whois?addr={}", peer_addr);
let ctx = match self.http_client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
match resp.json::<TailscaleWhoisResponse>().await {
Ok(whois) => {
let login = whois
.user_profile
.as_ref()
.map(|u| u.login_name.clone())
.unwrap_or_default();
let node_name = whois
.node
.as_ref()
.and_then(|n| n.computed_name.clone())
.unwrap_or_default();
let tenant_id = if !login.is_empty() {
format!("ts:{}:{}", login, node_name)
} else {
String::new()
};
if tenant_id.is_empty() {
debug!("Tailscale whois returned no identity for {}", peer_addr);
TenantContext::anonymous()
} else {
let ip_str = ip.to_string();
if let Ok(None) = tenants::tenant_by_ip(&self.pool, &ip_str).await {
let display = format!("{} ({})", node_name, login);
match tenants::Tenant::create(
&self.pool, &tenant_id, &display, None, None, None, None,
)
.await
{
Ok(_) => {
let _ = tenants::add_ip_mapping(
&self.pool, &tenant_id, &ip_str,
)
.await;
debug!("Auto-created tenant {} for {}", tenant_id, ip);
}
Err(e) => {
debug!(
"Tenant creation skipped (may already exist): {}",
e
);
let _ = tenants::add_ip_mapping(
&self.pool, &tenant_id, &ip_str,
)
.await;
}
}
}
TenantContext::new(tenant_id)
}
}
Err(e) => {
warn!("Failed to parse Tailscale whois response: {}", e);
TenantContext::anonymous()
}
}
}
Ok(resp) => {
debug!(
"Tailscale whois returned non-success status {} for {}",
resp.status(),
peer_addr
);
TenantContext::anonymous()
}
Err(e) => {
debug!("Tailscale whois API unreachable for {}: {}", peer_addr, e);
TenantContext::anonymous()
}
};
{
let mut cache = self.cache.write().await;
cache.insert(
ip,
CacheEntry {
context: ctx.clone(),
expires_at: Instant::now() + self.ttl,
},
);
}
ctx
}
}
pub struct HeaderResolver {
header_name: String,
}
impl HeaderResolver {
pub fn new(header_name: String) -> Self {
Self { header_name }
}
pub fn resolve_from_header(&self, headers: &hyper::HeaderMap) -> TenantContext {
match headers.get(&self.header_name) {
Some(value) => match value.to_str() {
Ok(tenant_id) if !tenant_id.is_empty() => TenantContext::new(tenant_id.to_string()),
_ => TenantContext::anonymous(),
},
None => TenantContext::anonymous(),
}
}
}
#[async_trait]
impl TenantResolver for HeaderResolver {
async fn resolve(&self, _peer_addr: &SocketAddr) -> TenantContext {
TenantContext::anonymous()
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum TenantResolverKind {
#[default]
IpMapping,
Tailscale,
Header,
}
impl std::fmt::Display for TenantResolverKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TenantResolverKind::IpMapping => write!(f, "ip-mapping"),
TenantResolverKind::Tailscale => write!(f, "tailscale"),
TenantResolverKind::Header => write!(f, "header"),
}
}
}
pub fn build_resolver(
kind: &TenantResolverKind,
pool: SqlitePool,
header_name: Option<String>,
) -> Arc<dyn TenantResolver> {
match kind {
TenantResolverKind::IpMapping => Arc::new(IpMappingResolver::new(pool)),
TenantResolverKind::Tailscale => Arc::new(TailscaleResolver::new(pool)),
TenantResolverKind::Header => {
let header = header_name.unwrap_or_else(|| "X-Tenant-Id".to_string());
Arc::new(HeaderResolver::new(header))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
#[tokio::test]
async fn test_ip_mapping_resolver_unmapped_returns_anonymous() {
let (db, _tmp) = crate::test_utils::create_db().await;
let resolver = IpMappingResolver::new(db.pool);
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 12345));
let ctx = resolver.resolve(&addr).await;
assert!(ctx.is_anonymous());
}
#[tokio::test]
async fn test_ip_mapping_resolver_mapped_returns_tenant() {
let (db, _tmp) = crate::test_utils::create_db().await;
tenants::Tenant::create(&db.pool, "t1", "Test User", None, None, None, None)
.await
.unwrap();
tenants::add_ip_mapping(&db.pool, "t1", "10.0.0.1")
.await
.unwrap();
let resolver = IpMappingResolver::new(db.pool);
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 12345));
let ctx = resolver.resolve(&addr).await;
assert_eq!(ctx.tenant_id.as_deref(), Some("t1"));
}
#[tokio::test]
async fn test_ip_mapping_resolver_multiple_ips_same_tenant() {
let (db, _tmp) = crate::test_utils::create_db().await;
tenants::Tenant::create(&db.pool, "t1", "Test User", None, None, None, None)
.await
.unwrap();
tenants::add_ip_mapping(&db.pool, "t1", "10.0.0.1")
.await
.unwrap();
tenants::add_ip_mapping(&db.pool, "t1", "10.0.0.2")
.await
.unwrap();
let resolver = IpMappingResolver::new(db.pool);
let addr1 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 1));
let addr2 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 1));
assert_eq!(
resolver.resolve(&addr1).await.tenant_id.as_deref(),
Some("t1")
);
assert_eq!(
resolver.resolve(&addr2).await.tenant_id.as_deref(),
Some("t1")
);
}
#[test]
fn test_header_resolver_present() {
let resolver = HeaderResolver::new("X-Tenant-Id".to_string());
let mut headers = hyper::HeaderMap::new();
headers.insert("X-Tenant-Id", "tenant-abc".parse().unwrap());
let ctx = resolver.resolve_from_header(&headers);
assert_eq!(ctx.tenant_id.as_deref(), Some("tenant-abc"));
}
#[test]
fn test_header_resolver_missing() {
let resolver = HeaderResolver::new("X-Tenant-Id".to_string());
let headers = hyper::HeaderMap::new();
let ctx = resolver.resolve_from_header(&headers);
assert!(ctx.is_anonymous());
}
}