mod common;
use std::sync::Arc;
use nodedb::bridge::dispatch::Dispatcher;
use nodedb::control::security::identity::{AuthMethod, Role};
use nodedb::control::security::sessions::{KillReason, SessionParams, SessionRegistry};
use nodedb::control::state::SharedState;
use nodedb::types::TenantId;
use nodedb::wal::WalManager;
fn make_state() -> Arc<SharedState> {
common::pgwire_auth_helpers::make_state()
}
fn sample_params(user_id: u64, username: &str) -> SessionParams {
SessionParams {
user_id,
username: username.to_string(),
db_user: username.to_string(),
peer_addr: "127.0.0.1:5555".to_string(),
protocol: "native".to_string(),
auth_method: "password".to_string(),
tenant_id: 1,
credential_version: 0,
current_database: None,
token_expiry_ms: None,
}
}
#[test]
fn active_sessions_register_unregister() {
let reg = SessionRegistry::new();
let rx = reg.register("s1", &sample_params(42, "alice")).unwrap();
assert!(!rx.has_changed().unwrap_or(false));
assert_eq!(reg.count(None), 1);
assert_eq!(reg.count(Some(42)), 1);
reg.unregister("s1");
assert_eq!(reg.count(None), 0);
}
#[test]
fn max_active_sessions_over_cap_rejects() {
let reg = SessionRegistry::with_cap(2);
reg.register("s1", &sample_params(1, "u1")).unwrap();
reg.register("s2", &sample_params(2, "u2")).unwrap();
let result = reg.register("s3", &sample_params(3, "u3"));
assert!(result.is_err(), "over-cap registration must fail");
assert_eq!(reg.count(None), 2);
}
#[test]
fn session_hard_revoke_close() {
let reg = SessionRegistry::new();
let mut rx = reg.register("s1", &sample_params(99, "bob")).unwrap();
assert!(!rx.has_changed().unwrap_or(false));
let killed = reg.kill_sessions_for_user(99, KillReason::AdminKill);
assert_eq!(killed, 1);
assert!(rx.has_changed().unwrap_or(false));
assert_ne!(*rx.borrow_and_update(), KillReason::Alive);
}
#[test]
fn show_sessions_lists_active() {
let reg = SessionRegistry::new();
reg.register("sess-xyz", &sample_params(7, "carol"))
.unwrap();
let all = reg.list_all();
assert_eq!(all.len(), 1);
assert_eq!(all[0].session_id, "sess-xyz");
assert_eq!(all[0].user_id, 7);
assert_eq!(all[0].protocol, "native");
}
#[tokio::test]
async fn credential_version_bumps_on_mutation() {
let state = make_state();
let v0 = state.credentials.current_version(0);
let user_id = state
.credentials
.create_user("dave", "pass123", TenantId::new(1), vec![Role::ReadOnly])
.expect("create_user failed");
let v1 = state.credentials.current_version(user_id);
assert!(v1 > v0, "version must advance after create_user");
state
.credentials
.update_roles("dave", vec![Role::ReadWrite])
.expect("update_roles failed");
let v2 = state.credentials.current_version(user_id);
assert!(v2 > v1, "version must advance after update_roles");
}
#[tokio::test]
async fn identity_rehydrate_on_version_advance() {
let state = make_state();
let user_id = state
.credentials
.create_user("eve", "pass456", TenantId::new(1), vec![Role::ReadOnly])
.expect("create_user failed");
let v_create = state.credentials.current_version(user_id);
let identity_v1 = state
.credentials
.to_identity("eve", AuthMethod::Trust)
.expect("identity must exist");
assert!(identity_v1.roles.contains(&Role::ReadOnly));
assert!(!identity_v1.roles.contains(&Role::ReadWrite));
state
.credentials
.update_roles("eve", vec![Role::ReadOnly, Role::ReadWrite])
.expect("update_roles failed");
let v_after = state.credentials.current_version(user_id);
assert!(v_after > v_create, "version must have advanced");
let identity_v2 = state
.credentials
.to_identity("eve", AuthMethod::Trust)
.expect("identity must exist after role update");
assert!(
identity_v2.roles.contains(&Role::ReadWrite),
"rehydrated identity must carry new role"
);
}
#[tokio::test]
async fn in_flight_propagation_grant_visible() {
let state = make_state();
let user_id = state
.credentials
.create_user("frank", "pw789", TenantId::new(1), vec![Role::ReadOnly])
.expect("create_user failed");
let v_before = state.credentials.current_version(user_id);
state
.credentials
.add_role("frank", Role::ReadWrite)
.expect("add_role failed");
let v_after = state.credentials.current_version(user_id);
assert!(v_after > v_before, "grant must bump version");
let fresh = state
.credentials
.to_identity("frank", AuthMethod::Trust)
.expect("identity must exist");
assert!(
fresh.roles.contains(&Role::ReadWrite),
"fresh identity must contain newly granted role"
);
}
#[tokio::test]
async fn commit_user_mutation_publishes_both_buses() {
let dir = tempfile::tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let wal = Arc::new(WalManager::open_for_testing(&wal_path).unwrap());
let (dispatcher, _) = Dispatcher::new(1, 64);
let state = SharedState::new(dispatcher, wal);
let mut uc_rx = state.credentials.subscribe_user_changes();
let mut si_rx = state.credentials.subscribe_session_invalidation();
let user_id = state
.credentials
.create_user("grace", "pw111", TenantId::new(1), vec![Role::ReadOnly])
.expect("create_user failed");
state
.credentials
.drop_user("grace")
.expect("drop_user failed");
let ev = tokio::time::timeout(std::time::Duration::from_millis(200), uc_rx.recv())
.await
.expect("timed out waiting for UserChanged")
.expect("channel closed");
assert_eq!(ev.user_id, user_id);
let si_ev = tokio::time::timeout(std::time::Duration::from_millis(200), si_rx.recv())
.await
.expect("timed out waiting for SessionInvalidated")
.expect("channel closed");
assert_eq!(si_ev.user_id, user_id);
assert!(
si_ev.reason.is_hard_revoke(),
"UserDropped must be hard-revoke"
);
}
#[tokio::test]
async fn commit_user_mutation_no_invalidation_publishes_user_changed_only() {
let dir = tempfile::tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let wal = Arc::new(WalManager::open_for_testing(&wal_path).unwrap());
let (dispatcher, _) = Dispatcher::new(1, 64);
let state = SharedState::new(dispatcher, wal);
let mut uc_rx = state.credentials.subscribe_user_changes();
let mut si_rx = state.credentials.subscribe_session_invalidation();
let user_id = state
.credentials
.create_user("henry", "pw222", TenantId::new(1), vec![Role::ReadOnly])
.expect("create_user failed");
let uc_ev = tokio::time::timeout(std::time::Duration::from_millis(200), uc_rx.recv())
.await
.expect("timed out waiting for UserChanged")
.expect("channel closed");
assert_eq!(uc_ev.user_id, user_id);
let si_result = tokio::time::timeout(std::time::Duration::from_millis(50), si_rx.recv()).await;
assert!(
si_result.is_err(),
"no SessionInvalidated expected for create_user"
);
}