use crate::format::safe_truncate;
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use std::time::{Duration, Instant};
static AUDIT_ENABLED: AtomicBool = AtomicBool::new(false);
fn days_from_epoch(days: u64) -> (u64, u64, u64) {
let z = days + 719468;
let era = z / 146097;
let doe = z - era * 146097; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; let y = yoe + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let d = doy - (153 * mp + 2) / 5 + 1; let m = if mp < 10 { mp + 3 } else { mp - 9 }; let y = if m <= 2 { y + 1 } else { y };
(y, m, d)
}
pub fn enable_audit_log() {
AUDIT_ENABLED.store(true, Ordering::Relaxed);
}
pub fn is_audit_enabled() -> bool {
AUDIT_ENABLED.load(Ordering::Relaxed)
}
pub fn audit_log_tool_call(
tool_name: &str,
args: &serde_json::Value,
duration_ms: u64,
success: bool,
) {
if !is_audit_enabled() {
return;
}
let _ = write_audit_entry(tool_name, args, duration_ms, success);
}
fn write_audit_entry(
tool_name: &str,
args: &serde_json::Value,
duration_ms: u64,
success: bool,
) -> std::io::Result<()> {
let dir = std::path::Path::new(".yoyo");
std::fs::create_dir_all(dir)?;
let path = dir.join("audit.jsonl");
let mut file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)?;
let ts = {
use std::time::SystemTime;
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| {
let secs = d.as_secs();
let days_since_epoch = secs / 86400;
let time_of_day = secs % 86400;
let hours = time_of_day / 3600;
let minutes = (time_of_day % 3600) / 60;
let seconds = time_of_day % 60;
let (year, month, day) = days_from_epoch(days_since_epoch);
format!(
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}",
year, month, day, hours, minutes, seconds
)
})
.unwrap_or_else(|_| "unknown".to_string())
};
let truncated_args = truncate_audit_args(args);
let entry = serde_json::json!({
"ts": ts,
"tool": tool_name,
"args": truncated_args,
"duration_ms": duration_ms,
"success": success,
});
writeln!(file, "{}", entry)?;
Ok(())
}
pub fn truncate_audit_args(args: &serde_json::Value) -> serde_json::Value {
match args {
serde_json::Value::Object(map) => {
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(k.clone(), truncate_audit_value(v));
}
serde_json::Value::Object(new_map)
}
other => other.clone(),
}
}
fn truncate_audit_value(v: &serde_json::Value) -> serde_json::Value {
match v {
serde_json::Value::String(s) if s.len() > 200 => serde_json::Value::String(format!(
"{}... [truncated, {} chars total]",
safe_truncate(s, 200),
s.len()
)),
other => other.clone(),
}
}
#[cfg(test)]
pub fn read_audit_log(n: usize) -> Vec<String> {
let path = std::path::Path::new(".yoyo").join("audit.jsonl");
match std::fs::read_to_string(&path) {
Ok(content) => {
let lines: Vec<&str> = content.lines().collect();
let start = lines.len().saturating_sub(n);
lines[start..].iter().map(|s| s.to_string()).collect()
}
Err(_) => Vec::new(),
}
}
const DEFAULT_SESSION_BUDGET_SECS: u64 = 2700;
static SESSION_BUDGET_SECS: OnceLock<Option<u64>> = OnceLock::new();
static SESSION_BUDGET_START: OnceLock<Instant> = OnceLock::new();
fn configured_session_budget() -> Option<u64> {
*SESSION_BUDGET_SECS
.get_or_init(|| parse_session_budget(std::env::var("YOYO_SESSION_BUDGET_SECS").ok()))
}
fn parse_session_budget(raw: Option<String>) -> Option<u64> {
match raw {
Some(s) if s.is_empty() => None,
Some(s) => Some(s.parse::<u64>().unwrap_or(DEFAULT_SESSION_BUDGET_SECS)),
None => None,
}
}
pub fn session_budget_remaining() -> Option<Duration> {
let budget_secs = configured_session_budget()?;
let start = SESSION_BUDGET_START.get_or_init(Instant::now);
let elapsed = start.elapsed();
let budget = Duration::from_secs(budget_secs);
Some(budget.saturating_sub(elapsed))
}
pub fn session_budget_exhausted(grace_secs: u64) -> bool {
match session_budget_remaining() {
Some(remaining) => remaining.as_secs() <= grace_secs,
None => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncate_audit_args_short_values() {
let args = serde_json::json!({"path": "src/main.rs", "command": "cargo test"});
let truncated = truncate_audit_args(&args);
assert_eq!(
truncated, args,
"Short strings should pass through unchanged"
);
}
#[test]
fn test_truncate_audit_args_long_values() {
let long_content = "x".repeat(500);
let args = serde_json::json!({"path": "test.txt", "content": long_content});
let truncated = truncate_audit_args(&args);
let content_val = truncated.get("content").unwrap().as_str().unwrap();
assert!(content_val.len() < 500, "Long content should be truncated");
assert!(
content_val.contains("... [truncated, 500 chars total]"),
"Should include truncation marker"
);
assert_eq!(truncated.get("path").unwrap().as_str().unwrap(), "test.txt");
}
#[test]
fn test_truncate_audit_args_non_string() {
let args = serde_json::json!({"count": 42, "flag": true, "ratio": 3.15});
let truncated = truncate_audit_args(&args);
assert_eq!(truncated, args, "Non-string values should pass through");
}
#[test]
fn test_truncate_audit_args_nested_object() {
let args = serde_json::json!({"meta": {"key": "value"}, "name": "test"});
let truncated = truncate_audit_args(&args);
assert_eq!(
truncated.get("meta").unwrap(),
&serde_json::json!({"key": "value"})
);
}
#[test]
fn test_audit_enabled_default_false() {
let fresh = AtomicBool::new(false);
assert!(!fresh.load(Ordering::Relaxed));
}
#[test]
fn test_read_audit_log_missing_file() {
let entries = read_audit_log(10);
let _ = entries;
}
#[test]
fn test_truncate_audit_args_exactly_200() {
let exact = "y".repeat(200);
let args = serde_json::json!({"content": exact});
let truncated = truncate_audit_args(&args);
assert_eq!(
truncated.get("content").unwrap().as_str().unwrap(),
exact,
"Exactly 200-char string should not be truncated"
);
}
#[test]
fn test_truncate_audit_args_201() {
let over = "z".repeat(201);
let args = serde_json::json!({"content": over});
let truncated = truncate_audit_args(&args);
let val = truncated.get("content").unwrap().as_str().unwrap();
assert!(
val.contains("... [truncated, 201 chars total]"),
"201-char string should be truncated"
);
}
#[test]
fn test_days_from_epoch_unix_epoch() {
let (y, m, d) = days_from_epoch(0);
assert_eq!((y, m, d), (1970, 1, 1));
}
#[test]
fn test_days_from_epoch_known_date() {
let (y, m, d) = days_from_epoch(19723);
assert_eq!((y, m, d), (2024, 1, 1));
}
#[test]
fn test_days_from_epoch_leap_year() {
let (y, m, d) = days_from_epoch(19782);
assert_eq!((y, m, d), (2024, 2, 29));
}
#[test]
fn test_days_from_epoch_y2k() {
let (y, m, d) = days_from_epoch(10957);
assert_eq!((y, m, d), (2000, 1, 1));
}
#[test]
fn test_parse_session_budget_unset() {
assert_eq!(parse_session_budget(None), None);
}
#[test]
fn test_parse_session_budget_empty() {
assert_eq!(parse_session_budget(Some(String::new())), None);
}
#[test]
fn test_parse_session_budget_valid() {
assert_eq!(parse_session_budget(Some("2700".to_string())), Some(2700));
assert_eq!(parse_session_budget(Some("0".to_string())), Some(0));
assert_eq!(parse_session_budget(Some("60".to_string())), Some(60));
}
#[test]
fn test_parse_session_budget_garbage_falls_back_to_default() {
assert_eq!(
parse_session_budget(Some("forty-five-minutes".to_string())),
Some(DEFAULT_SESSION_BUDGET_SECS)
);
assert_eq!(
parse_session_budget(Some("-1".to_string())),
Some(DEFAULT_SESSION_BUDGET_SECS)
);
}
#[test]
fn test_parse_session_budget_default_is_45_min() {
assert_eq!(DEFAULT_SESSION_BUDGET_SECS, 2700);
}
#[test]
#[serial_test::serial]
fn test_session_budget_remaining_unset_returns_none() {
if std::env::var("YOYO_SESSION_BUDGET_SECS").is_err() {
assert!(session_budget_remaining().is_none());
}
}
#[test]
fn test_session_budget_remaining_decreases_over_time() {
let budget = Duration::from_secs(60);
let start = Instant::now();
std::thread::sleep(Duration::from_millis(20));
let elapsed = start.elapsed();
let remaining = budget.saturating_sub(elapsed);
assert!(remaining < budget, "remaining should shrink as time passes");
assert!(
remaining > Duration::from_secs(50),
"20ms shouldn't burn most of a 60s budget"
);
}
#[test]
fn test_session_budget_remaining_returns_zero_after_expiry() {
let budget = Duration::from_secs(1);
let elapsed = Duration::from_secs(10);
let remaining = budget.saturating_sub(elapsed);
assert_eq!(remaining, Duration::ZERO);
}
#[test]
#[serial_test::serial]
fn test_session_budget_exhausted_unset_returns_false() {
if std::env::var("YOYO_SESSION_BUDGET_SECS").is_err() {
assert!(!session_budget_exhausted(0));
assert!(!session_budget_exhausted(30));
assert!(!session_budget_exhausted(99_999));
}
}
#[test]
fn test_session_budget_exhausted_with_headroom_returns_false() {
let budget = Duration::from_secs(9999);
let elapsed = Duration::from_millis(5);
let remaining = budget.saturating_sub(elapsed);
let exhausted = remaining.as_secs() <= 30;
assert!(
!exhausted,
"9999s budget with 5ms elapsed should have headroom"
);
}
#[test]
fn test_session_budget_exhausted_after_expiry_returns_true() {
let budget = Duration::from_secs(1);
let start = Instant::now();
std::thread::sleep(Duration::from_millis(20));
let elapsed = start.elapsed() + Duration::from_secs(10);
let remaining = budget.saturating_sub(elapsed);
let exhausted = remaining.as_secs() <= 30;
assert_eq!(remaining, Duration::ZERO);
assert!(exhausted, "expired budget must report exhausted");
}
#[test]
#[serial_test::serial]
fn test_aaa_session_budget_set_path_live_end_to_end() {
unsafe {
std::env::set_var("YOYO_SESSION_BUDGET_SECS", "9999");
}
let remaining = session_budget_remaining()
.expect("with env var set, session_budget_remaining() must return Some(_)");
assert!(
remaining > Duration::from_secs(9000),
"fresh 9999s budget should still have most of itself left, got {remaining:?}",
);
assert!(
remaining <= Duration::from_secs(9999),
"remaining should never exceed configured budget, got {remaining:?}",
);
assert!(
!session_budget_exhausted(30),
"fresh 9999s budget must not report exhausted with 30s grace",
);
assert!(
!session_budget_exhausted(0),
"fresh 9999s budget must not report exhausted with 0s grace",
);
assert!(
!session_budget_exhausted(8000),
"fresh 9999s budget must not report exhausted with 8000s grace",
);
assert!(
session_budget_exhausted(20_000),
"9999s budget must report exhausted when grace > budget",
);
}
}