use std::collections::HashSet;
use std::sync::Arc;
pub(super) use super::super::needle::AsciiInsensitiveNeedle as ContentNeedle;
use super::state::MemoriesState;
use super::types::{Memory, MemoryId};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OrderBy {
IdAsc,
IdDesc,
CreatedAsc,
CreatedDesc,
UpdatedAsc,
UpdatedDesc,
}
#[derive(Debug, Clone, Default)]
pub(super) struct MemoriesFilterSpec {
pub id_in: Option<HashSet<MemoryId>>,
pub source: Option<String>,
pub content_contains: Option<ContentNeedle>,
pub require_tag: Option<String>,
pub require_any_tag: Option<Vec<String>>,
pub require_all_tags: Option<Vec<String>>,
pub only_pinned: Option<bool>,
pub created_after_ns: Option<u64>,
pub created_before_ns: Option<u64>,
pub updated_after_ns: Option<u64>,
pub updated_before_ns: Option<u64>,
pub order_by: Option<OrderBy>,
pub limit: Option<usize>,
}
impl MemoriesFilterSpec {
pub(super) fn matches(&self, m: &Memory) -> bool {
if let Some(ids) = &self.id_in {
if !ids.contains(&m.id) {
return false;
}
}
if let Some(src) = &self.source {
if &m.source != src {
return false;
}
}
if let Some(needle) = &self.content_contains {
if !needle.matches(&m.content) {
return false;
}
}
if let Some(tag) = &self.require_tag {
if !m.tags.iter().any(|t| t == tag) {
return false;
}
}
if let Some(tags) = &self.require_any_tag {
if !tags.is_empty() && !tags.iter().any(|want| m.tags.iter().any(|t| t == want)) {
return false;
}
}
if let Some(tags) = &self.require_all_tags {
if !tags.is_empty() && !tags.iter().all(|want| m.tags.iter().any(|t| t == want)) {
return false;
}
}
if let Some(want_pinned) = self.only_pinned {
if m.pinned != want_pinned {
return false;
}
}
if let Some(ns) = self.created_after_ns {
if m.created_ns < ns {
return false;
}
}
if let Some(ns) = self.created_before_ns {
if m.created_ns > ns {
return false;
}
}
if let Some(ns) = self.updated_after_ns {
if m.updated_ns < ns {
return false;
}
}
if let Some(ns) = self.updated_before_ns {
if m.updated_ns > ns {
return false;
}
}
true
}
pub(super) fn execute(&self, state: &MemoriesState) -> Vec<Arc<Memory>> {
let mut out: Vec<Arc<Memory>> = state
.memories
.values()
.filter(|a| self.matches(a.as_ref()))
.cloned()
.collect();
if let Some(order) = self.order_by {
sort_memories(&mut out, order);
}
if let Some(limit) = self.limit {
out.truncate(limit);
}
out
}
}
pub struct MemoriesQuery<'a> {
state: &'a MemoriesState,
spec: MemoriesFilterSpec,
}
impl MemoriesState {
pub fn query(&self) -> MemoriesQuery<'_> {
MemoriesQuery {
state: self,
spec: MemoriesFilterSpec::default(),
}
}
}
impl<'a> MemoriesQuery<'a> {
pub fn where_id_in(mut self, ids: impl IntoIterator<Item = MemoryId>) -> Self {
self.spec.id_in = Some(ids.into_iter().collect());
self
}
pub fn where_source(mut self, source: impl Into<String>) -> Self {
self.spec.source = Some(source.into());
self
}
pub fn content_contains(mut self, needle: impl Into<String>) -> Self {
self.spec.content_contains = Some(ContentNeedle::new(needle));
self
}
pub fn where_tag(mut self, tag: impl Into<String>) -> Self {
self.spec.require_tag = Some(tag.into());
self
}
pub fn where_any_tag(mut self, tags: impl IntoIterator<Item = String>) -> Self {
self.spec.require_any_tag = Some(tags.into_iter().collect());
self
}
pub fn where_all_tags(mut self, tags: impl IntoIterator<Item = String>) -> Self {
self.spec.require_all_tags = Some(tags.into_iter().collect());
self
}
pub fn where_pinned(mut self, pinned: bool) -> Self {
self.spec.only_pinned = Some(pinned);
self
}
pub fn created_after(mut self, ns: u64) -> Self {
self.spec.created_after_ns = Some(ns);
self
}
pub fn created_before(mut self, ns: u64) -> Self {
self.spec.created_before_ns = Some(ns);
self
}
pub fn updated_after(mut self, ns: u64) -> Self {
self.spec.updated_after_ns = Some(ns);
self
}
pub fn updated_before(mut self, ns: u64) -> Self {
self.spec.updated_before_ns = Some(ns);
self
}
pub fn order_by(mut self, order: OrderBy) -> Self {
self.spec.order_by = Some(order);
self
}
pub fn limit(mut self, n: usize) -> Self {
self.spec.limit = Some(n);
self
}
pub fn collect(self) -> Vec<Arc<Memory>> {
self.spec.execute(self.state)
}
pub fn count(self) -> usize {
self.state
.memories
.values()
.filter(|a| self.spec.matches(a.as_ref()))
.count()
}
pub fn first(mut self) -> Option<Arc<Memory>> {
self.spec.limit = Some(1);
self.collect().into_iter().next()
}
pub fn exists(self) -> bool {
self.state
.memories
.values()
.any(|a| self.spec.matches(a.as_ref()))
}
}
fn sort_memories(memories: &mut [Arc<Memory>], order: OrderBy) {
match order {
OrderBy::IdAsc => memories.sort_by_key(|m| m.id),
OrderBy::IdDesc => memories.sort_by_key(|m| std::cmp::Reverse(m.id)),
OrderBy::CreatedAsc => memories.sort_by_key(|m| m.created_ns),
OrderBy::CreatedDesc => memories.sort_by_key(|m| std::cmp::Reverse(m.created_ns)),
OrderBy::UpdatedAsc => memories.sort_by_key(|m| m.updated_ns),
OrderBy::UpdatedDesc => memories.sort_by_key(|m| std::cmp::Reverse(m.updated_ns)),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mk(id: MemoryId, content: &str, tags: &[&str], pinned: bool, created: u64) -> Memory {
Memory {
id,
content: content.to_string(),
tags: tags.iter().map(|s| s.to_string()).collect(),
source: "test".into(),
created_ns: created,
updated_ns: created,
pinned,
}
}
fn sample() -> MemoriesState {
let mut s = MemoriesState::new();
for m in [
mk(1, "Meeting notes", &["work", "notes"], true, 100),
mk(2, "Grocery list", &["personal", "todo"], false, 200),
mk(3, "Reading list", &["personal", "reading"], true, 300),
mk(4, "Sprint plan", &["work", "planning"], false, 400),
mk(5, "Birthday ideas", &["personal"], false, 500),
] {
s.memories.insert(m.id, Arc::new(m));
}
s
}
#[test]
fn test_where_tag_single() {
let s = sample();
let mut ids: Vec<_> = s
.query()
.where_tag("work")
.collect()
.iter()
.map(|m| m.id)
.collect();
ids.sort();
assert_eq!(ids, vec![1, 4]);
}
#[test]
fn test_where_any_tag_is_or() {
let s = sample();
let mut ids: Vec<_> = s
.query()
.where_any_tag(["reading".into(), "planning".into()])
.collect()
.iter()
.map(|m| m.id)
.collect();
ids.sort();
assert_eq!(ids, vec![3, 4]);
}
#[test]
fn test_where_all_tags_is_and() {
let s = sample();
let ids: Vec<_> = s
.query()
.where_all_tags(["personal".into(), "reading".into()])
.collect()
.iter()
.map(|m| m.id)
.collect();
assert_eq!(ids, vec![3]);
let none: Vec<_> = s
.query()
.where_all_tags(["personal".into(), "work".into()])
.collect();
assert!(none.is_empty());
}
#[test]
fn empty_tag_filters_are_treated_as_no_constraint() {
let s = sample();
let total = s.memories.len();
let any_empty: Vec<_> = s.query().where_any_tag(Vec::<String>::new()).collect();
assert_eq!(
any_empty.len(),
total,
"require_any_tag(empty) must be treated as no constraint \
(got {}/{}); pre-fix this rejected every memory",
any_empty.len(),
total,
);
let all_empty: Vec<_> = s.query().where_all_tags(Vec::<String>::new()).collect();
assert_eq!(
all_empty.len(),
total,
"require_all_tags(empty) must return all memories"
);
}
#[test]
fn test_where_pinned_toggles() {
let s = sample();
let mut pinned_ids: Vec<_> = s
.query()
.where_pinned(true)
.collect()
.iter()
.map(|m| m.id)
.collect();
pinned_ids.sort();
assert_eq!(pinned_ids, vec![1, 3]);
assert_eq!(s.query().where_pinned(false).count(), 3);
}
#[test]
fn test_content_contains_case_insensitive() {
let s = sample();
let ids: Vec<_> = s
.query()
.content_contains("GROCERY")
.collect()
.iter()
.map(|m| m.id)
.collect();
assert_eq!(ids, vec![2]);
}
#[test]
fn test_order_by_created_desc_limit() {
let s = sample();
let ids: Vec<_> = s
.query()
.order_by(OrderBy::CreatedDesc)
.limit(2)
.collect()
.iter()
.map(|m| m.id)
.collect();
assert_eq!(ids, vec![5, 4]);
}
#[test]
fn test_composed_tag_and_pinned() {
let s = sample();
let ids: Vec<_> = s
.query()
.where_tag("personal")
.where_pinned(true)
.collect()
.iter()
.map(|m| m.id)
.collect();
assert_eq!(ids, vec![3]);
}
#[test]
fn test_where_source() {
let mut s = sample();
Arc::make_mut(s.memories.get_mut(&1).unwrap()).source = "llm".into();
assert_eq!(s.query().where_source("llm").count(), 1);
assert_eq!(s.query().where_source("test").count(), 4);
}
#[test]
fn test_where_id_in() {
let s = sample();
let mut ids: Vec<_> = s
.query()
.where_id_in([2, 4, 99])
.collect()
.iter()
.map(|m| m.id)
.collect();
ids.sort();
assert_eq!(ids, vec![2, 4]);
}
#[test]
fn test_first_and_exists() {
let s = sample();
let first = s
.query()
.where_pinned(true)
.order_by(OrderBy::CreatedDesc)
.first()
.unwrap();
assert_eq!(first.id, 3);
assert!(s.query().where_tag("work").exists());
assert!(!s.query().where_tag("unicorn").exists());
}
#[test]
fn test_empty_state_queries_empty() {
let s = MemoriesState::new();
assert_eq!(s.query().count(), 0);
assert!(s.query().first().is_none());
assert!(!s.query().exists());
}
}