use async_trait::async_trait;
use std::sync::Arc;
#[cfg(feature = "tee")]
use vti_common::auth::backend::AuthError;
use vti_common::auth::backend::{AttestationOutcome, AuthBackend, RoleResolution};
use vti_common::auth::handlers::KeyspaceSessionStore;
use vti_common::auth::jwt::JwtKeys;
use crate::acl::Role;
use crate::error::AppError;
use crate::server::AppState;
pub struct VtaAuthBackend {
state: Arc<AppState>,
sessions: KeyspaceSessionStore,
jwt_keys: Arc<JwtKeys>,
challenge_ttl: u64,
access_token_ttl: u64,
refresh_token_ttl: u64,
}
impl VtaAuthBackend {
pub async fn from_state(state: &AppState) -> Result<Self, AppError> {
let jwt_keys = state
.jwt_keys
.clone()
.ok_or_else(|| AppError::Internal("JWT keys not configured".to_string()))?;
let sessions = KeyspaceSessionStore::new(state.sessions_ks.clone());
let (challenge_ttl, access_token_ttl, refresh_token_ttl) = {
let cfg = state.config.read().await;
(
cfg.auth.challenge_ttl,
cfg.auth.access_token_expiry,
cfg.auth.refresh_token_expiry,
)
};
Ok(Self {
state: Arc::new(state.clone()),
sessions,
jwt_keys,
challenge_ttl,
access_token_ttl,
refresh_token_ttl,
})
}
}
#[async_trait]
impl AuthBackend for VtaAuthBackend {
type Store = KeyspaceSessionStore;
type Error = AppError;
type Role = Role;
fn sessions(&self) -> &Self::Store {
&self.sessions
}
async fn mint_access_token(
&self,
subject: &str,
session_id: &str,
role: &Self::Role,
contexts: &[String],
amr: &[String],
acr: &str,
tee_attested: bool,
ttl_secs: u64,
) -> Result<String, Self::Error> {
let claims = self
.jwt_keys
.new_claims(
subject.to_string(),
session_id.to_string(),
role.to_string(),
contexts.to_vec(),
ttl_secs,
tee_attested,
)
.with_aal(amr.to_vec(), acr.to_string());
self.jwt_keys
.encode(&claims)
.map_err(|e| AppError::Internal(format!("jwt encode failed: {e:?}")))
}
async fn check_acl(&self, did: &str) -> Result<RoleResolution<Self::Role>, Self::Error> {
let (role, allowed_contexts) =
vti_common::acl::check_acl_full(&self.state.acl_ks, did).await?;
if let Some(entry) = vti_common::acl::get_acl_entry(&self.state.acl_ks, did).await? {
device_access_gate(&entry).inspect_err(|e| {
tracing::warn!(%did, "auth rejected: {e}");
})?;
}
Ok(RoleResolution::with_contexts(role, allowed_contexts))
}
async fn validate_did(&self, did: &str) -> Result<(), Self::Error> {
#[cfg(feature = "tee")]
{
let config = self.state.config.read().await;
if let Some(ref allowed) = config.tee.allowed_did_methods {
let did_ok = allowed.iter().any(|prefix| did.starts_with(prefix));
if !did_ok {
tracing::warn!(%did, "auth rejected: DID method not in allowed_did_methods");
return Err(AuthError::DidMethodRejected.into());
}
}
}
let _ = did;
Ok(())
}
async fn attest_challenge(
&self,
_challenge_bytes: &[u8; 32],
) -> Result<AttestationOutcome, Self::Error> {
#[cfg(feature = "tee")]
{
let Some(ref tee) = self.state.tee else {
return Ok(AttestationOutcome::not_attested());
};
let config = self.state.config.read().await;
let vta_did = config.vta_did.clone();
let tee_mode = config.tee.mode.clone();
drop(config);
let user_data = vta_did.as_deref().unwrap_or("").as_bytes();
let nonce_bytes = &_challenge_bytes[..];
match tee.state.provider.attest(user_data, nonce_bytes) {
Ok(mut report) => {
report.vta_did = vta_did;
let value = serde_json::to_value(&report).map_err(|e| {
AppError::Internal(format!("failed to serialize attestation report: {e}"))
})?;
Ok(AttestationOutcome::attested(value))
}
Err(e) => {
if matches!(tee_mode, crate::config::TeeMode::Required) {
tracing::error!(
"TEE attestation failed in required mode — refusing challenge: {e}"
);
return Err(AuthError::AttestationFailed(e.to_string()).into());
}
tracing::warn!(
"TEE attestation failed (mode=optional) — challenge served without attestation: {e}"
);
Ok(AttestationOutcome::not_attested())
}
}
}
#[cfg(not(feature = "tee"))]
{
Ok(AttestationOutcome::not_attested())
}
}
fn challenge_ttl(&self) -> u64 {
self.challenge_ttl
}
fn access_token_ttl(&self) -> u64 {
self.access_token_ttl
}
fn refresh_token_ttl(&self) -> u64 {
self.refresh_token_ttl
}
}
fn device_access_gate(entry: &vti_common::acl::AclEntry) -> Result<(), AppError> {
if let Some(binding) = entry.device.as_ref() {
if binding.wiped_at.is_some() {
return Err(AppError::Forbidden("device has been wiped".into()));
}
if binding.disabled_at.is_some() {
return Err(AppError::Forbidden("device is disabled".into()));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::device_access_gate;
use crate::error::AppError;
use vti_common::acl::{AclEntry, DeviceBinding, Role};
fn entry_with_binding(binding_json: &str) -> AclEntry {
let binding: DeviceBinding = serde_json::from_str(binding_json).unwrap();
AclEntry::new("did:key:zAgent", Role::Reader, "test").with_device(Some(binding))
}
#[test]
fn gate_allows_entry_without_device() {
let entry = AclEntry::new("did:key:zAdmin", Role::Admin, "test");
assert!(device_access_gate(&entry).is_ok());
}
#[test]
fn gate_allows_active_device() {
let entry = entry_with_binding(
r#"{"deviceId":"d","displayName":"agent","registeredAt":"2026-01-01T00:00:00Z"}"#,
);
assert!(device_access_gate(&entry).is_ok());
}
#[test]
fn gate_rejects_disabled_device() {
let entry = entry_with_binding(
r#"{"deviceId":"d","displayName":"agent","registeredAt":"2026-01-01T00:00:00Z","disabledAt":"2026-06-01T00:00:00Z"}"#,
);
assert!(matches!(
device_access_gate(&entry),
Err(AppError::Forbidden(_))
));
}
#[test]
fn gate_rejects_wiped_device() {
let entry = entry_with_binding(
r#"{"deviceId":"d","displayName":"agent","registeredAt":"2026-01-01T00:00:00Z","wipedAt":"2026-06-01T00:00:00Z"}"#,
);
assert!(matches!(
device_access_gate(&entry),
Err(AppError::Forbidden(_))
));
}
}