1use async_trait::async_trait;
11use harness_core::{
12 Block, Context, Execution, Guide, GuideError, GuideId, GuideScope, RecallStore, Tool,
13 ToolError, ToolResult, ToolRisk, ToolSchema, World,
14};
15use serde_json::{json, Value};
16use std::sync::{Arc, OnceLock};
17
18pub fn recall_owner(world: &World) -> String {
20 world
21 .profile
22 .extra
23 .get("recall_owner")
24 .and_then(|v| v.as_str())
25 .unwrap_or("default")
26 .to_string()
27}
28
29pub struct SessionSearchTool {
32 store: Arc<dyn RecallStore>,
33 schema: ToolSchema,
34}
35
36impl SessionSearchTool {
37 pub fn new(store: Arc<dyn RecallStore>) -> Self {
38 Self {
39 store,
40 schema: ToolSchema {
41 name: "session_search".into(),
42 description: "Search your own past sessions, or scroll inside one. \
43 Three shapes: (1) pass `query` to find relevant past sessions \
44 (returns snippet + surrounding messages); (2) pass `session_id` + \
45 `around` to scroll messages near a point in a session; (3) pass \
46 nothing to list your most recent sessions."
47 .into(),
48 input: json!({
49 "type": "object",
50 "properties": {
51 "query": {"type": "string", "description": "Search text. Shape 1 (discovery)."},
52 "session_id": {"type": "string", "description": "Scroll within this session. Shape 2."},
53 "around": {"type": "integer", "description": "Anchor message id for scroll. Shape 2."},
54 "window": {"type": "integer", "default": 5, "description": "± messages around the anchor."},
55 "limit": {"type": "integer", "default": 3, "minimum": 1, "maximum": 20}
56 }
57 }),
58 },
59 }
60 }
61}
62
63#[async_trait]
64impl Tool for SessionSearchTool {
65 fn name(&self) -> &str {
66 &self.schema.name
67 }
68 fn schema(&self) -> &ToolSchema {
69 &self.schema
70 }
71 fn risk(&self) -> ToolRisk {
72 ToolRisk::ReadOnly
73 }
74 async fn invoke(&self, args: Value, world: &mut World) -> Result<ToolResult, ToolError> {
75 let owner = recall_owner(world);
76 let limit = args
77 .get("limit")
78 .and_then(|v| v.as_u64())
79 .unwrap_or(3)
80 .min(20) as usize;
81
82 let result = if let Some(q) = args
83 .get("query")
84 .and_then(|v| v.as_str())
85 .filter(|s| !s.is_empty())
86 {
87 match self.store.search(&owner, q, limit).await {
88 Ok(hits) => json!({"mode": "discover", "query": q, "count": hits.len(), "results": hits}),
89 Err(e) => return Ok(err_result(e)),
90 }
91 } else if let Some(sid) = args.get("session_id").and_then(|v| v.as_str()) {
92 let around = args
93 .get("around")
94 .and_then(|v| v.as_i64())
95 .unwrap_or(0);
96 let window = args
97 .get("window")
98 .and_then(|v| v.as_u64())
99 .unwrap_or(5) as usize;
100 match self.store.scroll(&owner, sid, around, window).await {
101 Ok(msgs) => json!({"mode": "scroll", "session_id": sid, "messages": msgs}),
102 Err(e) => return Ok(err_result(e)),
103 }
104 } else {
105 match self.store.recent(&owner, limit).await {
106 Ok(sessions) => json!({"mode": "browse", "sessions": sessions}),
107 Err(e) => return Ok(err_result(e)),
108 }
109 };
110 Ok(ToolResult {
111 ok: true,
112 content: result,
113 trace: None,
114 })
115 }
116}
117
118fn err_result(e: harness_core::RecallError) -> ToolResult {
119 ToolResult {
120 ok: false,
121 content: json!({"error": e.to_string()}),
122 trace: None,
123 }
124}
125
126const RECALL_MARKER: &str = "[recall]\n";
129
130pub struct RecallGuide {
131 store: Arc<dyn RecallStore>,
132 top_k: usize,
133}
134
135static RECALL_GUIDE_ID: OnceLock<GuideId> = OnceLock::new();
136static RECALL_GUIDE_SCOPE: OnceLock<GuideScope> = OnceLock::new();
137
138impl RecallGuide {
139 pub fn new(store: Arc<dyn RecallStore>) -> Self {
140 Self { store, top_k: 3 }
141 }
142 pub fn with_top_k(mut self, k: usize) -> Self {
143 self.top_k = k;
144 self
145 }
146}
147
148#[async_trait]
149impl Guide for RecallGuide {
150 fn id(&self) -> &GuideId {
151 RECALL_GUIDE_ID.get_or_init(|| "recall".to_string())
152 }
153 fn kind(&self) -> Execution {
154 Execution::Inferential
155 }
156 fn scope(&self) -> &GuideScope {
157 RECALL_GUIDE_SCOPE.get_or_init(|| GuideScope::Always)
158 }
159 async fn apply(&self, ctx: &mut Context, world: &World) -> Result<(), GuideError> {
160 let owner = recall_owner(world);
161 let query = ctx.task.description.clone();
162 let hits = self
163 .store
164 .search(&owner, &query, self.top_k)
165 .await
166 .unwrap_or_default();
167 if hits.is_empty() {
168 return Ok(());
169 }
170 let mut text = String::from(RECALL_MARKER);
171 text.push_str("Possibly-relevant context from your past sessions:\n");
172 for h in &hits {
173 text.push_str(&format!("- ({}) {}\n", h.session.session_id, h.snippet));
174 }
175 ctx.guides.push(Block::Text(text));
176 Ok(())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use harness_context::{default_world, FileRecall};
184 use harness_core::{RecallMessage, SessionMeta};
185
186 fn tmp_root() -> std::path::PathBuf {
187 use std::sync::atomic::{AtomicU64, Ordering};
188 static N: AtomicU64 = AtomicU64::new(0);
189 let n = N.fetch_add(1, Ordering::SeqCst);
190 let nanos = std::time::SystemTime::now()
191 .duration_since(std::time::UNIX_EPOCH)
192 .unwrap()
193 .as_nanos();
194 std::env::temp_dir().join(format!(
195 "harness-recall-tool-{}-{nanos}-{n}",
196 std::process::id()
197 ))
198 }
199
200 #[tokio::test]
201 async fn tool_discovery_scoped_to_owner() {
202 let root = tmp_root();
203 let store: Arc<dyn RecallStore> = Arc::new(FileRecall::open(&root).unwrap());
204 store
205 .ensure_session("alice", "s1", &SessionMeta::new("s1", 1))
206 .await
207 .unwrap();
208 store
209 .append(
210 "alice",
211 "s1",
212 &RecallMessage::new("user", "deploy the payment service", 1),
213 )
214 .await
215 .unwrap();
216
217 let tool = SessionSearchTool::new(store.clone());
218 let mut world = default_world(".");
219 world
220 .profile
221 .extra
222 .insert("recall_owner".into(), serde_json::json!("alice"));
223 let out = tool
224 .invoke(serde_json::json!({"query": "payment deploy"}), &mut world)
225 .await
226 .unwrap();
227 assert!(out.ok);
228 assert_eq!(out.content["count"], 1);
229
230 let mut bob = default_world(".");
231 bob.profile
232 .extra
233 .insert("recall_owner".into(), serde_json::json!("bob"));
234 let out2 = tool
235 .invoke(serde_json::json!({"query": "payment deploy"}), &mut bob)
236 .await
237 .unwrap();
238 assert_eq!(out2.content["count"], 0);
239
240 let _ = std::fs::remove_dir_all(&root);
241 }
242}