Skip to main content

harness_loop/
recall_layer.rs

1//! Cross-session recall wiring for [`crate::AgentLoop`].
2//!
3//! - [`SessionSearchTool`] — LLM-callable search over the recall store, three
4//!   shapes (discovery / scroll / browse). Owner is read from
5//!   `World.profile.extra["recall_owner"]` so it can only see the caller's own
6//!   sessions.
7//! - [`RecallGuide`] — optional. At session start, searches the store with the
8//!   task description and injects the top snippets. Off unless `.auto_inject()`.
9
10use async_trait::async_trait;
11use harness_core::{
12    Block, Context, Execution, Guide, GuideError, GuideId, GuideScope, RecallStore, Tool,
13    ToolError, ToolResult, ToolRisk, ToolSchema, World,
14};
15use serde_json::{json, Value};
16use std::sync::{Arc, OnceLock};
17
18/// Read the recall owner from the world profile (fallback "default").
19pub fn recall_owner(world: &World) -> String {
20    world
21        .profile
22        .extra
23        .get("recall_owner")
24        .and_then(|v| v.as_str())
25        .unwrap_or("default")
26        .to_string()
27}
28
29// ───── session_search tool ────────────────────────────────────────────────
30
31pub struct SessionSearchTool {
32    store: Arc<dyn RecallStore>,
33    schema: ToolSchema,
34}
35
36impl SessionSearchTool {
37    pub fn new(store: Arc<dyn RecallStore>) -> Self {
38        Self {
39            store,
40            schema: ToolSchema {
41                name: "session_search".into(),
42                description: "Search your own past sessions, or scroll inside one. \
43                    Three shapes: (1) pass `query` to find relevant past sessions \
44                    (returns snippet + surrounding messages); (2) pass `session_id` + \
45                    `around` to scroll messages near a point in a session; (3) pass \
46                    nothing to list your most recent sessions."
47                    .into(),
48                input: json!({
49                    "type": "object",
50                    "properties": {
51                        "query": {"type": "string", "description": "Search text. Shape 1 (discovery)."},
52                        "session_id": {"type": "string", "description": "Scroll within this session. Shape 2."},
53                        "around": {"type": "integer", "description": "Anchor message id for scroll. Shape 2."},
54                        "window": {"type": "integer", "default": 5, "description": "± messages around the anchor."},
55                        "limit": {"type": "integer", "default": 3, "minimum": 1, "maximum": 20}
56                    }
57                }),
58            },
59        }
60    }
61}
62
63#[async_trait]
64impl Tool for SessionSearchTool {
65    fn name(&self) -> &str {
66        &self.schema.name
67    }
68    fn schema(&self) -> &ToolSchema {
69        &self.schema
70    }
71    fn risk(&self) -> ToolRisk {
72        ToolRisk::ReadOnly
73    }
74    async fn invoke(&self, args: Value, world: &mut World) -> Result<ToolResult, ToolError> {
75        let owner = recall_owner(world);
76        let limit = args
77            .get("limit")
78            .and_then(|v| v.as_u64())
79            .unwrap_or(3)
80            .min(20) as usize;
81
82        let result = if let Some(q) = args
83            .get("query")
84            .and_then(|v| v.as_str())
85            .filter(|s| !s.is_empty())
86        {
87            match self.store.search(&owner, q, limit).await {
88                Ok(hits) => json!({"mode": "discover", "query": q, "count": hits.len(), "results": hits}),
89                Err(e) => return Ok(err_result(e)),
90            }
91        } else if let Some(sid) = args.get("session_id").and_then(|v| v.as_str()) {
92            let around = args
93                .get("around")
94                .and_then(|v| v.as_i64())
95                .unwrap_or(0);
96            let window = args
97                .get("window")
98                .and_then(|v| v.as_u64())
99                .unwrap_or(5) as usize;
100            match self.store.scroll(&owner, sid, around, window).await {
101                Ok(msgs) => json!({"mode": "scroll", "session_id": sid, "messages": msgs}),
102                Err(e) => return Ok(err_result(e)),
103            }
104        } else {
105            match self.store.recent(&owner, limit).await {
106                Ok(sessions) => json!({"mode": "browse", "sessions": sessions}),
107                Err(e) => return Ok(err_result(e)),
108            }
109        };
110        Ok(ToolResult {
111            ok: true,
112            content: result,
113            trace: None,
114        })
115    }
116}
117
118fn err_result(e: harness_core::RecallError) -> ToolResult {
119    ToolResult {
120        ok: false,
121        content: json!({"error": e.to_string()}),
122        trace: None,
123    }
124}
125
126// ───── RecallGuide (opt-in auto-inject) ───────────────────────────────────
127
128const RECALL_MARKER: &str = "[recall]\n";
129
130pub struct RecallGuide {
131    store: Arc<dyn RecallStore>,
132    top_k: usize,
133}
134
135static RECALL_GUIDE_ID: OnceLock<GuideId> = OnceLock::new();
136static RECALL_GUIDE_SCOPE: OnceLock<GuideScope> = OnceLock::new();
137
138impl RecallGuide {
139    pub fn new(store: Arc<dyn RecallStore>) -> Self {
140        Self { store, top_k: 3 }
141    }
142    pub fn with_top_k(mut self, k: usize) -> Self {
143        self.top_k = k;
144        self
145    }
146}
147
148#[async_trait]
149impl Guide for RecallGuide {
150    fn id(&self) -> &GuideId {
151        RECALL_GUIDE_ID.get_or_init(|| "recall".to_string())
152    }
153    fn kind(&self) -> Execution {
154        Execution::Inferential
155    }
156    fn scope(&self) -> &GuideScope {
157        RECALL_GUIDE_SCOPE.get_or_init(|| GuideScope::Always)
158    }
159    async fn apply(&self, ctx: &mut Context, world: &World) -> Result<(), GuideError> {
160        let owner = recall_owner(world);
161        let query = ctx.task.description.clone();
162        let hits = self
163            .store
164            .search(&owner, &query, self.top_k)
165            .await
166            .unwrap_or_default();
167        if hits.is_empty() {
168            return Ok(());
169        }
170        let mut text = String::from(RECALL_MARKER);
171        text.push_str("Possibly-relevant context from your past sessions:\n");
172        for h in &hits {
173            text.push_str(&format!("- ({}) {}\n", h.session.session_id, h.snippet));
174        }
175        ctx.guides.push(Block::Text(text));
176        Ok(())
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use harness_context::{default_world, FileRecall};
184    use harness_core::{RecallMessage, SessionMeta};
185
186    fn tmp_root() -> std::path::PathBuf {
187        use std::sync::atomic::{AtomicU64, Ordering};
188        static N: AtomicU64 = AtomicU64::new(0);
189        let n = N.fetch_add(1, Ordering::SeqCst);
190        let nanos = std::time::SystemTime::now()
191            .duration_since(std::time::UNIX_EPOCH)
192            .unwrap()
193            .as_nanos();
194        std::env::temp_dir().join(format!(
195            "harness-recall-tool-{}-{nanos}-{n}",
196            std::process::id()
197        ))
198    }
199
200    #[tokio::test]
201    async fn tool_discovery_scoped_to_owner() {
202        let root = tmp_root();
203        let store: Arc<dyn RecallStore> = Arc::new(FileRecall::open(&root).unwrap());
204        store
205            .ensure_session("alice", "s1", &SessionMeta::new("s1", 1))
206            .await
207            .unwrap();
208        store
209            .append(
210                "alice",
211                "s1",
212                &RecallMessage::new("user", "deploy the payment service", 1),
213            )
214            .await
215            .unwrap();
216
217        let tool = SessionSearchTool::new(store.clone());
218        let mut world = default_world(".");
219        world
220            .profile
221            .extra
222            .insert("recall_owner".into(), serde_json::json!("alice"));
223        let out = tool
224            .invoke(serde_json::json!({"query": "payment deploy"}), &mut world)
225            .await
226            .unwrap();
227        assert!(out.ok);
228        assert_eq!(out.content["count"], 1);
229
230        let mut bob = default_world(".");
231        bob.profile
232            .extra
233            .insert("recall_owner".into(), serde_json::json!("bob"));
234        let out2 = tool
235            .invoke(serde_json::json!({"query": "payment deploy"}), &mut bob)
236            .await
237            .unwrap();
238        assert_eq!(out2.content["count"], 0);
239
240        let _ = std::fs::remove_dir_all(&root);
241    }
242}