use anyhow::Result;
use serde_json::Value;
use std::time::{Duration, Instant};
use crate::hook_emit::{post_hook_event, HookEventPayload};
use crate::prompt_log::{PromptLogEntry, PromptLogger};
use crate::{hook_prompt_excerpt, HookType, InjectionKind};
const PROMPT_CONTEXT_PATH: &str = "/api/v1/kg/prompt-context";
const PALACE_RECALL_PATH: &str = "/api/v1/palaces/{slug}/recall";
const PALACE_KG_ALL_PATH: &str = "/api/v1/palaces/{slug}/kg/all";
const HTTP_TIMEOUT: Duration = Duration::from_millis(2500);
const DEFAULT_TOP_K: usize = 5;
const INJECTION_BYTE_CAP: usize = 4 * 1024;
const DRAWER_PREVIEW_CHARS: usize = 220;
pub const ENV_TOP_K: &str = "TRUSTY_MEMORY_PROMPT_TOP_K";
pub const ENV_RECALL_DENY_TAGS: &str = "TRUSTY_MEMORY_PROMPT_RECALL_DENY_TAGS";
const DEFAULT_DENY_TAGS: &[&str] = &["claude-session", "user-prompt"];
const EMPTY_PLACEHOLDER: &str = "No prompt facts stored yet.";
pub async fn handle_prompt_context() -> Result<()> {
let start = Instant::now();
let trigger_payload = read_stdin_best_effort();
let body = build_injection_body(&trigger_payload).await;
if body.ends_with('\n') {
print!("{body}");
} else {
println!("{body}");
}
emit_hook_event(&trigger_payload, &body, start).await;
Ok(())
}
async fn emit_hook_event(trigger_payload: &str, injection: &str, start: Instant) {
let user_prompt = parse_user_prompt(trigger_payload);
let palace_id = resolve_palace_slug(trigger_payload);
let payload = HookEventPayload {
palace_id: palace_id.clone(),
palace_name: palace_id,
hook_type: HookType::UserPromptSubmit,
injection_kind: InjectionKind::PromptContext,
injection_length: injection.len() as u64,
trigger_prompt_excerpt: hook_prompt_excerpt(&user_prompt),
duration_ms: start.elapsed().as_millis() as u64,
};
post_hook_event(payload).await;
}
pub(crate) async fn build_injection_body(trigger_payload: &str) -> String {
let start = Instant::now();
let user_prompt = parse_user_prompt(trigger_payload);
let addr = match trusty_common::read_daemon_addr("trusty-memory") {
Ok(Some(addr)) => addr,
Ok(None) | Err(_) => {
log_entry(trigger_payload, "", 0, start);
return String::new();
}
};
let base = if addr.starts_with("http://") || addr.starts_with("https://") {
addr
} else {
format!("http://{addr}")
};
let client = match reqwest::Client::builder()
.timeout(HTTP_TIMEOUT)
.connect_timeout(HTTP_TIMEOUT)
.build()
{
Ok(c) => c,
Err(_) => {
log_entry(trigger_payload, "", 0, start);
return String::new();
}
};
let palace_slug = resolve_palace_slug(trigger_payload);
let global_facts = fetch_global_prompt_context(&client, &base).await;
let (drawers, kg_triples) = match &palace_slug {
Some(slug) => {
let top_k = configured_top_k();
let drawers_fut = fetch_palace_recall(&client, &base, slug, &user_prompt, top_k);
let kg_fut = fetch_palace_kg_triples(&client, &base, slug);
let (drawers, kg_all) = tokio::join!(drawers_fut, kg_fut);
let deny_tags = configured_deny_tags();
let drawers = filter_drawers_by_deny_tags(drawers, &deny_tags);
let kg_filtered = select_relevant_triples(&kg_all, &user_prompt, top_k);
(drawers, kg_filtered)
}
None => (Vec::new(), Vec::new()),
};
let composed = compose_injection(
global_facts.as_deref(),
&drawers,
&kg_triples,
palace_slug.as_deref(),
);
let body = if composed.is_empty() {
EMPTY_PLACEHOLDER.to_string()
} else {
composed
};
let facts_count = count_facts(&body);
log_entry(trigger_payload, &body, facts_count, start);
body
}
fn read_stdin_best_effort() -> String {
use std::io::Read;
const STDIN_CAP_BYTES: usize = 64 * 1024;
let stdin = std::io::stdin();
if std::io::IsTerminal::is_terminal(&stdin) {
return String::new();
}
let mut buf = String::new();
let _ = stdin
.lock()
.take(STDIN_CAP_BYTES as u64)
.read_to_string(&mut buf);
buf
}
fn parse_user_prompt(stdin_payload: &str) -> String {
if stdin_payload.trim().is_empty() {
return String::new();
}
if let Ok(value) = serde_json::from_str::<Value>(stdin_payload) {
if let Some(p) = value.get("prompt").and_then(|v| v.as_str()) {
return p.trim().to_string();
}
}
stdin_payload.trim().to_string()
}
fn configured_top_k() -> usize {
std::env::var(ENV_TOP_K)
.ok()
.and_then(|v| v.trim().parse::<usize>().ok())
.map(|k| k.clamp(1, 20))
.unwrap_or(DEFAULT_TOP_K)
}
fn configured_deny_tags() -> Vec<String> {
if let Ok(raw) = std::env::var(ENV_RECALL_DENY_TAGS) {
let parsed: Vec<String> = raw
.split(',')
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.collect();
if !parsed.is_empty() {
return parsed;
}
}
DEFAULT_DENY_TAGS.iter().map(|s| s.to_lowercase()).collect()
}
fn filter_drawers_by_deny_tags(
drawers: Vec<RecalledDrawer>,
deny_tags: &[String],
) -> Vec<RecalledDrawer> {
if deny_tags.is_empty() {
return drawers;
}
drawers
.into_iter()
.filter(|d| {
if d.tags.is_empty() {
return true;
}
!d.tags
.iter()
.any(|t| deny_tags.iter().any(|deny| deny.eq_ignore_ascii_case(t)))
})
.collect()
}
async fn fetch_global_prompt_context(client: &reqwest::Client, base: &str) -> Option<String> {
let url = format!("{base}{PROMPT_CONTEXT_PATH}");
let resp = client.get(&url).send().await.ok()?;
if !resp.status().is_success() {
return None;
}
let body = resp.text().await.ok()?;
let trimmed = body.trim();
if trimmed.is_empty() || trimmed == EMPTY_PLACEHOLDER {
None
} else {
Some(body)
}
}
async fn fetch_palace_recall(
client: &reqwest::Client,
base: &str,
palace: &str,
prompt: &str,
top_k: usize,
) -> Vec<RecalledDrawer> {
if prompt.is_empty() {
return Vec::new();
}
let path = PALACE_RECALL_PATH.replace("{slug}", palace);
let url = format!("{base}{path}");
let resp = match client
.get(&url)
.query(&[("q", prompt.to_string()), ("top_k", top_k.to_string())])
.send()
.await
{
Ok(r) => r,
Err(_) => return Vec::new(),
};
if !resp.status().is_success() {
return Vec::new();
}
let body: Value = match resp.json().await {
Ok(b) => b,
Err(_) => return Vec::new(),
};
let Some(arr) = body.as_array() else {
return Vec::new();
};
arr.iter()
.filter_map(RecalledDrawer::from_recall_entry)
.filter(|d| d.layer.unwrap_or(0) > 0)
.take(top_k)
.collect()
}
async fn fetch_palace_kg_triples(
client: &reqwest::Client,
base: &str,
palace: &str,
) -> Vec<RawTriple> {
let path = PALACE_KG_ALL_PATH.replace("{slug}", palace);
let url = format!("{base}{path}");
let resp = match client.get(&url).query(&[("limit", "200")]).send().await {
Ok(r) => r,
Err(_) => return Vec::new(),
};
if !resp.status().is_success() {
return Vec::new();
}
let body: Value = match resp.json().await {
Ok(b) => b,
Err(_) => return Vec::new(),
};
let Some(arr) = body.as_array() else {
return Vec::new();
};
arr.iter().filter_map(RawTriple::from_value).collect()
}
fn select_relevant_triples(triples: &[RawTriple], prompt: &str, top_k: usize) -> Vec<RawTriple> {
use std::collections::HashSet;
let words: HashSet<String> = prompt
.to_lowercase()
.split(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
.filter(|w| w.len() >= 3)
.map(|w| w.to_string())
.collect();
if words.is_empty() {
return Vec::new();
}
let mut out: Vec<RawTriple> = Vec::with_capacity(top_k);
for t in triples {
if triple_overlaps(t, &words) {
out.push(t.clone());
if out.len() >= top_k {
break;
}
}
}
out
}
fn triple_overlaps(t: &RawTriple, prompt_words: &std::collections::HashSet<String>) -> bool {
let candidates = [t.subject.as_str(), t.object.as_str()];
for candidate in candidates {
for tok in candidate
.to_lowercase()
.split(|c: char| c == ':' || c.is_whitespace() || c == '_' || c == '-' || c == '/')
{
if tok.len() >= 3 && prompt_words.contains(tok) {
return true;
}
}
}
false
}
fn compose_injection(
global_facts: Option<&str>,
drawers: &[RecalledDrawer],
triples: &[RawTriple],
palace_slug: Option<&str>,
) -> String {
let mut out = String::new();
if let Some(facts) = global_facts {
push_section(&mut out, facts.trim_end());
}
if !drawers.is_empty() {
let mut section = String::new();
if let Some(slug) = palace_slug {
section.push_str(&format!("## Relevant memories from palace `{slug}`\n"));
} else {
section.push_str("## Relevant memories\n");
}
for d in drawers {
section.push_str("- ");
section.push_str(&drawer_preview(&d.content));
if !d.tags.is_empty() {
section.push_str(" _(tags: ");
let tags = d
.tags
.iter()
.map(|t| format!("`{t}`"))
.collect::<Vec<_>>()
.join(", ");
section.push_str(&tags);
section.push(')');
section.push('_');
}
section.push('\n');
}
push_section(&mut out, section.trim_end());
}
if !triples.is_empty() {
let mut section = String::new();
section.push_str("## Relevant KG facts\n");
for t in triples {
section.push_str(&format!(
"- {} **{}** {}\n",
t.subject, t.predicate, t.object
));
}
push_section(&mut out, section.trim_end());
}
if out.len() > INJECTION_BYTE_CAP {
const ELLIPSIS: char = '…';
let ellipsis_len = ELLIPSIS.len_utf8();
let mut cut = INJECTION_BYTE_CAP.saturating_sub(ellipsis_len);
while cut > 0 && !out.is_char_boundary(cut) {
cut -= 1;
}
out.truncate(cut);
out.push(ELLIPSIS);
}
out
}
fn push_section(out: &mut String, section: &str) {
if section.is_empty() {
return;
}
if !out.is_empty() {
if !out.ends_with('\n') {
out.push('\n');
}
out.push('\n');
}
out.push_str(section);
}
fn drawer_preview(content: &str) -> String {
let normalised: String = content.split_whitespace().collect::<Vec<_>>().join(" ");
if normalised.chars().count() <= DRAWER_PREVIEW_CHARS {
normalised
} else {
let kept: String = normalised
.chars()
.take(DRAWER_PREVIEW_CHARS.saturating_sub(1))
.collect();
format!("{kept}…")
}
}
fn count_facts(body: &str) -> usize {
body.lines()
.filter(|l| l.trim_start().starts_with("- "))
.count()
}
fn resolve_palace_slug(stdin_payload: &str) -> Option<String> {
if let Some(slug) = palace_slug_from_stdin_cwd(stdin_payload) {
return Some(slug);
}
crate::messaging::cwd_palace_slug().ok()
}
fn resolve_palace_for_log(stdin_payload: &str) -> String {
resolve_palace_slug(stdin_payload).unwrap_or_else(|| "<unknown>".to_string())
}
fn palace_slug_from_stdin_cwd(stdin_payload: &str) -> Option<String> {
if stdin_payload.trim().is_empty() {
return None;
}
let value: Value = serde_json::from_str(stdin_payload).ok()?;
let cwd = value.get("cwd")?.as_str()?;
if cwd.is_empty() {
return None;
}
crate::messaging::cwd_palace_slug_at(std::path::Path::new(cwd)).ok()
}
fn log_entry(trigger_prompt: &str, injection: &str, facts_count: usize, start: Instant) {
let logger = PromptLogger::from_env();
let palace = resolve_palace_for_log(trigger_prompt);
let entry = PromptLogEntry::new(
"UserPromptSubmit",
"prompt-context-facts",
palace,
trigger_prompt,
injection,
)
.with_palace_facts_count(facts_count)
.with_duration_ms(start.elapsed().as_millis() as u64);
logger.log(entry);
}
#[derive(Debug, Clone)]
struct RecalledDrawer {
content: String,
tags: Vec<String>,
layer: Option<u8>,
}
impl RecalledDrawer {
fn from_recall_entry(v: &Value) -> Option<Self> {
let content = v.get("content")?.as_str()?.to_string();
let tags = v
.get("tags")
.and_then(|t| t.as_array())
.map(|arr| {
arr.iter()
.filter_map(|t| t.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let layer = v.get("layer").and_then(|l| l.as_u64()).map(|n| n as u8);
if content.trim().is_empty() {
return None;
}
Some(Self {
content,
tags,
layer,
})
}
}
#[derive(Debug, Clone)]
struct RawTriple {
subject: String,
predicate: String,
object: String,
}
impl RawTriple {
fn from_value(v: &Value) -> Option<Self> {
let subject = v.get("subject")?.as_str()?.to_string();
let predicate = v.get("predicate")?.as_str()?.to_string();
let object = v.get("object")?.as_str()?.to_string();
Some(Self {
subject,
predicate,
object,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "axum-server")]
use serde_json::json;
#[test]
fn parse_user_prompt_prefers_prompt_field() {
let json_with_prompt = serde_json::json!({
"prompt": "what is rust?",
"cwd": "/tmp/example",
})
.to_string();
assert_eq!(parse_user_prompt(&json_with_prompt), "what is rust?");
let json_without_prompt = serde_json::json!({"cwd": "/tmp/example"}).to_string();
assert_eq!(parse_user_prompt(&json_without_prompt), json_without_prompt);
assert_eq!(parse_user_prompt("plain text query"), "plain text query");
assert_eq!(parse_user_prompt(""), "");
}
#[test]
fn filter_drawers_by_deny_tags_handles_edge_cases() {
let make = |tags: &[&str]| RecalledDrawer {
content: "irrelevant".into(),
tags: tags.iter().map(|s| s.to_string()).collect(),
layer: Some(2),
};
let drawers = vec![make(&["claude-session"]), make(&["rust"])];
let out = filter_drawers_by_deny_tags(drawers.clone(), &[]);
assert_eq!(out.len(), 2, "empty deny list must pass everything");
let drawers = vec![make(&["Claude-Session"]), make(&["rust"])];
let out = filter_drawers_by_deny_tags(drawers, &["claude-session".to_string()]);
assert_eq!(out.len(), 1);
assert!(out[0].tags.iter().any(|t| t == "rust"));
let drawers = vec![make(&[]), make(&["user-prompt"])];
let out = filter_drawers_by_deny_tags(drawers, &["user-prompt".to_string()]);
assert_eq!(out.len(), 1, "tagless drawers must survive the filter");
assert!(out[0].tags.is_empty());
let drawers = vec![
make(&["claude-session"]),
make(&["user-prompt"]),
make(&["signal"]),
];
let out = filter_drawers_by_deny_tags(
drawers,
&["claude-session".to_string(), "user-prompt".to_string()],
);
assert_eq!(out.len(), 1);
assert_eq!(out[0].tags, vec!["signal".to_string()]);
}
#[test]
fn select_relevant_triples_filters_by_prompt_overlap() {
let triples = vec![
RawTriple {
subject: "tga".into(),
predicate: "is_alias_for".into(),
object: "trusty-git-analytics".into(),
},
RawTriple {
subject: "python".into(),
predicate: "is-a".into(),
object: "language".into(),
},
RawTriple {
subject: "rust".into(),
predicate: "is-a".into(),
object: "language".into(),
},
];
let chosen = select_relevant_triples(&triples, "tell me about rust integration", 5);
assert_eq!(chosen.len(), 1, "only the rust triple should match");
assert_eq!(chosen[0].subject, "rust");
let none = select_relevant_triples(&triples, "weather forecast next week", 5);
assert!(none.is_empty());
}
#[test]
fn compose_injection_truncates_at_cap() {
let big_global = "## Big block\n".to_string() + &"- fact line\n".repeat(500);
let drawers: Vec<RecalledDrawer> = (0..5)
.map(|i| RecalledDrawer {
content: format!("drawer {i} content"),
tags: vec!["tag1".into()],
layer: Some(2),
})
.collect();
let triples: Vec<RawTriple> = (0..5)
.map(|i| RawTriple {
subject: format!("subject{i}"),
predicate: "p".into(),
object: "object".into(),
})
.collect();
let out = compose_injection(Some(&big_global), &drawers, &triples, Some("alpha"));
assert!(
out.len() <= INJECTION_BYTE_CAP,
"expected len <= cap; got {}",
out.len()
);
assert!(
out.ends_with('…'),
"expected `…` truncation marker; got tail: {}",
&out[out.len().saturating_sub(20)..]
);
}
#[test]
fn compose_injection_empty_inputs_yields_empty() {
let out = compose_injection(None, &[], &[], Some("alpha"));
assert!(out.is_empty(), "got: {out:?}");
}
#[test]
fn resolve_palace_for_log_prefers_stdin_cwd() {
let tmp = tempfile::tempdir().expect("tempdir");
let project = tmp.path().join("stdin-driven-project");
std::fs::create_dir_all(&project).expect("create project dir");
let payload = serde_json::json!({
"hook_event_name": "UserPromptSubmit",
"cwd": project.to_string_lossy(),
"prompt": "hello"
})
.to_string();
let expected =
crate::messaging::cwd_palace_slug_at(&project).expect("derive slug from stdin cwd");
let got = resolve_palace_for_log(&payload);
assert_eq!(
got, expected,
"stdin `cwd` must override the process cwd for the log palace slug"
);
assert!(
got.contains("stdin-driven-project"),
"expected slug derived from stdin path, got {got:?}"
);
}
#[test]
fn resolve_palace_for_log_falls_back_to_process_cwd() {
let from_empty = resolve_palace_for_log("");
let from_garbage = resolve_palace_for_log("not json at all");
assert_eq!(from_empty, from_garbage);
assert_ne!(from_empty, "<unknown>");
}
#[tokio::test]
async fn prompt_context_returns_ok_without_daemon() {
let _guard = crate::commands::env_test_lock().lock().await;
let tmp = tempfile::tempdir().expect("tempdir");
unsafe {
std::env::set_var(trusty_common::DATA_DIR_OVERRIDE_ENV, tmp.path());
}
let res = handle_prompt_context().await;
unsafe {
std::env::remove_var(trusty_common::DATA_DIR_OVERRIDE_ENV);
}
assert!(
res.is_ok(),
"missing daemon lockfile must degrade to Ok(()), got {res:?}"
);
}
#[cfg(feature = "axum-server")]
#[tokio::test]
async fn prompt_context_recalls_palace_drawers() {
let _guard = crate::commands::env_test_lock().lock().await;
let (state, _data_dir_tmp, _project_dir_tmp, project_dir, slug, addr_handle) =
spin_up_test_daemon_with_palace("prompt-ctx-recall-pop").await;
for (text, tags) in [
(
"Rust integration uses tokio for async tasks and serde for JSON",
vec!["rust", "tokio"],
),
(
"Python bindings ship via PyO3 with custom ABI shims",
vec!["python", "pyo3"],
),
(
"Knowledge graph stores triples in redb with valid_from intervals",
vec!["kg", "redb"],
),
] {
let tags_json: Vec<Value> = tags.iter().map(|t| json!(t)).collect();
let _ = crate::tools::dispatch_tool(
&state,
"memory_remember",
json!({
"palace": slug,
"text": text,
"room": "General",
"tags": tags_json,
}),
)
.await
.expect("memory_remember");
}
let payload = json!({
"hook_event_name": "UserPromptSubmit",
"cwd": project_dir.to_string_lossy(),
"prompt": "how does rust integration work?"
})
.to_string();
let start = std::time::Instant::now();
let body = build_injection_body(&payload).await;
let elapsed_ms = start.elapsed().as_millis();
eprintln!("prompt_context_recalls_palace_drawers latency: {elapsed_ms}ms");
assert_ne!(
body, EMPTY_PLACEHOLDER,
"populated palace must return real content, not the placeholder"
);
assert!(
body.to_lowercase().contains("rust") && body.to_lowercase().contains("integration"),
"expected rust integration drawer in injection; got:\n{body}"
);
assert!(
body.contains("Relevant memories") || body.contains("memories from palace"),
"expected a `Relevant memories` section; got:\n{body}"
);
assert!(
elapsed_ms < 5_000,
"prompt-context too slow ({elapsed_ms}ms) — investigate"
);
addr_handle.shutdown().await;
}
#[cfg(feature = "axum-server")]
#[tokio::test]
async fn prompt_context_empty_palace_falls_back_to_global() {
let _guard = crate::commands::env_test_lock().lock().await;
let (_state, _data_dir_tmp, _project_dir_tmp, project_dir, _slug, addr_handle) =
spin_up_test_daemon_with_palace("prompt-ctx-recall-empty").await;
let payload = json!({
"hook_event_name": "UserPromptSubmit",
"cwd": project_dir.to_string_lossy(),
"prompt": "no drawers exist here"
})
.to_string();
let body = build_injection_body(&payload).await;
assert_eq!(
body, EMPTY_PLACEHOLDER,
"empty palace + empty prompt-facts must fall back to the placeholder"
);
addr_handle.shutdown().await;
}
#[cfg(feature = "axum-server")]
async fn spin_up_test_daemon_with_palace(
palace_slug: &str,
) -> (
crate::AppState,
tempfile::TempDir,
tempfile::TempDir,
std::path::PathBuf,
String,
DaemonHandle,
) {
let data_tmp = tempfile::tempdir().expect("data tempdir");
let project_tmp = tempfile::tempdir().expect("project tempdir");
let project_dir = project_tmp.path().join(palace_slug);
std::fs::create_dir_all(&project_dir).expect("project dir");
unsafe {
std::env::set_var(trusty_common::DATA_DIR_OVERRIDE_ENV, data_tmp.path());
std::env::remove_var(crate::prompt_log::ENV_ENABLED);
std::env::remove_var(crate::prompt_log::ENV_DIR);
std::env::remove_var(crate::prompt_log::ENV_HASH_PROMPTS);
}
let data_root = trusty_common::resolve_data_dir("trusty-memory")
.expect("resolve data dir under override");
let state = crate::AppState::new(data_root.clone());
let _ = crate::tools::dispatch_tool(&state, "palace_create", json!({"name": palace_slug}))
.await
.expect("palace_create");
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind 127.0.0.1:0");
let addr = listener.local_addr().expect("local_addr");
let state_for_server = state.clone();
let handle = tokio::spawn(async move {
let _ = crate::run_http_on(state_for_server, listener).await;
});
let addr_file = data_root.join("http_addr");
let mut attempts = 0;
while !addr_file.exists() && attempts < 500 {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
attempts += 1;
}
assert!(
addr_file.exists(),
"daemon never wrote http_addr at {} (attempts={attempts})",
addr_file.display()
);
(
state,
data_tmp,
project_tmp,
project_dir,
palace_slug.to_string(),
DaemonHandle {
addr,
join: Some(handle),
},
)
}
#[cfg(feature = "axum-server")]
struct DaemonHandle {
#[allow(dead_code)]
addr: std::net::SocketAddr,
join: Option<tokio::task::JoinHandle<()>>,
}
#[cfg(feature = "axum-server")]
impl DaemonHandle {
async fn shutdown(mut self) {
if let Some(h) = self.join.take() {
h.abort();
let _ = h.await;
}
unsafe {
std::env::remove_var(trusty_common::DATA_DIR_OVERRIDE_ENV);
}
}
}
#[cfg(feature = "axum-server")]
#[tokio::test]
async fn prompt_context_recall_filters_deny_tags() {
let _guard = crate::commands::env_test_lock().lock().await;
unsafe {
std::env::remove_var(ENV_RECALL_DENY_TAGS);
}
let (state, _data_dir_tmp, _project_dir_tmp, project_dir, slug, addr_handle) =
spin_up_test_daemon_with_palace("prompt-ctx-deny-tags").await;
for (text, tags) in [
(
"user: how do I use rust async tokio runtime and serde derive macros in this project to glue an http handler to a kafka producer",
vec!["claude-session", "user-prompt", "rust"],
),
(
"user: yes please go ahead and refactor the rust async producer module, this captured prompt fragment should never be surfaced",
vec!["user-prompt", "rust"],
),
(
"Rust integration uses tokio for async tasks and serde for JSON",
vec!["rust", "tokio"],
),
] {
let tags_json: Vec<Value> = tags.iter().map(|t| json!(t)).collect();
let _ = crate::tools::dispatch_tool(
&state,
"memory_remember",
json!({
"palace": slug,
"text": text,
"room": "General",
"tags": tags_json,
}),
)
.await
.expect("memory_remember");
}
let payload = json!({
"hook_event_name": "UserPromptSubmit",
"cwd": project_dir.to_string_lossy(),
"prompt": "how does rust integration work?"
})
.to_string();
let body = build_injection_body(&payload).await;
assert!(
body.contains("tokio") && body.contains("serde"),
"signal drawer must survive deny filter; got:\n{body}"
);
assert!(
!body.contains("kafka producer"),
"claude-session-tagged drawer must be filtered out; got:\n{body}"
);
assert!(
!body.contains("captured prompt fragment"),
"user-prompt-tagged drawer must be filtered out; got:\n{body}"
);
addr_handle.shutdown().await;
}
#[cfg(feature = "axum-server")]
#[tokio::test]
async fn prompt_context_recall_env_override_extends_deny_list() {
let _guard = crate::commands::env_test_lock().lock().await;
unsafe {
std::env::set_var(ENV_RECALL_DENY_TAGS, "noise-tag");
}
let (state, _data_dir_tmp, _project_dir_tmp, project_dir, slug, addr_handle) =
spin_up_test_daemon_with_palace("prompt-ctx-env-deny").await;
let _ = crate::tools::dispatch_tool(
&state,
"memory_remember",
json!({
"palace": slug,
"text": "Rust integration uses tokio and serde for the async layer",
"room": "General",
"tags": ["noise-tag", "rust"],
}),
)
.await
.expect("memory_remember");
let payload = json!({
"hook_event_name": "UserPromptSubmit",
"cwd": project_dir.to_string_lossy(),
"prompt": "how does rust integration work?"
})
.to_string();
let body = build_injection_body(&payload).await;
assert!(
!body.contains("tokio and serde"),
"noise-tag drawer must be filtered when env override targets it; got:\n{body}"
);
unsafe {
std::env::remove_var(ENV_RECALL_DENY_TAGS);
}
addr_handle.shutdown().await;
}
#[cfg(feature = "axum-server")]
#[tokio::test]
async fn prompt_context_recall_all_filtered_falls_back_to_global() {
let _guard = crate::commands::env_test_lock().lock().await;
unsafe {
std::env::remove_var(ENV_RECALL_DENY_TAGS);
}
let (state, _data_dir_tmp, _project_dir_tmp, project_dir, slug, addr_handle) =
spin_up_test_daemon_with_palace("prompt-ctx-all-filtered").await;
for (text, tags) in [
(
"user: status update on the rust async rewrite, the kafka consumer should not surface in any prompt-context injection",
vec!["claude-session", "user-prompt", "rust"],
),
(
"user: yes please continue with the rust refactor on the producer side, this prompt fragment must be filtered out of recall",
vec!["claude-session", "rust"],
),
] {
let tags_json: Vec<Value> = tags.iter().map(|t| json!(t)).collect();
let _ = crate::tools::dispatch_tool(
&state,
"memory_remember",
json!({
"palace": slug,
"text": text,
"room": "General",
"tags": tags_json,
}),
)
.await
.expect("memory_remember");
}
let payload = json!({
"hook_event_name": "UserPromptSubmit",
"cwd": project_dir.to_string_lossy(),
"prompt": "tell me about rust"
})
.to_string();
let body = build_injection_body(&payload).await;
assert!(
!body.contains("kafka consumer") && !body.contains("producer side"),
"filtered drawer content must not leak; got:\n{body}"
);
assert!(
!body.contains("Relevant memories"),
"no `Relevant memories` section should render when every drawer is filtered; got:\n{body}"
);
addr_handle.shutdown().await;
}
#[tokio::test]
async fn prompt_context_logs_attempt_without_daemon() {
let _guard = crate::commands::env_test_lock().lock().await;
let tmp = tempfile::tempdir().expect("tempdir");
unsafe {
std::env::set_var(trusty_common::DATA_DIR_OVERRIDE_ENV, tmp.path());
std::env::remove_var(crate::prompt_log::ENV_ENABLED);
std::env::remove_var(crate::prompt_log::ENV_DIR);
std::env::remove_var(crate::prompt_log::ENV_HASH_PROMPTS);
}
let res = handle_prompt_context().await;
let logs_dir = trusty_common::resolve_data_dir("trusty-memory")
.expect("resolve data dir")
.join("logs");
unsafe {
std::env::remove_var(trusty_common::DATA_DIR_OVERRIDE_ENV);
}
assert!(res.is_ok());
let files: Vec<_> = std::fs::read_dir(&logs_dir)
.expect("logs dir should be created")
.flatten()
.map(|e| e.path())
.filter(|p| {
p.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| n.starts_with("enriched-prompts."))
})
.collect();
assert_eq!(
files.len(),
1,
"expected one enriched-prompts log file, got {files:?}"
);
let content = std::fs::read_to_string(&files[0]).expect("read log");
let line = content.lines().next().expect("at least one line");
let parsed: crate::prompt_log::PromptLogEntry =
serde_json::from_str(line).expect("parse JSONL");
assert_eq!(parsed.hook_type, "UserPromptSubmit");
assert_eq!(parsed.injection_kind, "prompt-context-facts");
}
}