use std::future::{Future, IntoFuture};
use std::pin::Pin;
use chrono::{DateTime, FixedOffset};
use crate::embedding::EmbeddingModel;
use crate::memory::{KindSelector, Memories, Scope};
use crate::store::MemoryStore;
use crate::vector::{FilterCondition, MemoryFilter, NumericRange, VectorIndex};
use super::{Client, ClientError};
pub const DEFAULT_LIMIT: usize = 10;
#[must_use = "search(..) returns a builder that must be awaited"]
pub struct SearchBuilder<'a> {
client: &'a Client,
query: String,
scope: Scope,
limit: usize,
episodic: bool,
semantic: bool,
metadata_filter: Option<MemoryFilter>,
min_similarity: Option<f32>,
created_at_range: NumericRange,
event_at_range: NumericRange,
confidence_range: NumericRange,
category: Option<String>,
#[cfg(feature = "knowledge-graph")]
graph_depth: Option<usize>,
}
impl<'a> SearchBuilder<'a> {
pub(super) fn new(client: &'a Client, query: String, scope: Scope) -> Self {
Self {
client,
query,
scope,
limit: DEFAULT_LIMIT,
episodic: false,
semantic: false,
metadata_filter: None,
min_similarity: None,
created_at_range: NumericRange::default(),
event_at_range: NumericRange::default(),
confidence_range: NumericRange::default(),
category: None,
#[cfg(feature = "knowledge-graph")]
graph_depth: None,
}
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn episodic(mut self) -> Self {
self.episodic = true;
self
}
pub fn semantic(mut self) -> Self {
self.semantic = true;
self
}
pub fn metadata_filter(mut self, filter: MemoryFilter) -> Self {
self.metadata_filter = Some(filter);
self
}
pub fn min_similarity(mut self, threshold: f32) -> Self {
self.min_similarity = Some(threshold);
self
}
pub fn created_after(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.created_at_range.gte = Some(at.into().timestamp_millis() as f64);
self
}
pub fn created_before(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.created_at_range.lt = Some(at.into().timestamp_millis() as f64);
self
}
pub fn event_at_after(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.event_at_range.gte = Some(at.into().timestamp_millis() as f64);
self
}
pub fn event_at_before(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.event_at_range.lt = Some(at.into().timestamp_millis() as f64);
self
}
pub fn min_confidence(mut self, min: i8) -> Self {
self.confidence_range.gte = Some(f64::from(crate::memory::Confidence::new(min).get()));
self
}
pub fn category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
#[cfg(feature = "knowledge-graph")]
pub fn with_graph(mut self) -> Self {
self.graph_depth = Some(crate::graph::DEFAULT_ENRICHMENT_DEPTH);
self
}
#[cfg(feature = "knowledge-graph")]
pub fn with_graph_depth(mut self, depth: usize) -> Self {
self.graph_depth = Some(depth.clamp(1, crate::graph::MAX_ENRICHMENT_DEPTH));
self
}
fn kind_selector(&self) -> KindSelector {
kind_selector(self.episodic, self.semantic)
}
}
fn kind_selector(episodic: bool, semantic: bool) -> KindSelector {
match (episodic, semantic) {
(false, false) => KindSelector::default(),
(episodic, semantic) => KindSelector { episodic, semantic },
}
}
fn combine_filter(
metadata_filter: Option<MemoryFilter>,
created_at: NumericRange,
event_at: NumericRange,
confidence: NumericRange,
category: Option<String>,
) -> Option<MemoryFilter> {
if metadata_filter.is_none()
&& created_at.is_unbounded()
&& event_at.is_unbounded()
&& confidence.is_unbounded()
&& category.is_none()
{
return None;
}
let mut combined = metadata_filter.unwrap_or_default();
if !created_at.is_unbounded() {
combined.must.push(FilterCondition::Range {
field: "created_at".to_string(),
range: created_at,
});
}
if !event_at.is_unbounded() {
combined.must.push(FilterCondition::Range {
field: "event_at".to_string(),
range: event_at,
});
}
if !confidence.is_unbounded() {
combined.must.push(FilterCondition::Range {
field: "confidence".to_string(),
range: confidence,
});
}
if let Some(category) = category {
combined.must.push(FilterCondition::Equals {
field: "category".to_string(),
value: crate::vector::MatchValue::Keyword(category),
});
}
Some(combined)
}
impl<'a> IntoFuture for SearchBuilder<'a> {
type Output = Result<Memories, ClientError>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(execute(self))
}
}
async fn execute(builder: SearchBuilder<'_>) -> Result<Memories, ClientError> {
let kinds = builder.kind_selector();
#[cfg(feature = "knowledge-graph")]
let graph_depth = builder.graph_depth;
let SearchBuilder {
client,
query,
scope,
limit,
metadata_filter,
min_similarity,
created_at_range,
event_at_range,
confidence_range,
category,
..
} = builder;
#[cfg(feature = "knowledge-graph")]
let graph_scope = scope.clone();
let combined_filter = combine_filter(
metadata_filter,
created_at_range,
event_at_range,
confidence_range,
category,
);
let inner = client.inner.clone();
let query_vector = inner.embedder.embed(&query).await?;
let hits = inner
.index
.search(scope, query_vector, limit, kinds, combined_filter, min_similarity)
.await?;
let pids: Vec<&str> = hits.iter().map(|(pid, _)| pid.as_str()).collect();
let mut rows = inner.store.find_by_pids(&pids).await?;
let order: std::collections::HashMap<&str, (usize, f32)> = hits
.iter()
.enumerate()
.map(|(idx, (pid, score))| (pid.as_str(), (idx, *score)))
.collect();
rows.sort_by_key(|m| order.get(m.pid.as_str()).map(|(idx, _)| *idx).unwrap_or(usize::MAX));
for memory in &mut rows {
memory.score = order.get(memory.pid.as_str()).map(|(_, score)| *score);
}
let memories = Memories::new(rows, inner.system_prompt.clone());
#[cfg(feature = "knowledge-graph")]
if let Some(depth) = graph_depth {
if let Some(graph) = inner.graph.as_deref() {
use crate::graph::GraphStore;
let seed_pids: Vec<&str> = memories.list().iter().map(|m| m.pid.as_str()).collect();
let context = graph.neighbors(&seed_pids, &graph_scope, depth).await?;
return Ok(memories.with_graph_context(context));
}
}
Ok(memories)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_select_all_kinds_when_no_kind_toggled() {
let selector = kind_selector(false, false);
assert!(selector.episodic);
assert!(selector.semantic);
}
#[test]
fn should_select_all_kinds_when_both_kinds_toggled() {
let selector = kind_selector(true, true);
assert!(selector.episodic);
assert!(selector.semantic);
}
#[test]
fn should_select_only_episodic_when_only_episodic_toggled() {
let selector = kind_selector(true, false);
assert!(selector.episodic);
assert!(!selector.semantic);
}
#[test]
fn should_select_only_semantic_when_only_semantic_toggled() {
let selector = kind_selector(false, true);
assert!(!selector.episodic);
assert!(selector.semantic);
}
}