mod action;
mod mcp_client;
use std::str::FromStr;
use std::time::Duration;
use action::Action;
use mcp_client::McpHttpClient;
use crate::adapter::AgentAdapter;
use crate::adapter::claude_code::ClaudeCodeAdapter;
use crate::mcp::resolve::MEMORY_MCP_NAME;
use crate::mcp::resolve::{ResolvedKind, resolve_mcps};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TagRecallQuery {
pub tag: String,
pub keyword: String,
}
pub fn tag_recall_queries(tags: &[String]) -> anyhow::Result<Vec<TagRecallQuery>> {
tags.iter()
.map(|tag| {
validate_tag(tag)?;
Ok(TagRecallQuery {
tag: tag.clone(),
keyword: action::tag_keyword(tag),
})
})
.collect()
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BundleRecallQuery {
pub bundle: String,
pub keyword: String,
}
pub fn bundle_recall_queries(bundles: &[String]) -> anyhow::Result<Vec<BundleRecallQuery>> {
bundles
.iter()
.map(|bundle| {
validate_bundle(bundle)?;
Ok(BundleRecallQuery {
bundle: bundle.clone(),
keyword: action::bundle_keyword(bundle),
})
})
.collect()
}
const HOOK_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookEvent {
SessionStart,
TurnStart,
SessionEnd,
}
impl FromStr for HookEvent {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"session_start" => Ok(HookEvent::SessionStart),
"turn_start" => Ok(HookEvent::TurnStart),
"session_end" => Ok(HookEvent::SessionEnd),
other => Err(anyhow::anyhow!(
"unknown hook event '{other}' (expected session_start|turn_start|session_end)"
)),
}
}
}
fn dispatch(
event: HookEvent,
tag_queries: &[TagRecallQuery],
bundle_queries: &[BundleRecallQuery],
) -> Vec<Action> {
match event {
HookEvent::SessionStart => vec![Action::WakeUp],
HookEvent::TurnStart => {
let mut actions = vec![Action::Recall];
actions.extend(tag_queries.iter().cloned().map(Action::RecallTag));
actions.extend(bundle_queries.iter().cloned().map(Action::RecallBundle));
actions
}
HookEvent::SessionEnd => vec![Action::Store],
}
}
pub fn run(event: &str) -> anyhow::Result<()> {
let parsed = match HookEvent::from_str(event) {
Ok(e) => e,
Err(e) => {
eprintln!("llmenv: {e}");
return Ok(());
}
};
match run_inner(parsed) {
Ok(text) => {
let out = ClaudeCodeAdapter.emit_hook_context(&text);
if !out.is_empty() {
println!("{out}");
}
}
Err(e) => {
eprintln!("llmenv: memory {event} skipped: {e}");
}
}
Ok(())
}
fn run_inner(event: HookEvent) -> anyhow::Result<String> {
let config_path = crate::paths::config_path()?;
let config = crate::config::Config::load(&config_path)?;
let env = crate::scope::matcher::Env::detect();
let active = crate::scope::evaluate(&config, &env);
let config_dir = config_path
.parent()
.ok_or_else(|| anyhow::anyhow!("config path has no parent"))?;
let url = memory_url(&config, config_dir, &active)?
.ok_or_else(|| anyhow::anyhow!("no memory backend active for this scope"))?;
let mut tags = active.tags.iter().cloned().collect::<Vec<_>>();
tags.sort();
let bundles: Vec<String> = {
let mut set = std::collections::BTreeSet::new();
for scope in &active.scopes {
for b in &scope.enable_bundles {
set.insert(b.clone());
}
}
set.into_iter().collect()
};
let tag_queries = tag_recall_queries(&tags)?;
let bundle_queries = bundle_recall_queries(&bundles)?;
let query = tags.join(", ");
let chunk = crate::icm::generate_context_chunk(&active, &bundles);
let client = McpHttpClient::new(url, HOOK_TIMEOUT)
.map_err(|e| anyhow::anyhow!("invalid memory backend URL: {e}"))?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(async {
let mut out = String::new();
for action in dispatch(event, &tag_queries, &bundle_queries) {
let text = action.run(&client, &query, &chunk).await?;
if !text.is_empty() {
if !out.is_empty() {
out.push_str("\n\n");
}
out.push_str(&text);
}
}
Ok::<String, anyhow::Error>(out)
})
}
fn memory_url(
config: &crate::config::Config,
config_dir: &std::path::Path,
active: &crate::scope::ActiveScopes,
) -> anyhow::Result<Option<String>> {
let top_memory = config
.features
.as_ref()
.map(|f| f.memory.as_slice())
.unwrap_or_default();
let manually_enabled: std::collections::BTreeSet<&str> = active
.scopes
.iter()
.flat_map(|s| s.enable_bundles.iter().map(String::as_str))
.collect();
let firing: Vec<&crate::config::Bundle> = config
.bundle
.iter()
.filter(|b| {
b.when.iter().any(|bt| active.tags.contains(bt))
|| manually_enabled.contains(b.name.as_str())
})
.collect();
let bundle_refs = build_hook_bundle_refs(config_dir, &firing);
let (bundle_memory, bundle_host) = if bundle_refs.is_empty() {
(Vec::new(), std::collections::BTreeMap::new())
} else {
let merged = crate::merge::merge(&config.capabilities, &config.native, &bundle_refs)
.unwrap_or_default();
let mem = merged
.capabilities
.features
.map(|f| f.memory)
.unwrap_or_default();
(mem, merged.capabilities.host)
};
let mut all_memory: Vec<crate::config::Memory> = top_memory
.iter()
.chain(bundle_memory.iter())
.cloned()
.collect();
crate::util::dedup(&mut all_memory);
let mut all_host = bundle_host;
for (k, v) in &config.host {
all_host.insert(k.clone(), v.clone());
}
let resolved = resolve_mcps(&config.mcp, &all_memory, &all_host, &active.tags)
.map_err(|e| anyhow::anyhow!("failed to resolve MCP servers: {e}"))?;
Ok(resolved.into_iter().find_map(|m| match m.kind {
ResolvedKind::Remote { url, .. } if m.name == MEMORY_MCP_NAME => Some(url),
_ => None,
}))
}
fn build_hook_bundle_refs(
config_dir: &std::path::Path,
firing: &[&crate::config::Bundle],
) -> Vec<crate::merge::BundleRef> {
let bundles_dir = config_dir.join("bundles");
firing
.iter()
.filter_map(|b| {
let path = bundles_dir.join(&b.name);
path.exists().then_some(crate::merge::BundleRef {
name: b.name.clone(),
path,
precedence: 1,
})
})
.collect()
}
fn validate_tag(tag: &str) -> anyhow::Result<()> {
if tag.is_empty() {
return Err(anyhow::anyhow!("empty tag in recall query"));
}
if !tag
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(anyhow::anyhow!(
"tag '{}' contains invalid characters (only alphanumeric, -, _ allowed)",
tag
));
}
Ok(())
}
fn validate_bundle(bundle: &str) -> anyhow::Result<()> {
if bundle.is_empty() {
return Err(anyhow::anyhow!("empty bundle name in recall query"));
}
if !bundle
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(anyhow::anyhow!(
"bundle '{}' contains invalid characters (only alphanumeric, -, _ allowed)",
bundle
));
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn parses_neutral_event_names() {
assert_eq!(
"session_start".parse::<HookEvent>().unwrap(),
HookEvent::SessionStart
);
assert_eq!(
"turn_start".parse::<HookEvent>().unwrap(),
HookEvent::TurnStart
);
assert_eq!(
"session_end".parse::<HookEvent>().unwrap(),
HookEvent::SessionEnd
);
}
#[test]
fn rejects_unknown_event() {
assert!("nope".parse::<HookEvent>().is_err());
}
#[test]
fn dispatch_maps_events_to_actions() {
assert_eq!(
dispatch(HookEvent::SessionStart, &[], &[]),
vec![Action::WakeUp]
);
assert_eq!(
dispatch(HookEvent::TurnStart, &[], &[]),
vec![Action::Recall]
);
assert_eq!(
dispatch(HookEvent::SessionEnd, &[], &[]),
vec![Action::Store]
);
}
#[test]
fn turn_start_expands_one_recall_tag_per_active_tag() {
let tags = vec!["rust".to_string(), "work-vpn".to_string()];
let queries = tag_recall_queries(&tags).expect("valid tags");
let actions = dispatch(HookEvent::TurnStart, &queries, &[]);
assert_eq!(
actions,
vec![
Action::Recall,
Action::RecallTag(TagRecallQuery {
tag: "rust".to_string(),
keyword: "llmenv-tag:rust".to_string(),
}),
Action::RecallTag(TagRecallQuery {
tag: "work-vpn".to_string(),
keyword: "llmenv-tag:work-vpn".to_string(),
}),
],
"TurnStart must run project recall then one tag recall per active tag"
);
}
#[test]
fn turn_start_expands_one_recall_bundle_per_active_bundle() {
let bundles = vec!["base".to_string(), "rust-defaults".to_string()];
let queries = bundle_recall_queries(&bundles).expect("valid bundles");
let actions = dispatch(HookEvent::TurnStart, &[], &queries);
assert_eq!(
actions,
vec![
Action::Recall,
Action::RecallBundle(BundleRecallQuery {
bundle: "base".to_string(),
keyword: "llmenv-bundle:base".to_string(),
}),
Action::RecallBundle(BundleRecallQuery {
bundle: "rust-defaults".to_string(),
keyword: "llmenv-bundle:rust-defaults".to_string(),
}),
],
"TurnStart must emit one bundle recall per active bundle"
);
}
#[test]
fn turn_start_interleaves_tag_and_bundle_recalls() {
let tag_qs = tag_recall_queries(&["rust".to_string()]).expect("valid");
let bundle_qs = bundle_recall_queries(&["base".to_string()]).expect("valid");
let actions = dispatch(HookEvent::TurnStart, &tag_qs, &bundle_qs);
assert_eq!(actions[0], Action::Recall);
assert!(matches!(actions[1], Action::RecallTag(_)));
assert!(matches!(actions[2], Action::RecallBundle(_)));
assert_eq!(actions.len(), 3);
}
#[test]
fn validate_tag_accepts_valid_tags() {
assert!(validate_tag("base").is_ok());
assert!(validate_tag("rust-lang").is_ok());
assert!(validate_tag("work_project").is_ok());
assert!(validate_tag("tag123").is_ok());
assert!(validate_tag("my-tag_123").is_ok());
}
#[test]
fn validate_tag_rejects_empty() {
assert!(validate_tag("").is_err());
}
#[test]
fn validate_tag_rejects_special_characters() {
assert!(validate_tag("tag:space").is_err());
assert!(validate_tag("tag space").is_err());
assert!(validate_tag("tag/path").is_err());
assert!(validate_tag("tag.dot").is_err());
assert!(validate_tag("tag@at").is_err());
assert!(validate_tag("tag#hash").is_err());
assert!(validate_tag("tag$dollar").is_err());
assert!(validate_tag("tag\"quote").is_err());
}
#[test]
fn validate_tag_rejects_query_injection_attempts() {
assert!(validate_tag("tag,malicious").is_err());
assert!(validate_tag("tag OR other").is_err());
assert!(validate_tag("tag AND other").is_err());
}
#[test]
fn dispatch_tag_and_bundle_with_same_name_produce_distinct_recalls() {
let tag_qs = tag_recall_queries(&["foo".to_string()]).expect("valid");
let bundle_qs = bundle_recall_queries(&["foo".to_string()]).expect("valid");
let actions = dispatch(HookEvent::TurnStart, &tag_qs, &bundle_qs);
assert_eq!(actions.len(), 3);
match &actions[1] {
Action::RecallTag(q) => assert_eq!(q.keyword, "llmenv-tag:foo"),
other => panic!("expected RecallTag, got {other:?}"),
}
match &actions[2] {
Action::RecallBundle(q) => assert_eq!(q.keyword, "llmenv-bundle:foo"),
other => panic!("expected RecallBundle, got {other:?}"),
}
}
use proptest::prelude::*;
fn valid_name() -> impl Strategy<Value = String> {
"[a-zA-Z0-9_-]{1,24}"
}
proptest! {
#[test]
fn prop_dispatch_turn_start_ordering(
tags in proptest::collection::vec(valid_name(), 0..8),
bundles in proptest::collection::vec(valid_name(), 0..8),
) {
let tag_qs = tag_recall_queries(&tags).expect("valid tags");
let bundle_qs = bundle_recall_queries(&bundles).expect("valid bundles");
let actions = dispatch(HookEvent::TurnStart, &tag_qs, &bundle_qs);
prop_assert_eq!(actions.len(), 1 + tags.len() + bundles.len());
prop_assert!(matches!(actions[0], Action::Recall));
for a in &actions[1..=tags.len()] {
prop_assert!(matches!(a, Action::RecallTag(_)), "expected RecallTag, got {a:?}");
}
for a in &actions[1 + tags.len()..] {
prop_assert!(
matches!(a, Action::RecallBundle(_)),
"expected RecallBundle, got {a:?}"
);
}
}
}
}