use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use thiserror::Error;
use rig_compose::{KernelError, Tool, ToolSchema};
#[derive(Debug, Error)]
pub enum MemoryLookupError {
#[error("memory lookup backend error: {0}")]
Backend(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryLookupHit {
pub score: f32,
pub summary: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub key: Option<String>,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub metadata: Value,
}
impl MemoryLookupHit {
pub fn new(score: f32, summary: impl Into<String>) -> Self {
Self {
score,
summary: summary.into(),
key: None,
metadata: Value::Null,
}
}
pub fn with_key(mut self, key: impl Into<String>) -> Self {
self.key = Some(key.into());
self
}
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = metadata;
self
}
}
#[async_trait]
pub trait MemoryLookupStore: Send + Sync {
async fn lookup(
&self,
query: &str,
k: usize,
) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
}
pub struct MemoryLookupTool {
store: Arc<dyn MemoryLookupStore>,
}
impl MemoryLookupTool {
pub const NAME: &'static str = "memory.lookup";
pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
Self { store }
}
pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
Arc::new(Self::new(store))
}
}
#[derive(Deserialize)]
struct LookupArgs {
query: String,
#[serde(default = "default_k")]
k: usize,
}
fn default_k() -> usize {
3
}
#[async_trait]
impl Tool for MemoryLookupTool {
fn schema(&self) -> ToolSchema {
ToolSchema {
name: Self::NAME.into(),
description: "Retrieve up to k similar memory episodes for a query.".into(),
args_schema: json!({
"type": "object",
"required": ["query"],
"properties": {
"query": {"type": "string"},
"k": {"type": "integer", "minimum": 1, "default": 3}
}
}),
result_schema: json!({
"type": "object",
"properties": {
"hits": {
"type": "array",
"items": {
"type": "object",
"properties": {
"score": {"type": "number"},
"summary": {"type": "string"},
"key": {"type": "string"},
"metadata": {"type": "object"}
}
}
}
}
}),
}
}
fn name(&self) -> rig_compose::tool::ToolName {
Self::NAME.to_string()
}
async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
let parsed: LookupArgs = serde_json::from_value(args)?;
if parsed.k == 0 {
return Err(KernelError::InvalidArgument(
"memory.lookup requires k >= 1".into(),
));
}
let hits = self
.store
.lookup(&parsed.query, parsed.k)
.await
.map_err(|err| KernelError::ToolFailed(err.to_string()))?;
Ok(json!({ "hits": hits }))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct StubMemory;
#[async_trait]
impl MemoryLookupStore for StubMemory {
async fn lookup(
&self,
query: &str,
k: usize,
) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
Ok(vec![
MemoryLookupHit::new(0.9, format!("matched {query}"))
.with_key("ep-1")
.with_metadata(json!({"rank": 1})),
]
.into_iter()
.take(k)
.collect())
}
}
#[tokio::test]
async fn lookup_tool_returns_hits() {
let tool = MemoryLookupTool::new(Arc::new(StubMemory));
let out = tool
.invoke(json!({"query": "beacon", "k": 1}))
.await
.unwrap();
let score = out["hits"][0]["score"].as_f64().unwrap();
assert!((score - 0.9).abs() < 1e-6);
assert_eq!(out["hits"][0]["key"], "ep-1");
}
#[tokio::test]
async fn lookup_tool_rejects_zero_k() {
let tool = MemoryLookupTool::new(Arc::new(StubMemory));
let err = tool
.invoke(json!({"query": "beacon", "k": 0}))
.await
.unwrap_err();
assert!(matches!(err, KernelError::InvalidArgument(_)));
}
}