use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Mutex, OnceLock};
use std::time::Instant;
static SESSION_ID: OnceLock<String> = OnceLock::new();
pub fn session_id() -> &'static str {
SESSION_ID.get_or_init(|| {
std::env::var("TIRITH_SESSION_ID").unwrap_or_else(|_| generate_session_id())
})
}
fn generate_session_id() -> String {
uuid::Uuid::new_v4().to_string()
}
pub fn new_session_id() -> String {
generate_session_id()
}
pub fn env_session_id() -> Option<&'static str> {
static CACHED: OnceLock<Option<String>> = OnceLock::new();
CACHED
.get_or_init(|| {
std::env::var("TIRITH_SESSION_ID")
.ok()
.filter(|s| !s.is_empty())
})
.as_deref()
}
struct FallbackEntry {
session_id: String,
cached_at: Instant,
}
static FALLBACK_CACHE: OnceLock<Mutex<HashMap<String, FallbackEntry>>> = OnceLock::new();
const FALLBACK_FILE_MAX_AGE_SECS: u64 = 4 * 3600;
const FALLBACK_CACHE_REFRESH_SECS: u64 = 300;
pub fn fallback_session_id() -> String {
let scope = compute_scope();
let cache = FALLBACK_CACHE.get_or_init(|| Mutex::new(HashMap::new()));
if let Ok(map) = cache.lock() {
if let Some(entry) = map.get(&scope) {
if entry.cached_at.elapsed().as_secs() < FALLBACK_CACHE_REFRESH_SECS {
return entry.session_id.clone();
}
}
}
let id = load_or_create_fallback_file(&scope);
if let Ok(mut map) = cache.lock() {
map.insert(
scope,
FallbackEntry {
session_id: id.clone(),
cached_at: Instant::now(),
},
);
}
id
}
pub fn resolve_session_id() -> String {
if let Some(env_id) = env_session_id() {
return env_id.to_string();
}
fallback_session_id()
}
fn compute_scope() -> String {
let integration = std::env::var("TIRITH_INTEGRATION")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "unknown".to_string());
let integration: String = integration
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_')
.take(32)
.collect();
let cwd = std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_default();
let cwd_hash = {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
hasher.update(cwd.as_bytes());
let digest = hasher.finalize();
hex_encode_8(&digest)
};
format!("{integration}-{cwd_hash}")
}
fn hex_encode_8(bytes: &[u8]) -> String {
bytes.iter().take(4).map(|b| format!("{b:02x}")).collect()
}
fn fallback_file_path(scope: &str) -> Option<PathBuf> {
let state = crate::policy::state_dir()?;
Some(state.join("sessions").join(format!("fallback-{scope}.id")))
}
fn load_or_create_fallback_file(scope: &str) -> String {
let path = match fallback_file_path(scope) {
Some(p) => p,
None => return generate_session_id(),
};
if let Ok(meta) = std::fs::symlink_metadata(&path) {
if let Ok(modified) = meta.modified() {
if let Ok(age) = std::time::SystemTime::now().duration_since(modified) {
if age.as_secs() < FALLBACK_FILE_MAX_AGE_SECS {
if let Ok(content) = std::fs::read_to_string(&path) {
let id = content.trim().to_string();
if !id.is_empty() && id.len() <= 128 {
return id;
}
}
}
}
}
}
let new_id = generate_session_id();
write_fallback_file(&path, &new_id);
new_id
}
fn write_fallback_file(path: &PathBuf, session_id: &str) {
if let Some(parent) = path.parent() {
if let Err(e) = std::fs::create_dir_all(parent) {
crate::audit::audit_diagnostic(format!(
"tirith: session: cannot create dir {}: {e}",
parent.display()
));
return;
}
}
#[cfg(unix)]
{
match std::fs::symlink_metadata(path) {
Ok(meta) if meta.file_type().is_symlink() => {
crate::audit::audit_diagnostic(format!(
"tirith: session: refusing to follow symlink at {}",
path.display()
));
return;
}
_ => {}
}
}
let mut open_opts = std::fs::OpenOptions::new();
open_opts.write(true).create(true).truncate(true);
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
open_opts.mode(0o600);
open_opts.custom_flags(libc::O_NOFOLLOW);
}
let file = match open_opts.open(path) {
Ok(f) => f,
Err(e) => {
crate::audit::audit_diagnostic(format!(
"tirith: session: cannot write fallback {}: {e} — session ID may be unstable",
path.display()
));
return;
}
};
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = file.set_permissions(std::fs::Permissions::from_mode(0o600));
}
use std::io::Write;
let mut writer = std::io::BufWriter::new(&file);
let write_ok = writer
.write_all(session_id.as_bytes())
.and_then(|_| writer.write_all(b"\n"))
.and_then(|_| writer.flush())
.is_ok();
drop(writer);
if !write_ok {
let _ = std::fs::remove_file(path);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_id_stable_within_process() {
let id1 = session_id();
let id2 = session_id();
assert_eq!(id1, id2);
}
#[test]
fn test_generate_session_id_unique() {
let a = generate_session_id();
std::thread::sleep(std::time::Duration::from_millis(1));
let b = generate_session_id();
assert_ne!(a, b);
}
#[test]
fn test_generate_session_id_format() {
let id = generate_session_id();
assert_eq!(id.len(), 36);
assert!(uuid::Uuid::parse_str(&id).is_ok());
}
#[test]
fn test_resolve_session_id_returns_non_empty() {
let id = resolve_session_id();
assert!(!id.is_empty());
assert!(id.len() <= 128);
}
#[test]
fn test_resolve_session_id_stable_on_repeated_calls() {
let id1 = resolve_session_id();
let id2 = resolve_session_id();
assert_eq!(id1, id2);
}
#[test]
fn test_compute_scope_format() {
let scope = compute_scope();
assert!(scope.contains('-'));
let parts: Vec<&str> = scope.rsplitn(2, '-').collect();
assert_eq!(parts[0].len(), 8);
assert!(parts[0].chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_hex_encode_8() {
let bytes = [0xAB, 0xCD, 0xEF, 0x12, 0x34];
assert_eq!(hex_encode_8(&bytes), "abcdef12");
}
#[test]
fn test_hex_encode_8_short_input() {
let bytes = [0x01, 0x02];
assert_eq!(hex_encode_8(&bytes), "0102");
}
#[cfg(unix)]
#[test]
fn test_fallback_file_roundtrip() {
let _guard = crate::TEST_ENV_LOCK
.lock()
.unwrap_or_else(|e| e.into_inner());
let dir = tempfile::tempdir().unwrap();
let state_home = dir.path().join("state");
unsafe { std::env::set_var("XDG_STATE_HOME", &state_home) };
let scope = "test-integration-abcd1234";
let id = load_or_create_fallback_file(scope);
assert!(!id.is_empty());
assert!(uuid::Uuid::parse_str(&id).is_ok());
let id2 = load_or_create_fallback_file(scope);
assert_eq!(id, id2);
if let Some(path) = fallback_file_path(scope) {
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::metadata(&path).unwrap().permissions();
assert_eq!(perms.mode() & 0o777, 0o600);
}
unsafe { std::env::remove_var("XDG_STATE_HOME") };
}
#[cfg(unix)]
#[test]
fn test_env_session_id_priority() {
let resolved = resolve_session_id();
assert!(!resolved.is_empty());
}
}