use metrics_exporter_prometheus::PrometheusBuilder;
use regex::Regex;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Mutex, OnceLock};
use std::time::Instant;
use tracing::Instrument;
use tracing::{error, info, info_span, Span};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
pub const MAX_LOG_ENTRIES: usize = 100_000;
static SAMPLING_RATE_MICRO: AtomicU64 = AtomicU64::new(1_000_000);
static SAMPLE_COUNTER: AtomicU64 = AtomicU64::new(0);
pub fn set_sampling_rate(rate: f64) {
let clamped = rate.clamp(0.0, 1.0);
SAMPLING_RATE_MICRO.store((clamped * 1_000_000.0) as u64, Ordering::Relaxed);
}
pub fn sampling_rate() -> f64 {
SAMPLING_RATE_MICRO.load(Ordering::Relaxed) as f64 / 1_000_000.0
}
pub fn should_sample() -> bool {
let rate_micro = SAMPLING_RATE_MICRO.load(Ordering::Relaxed);
if rate_micro >= 1_000_000 {
return true;
}
if rate_micro == 0 {
return false;
}
let count = SAMPLE_COUNTER.fetch_add(1, Ordering::Relaxed);
(count % 1_000_000) < rate_micro
}
static TRACING_GUARD: OnceLock<Mutex<Option<tracing_appender::non_blocking::WorkerGuard>>> =
OnceLock::new();
static LOG_ENTRY_COUNT: AtomicUsize = AtomicUsize::new(0);
pub fn increment_log_count() -> usize {
LOG_ENTRY_COUNT.fetch_add(1, Ordering::Relaxed) + 1
}
pub fn log_entry_count() -> usize {
LOG_ENTRY_COUNT.load(Ordering::Relaxed)
}
pub fn rotate_if_needed() -> bool {
let mut count = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
loop {
if count >= MAX_LOG_ENTRIES {
match LOG_ENTRY_COUNT.compare_exchange_weak(
count,
count / 2,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
info!(
"Telemetry log rotation triggered: {} entries exceeded limit, reset to {}",
count,
count / 2
);
return true;
}
Err(actual) => count = actual,
}
} else {
return false;
}
}
}
pub fn sanitize_for_log(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
'\x0b' => out.push_str("\\v"),
'\x0c' => out.push_str("\\f"),
'\x1b' => out.push_str("\\e"),
'\x00' => out.push_str("\\0"),
c if c.is_control() => out.push_str(&format!("\\u{:04x}", c as u32)),
_ => out.push(c),
}
}
out
}
static SECRET_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
fn secret_patterns() -> &'static Vec<Regex> {
SECRET_PATTERNS.get_or_init(|| {
vec![
Regex::new(r"(?i)(sk-|key-|token-)[A-Za-z0-9_\-]{8,}").expect("invalid secret regex"),
Regex::new(r"(?i)Bearer\s+[A-Za-z0-9_\-\.]{8,}").expect("invalid bearer regex"),
Regex::new(r"(?i)(password|passwd|pwd)\s*=\s*\S+").expect("invalid password regex"),
]
})
}
pub fn redact_secrets(input: &str) -> String {
let mut result = input.to_string();
for pattern in secret_patterns() {
result = pattern.replace_all(&result, "[REDACTED]").to_string();
}
result
}
pub fn init_tracing() {
let filter = std::env::var("RUST_LOG").or_else(|_| std::env::var("SELFWARE_LOG_LEVEL"));
if let Ok(f) = filter {
init_tracing_with_filter(&f);
}
}
pub fn init_tracing_verbose() {
init_tracing_with_filter("info")
}
pub fn init_tracing_with_filter(filter: &str) {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
let filter_layer = EnvFilter::try_new(filter).unwrap_or_else(|_| EnvFilter::new("warn"));
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(false)
.with_thread_ids(false)
.with_thread_names(false)
.with_file(false)
.with_line_number(false)
.with_level(true)
.compact()
.with_writer(std::io::stderr);
let log_dir = dirs::data_local_dir()
.unwrap_or_else(|| std::path::PathBuf::from("."))
.join("selfware")
.join("logs");
let _ = std::fs::create_dir_all(&log_dir);
let file_appender = tracing_appender::rolling::daily(log_dir, "selfware.log");
let (non_blocking, guard) = tracing_appender::non_blocking(file_appender);
let _ = TRACING_GUARD.set(Mutex::new(Some(guard)));
let file_layer = tracing_subscriber::fmt::layer()
.with_writer(non_blocking)
.with_ansi(false)
.with_file(true)
.with_line_number(true);
let subscriber = tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.with(file_layer);
if let Ok(endpoint) = std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT") {
use opentelemetry_otlp::WithExportConfig;
if let Ok(tracer) = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(endpoint),
)
.install_batch(opentelemetry_sdk::runtime::Tokio)
{
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let _ = subscriber.with(telemetry).try_init();
return; }
}
let _ = subscriber.try_init();
});
}
pub fn shutdown_tracing() {
if let Some(guard_slot) = TRACING_GUARD.get() {
if let Ok(mut slot) = guard_slot.lock() {
drop(slot.take()); }
}
}
pub struct Metrics {
pub api_requests: AtomicU64,
pub api_errors: AtomicU64,
pub tool_executions: AtomicU64,
pub tool_errors: AtomicU64,
pub tokens_processed: AtomicU64,
}
static METRICS: Metrics = Metrics {
api_requests: AtomicU64::new(0),
api_errors: AtomicU64::new(0),
tool_executions: AtomicU64::new(0),
tool_errors: AtomicU64::new(0),
tokens_processed: AtomicU64::new(0),
};
pub fn increment_api_requests() {
METRICS.api_requests.fetch_add(1, Ordering::Relaxed);
metrics::increment_counter!("selfware_api_requests_total");
}
pub fn increment_api_errors() {
METRICS.api_errors.fetch_add(1, Ordering::Relaxed);
metrics::increment_counter!("selfware_api_errors_total");
}
pub fn increment_tool_executions() {
METRICS.tool_executions.fetch_add(1, Ordering::Relaxed);
metrics::increment_counter!("selfware_tool_executions_total");
}
pub fn increment_tool_errors() {
METRICS.tool_errors.fetch_add(1, Ordering::Relaxed);
metrics::increment_counter!("selfware_tool_errors_total");
}
pub fn add_tokens_processed(count: u64) {
METRICS.tokens_processed.fetch_add(count, Ordering::Relaxed);
metrics::counter!("selfware_tokens_processed_total", count);
}
pub fn get_metrics() -> &'static Metrics {
&METRICS
}
pub fn start_prometheus_exporter(bind_addr: std::net::SocketAddr) -> anyhow::Result<()> {
PrometheusBuilder::new()
.with_http_listener(bind_addr)
.install()
.map_err(|e| anyhow::anyhow!("Failed to start Prometheus exporter: {}", e))?;
metrics::describe_counter!(
"selfware_api_requests_total",
"Total number of LLM API requests made"
);
metrics::describe_counter!(
"selfware_api_errors_total",
"Total number of LLM API errors"
);
metrics::describe_counter!(
"selfware_tool_executions_total",
"Total number of tool executions"
);
metrics::describe_counter!(
"selfware_tool_errors_total",
"Total number of tool execution errors"
);
metrics::describe_counter!(
"selfware_tokens_processed_total",
"Total number of tokens processed"
);
Ok(())
}
#[macro_export]
macro_rules! tool_span {
($tool_name:expr) => {
tracing::info_span!(
"tool_execution",
tool_name = $tool_name,
duration_ms = tracing::field::Empty,
success = tracing::field::Empty,
error = tracing::field::Empty,
)
};
}
pub async fn track_tool_execution<F, Fut, T, E>(tool_name: &str, f: F) -> Result<T, E>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let start = Instant::now();
let safe_name = redact_secrets(&sanitize_for_log(tool_name));
let span = info_span!(
"tool.execute",
tool_name = safe_name.as_str(),
duration_ms = tracing::field::Empty,
success = tracing::field::Empty,
error = tracing::field::Empty,
);
increment_tool_executions();
async {
info!("Starting tool execution");
match f().await {
Ok(result) => {
let duration = start.elapsed().as_millis() as u64;
span.record("duration_ms", duration);
span.record("success", true);
info!(
duration_ms = duration,
"Tool execution completed successfully"
);
Ok(result)
}
Err(e) => {
increment_tool_errors();
let duration = start.elapsed().as_millis() as u64;
let safe_err = redact_secrets(&sanitize_for_log(&e.to_string()));
span.record("duration_ms", duration);
span.record("success", false);
span.record("error", safe_err.as_str());
error!(
duration_ms = duration,
error = safe_err.as_str(),
"Tool execution failed"
);
Err(e)
}
}
}
.instrument(span.clone())
.await
}
pub fn record_success() {
Span::current().record("success", true);
if should_sample() {
info!("Operation completed successfully");
}
increment_log_count();
}
pub fn record_failure(error: &str) {
let safe_err = redact_secrets(&sanitize_for_log(error));
Span::current().record("success", false);
Span::current().record("error", safe_err.as_str());
error!(error = safe_err.as_str(), "Operation failed");
}
pub fn enter_agent_step(state: &str, step: usize) -> tracing::span::Span {
let safe_state = sanitize_for_log(state);
let span = info_span!("agent.step", state = safe_state.as_str(), step = step,);
span
}
pub fn record_state_transition(from: &str, to: &str) {
let safe_from = sanitize_for_log(from);
let safe_to = sanitize_for_log(to);
if should_sample() {
info!(
from = safe_from.as_str(),
to = safe_to.as_str(),
"Agent state transition"
);
}
increment_log_count();
}
#[cfg(test)]
pub fn init_test_tracing() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_max_level(tracing::Level::DEBUG)
.try_init();
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn sampling_test_guard() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("sampling test lock poisoned")
}
fn rotation_test_guard() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn test_record_state_transition_does_not_panic() {
record_state_transition("Planning", "Executing");
}
#[test]
fn test_enter_agent_step_returns_span() {
let span = enter_agent_step("Executing", 1);
let _guard = span.enter();
}
#[test]
fn test_record_success_does_not_panic() {
record_success();
}
#[test]
fn test_record_failure_does_not_panic() {
record_failure("test error");
}
#[tokio::test]
async fn test_track_tool_execution_success() {
let result: Result<i32, &str> =
track_tool_execution("test_tool", || async { Ok(42) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_track_tool_execution_failure() {
let result: Result<i32, &str> =
track_tool_execution("test_tool", || async { Err("test error") }).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "test error");
}
#[test]
fn test_init_test_tracing_does_not_panic() {
init_test_tracing();
}
#[test]
fn test_tool_span_macro() {
let span = tool_span!("my_tool");
let _guard = span.enter();
}
#[test]
fn test_enter_agent_step_different_states() {
let states = ["Planning", "Executing", "WaitingForTool", "Completed"];
for (i, state) in states.iter().enumerate() {
let span = enter_agent_step(state, i);
let _guard = span.enter();
}
}
#[test]
fn test_record_state_transition_various() {
let transitions = [
("Idle", "Planning"),
("Planning", "Executing"),
("Executing", "WaitingForTool"),
("WaitingForTool", "Executing"),
("Executing", "Completed"),
];
for (from, to) in transitions {
record_state_transition(from, to);
}
}
#[tokio::test]
async fn test_track_tool_execution_with_delay() {
let result: Result<u64, &str> = track_tool_execution("slow_tool", || async {
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
Ok(100u64)
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 100);
}
#[test]
fn test_nested_spans() {
let outer = enter_agent_step("Outer", 0);
let _outer_guard = outer.enter();
let inner = enter_agent_step("Inner", 1);
let _inner_guard = inner.enter();
record_success();
}
#[test]
fn test_record_failure_with_various_errors() {
let errors = [
"Connection timeout",
"File not found",
"Permission denied",
"",
"Error with special chars: <>&\"'",
];
for error in errors {
record_failure(error);
}
}
#[tokio::test]
async fn test_track_tool_execution_with_string_error() {
let result: Result<(), String> = track_tool_execution("string_error_tool", || async {
Err("Custom error message".to_string())
})
.await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Custom error message");
}
#[test]
fn test_multiple_init_test_tracing_calls() {
init_test_tracing();
init_test_tracing();
init_test_tracing();
}
#[tokio::test]
async fn test_track_tool_execution_returns_correct_value() {
let result: Result<Vec<i32>, &str> =
track_tool_execution("vec_tool", || async { Ok(vec![1, 2, 3, 4, 5]) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), vec![1, 2, 3, 4, 5]);
}
#[tokio::test]
async fn test_track_tool_execution_with_complex_type() {
#[derive(Debug, PartialEq)]
struct ComplexResult {
value: i32,
name: String,
}
let result: Result<ComplexResult, &str> = track_tool_execution("complex_tool", || async {
Ok(ComplexResult {
value: 42,
name: "test".to_string(),
})
})
.await;
assert!(result.is_ok());
let res = result.unwrap();
assert_eq!(res.value, 42);
assert_eq!(res.name, "test");
}
#[tokio::test]
async fn test_track_tool_execution_error_message_preserved() {
let error_msg = "Very specific error: code 123";
let result: Result<(), String> =
track_tool_execution("error_tool", || async { Err(error_msg.to_string()) }).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), error_msg);
}
#[test]
fn test_enter_agent_step_high_step_numbers() {
let span = enter_agent_step("Testing", 999999);
let _guard = span.enter();
record_success();
}
#[test]
fn test_enter_agent_step_zero_step() {
let span = enter_agent_step("Start", 0);
let _guard = span.enter();
record_success();
}
#[test]
fn test_record_state_transition_same_state() {
record_state_transition("Running", "Running");
}
#[test]
fn test_record_state_transition_empty_states() {
record_state_transition("", "");
}
#[test]
fn test_record_failure_unicode() {
record_failure("错误: 连接失败 🔥");
record_failure("Ошибка подключения");
record_failure("エラー: 接続に失敗しました");
}
#[test]
fn test_record_success_multiple_calls() {
for _ in 0..10 {
record_success();
}
}
#[test]
fn test_record_failure_multiple_calls() {
for i in 0..10 {
record_failure(&format!("Error {}", i));
}
}
#[test]
fn test_tool_span_various_names() {
let tool_names = [
"file_read",
"shell_exec",
"cargo_build",
"git_commit",
"http_request",
"database_query",
"cache_lookup",
"",
"tool-with-dashes",
"tool.with.dots",
"tool_with_unicode_名前",
];
for name in tool_names {
let span = tool_span!(name);
let _guard = span.enter();
}
}
#[test]
fn test_enter_agent_step_long_state_name() {
let long_state = "A".repeat(1000);
let span = enter_agent_step(&long_state, 0);
let _guard = span.enter();
}
#[test]
fn test_record_failure_long_error() {
let long_error = "Error: ".to_string() + &"x".repeat(10000);
record_failure(&long_error);
}
#[tokio::test]
async fn test_track_tool_execution_unit_result() {
let result: Result<(), &str> = track_tool_execution("void_tool", || async { Ok(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_track_tool_execution_bool_result() {
let result: Result<bool, &str> =
track_tool_execution("bool_tool", || async { Ok(true) }).await;
assert!(result.is_ok());
assert!(result.unwrap());
}
#[tokio::test]
async fn test_track_tool_execution_option_in_ok() {
let result: Result<Option<i32>, &str> =
track_tool_execution("option_tool", || async { Ok(Some(42)) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(42));
}
#[tokio::test]
async fn test_track_tool_execution_none_in_ok() {
let result: Result<Option<i32>, &str> =
track_tool_execution("none_tool", || async { Ok(None) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
#[test]
fn test_span_hierarchy() {
let span1 = enter_agent_step("Level1", 0);
let _g1 = span1.enter();
let span2 = enter_agent_step("Level2", 1);
let _g2 = span2.enter();
let span3 = enter_agent_step("Level3", 2);
let _g3 = span3.enter();
record_success();
}
#[tokio::test]
async fn test_track_tool_execution_with_computation() {
let result: Result<i32, &str> = track_tool_execution("compute_tool", || async {
let mut sum = 0;
for i in 0..100 {
sum += i;
}
Ok(sum)
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 4950);
}
#[test]
fn test_record_state_transition_special_chars() {
record_state_transition("State<A>", "State<B>");
record_state_transition("State[1]", "State[2]");
record_state_transition("State{x}", "State{y}");
}
#[tokio::test]
async fn test_multiple_sequential_tool_executions() {
for i in 0..5 {
let result: Result<i32, &str> =
track_tool_execution(&format!("sequential_tool_{}", i), || async move { Ok(i) })
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), i);
}
}
#[tokio::test]
async fn test_track_tool_execution_string_return() {
let result: Result<String, &str> =
track_tool_execution("string_tool", || async { Ok("Hello, World!".to_string()) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Hello, World!");
}
#[test]
fn test_tool_span_record_fields() {
let span = tool_span!("recordable_tool");
span.record("success", true);
span.record("duration_ms", 100u64);
let _guard = span.enter();
}
#[test]
fn test_enter_agent_step_returned_span() {
let span = enter_agent_step("ValidSpan", 42);
let _guard = span.enter();
}
#[test]
fn test_concurrent_spans() {
use std::thread;
let handles: Vec<_> = (0..4)
.map(|i| {
thread::spawn(move || {
let span = enter_agent_step(&format!("Thread{}", i), i);
let _guard = span.enter();
record_success();
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_sanitize_for_log_basic() {
assert_eq!(sanitize_for_log("hello world"), "hello world");
}
#[test]
fn test_sanitize_for_log_newlines() {
assert_eq!(
sanitize_for_log("line1\nline2\r\nline3"),
"line1\\nline2\\r\\nline3"
);
}
#[test]
fn test_sanitize_for_log_control_chars() {
assert_eq!(sanitize_for_log("a\x00b\x1bc"), "a\\0b\\ec");
}
#[test]
fn test_sanitize_for_log_preserves_unicode() {
assert_eq!(sanitize_for_log("hello 世界"), "hello 世界");
}
#[test]
fn test_redact_secrets_api_key() {
let input = "Using api key sk-abc12345defghijk";
let result = redact_secrets(input);
assert!(!result.contains("sk-abc12345defghijk"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_bearer_token() {
let input = "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.test";
let result = redact_secrets(input);
assert!(!result.contains("eyJhbGciOiJIUzI1NiJ9"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_password_in_connection() {
let input = "postgres://user:password=mysecret@localhost/db";
let result = redact_secrets(input);
assert!(!result.contains("mysecret"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_preserves_normal_text() {
let input = "This is a normal log message with no secrets";
let result = redact_secrets(input);
assert_eq!(result, input);
}
#[test]
fn test_redact_secrets_token_prefix() {
let input = "token-abcdefghijklmnop is being used";
let result = redact_secrets(input);
assert!(!result.contains("token-abcdefghijklmnop"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_set_and_get_sampling_rate() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(0.5);
let rate = sampling_rate();
assert!((rate - 0.5).abs() < 0.01);
set_sampling_rate(0.0);
assert!((sampling_rate()).abs() < 0.01);
set_sampling_rate(1.0);
assert!((sampling_rate() - 1.0).abs() < 0.01);
set_sampling_rate(2.0);
assert!((sampling_rate() - 1.0).abs() < 0.01);
set_sampling_rate(-1.0);
assert!((sampling_rate()).abs() < 0.01);
set_sampling_rate(original);
}
#[test]
fn test_should_sample_full_rate() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(1.0);
for _ in 0..100 {
assert!(should_sample());
}
set_sampling_rate(original);
}
#[test]
fn test_should_sample_zero_rate() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(0.0);
for _ in 0..100 {
assert!(!should_sample());
}
set_sampling_rate(original);
}
#[test]
fn test_log_entry_count_and_increment() {
let before = log_entry_count();
let after = increment_log_count();
assert_eq!(after, before + 1);
}
#[test]
fn test_rotate_if_needed_below_limit() {
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(0, Ordering::Relaxed);
let rotated = rotate_if_needed();
assert!(!rotated);
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_max_log_entries_constant() {
assert_eq!(MAX_LOG_ENTRIES, 100_000);
}
#[test]
fn test_increment_functions_update_metrics() {
let before_api = METRICS.api_requests.load(Ordering::Relaxed);
increment_api_requests();
assert_eq!(METRICS.api_requests.load(Ordering::Relaxed), before_api + 1);
let before_errors = METRICS.api_errors.load(Ordering::Relaxed);
increment_api_errors();
assert_eq!(
METRICS.api_errors.load(Ordering::Relaxed),
before_errors + 1
);
let before_tool = METRICS.tool_executions.load(Ordering::Relaxed);
increment_tool_executions();
assert_eq!(
METRICS.tool_executions.load(Ordering::Relaxed),
before_tool + 1
);
let before_tool_err = METRICS.tool_errors.load(Ordering::Relaxed);
increment_tool_errors();
assert_eq!(
METRICS.tool_errors.load(Ordering::Relaxed),
before_tool_err + 1
);
let before_tokens = METRICS.tokens_processed.load(Ordering::Relaxed);
add_tokens_processed(42);
assert_eq!(
METRICS.tokens_processed.load(Ordering::Relaxed),
before_tokens + 42
);
}
#[test]
fn test_sanitize_for_log_tab() {
assert_eq!(sanitize_for_log("before\tafter"), "before\\tafter");
assert_eq!(sanitize_for_log("\t"), "\\t");
assert_eq!(sanitize_for_log("\t\t\t"), "\\t\\t\\t");
}
#[test]
fn test_sanitize_for_log_vertical_tab() {
assert_eq!(sanitize_for_log("a\x0bb"), "a\\vb");
assert_eq!(sanitize_for_log("\x0b"), "\\v");
}
#[test]
fn test_sanitize_for_log_form_feed() {
assert_eq!(sanitize_for_log("a\x0cb"), "a\\fb");
assert_eq!(sanitize_for_log("\x0c"), "\\f");
}
#[test]
fn test_sanitize_for_log_escape_char() {
assert_eq!(sanitize_for_log("a\x1bb"), "a\\eb");
assert_eq!(sanitize_for_log("\x1b[31m"), "\\e[31m");
}
#[test]
fn test_sanitize_for_log_null() {
assert_eq!(sanitize_for_log("a\x00b"), "a\\0b");
assert_eq!(sanitize_for_log("\x00"), "\\0");
}
#[test]
fn test_sanitize_for_log_generic_control_char() {
assert_eq!(sanitize_for_log("a\x01b"), "a\\u0001b");
assert_eq!(sanitize_for_log("x\x02y"), "x\\u0002y");
assert_eq!(sanitize_for_log("\x03"), "\\u0003");
assert_eq!(sanitize_for_log("\x04"), "\\u0004");
assert_eq!(sanitize_for_log("\x05"), "\\u0005");
assert_eq!(sanitize_for_log("\x06"), "\\u0006");
assert_eq!(sanitize_for_log("\x07"), "\\u0007");
assert_eq!(sanitize_for_log("\x0e"), "\\u000e");
assert_eq!(sanitize_for_log("\x7f"), "\\u007f");
}
#[test]
fn test_sanitize_for_log_all_special_chars_combined() {
let input = "normal\n\r\t\x0b\x0c\x1b\x00\x01text";
let expected = "normal\\n\\r\\t\\v\\f\\e\\0\\u0001text";
assert_eq!(sanitize_for_log(input), expected);
}
#[test]
fn test_sanitize_for_log_empty_string() {
assert_eq!(sanitize_for_log(""), "");
}
#[test]
fn test_sanitize_for_log_only_control_chars() {
assert_eq!(
sanitize_for_log("\x00\x01\x02\x03"),
"\\0\\u0001\\u0002\\u0003"
);
}
#[test]
fn test_redact_secrets_key_prefix() {
let input = "Using key-abcdefghijklmnop for auth";
let result = redact_secrets(input);
assert!(!result.contains("key-abcdefghijklmnop"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_passwd_variant() {
let input = "config passwd=secretvalue123 host=localhost";
let result = redact_secrets(input);
assert!(!result.contains("secretvalue123"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_pwd_variant() {
let input = "connection pwd=mypassword host=db.example.com";
let result = redact_secrets(input);
assert!(!result.contains("mypassword"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_case_insensitive_sk() {
let input = "SK-ABCDEFGHIJKLMNOP";
let result = redact_secrets(input);
assert!(!result.contains("SK-ABCDEFGHIJKLMNOP"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_case_insensitive_bearer() {
let input = "BEARER AbCdEfGhIjKlMnOp";
let result = redact_secrets(input);
assert!(!result.contains("AbCdEfGhIjKlMnOp"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_case_insensitive_password() {
let input = "PASSWORD=SuperSecret123";
let result = redact_secrets(input);
assert!(!result.contains("SuperSecret123"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_multiple_secrets_in_one_string() {
let input = "sk-key123456789 and token-abcdefghij and password=secret";
let result = redact_secrets(input);
assert!(!result.contains("sk-key123456789"));
assert!(!result.contains("token-abcdefghij"));
assert!(result.matches("[REDACTED]").count() >= 2);
}
#[test]
fn test_redact_secrets_short_key_not_redacted() {
let input = "sk-abc is too short";
let result = redact_secrets(input);
assert_eq!(result, input);
}
#[test]
fn test_redact_secrets_password_with_spaces_around_equals() {
let input = "password = mysecretpassword";
let result = redact_secrets(input);
assert!(!result.contains("mysecretpassword"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_redact_secrets_empty_string() {
assert_eq!(redact_secrets(""), "");
}
#[test]
fn test_redact_secrets_key_prefix_with_dashes_and_underscores() {
let input = "key-abc_def-ghi_jkl";
let result = redact_secrets(input);
assert!(!result.contains("key-abc_def-ghi_jkl"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_should_sample_partial_rate_exercises_counter_path() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(0.5);
let mut sampled = 0;
let total = 2_000_000; for _ in 0..total {
if should_sample() {
sampled += 1;
}
}
assert!(
sampled > 0,
"At 50% rate over 2M calls, at least some events should be sampled"
);
assert!(
sampled < total as i64 as usize,
"At 50% rate over 2M calls, not all events should be sampled"
);
set_sampling_rate(original);
}
#[test]
fn test_should_sample_low_rate_exercises_counter_path() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(0.001);
for _ in 0..1000 {
let _ = should_sample();
}
set_sampling_rate(original);
}
#[test]
fn test_rotate_if_needed_triggers_rotation() {
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(MAX_LOG_ENTRIES + 10, Ordering::Relaxed);
let rotated = rotate_if_needed();
assert!(
rotated,
"Should trigger rotation when count >= MAX_LOG_ENTRIES"
);
let after = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
assert_eq!(after, (MAX_LOG_ENTRIES + 10) / 2);
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_rotate_if_needed_exactly_at_limit() {
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(MAX_LOG_ENTRIES, Ordering::Relaxed);
let rotated = rotate_if_needed();
assert!(
rotated,
"Should trigger rotation when count == MAX_LOG_ENTRIES"
);
let after = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
assert_eq!(after, MAX_LOG_ENTRIES / 2);
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_rotate_if_needed_just_below_limit() {
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(MAX_LOG_ENTRIES - 1, Ordering::Relaxed);
let rotated = rotate_if_needed();
assert!(!rotated, "Should not rotate when count < MAX_LOG_ENTRIES");
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_rotate_if_needed_at_zero() {
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(0, Ordering::Relaxed);
let rotated = rotate_if_needed();
assert!(!rotated, "Should not rotate when count is 0");
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_rotate_if_needed_double_rotation() {
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(MAX_LOG_ENTRIES * 2, Ordering::Relaxed);
let rotated = rotate_if_needed();
assert!(rotated);
let after_first = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
assert_eq!(after_first, MAX_LOG_ENTRIES);
let rotated2 = rotate_if_needed();
assert!(rotated2);
let after_second = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
assert_eq!(after_second, MAX_LOG_ENTRIES / 2);
let rotated3 = rotate_if_needed();
assert!(!rotated3);
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_get_metrics_returns_static_ref() {
let m = get_metrics();
let before = m.api_requests.load(Ordering::Relaxed);
increment_api_requests();
let after = m.api_requests.load(Ordering::Relaxed);
assert_eq!(after, before + 1);
}
#[test]
fn test_get_metrics_all_fields_accessible() {
let m = get_metrics();
let _ = m.api_requests.load(Ordering::Relaxed);
let _ = m.api_errors.load(Ordering::Relaxed);
let _ = m.tool_executions.load(Ordering::Relaxed);
let _ = m.tool_errors.load(Ordering::Relaxed);
let _ = m.tokens_processed.load(Ordering::Relaxed);
}
#[test]
fn test_add_tokens_processed_zero() {
let before = METRICS.tokens_processed.load(Ordering::Relaxed);
add_tokens_processed(0);
let after = METRICS.tokens_processed.load(Ordering::Relaxed);
assert!(after >= before);
}
#[test]
fn test_add_tokens_processed_large_value() {
let before = METRICS.tokens_processed.load(Ordering::Relaxed);
add_tokens_processed(1_000_000);
let after = METRICS.tokens_processed.load(Ordering::Relaxed);
assert!(after >= before + 1_000_000);
}
#[test]
fn test_add_tokens_processed_multiple_adds() {
let before = METRICS.tokens_processed.load(Ordering::Relaxed);
add_tokens_processed(10);
add_tokens_processed(20);
add_tokens_processed(30);
let after = METRICS.tokens_processed.load(Ordering::Relaxed);
assert!(after >= before + 60);
}
#[test]
fn test_increment_log_count_sequential() {
let before = log_entry_count();
let r1 = increment_log_count();
let r2 = increment_log_count();
let r3 = increment_log_count();
assert_eq!(r1, before + 1);
assert_eq!(r2, before + 2);
assert_eq!(r3, before + 3);
}
#[test]
fn test_log_entry_count_matches_after_increments() {
let before = log_entry_count();
for _ in 0..5 {
increment_log_count();
}
let after = log_entry_count();
assert_eq!(after, before + 5);
}
#[test]
fn test_shutdown_tracing_no_panic_when_not_initialized() {
shutdown_tracing();
}
#[test]
fn test_shutdown_tracing_idempotent() {
shutdown_tracing();
shutdown_tracing();
shutdown_tracing();
}
#[test]
fn test_init_tracing_no_rust_log() {
init_tracing();
}
#[test]
fn test_record_failure_redacts_secrets() {
record_failure("Connection failed with sk-supersecretkey123456");
record_failure("Auth error: Bearer eyJhbGciOiJIUzI1NiJ9.payload");
record_failure("DB error: password=mysecretpass");
}
#[test]
fn test_record_state_transition_with_control_chars() {
record_state_transition("State\nA", "State\tB");
record_state_transition("From\x00", "To\x1b");
record_state_transition("From\x0b", "To\x0c");
}
#[test]
fn test_enter_agent_step_with_control_chars() {
let span = enter_agent_step("State\n\r\t\x00\x1b", 0);
let _guard = span.enter();
}
#[tokio::test]
async fn test_track_tool_execution_redacts_tool_name() {
let result: Result<i32, &str> =
track_tool_execution("tool_with_sk-secretkeyvalue123", || async { Ok(1) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_track_tool_execution_sanitizes_tool_name() {
let result: Result<i32, &str> =
track_tool_execution("tool\nwith\nnewlines", || async { Ok(1) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_track_tool_execution_error_with_secrets() {
let result: Result<i32, String> = track_tool_execution("secure_tool", || async {
Err("Failed with password=secret123".to_string())
})
.await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Failed with password=secret123");
}
#[test]
fn test_secret_patterns_initialized() {
let _ = redact_secrets("test");
let _ = redact_secrets("another test");
}
#[test]
fn test_sampling_rate_precision() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(0.123456);
let rate = sampling_rate();
assert!((rate - 0.123456).abs() < 0.000002);
set_sampling_rate(0.999999);
let rate = sampling_rate();
assert!((rate - 0.999999).abs() < 0.000002);
set_sampling_rate(0.000001);
let rate = sampling_rate();
assert!((rate - 0.000001).abs() < 0.000002);
set_sampling_rate(original);
}
#[test]
fn test_metrics_concurrent_increments() {
use std::thread;
let before_api = METRICS.api_requests.load(Ordering::Relaxed);
let before_tool = METRICS.tool_executions.load(Ordering::Relaxed);
let handles: Vec<_> = (0..10)
.map(|i| {
thread::spawn(move || {
for _ in 0..100 {
if i % 2 == 0 {
increment_api_requests();
} else {
increment_tool_executions();
}
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let after_api = METRICS.api_requests.load(Ordering::Relaxed);
let after_tool = METRICS.tool_executions.load(Ordering::Relaxed);
assert_eq!(after_api - before_api, 500);
assert_eq!(after_tool - before_tool, 500);
}
#[test]
fn test_sanitize_for_log_injection_attempt() {
let malicious = "normal input\n[ERROR] Fake admin alert: system compromised";
let sanitized = sanitize_for_log(malicious);
assert!(!sanitized.contains('\n'));
assert!(sanitized.contains("\\n"));
assert!(sanitized.contains("[ERROR] Fake admin alert: system compromised"));
}
#[test]
fn test_sanitize_for_log_carriage_return_injection() {
let malicious = "first line\r[INFO] spoofed log entry";
let sanitized = sanitize_for_log(malicious);
assert!(!sanitized.contains('\r'));
assert!(sanitized.contains("\\r"));
}
#[test]
fn test_redact_secrets_bearer_jwt_style() {
let input = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0";
let result = redact_secrets(input);
assert!(result.contains("[REDACTED]"));
assert!(!result.contains("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"));
}
#[test]
fn test_redact_secrets_token_with_underscores() {
let input = "token-abc_def_ghi_jkl_mno";
let result = redact_secrets(input);
assert!(!result.contains("token-abc_def_ghi_jkl_mno"));
assert!(result.contains("[REDACTED]"));
}
#[test]
fn test_sanitize_then_redact_pipeline() {
let input = "Error: sk-secretkey12345678\nNew line injection";
let sanitized = sanitize_for_log(input);
let redacted = redact_secrets(&sanitized);
assert!(!redacted.contains('\n'));
assert!(!redacted.contains("sk-secretkey12345678"));
assert!(redacted.contains("[REDACTED]"));
assert!(redacted.contains("\\n"));
}
#[test]
fn test_redact_then_sanitize_order_independence() {
let input = "password=secret123\ttab_separated";
let s_then_r = redact_secrets(&sanitize_for_log(input));
assert!(!s_then_r.contains('\t'));
}
#[test]
fn test_get_metrics_is_same_as_static() {
let m = get_metrics();
let before = m.tool_errors.load(Ordering::Relaxed);
increment_tool_errors();
assert_eq!(m.tool_errors.load(Ordering::Relaxed), before + 1);
}
#[test]
fn test_increment_api_requests_multiple() {
let before = METRICS.api_requests.load(Ordering::Relaxed);
for _ in 0..5 {
increment_api_requests();
}
assert_eq!(METRICS.api_requests.load(Ordering::Relaxed), before + 5);
}
#[test]
fn test_increment_api_errors_multiple() {
let before = METRICS.api_errors.load(Ordering::Relaxed);
for _ in 0..3 {
increment_api_errors();
}
assert_eq!(METRICS.api_errors.load(Ordering::Relaxed), before + 3);
}
#[test]
fn test_increment_tool_executions_multiple() {
let before = METRICS.tool_executions.load(Ordering::Relaxed);
for _ in 0..7 {
increment_tool_executions();
}
assert_eq!(METRICS.tool_executions.load(Ordering::Relaxed), before + 7);
}
#[test]
fn test_increment_tool_errors_multiple() {
let before = METRICS.tool_errors.load(Ordering::Relaxed);
for _ in 0..4 {
increment_tool_errors();
}
assert_eq!(METRICS.tool_errors.load(Ordering::Relaxed), before + 4);
}
#[test]
fn test_rotate_if_needed_concurrent() {
use std::thread;
let _guard = rotation_test_guard();
let saved = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
LOG_ENTRY_COUNT.store(MAX_LOG_ENTRIES + 100, Ordering::Relaxed);
let handles: Vec<_> = (0..4).map(|_| thread::spawn(rotate_if_needed)).collect();
let mut any_rotated = false;
for h in handles {
if h.join().unwrap() {
any_rotated = true;
}
}
assert!(
any_rotated,
"At least one thread should have triggered rotation"
);
let final_count = LOG_ENTRY_COUNT.load(Ordering::Relaxed);
assert!(
final_count < MAX_LOG_ENTRIES + 100,
"Count should have been reduced from the original"
);
LOG_ENTRY_COUNT.store(saved, Ordering::Relaxed);
}
#[test]
fn test_init_tracing_verbose_no_panic() {
init_tracing_verbose();
}
#[test]
fn test_init_tracing_with_filter_invalid_fallback() {
init_tracing_with_filter("this is not a valid filter!!!@@@");
}
#[test]
fn test_sampling_rate_boundary_zero() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(0.0);
assert_eq!(sampling_rate(), 0.0);
set_sampling_rate(original);
}
#[test]
fn test_sampling_rate_boundary_one() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(1.0);
assert_eq!(sampling_rate(), 1.0);
set_sampling_rate(original);
}
#[test]
fn test_sampling_rate_clamp_negative() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(-100.0);
assert_eq!(sampling_rate(), 0.0);
set_sampling_rate(original);
}
#[test]
fn test_sampling_rate_clamp_large_positive() {
let _guard = sampling_test_guard();
let original = sampling_rate();
set_sampling_rate(999.0);
assert_eq!(sampling_rate(), 1.0);
set_sampling_rate(original);
}
}