use std::collections::HashMap;
use std::time::Instant;
use chrono::Utc;
use sha2::{Digest, Sha256};
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use super::cache::{ActivityCache, ActivityState, ActivityVerdict, CheckMetrics, CostTally};
#[derive(Debug, Error)]
pub enum ActivityError {
#[error("LLM classification error: {0}")]
Llm(String),
#[error("serialization error: {0}")]
Serialization(String),
#[error("OPENROUTER_API_KEY is not configured")]
MissingApiKey,
}
pub trait LlmClassifier: Send + Sync {
fn classify(
&self,
pane_text: &str,
) -> impl Future<Output = Result<(ActivityVerdict, u32, u32), ActivityError>> + Send;
}
#[derive(Debug)]
pub struct ActivityCheckResult {
pub verdict: ActivityVerdict,
pub cost: CheckMetrics,
pub cache_hit: bool,
pub tally: CostTally,
}
pub struct ActivityMonitor<C: LlmClassifier> {
cache: Mutex<HashMap<String, ActivityCache>>,
llm: C,
model: String,
}
impl<C: LlmClassifier> std::fmt::Debug for ActivityMonitor<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ActivityMonitor")
.field("model", &self.model)
.finish_non_exhaustive()
}
}
impl<C: LlmClassifier> ActivityMonitor<C> {
pub fn new(llm: C, model: impl Into<String>) -> Self {
Self {
cache: Mutex::new(HashMap::new()),
llm,
model: model.into(),
}
}
pub async fn check(
&self,
session_id: &str,
pane_text: &str,
) -> Result<ActivityCheckResult, ActivityError> {
let pane_tail = last_n_lines(pane_text, 60);
let hash = sha256_hex(pane_tail.as_bytes());
let start = Instant::now();
let mut caches = self.cache.lock().await;
let entry = caches
.entry(session_id.to_owned())
.or_insert_with(|| ActivityCache::new(&self.model));
if entry.check_unchanged(&hash) {
let verdict = entry
.last_verdict()
.cloned()
.unwrap_or_else(|| ActivityVerdict {
state: ActivityState::Unknown,
summary: "no prior verdict in cache".into(),
confidence: 0.0,
});
let metrics = CheckMetrics {
session_id: session_id.to_owned(),
at: Utc::now(),
model: self.model.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: start.elapsed().as_millis() as u64,
cache_hit: true,
verdict_state: verdict.state.clone(),
};
let tally = entry.tally().clone();
entry.update_cache_hit(&hash, metrics.clone());
debug!(session = %session_id, "activity check: cache hit");
return Ok(ActivityCheckResult {
verdict,
cost: metrics,
cache_hit: true,
tally,
});
}
drop(caches);
let classify_result = self.llm.classify(&pane_tail).await;
let (verdict, input_tokens, output_tokens) = match classify_result {
Ok(r) => r,
Err(ActivityError::MissingApiKey) => {
warn!("activity monitor: OPENROUTER_API_KEY not configured; returning Unknown");
(
ActivityVerdict {
state: ActivityState::Unknown,
summary: "OPENROUTER_API_KEY not configured".into(),
confidence: 0.0,
},
0,
0,
)
}
Err(e) => return Err(e),
};
let latency_ms = start.elapsed().as_millis() as u64;
let metrics = CheckMetrics {
session_id: session_id.to_owned(),
at: Utc::now(),
model: self.model.clone(),
input_tokens,
output_tokens,
latency_ms,
cache_hit: false,
verdict_state: verdict.state.clone(),
};
let mut caches = self.cache.lock().await;
let entry = caches
.entry(session_id.to_owned())
.or_insert_with(|| ActivityCache::new(&self.model));
entry.update_llm_hit(&hash, verdict.clone(), metrics.clone());
let tally = entry.tally().clone();
debug!(session = %session_id, state = ?verdict.state, "activity check: LLM verdict");
Ok(ActivityCheckResult {
verdict,
cost: metrics,
cache_hit: false,
tally,
})
}
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut h = Sha256::new();
h.update(bytes);
let digest = h.finalize();
let mut hex = String::with_capacity(digest.len() * 2);
for b in digest {
use std::fmt::Write as _;
let _ = write!(hex, "{b:02x}");
}
hex
}
fn last_n_lines(text: &str, n: usize) -> String {
let lines: Vec<&str> = text.lines().collect();
let start = lines.len().saturating_sub(n);
lines[start..].join("\n")
}
pub struct OpenRouterClassifier {
model: String,
}
impl OpenRouterClassifier {
pub fn new() -> Self {
let model =
std::env::var("TRUSTY_LLM_MODEL").unwrap_or_else(|_| "openai/gpt-4o-mini".to_owned());
Self { model }
}
}
impl Default for OpenRouterClassifier {
fn default() -> Self {
Self::new()
}
}
impl LlmClassifier for OpenRouterClassifier {
async fn classify(
&self,
pane_text: &str,
) -> Result<(ActivityVerdict, u32, u32), ActivityError> {
use tokio::sync::mpsc;
use trusty_common::ChatMessage;
use trusty_common::chat::{ChatEvent, ChatProvider, OpenRouterProvider};
let api_key =
std::env::var("OPENROUTER_API_KEY").map_err(|_| ActivityError::MissingApiKey)?;
let prompt = format!(
"Classify the activity state of this Claude Code terminal session.\n\
Respond ONLY with valid JSON: {{\"state\": \"<state>\", \"summary\": \"<summary>\", \"confidence\": <0.0-1.0>}}\n\
Valid states: working, idle, blocked_on_permission, errored, done, unknown\n\n\
Terminal output (last 60 lines):\n```\n{pane_text}\n```"
);
let messages = vec![ChatMessage {
role: "user".into(),
content: prompt,
tool_call_id: None,
tool_calls: None,
}];
let provider = OpenRouterProvider::new(api_key, self.model.clone());
let (tx, mut rx) = mpsc::channel::<ChatEvent>(64);
let send_fut = provider.chat_stream(messages, vec![], tx);
let mut full_text = String::new();
let (send_result, ()) = tokio::join!(send_fut, async {
while let Some(event) = rx.recv().await {
if let ChatEvent::Delta(d) = event {
full_text.push_str(&d);
}
}
});
send_result.map_err(|e| ActivityError::Llm(e.to_string()))?;
let json_str = extract_json(&full_text).unwrap_or(&full_text);
let parsed: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
ActivityError::Serialization(format!("parse failed: {e} — raw: {full_text}"))
})?;
let state_str = parsed["state"].as_str().unwrap_or("unknown");
let state = parse_state(state_str);
let summary = parsed["summary"]
.as_str()
.unwrap_or("no summary")
.to_owned();
let confidence = parsed["confidence"].as_f64().unwrap_or(0.5) as f32;
Ok((
ActivityVerdict {
state,
summary,
confidence,
},
0, 0,
))
}
}
fn extract_json(text: &str) -> Option<&str> {
let start = text.find('{')?;
let end = text.rfind('}')?;
if end > start {
Some(&text[start..=end])
} else {
None
}
}
fn parse_state(s: &str) -> ActivityState {
match s.to_ascii_lowercase().as_str() {
"working" => ActivityState::Working,
"idle" => ActivityState::Idle,
"blocked_on_permission" => ActivityState::BlockedOnPermission,
"errored" => ActivityState::Errored,
"done" => ActivityState::Done,
_ => ActivityState::Unknown,
}
}
use std::future::Future;
#[cfg(test)]
mod tests {
use super::*;
struct MockClassifier {
verdict: ActivityVerdict,
call_count: Mutex<u32>,
}
impl MockClassifier {
fn new(state: ActivityState) -> Self {
Self {
verdict: ActivityVerdict {
state,
summary: "mock".into(),
confidence: 1.0,
},
call_count: Mutex::new(0),
}
}
async fn calls(&self) -> u32 {
*self.call_count.lock().await
}
}
impl LlmClassifier for MockClassifier {
async fn classify(
&self,
_pane_text: &str,
) -> Result<(ActivityVerdict, u32, u32), ActivityError> {
*self.call_count.lock().await += 1;
Ok((self.verdict.clone(), 50, 10))
}
}
#[tokio::test]
async fn monitor_cache_miss_calls_llm() {
let classifier = MockClassifier::new(ActivityState::Working);
let monitor = ActivityMonitor::new(classifier, "test-model");
let result = monitor.check("s1", "some pane content").await.unwrap();
assert_eq!(result.verdict.state, ActivityState::Working);
assert!(!result.cache_hit);
assert_eq!(monitor.llm.calls().await, 1);
}
#[tokio::test]
async fn monitor_cache_hit_skips_llm() {
let classifier = MockClassifier::new(ActivityState::Idle);
let monitor = ActivityMonitor::new(classifier, "test-model");
let pane = "unchanged content";
let r1 = monitor.check("s1", pane).await.unwrap();
assert!(!r1.cache_hit);
let r2 = monitor.check("s1", pane).await.unwrap();
assert!(r2.cache_hit);
assert_eq!(monitor.llm.calls().await, 1);
}
#[tokio::test]
async fn monitor_different_sessions_independent_caches() {
let classifier = MockClassifier::new(ActivityState::Working);
let monitor = ActivityMonitor::new(classifier, "test-model");
monitor.check("s1", "content A").await.unwrap();
monitor.check("s2", "content B").await.unwrap();
assert_eq!(monitor.llm.calls().await, 2);
}
#[test]
fn sha256_hex_is_deterministic() {
let h1 = sha256_hex(b"hello");
let h2 = sha256_hex(b"hello");
assert_eq!(h1, h2);
assert_ne!(h1, sha256_hex(b"world"));
}
#[test]
fn last_n_lines_short() {
let text = "a\nb\nc";
assert_eq!(last_n_lines(text, 10), "a\nb\nc");
}
#[test]
fn last_n_lines_long() {
let text = (0..100)
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join("\n");
let tail = last_n_lines(&text, 60);
let lines: Vec<&str> = tail.lines().collect();
assert_eq!(lines.len(), 60);
assert_eq!(lines[0], "40");
}
#[test]
fn open_router_classifier_returns_degraded_without_key() {
let _prev = std::env::var("OPENROUTER_API_KEY").ok();
unsafe { std::env::remove_var("OPENROUTER_API_KEY") };
let e = ActivityError::MissingApiKey;
assert!(e.to_string().contains("OPENROUTER_API_KEY"));
}
}