use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use trusty_mpm_core::overseer::{Overseer, OverseerContext, OverseerDecision};
const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(3);
const CHAT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CHAT_HISTORY_LIMIT: usize = 20;
const CHAT_SYSTEM_PROMPT: &str = "You are a helpful assistant integrated with \
trusty-mpm, a Claude Code session manager. You can discuss sessions, projects, \
tmux, and general questions. Be concise.";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("LLM chat is not configured (no API key)")]
NotConfigured,
#[error("LLM request failed: {0}")]
Http(String),
#[error("LLM response could not be parsed: {0}")]
Parse(String),
#[error("LLM returned an empty response")]
EmptyResponse,
}
const SYSTEM_PROMPT: &str = "You are a security overseer for an AI coding \
assistant. Evaluate the following tool use request and respond with exactly \
one word: ALLOW, BLOCK, or FLAG. BLOCK only for clearly dangerous operations \
(rm -rf, credential exfiltration, etc). FLAG for uncertain cases. ALLOW for \
normal development operations.";
#[derive(Debug)]
pub struct LlmOverseer {
api_key: String,
model: String,
client: reqwest::blocking::Client,
chat_client: reqwest::Client,
}
impl LlmOverseer {
pub fn new(model: impl Into<String>, api_key_env: &str) -> Self {
let api_key = resolve_api_key(api_key_env);
let client = reqwest::blocking::Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.unwrap_or_default();
let chat_client = reqwest::Client::builder()
.timeout(CHAT_TIMEOUT)
.build()
.unwrap_or_default();
Self {
api_key,
model: model.into(),
client,
chat_client,
}
}
pub async fn chat(
&self,
history: &mut Vec<ChatMessage>,
user_msg: &str,
) -> Result<String, LlmError> {
if self.api_key.is_empty() {
return Err(LlmError::NotConfigured);
}
history.push(ChatMessage::user(user_msg));
cap_history(history);
let body = serde_json::json!({
"model": self.model,
"messages": build_chat_messages(history),
"temperature": 0.7,
});
let response = self
.chat_client
.post(OPENROUTER_URL)
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
let json: Value = response
.json()
.await
.map_err(|e| LlmError::Parse(e.to_string()))?;
let reply = extract_reply(&json);
if reply.trim().is_empty() {
return Err(LlmError::EmptyResponse);
}
history.push(ChatMessage::assistant(reply.clone()));
cap_history(history);
Ok(reply)
}
fn evaluate(&self, tool: &str, input: &str) -> OverseerDecision {
let user_message = format!("Tool: {tool}\nInput: {input}");
let body = serde_json::json!({
"model": self.model,
"messages": [
{ "role": "system", "content": SYSTEM_PROMPT },
{ "role": "user", "content": user_message },
],
"max_tokens": 16,
"temperature": 0.0,
});
let response = self
.client
.post(OPENROUTER_URL)
.bearer_auth(&self.api_key)
.json(&body)
.send();
match response {
Ok(resp) => match resp.json::<Value>() {
Ok(json) => {
let reply = extract_reply(&json);
parse_verdict(&reply)
}
Err(e) => {
tracing::warn!("LLM overseer: bad response body: {e}; allowing");
OverseerDecision::Allow
}
},
Err(e) => {
tracing::warn!("LLM overseer: request failed: {e}; allowing");
OverseerDecision::Allow
}
}
}
}
impl Overseer for LlmOverseer {
fn pre_tool_use(&self, ctx: &OverseerContext) -> OverseerDecision {
if !self.is_enabled() {
return OverseerDecision::Allow;
}
let tool = ctx.tool_name.as_deref().unwrap_or("unknown");
let input = ctx.tool_input.as_deref().unwrap_or("");
self.evaluate(tool, input)
}
fn post_tool_use(&self, _ctx: &OverseerContext, _output: &str) -> OverseerDecision {
OverseerDecision::Allow
}
fn session_question(&self, _ctx: &OverseerContext, question: &str) -> OverseerDecision {
OverseerDecision::FlagForHuman {
summary: format!("session question needs review: {question}"),
}
}
fn is_enabled(&self) -> bool {
!self.api_key.is_empty()
}
}
fn extract_reply(json: &Value) -> String {
json.get("choices")
.and_then(Value::as_array)
.and_then(|c| c.first())
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(Value::as_str)
.unwrap_or("")
.to_string()
}
fn parse_verdict(reply: &str) -> OverseerDecision {
let upper = reply.to_uppercase();
if upper.contains("BLOCK") {
OverseerDecision::Block {
reason: format!("LLM overseer blocked the tool use: {}", reply.trim()),
}
} else if upper.contains("FLAG") {
OverseerDecision::FlagForHuman {
summary: format!("LLM overseer flagged the tool use: {}", reply.trim()),
}
} else {
OverseerDecision::Allow
}
}
fn resolve_api_key(var_name: &str) -> String {
for file in [".env.local", ".env"] {
if let Some(value) = read_dotenv_key(std::path::Path::new(file), var_name) {
return value;
}
}
std::env::var(var_name).unwrap_or_default()
}
fn read_dotenv_key(path: &std::path::Path, var_name: &str) -> Option<String> {
let contents = std::fs::read_to_string(path).ok()?;
for line in contents.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=')
&& key.trim() == var_name
{
let value = value.trim().trim_matches('"').trim_matches('\'').trim();
if !value.is_empty() {
return Some(value.to_string());
}
}
}
None
}
fn cap_history(history: &mut Vec<ChatMessage>) {
if history.len() > CHAT_HISTORY_LIMIT {
let overflow = history.len() - CHAT_HISTORY_LIMIT;
history.drain(0..overflow);
}
}
fn build_chat_messages(history: &[ChatMessage]) -> Vec<Value> {
let mut messages = Vec::with_capacity(history.len() + 1);
messages.push(serde_json::json!({
"role": "system",
"content": CHAT_SYSTEM_PROMPT,
}));
for msg in history {
messages.push(serde_json::json!({
"role": msg.role,
"content": msg.content,
}));
}
messages
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_verdict_block() {
let decision = parse_verdict("BLOCK");
assert!(matches!(decision, OverseerDecision::Block { .. }));
}
#[test]
fn parse_verdict_flag() {
let decision = parse_verdict("FLAG");
assert!(matches!(decision, OverseerDecision::FlagForHuman { .. }));
}
#[test]
fn parse_verdict_allow() {
assert_eq!(parse_verdict("ALLOW"), OverseerDecision::Allow);
}
#[test]
fn parse_verdict_is_case_insensitive() {
assert!(matches!(
parse_verdict("block"),
OverseerDecision::Block { .. }
));
assert!(matches!(
parse_verdict("flag"),
OverseerDecision::FlagForHuman { .. }
));
}
#[test]
fn parse_verdict_tolerates_surrounding_prose() {
assert!(matches!(
parse_verdict("I would BLOCK this — it deletes the repo."),
OverseerDecision::Block { .. }
));
}
#[test]
fn parse_verdict_block_wins_over_flag() {
assert!(matches!(
parse_verdict("BLOCK, do not FLAG"),
OverseerDecision::Block { .. }
));
}
#[test]
fn parse_verdict_empty_is_allow() {
assert_eq!(parse_verdict(""), OverseerDecision::Allow);
assert_eq!(
parse_verdict("something else entirely"),
OverseerDecision::Allow
);
}
#[test]
fn extract_reply_reads_content() {
let json = serde_json::json!({
"choices": [ { "message": { "content": "BLOCK" } } ]
});
assert_eq!(extract_reply(&json), "BLOCK");
}
#[test]
fn extract_reply_handles_missing() {
assert_eq!(extract_reply(&serde_json::json!({})), "");
assert_eq!(extract_reply(&serde_json::json!({ "choices": [] })), "");
}
#[test]
fn disabled_without_key() {
let overseer = LlmOverseer::new("test-model", "TRUSTY_MPM_NO_SUCH_KEY_VAR");
assert!(!overseer.is_enabled());
let ctx = OverseerContext::new(
trusty_mpm_core::session::SessionId::new(),
"tmpm-test",
Some("Bash".into()),
Some("ls".into()),
);
assert_eq!(overseer.pre_tool_use(&ctx), OverseerDecision::Allow);
}
#[test]
fn enabled_with_key() {
unsafe {
std::env::set_var("TRUSTY_MPM_TEST_LLM_KEY", "sk-test-123");
}
let overseer = LlmOverseer::new("test-model", "TRUSTY_MPM_TEST_LLM_KEY");
assert!(overseer.is_enabled());
unsafe {
std::env::remove_var("TRUSTY_MPM_TEST_LLM_KEY");
}
}
#[test]
fn read_dotenv_key_parses_value() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join(".env");
let mut file = std::fs::File::create(&path).unwrap();
writeln!(file, "# a comment").unwrap();
writeln!(file, "OTHER=ignored").unwrap();
writeln!(file, "OPENROUTER_API_KEY=\"sk-or-v1-abc\"").unwrap();
let value = read_dotenv_key(&path, "OPENROUTER_API_KEY");
assert_eq!(value.as_deref(), Some("sk-or-v1-abc"));
}
#[test]
fn read_dotenv_key_missing_file() {
let value = read_dotenv_key(std::path::Path::new("/no/such/.env"), "ANY");
assert!(value.is_none());
}
#[test]
fn cap_history_keeps_recent() {
let mut history: Vec<ChatMessage> = (0..CHAT_HISTORY_LIMIT + 4)
.map(|i| ChatMessage::user(format!("msg-{i}")))
.collect();
cap_history(&mut history);
assert_eq!(history.len(), CHAT_HISTORY_LIMIT);
assert_eq!(
history.last().unwrap().content,
format!("msg-{}", CHAT_HISTORY_LIMIT + 3)
);
}
#[test]
fn cap_history_leaves_short_history() {
let mut history = vec![ChatMessage::user("a"), ChatMessage::assistant("b")];
cap_history(&mut history);
assert_eq!(history.len(), 2);
}
#[test]
fn build_chat_messages_includes_history() {
let history = vec![
ChatMessage::user("hello"),
ChatMessage::assistant("hi there"),
];
let messages = build_chat_messages(&history);
assert_eq!(messages.len(), 3);
assert_eq!(messages[0]["role"], "system");
assert_eq!(messages[1]["role"], "user");
assert_eq!(messages[1]["content"], "hello");
assert_eq!(messages[2]["role"], "assistant");
assert_eq!(messages[2]["content"], "hi there");
}
#[tokio::test]
async fn chat_without_key_is_not_configured() {
let overseer = tokio::task::spawn_blocking(|| {
LlmOverseer::new("test-model", "TRUSTY_MPM_NO_SUCH_CHAT_KEY")
})
.await
.expect("build overseer");
let mut history = Vec::new();
let err = overseer.chat(&mut history, "hello").await.unwrap_err();
assert!(matches!(err, LlmError::NotConfigured));
tokio::task::spawn_blocking(move || drop(overseer))
.await
.expect("drop overseer");
}
#[test]
fn chat_message_constructors_set_role() {
assert_eq!(ChatMessage::user("x").role, "user");
assert_eq!(ChatMessage::assistant("y").role, "assistant");
}
}