use crate::db::Database;
use crate::persistence::Persistence;
use crate::providers::ToolDefinition;
use serde_json::json;
pub fn definition() -> ToolDefinition {
ToolDefinition {
name: "RecallContext".to_string(),
description: "Recall earlier conversation context that may have scrolled \
out of your current window. Use when you need to remember what was \
discussed or decided earlier in the session."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search term to find in conversation history"
},
"turn": {
"type": "integer",
"description": "Specific turn number to recall (1-based)"
}
}
}),
}
}
pub async fn recall_context(db: &Database, session_id: &str, args: &serde_json::Value) -> String {
let query = args["query"].as_str();
let turn = args["turn"].as_u64();
if query.is_none() && turn.is_none() {
return "Provide either 'query' (search term) or 'turn' (number) to recall context."
.to_string();
}
let history = match db.load_all_messages(session_id).await {
Ok(msgs) => msgs,
Err(e) => return format!("Failed to load history: {e}"),
};
if history.is_empty() {
return "No conversation history found.".to_string();
}
if let Some(turn_num) = turn {
let idx = turn_num.saturating_sub(1) as usize;
if idx >= history.len() {
return format!(
"Turn {} does not exist. Session has {} messages.",
turn_num,
history.len()
);
}
let msg = &history[idx];
let content = msg
.full_content
.as_deref()
.or(msg.content.as_deref())
.unwrap_or("(no content)");
let display = if content.len() > 2000 {
format!(
"{}... [truncated, {} chars total]",
&content[..2000],
content.len()
)
} else {
content.to_string()
};
return format!("## Turn {} ({})\n\n{}", turn_num, msg.role, display);
}
if let Some(q) = query {
let q_lower = q.to_lowercase();
let mut matches = Vec::new();
for (i, msg) in history.iter().enumerate() {
let searchable = msg.full_content.as_deref().or(msg.content.as_deref());
if let Some(text) = searchable
&& text.to_lowercase().contains(&q_lower)
{
let snippet = extract_snippet(text, &q_lower, 200);
matches.push(format!("**Turn {} ({}):** {}\n", i + 1, msg.role, snippet));
}
}
if matches.is_empty() {
return format!("No matches for '{q}' in conversation history.");
}
let total = matches.len();
let shown: Vec<_> = matches.into_iter().take(10).collect();
let mut result = format!("## Found {total} matches for '{q}'\n\n");
result.push_str(&shown.join("\n"));
if total > 10 {
result.push_str(&format!("\n... and {} more matches\n", total - 10));
}
return result;
}
"Provide 'query' or 'turn' parameter.".to_string()
}
fn extract_snippet(text: &str, query: &str, max_len: usize) -> String {
let lower = text.to_lowercase();
let pos = match lower.find(query) {
Some(p) => p,
None => return text.chars().take(max_len).collect(),
};
let start = pos.saturating_sub(50);
let end = (pos + query.len() + 150).min(text.len());
let snippet = &text[start..end];
if start > 0 || end < text.len() {
format!("...{snippet}...")
} else {
snippet.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::persistence::{Persistence, Role};
use serde_json::json;
#[test]
fn test_definition() {
let def = definition();
assert_eq!(def.name, "RecallContext");
}
#[test]
fn test_extract_snippet_found() {
let text = "The quick brown fox jumps over the lazy dog";
let snippet = extract_snippet(text, "fox", 100);
assert!(snippet.contains("fox"));
}
#[test]
fn test_extract_snippet_not_found() {
let text = "hello world";
let snippet = extract_snippet(text, "xyz", 100);
assert_eq!(snippet, "hello world");
}
#[test]
fn test_extract_snippet_at_start_no_leading_ellipsis() {
let text = "match at the start and some more text here";
let snippet = extract_snippet(text, "match", 100);
assert!(
!snippet.starts_with("..."),
"no leading ellipsis when at start"
);
assert!(snippet.contains("match"));
}
#[test]
fn test_extract_snippet_mid_text_has_ellipsis() {
let text = format!("{}needle{}", "a".repeat(100), "b".repeat(100));
let snippet = extract_snippet(&text, "needle", 200);
assert!(
snippet.starts_with("..."),
"should have leading ellipsis: {snippet}"
);
assert!(
snippet.ends_with("..."),
"should have trailing ellipsis: {snippet}"
);
}
#[test]
fn test_extract_snippet_not_found_truncated_at_max_len() {
let text = "a".repeat(500);
let snippet = extract_snippet(&text, "nothere", 50);
assert_eq!(snippet.chars().count(), 50);
}
#[test]
fn test_extract_snippet_empty_text() {
let snippet = extract_snippet("", "query", 100);
assert_eq!(snippet, "");
}
#[test]
fn test_extract_snippet_query_is_case_lowered() {
let text = "Error: file not found at line 42";
let lower_q = "error";
let snippet = extract_snippet(text, lower_q, 200);
assert!(snippet.contains("Error"));
}
async fn test_db() -> (Database, tempfile::TempDir, String) {
let dir = tempfile::TempDir::new().unwrap();
let db = Database::open(&dir.path().join("recall_test.db"))
.await
.unwrap();
let sid = db.create_session("koda", dir.path()).await.unwrap();
(db, dir, sid)
}
#[tokio::test]
async fn test_recall_no_query_or_turn() {
let (db, _dir, sid) = test_db().await;
let result = recall_context(&db, &sid, &json!({})).await;
assert!(
result.contains("Provide"),
"should ask for query or turn: {result}"
);
}
#[tokio::test]
async fn test_recall_empty_history() {
let (db, _dir, sid) = test_db().await;
let result = recall_context(&db, &sid, &json!({"turn": 1})).await;
assert!(result.contains("No conversation history"), "got: {result}");
}
#[tokio::test]
async fn test_recall_by_turn_hit() {
let (db, _dir, sid) = test_db().await;
db.insert_message(&sid, &Role::User, Some("hello world"), None, None, None)
.await
.unwrap();
let result = recall_context(&db, &sid, &json!({"turn": 1})).await;
assert!(result.contains("hello world"), "got: {result}");
assert!(result.contains("Turn 1"), "got: {result}");
}
#[tokio::test]
async fn test_recall_by_turn_out_of_bounds() {
let (db, _dir, sid) = test_db().await;
db.insert_message(&sid, &Role::User, Some("msg1"), None, None, None)
.await
.unwrap();
let result = recall_context(&db, &sid, &json!({"turn": 99})).await;
assert!(result.contains("does not exist"), "got: {result}");
}
#[tokio::test]
async fn test_recall_by_query_match() {
let (db, _dir, sid) = test_db().await;
db.insert_message(
&sid,
&Role::Assistant,
Some("The error was a null pointer exception"),
None,
None,
None,
)
.await
.unwrap();
let result = recall_context(&db, &sid, &json!({"query": "null pointer"})).await;
assert!(result.contains("null pointer"), "got: {result}");
assert!(result.contains("Found"), "got: {result}");
}
#[tokio::test]
async fn test_recall_by_query_no_match() {
let (db, _dir, sid) = test_db().await;
db.insert_message(&sid, &Role::User, Some("hello world"), None, None, None)
.await
.unwrap();
let result = recall_context(&db, &sid, &json!({"query": "xyzzy"})).await;
assert!(result.contains("No matches"), "got: {result}");
}
#[tokio::test]
async fn test_recall_by_turn_long_content_truncated() {
let (db, _dir, sid) = test_db().await;
let long_msg = "z".repeat(3000);
db.insert_message(&sid, &Role::User, Some(&long_msg), None, None, None)
.await
.unwrap();
let result = recall_context(&db, &sid, &json!({"turn": 1})).await;
assert!(
result.contains("[truncated"),
"long message should be truncated: {result}"
);
}
}