mod action;
mod mcp_client;
use std::str::FromStr;
use std::time::Duration;
use action::Action;
use mcp_client::McpHttpClient;
use crate::adapter::AgentAdapter;
use crate::adapter::claude_code::ClaudeCodeAdapter;
use crate::config::MEMORY_MCP_NAME;
use crate::mcp::resolve::{ResolvedKind, resolve_mcps};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TagRecallQuery {
pub tag: String,
pub keyword: String,
}
pub fn tag_recall_queries(tags: &[String]) -> anyhow::Result<Vec<TagRecallQuery>> {
tags.iter()
.map(|tag| {
validate_tag(tag)?;
Ok(TagRecallQuery {
tag: tag.clone(),
keyword: action::tag_keyword(tag),
})
})
.collect()
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BundleRecallQuery {
pub bundle: String,
pub keyword: String,
}
pub fn bundle_recall_queries(bundles: &[String]) -> anyhow::Result<Vec<BundleRecallQuery>> {
bundles
.iter()
.map(|bundle| {
validate_bundle(bundle)?;
Ok(BundleRecallQuery {
bundle: bundle.clone(),
keyword: action::bundle_keyword(bundle),
})
})
.collect()
}
const HOOK_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookEvent {
SessionStart,
TurnStart,
SessionEnd,
}
impl FromStr for HookEvent {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"session_start" => Ok(HookEvent::SessionStart),
"turn_start" => Ok(HookEvent::TurnStart),
"session_end" => Ok(HookEvent::SessionEnd),
other => Err(anyhow::anyhow!(
"unknown hook event '{other}' (expected session_start|turn_start|session_end)"
)),
}
}
}
fn dispatch(
event: HookEvent,
tag_queries: &[TagRecallQuery],
bundle_queries: &[BundleRecallQuery],
) -> Vec<Action> {
match event {
HookEvent::SessionStart => vec![Action::WakeUp],
HookEvent::TurnStart => {
let mut actions = vec![Action::Recall];
actions.extend(tag_queries.iter().cloned().map(Action::RecallTag));
actions.extend(bundle_queries.iter().cloned().map(Action::RecallBundle));
actions
}
HookEvent::SessionEnd => vec![Action::Store],
}
}
pub fn run(event: &str) -> anyhow::Result<()> {
let parsed = match HookEvent::from_str(event) {
Ok(e) => e,
Err(e) => {
eprintln!("llmenv: {e}");
return Ok(());
}
};
match run_inner(parsed) {
Ok(text) => {
let out = ClaudeCodeAdapter.emit_hook_context(&text);
if !out.is_empty() {
println!("{out}");
}
}
Err(e) => {
eprintln!("llmenv: memory {event} skipped: {e}");
}
}
Ok(())
}
fn run_inner(event: HookEvent) -> anyhow::Result<String> {
let config_path = crate::paths::config_path()?;
let config = crate::config::Config::load(&config_path)?;
let env = crate::scope::matcher::Env::detect();
let active = crate::scope::evaluate(&config, &env);
let url = memory_url(&config, &active)?
.ok_or_else(|| anyhow::anyhow!("no memory backend active for this scope"))?;
let mut tags = active.tags.iter().cloned().collect::<Vec<_>>();
tags.sort();
let bundles: Vec<String> = {
let mut set = std::collections::BTreeSet::new();
for scope in &active.scopes {
for b in &scope.enable_bundles {
set.insert(b.clone());
}
}
set.into_iter().collect()
};
let tag_queries = tag_recall_queries(&tags)?;
let bundle_queries = bundle_recall_queries(&bundles)?;
let query = tags.join(", ");
let chunk = crate::icm::generate_context_chunk(&active, &bundles);
let client = McpHttpClient::new(url, HOOK_TIMEOUT)
.map_err(|e| anyhow::anyhow!("invalid memory backend URL: {e}"))?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(async {
let mut out = String::new();
for action in dispatch(event, &tag_queries, &bundle_queries) {
let text = action.run(&client, &query, &chunk).await?;
if !text.is_empty() {
if !out.is_empty() {
out.push_str("\n\n");
}
out.push_str(&text);
}
}
Ok::<String, anyhow::Error>(out)
})
}
fn memory_url(
config: &crate::config::Config,
active: &crate::scope::ActiveScopes,
) -> anyhow::Result<Option<String>> {
let resolved = resolve_mcps(config, &active.tags)
.map_err(|e| anyhow::anyhow!("failed to resolve MCP servers: {e}"))?;
Ok(resolved.into_iter().find_map(|m| match m.kind {
ResolvedKind::Remote { url, .. } if m.name == MEMORY_MCP_NAME => Some(url),
_ => None,
}))
}
fn validate_tag(tag: &str) -> anyhow::Result<()> {
if tag.is_empty() {
return Err(anyhow::anyhow!("empty tag in recall query"));
}
if !tag
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(anyhow::anyhow!(
"tag '{}' contains invalid characters (only alphanumeric, -, _ allowed)",
tag
));
}
Ok(())
}
fn validate_bundle(bundle: &str) -> anyhow::Result<()> {
if bundle.is_empty() {
return Err(anyhow::anyhow!("empty bundle name in recall query"));
}
if !bundle
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(anyhow::anyhow!(
"bundle '{}' contains invalid characters (only alphanumeric, -, _ allowed)",
bundle
));
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn parses_neutral_event_names() {
assert_eq!(
"session_start".parse::<HookEvent>().unwrap(),
HookEvent::SessionStart
);
assert_eq!(
"turn_start".parse::<HookEvent>().unwrap(),
HookEvent::TurnStart
);
assert_eq!(
"session_end".parse::<HookEvent>().unwrap(),
HookEvent::SessionEnd
);
}
#[test]
fn rejects_unknown_event() {
assert!("nope".parse::<HookEvent>().is_err());
}
#[test]
fn dispatch_maps_events_to_actions() {
assert_eq!(
dispatch(HookEvent::SessionStart, &[], &[]),
vec![Action::WakeUp]
);
assert_eq!(
dispatch(HookEvent::TurnStart, &[], &[]),
vec![Action::Recall]
);
assert_eq!(
dispatch(HookEvent::SessionEnd, &[], &[]),
vec![Action::Store]
);
}
#[test]
fn turn_start_expands_one_recall_tag_per_active_tag() {
let tags = vec!["rust".to_string(), "work-vpn".to_string()];
let queries = tag_recall_queries(&tags).expect("valid tags");
let actions = dispatch(HookEvent::TurnStart, &queries, &[]);
assert_eq!(
actions,
vec![
Action::Recall,
Action::RecallTag(TagRecallQuery {
tag: "rust".to_string(),
keyword: "llmenv-tag:rust".to_string(),
}),
Action::RecallTag(TagRecallQuery {
tag: "work-vpn".to_string(),
keyword: "llmenv-tag:work-vpn".to_string(),
}),
],
"TurnStart must run project recall then one tag recall per active tag"
);
}
#[test]
fn turn_start_expands_one_recall_bundle_per_active_bundle() {
let bundles = vec!["base".to_string(), "rust-defaults".to_string()];
let queries = bundle_recall_queries(&bundles).expect("valid bundles");
let actions = dispatch(HookEvent::TurnStart, &[], &queries);
assert_eq!(
actions,
vec![
Action::Recall,
Action::RecallBundle(BundleRecallQuery {
bundle: "base".to_string(),
keyword: "llmenv-bundle:base".to_string(),
}),
Action::RecallBundle(BundleRecallQuery {
bundle: "rust-defaults".to_string(),
keyword: "llmenv-bundle:rust-defaults".to_string(),
}),
],
"TurnStart must emit one bundle recall per active bundle"
);
}
#[test]
fn turn_start_interleaves_tag_and_bundle_recalls() {
let tag_qs = tag_recall_queries(&["rust".to_string()]).expect("valid");
let bundle_qs = bundle_recall_queries(&["base".to_string()]).expect("valid");
let actions = dispatch(HookEvent::TurnStart, &tag_qs, &bundle_qs);
assert_eq!(actions[0], Action::Recall);
assert!(matches!(actions[1], Action::RecallTag(_)));
assert!(matches!(actions[2], Action::RecallBundle(_)));
assert_eq!(actions.len(), 3);
}
#[test]
fn validate_tag_accepts_valid_tags() {
assert!(validate_tag("base").is_ok());
assert!(validate_tag("rust-lang").is_ok());
assert!(validate_tag("work_project").is_ok());
assert!(validate_tag("tag123").is_ok());
assert!(validate_tag("my-tag_123").is_ok());
}
#[test]
fn validate_tag_rejects_empty() {
assert!(validate_tag("").is_err());
}
#[test]
fn validate_tag_rejects_special_characters() {
assert!(validate_tag("tag:space").is_err());
assert!(validate_tag("tag space").is_err());
assert!(validate_tag("tag/path").is_err());
assert!(validate_tag("tag.dot").is_err());
assert!(validate_tag("tag@at").is_err());
assert!(validate_tag("tag#hash").is_err());
assert!(validate_tag("tag$dollar").is_err());
assert!(validate_tag("tag\"quote").is_err());
}
#[test]
fn validate_tag_rejects_query_injection_attempts() {
assert!(validate_tag("tag,malicious").is_err());
assert!(validate_tag("tag OR other").is_err());
assert!(validate_tag("tag AND other").is_err());
}
#[test]
fn dispatch_tag_and_bundle_with_same_name_produce_distinct_recalls() {
let tag_qs = tag_recall_queries(&["foo".to_string()]).expect("valid");
let bundle_qs = bundle_recall_queries(&["foo".to_string()]).expect("valid");
let actions = dispatch(HookEvent::TurnStart, &tag_qs, &bundle_qs);
assert_eq!(actions.len(), 3);
match &actions[1] {
Action::RecallTag(q) => assert_eq!(q.keyword, "llmenv-tag:foo"),
other => panic!("expected RecallTag, got {other:?}"),
}
match &actions[2] {
Action::RecallBundle(q) => assert_eq!(q.keyword, "llmenv-bundle:foo"),
other => panic!("expected RecallBundle, got {other:?}"),
}
}
use proptest::prelude::*;
fn valid_name() -> impl Strategy<Value = String> {
"[a-zA-Z0-9_-]{1,24}"
}
proptest! {
#[test]
fn prop_dispatch_turn_start_ordering(
tags in proptest::collection::vec(valid_name(), 0..8),
bundles in proptest::collection::vec(valid_name(), 0..8),
) {
let tag_qs = tag_recall_queries(&tags).expect("valid tags");
let bundle_qs = bundle_recall_queries(&bundles).expect("valid bundles");
let actions = dispatch(HookEvent::TurnStart, &tag_qs, &bundle_qs);
prop_assert_eq!(actions.len(), 1 + tags.len() + bundles.len());
prop_assert!(matches!(actions[0], Action::Recall));
for a in &actions[1..=tags.len()] {
prop_assert!(matches!(a, Action::RecallTag(_)), "expected RecallTag, got {a:?}");
}
for a in &actions[1 + tags.len()..] {
prop_assert!(
matches!(a, Action::RecallBundle(_)),
"expected RecallBundle, got {a:?}"
);
}
}
}
}