#![cfg(feature = "integration")]
#![allow(dead_code)]
use std::ops::Deref;
use std::sync::Once;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use memoir_core::client::{Client, WorkerHandle};
use memoir_core::llm::LlmConfig;
use memoir_core::memory::{MemoryKind, Scope};
use qdrant_client::Qdrant;
use qdrant_client::qdrant::DeleteCollectionBuilder;
use sea_orm::{ConnectionTrait, Database, DatabaseConnection};
const TEST_ID_ALPHABET: [char; 36] = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
];
pub async fn fresh_client() -> Result<TestClient> {
build_test_client(None, None).await
}
pub async fn fresh_client_with_extraction() -> Result<TestClient> {
let ollama_url =
std::env::var("OLLAMA_URL").context("OLLAMA_URL env var must be set for extraction tests")?;
let ollama_model = std::env::var("OLLAMA_MODEL")
.context("OLLAMA_MODEL env var must be set for extraction tests")?;
build_test_client(Some(LlmConfig::ollama(ollama_url, ollama_model)), None).await
}
#[cfg(feature = "knowledge-graph")]
pub async fn fresh_graph_client() -> Result<TestClient> {
let falkor_url =
std::env::var("FALKOR_URL").context("FALKOR_URL env var must be set for graph integration tests")?;
let ollama_url =
std::env::var("OLLAMA_URL").context("OLLAMA_URL env var must be set for graph integration tests")?;
let ollama_model = std::env::var("OLLAMA_MODEL")
.context("OLLAMA_MODEL env var must be set for graph integration tests")?;
let llm = LlmConfig::ollama(ollama_url, ollama_model);
build_test_client(Some(llm.clone()), Some(GraphConfig { falkor_url, llm })).await
}
struct GraphConfig {
falkor_url: String,
llm: LlmConfig,
}
static TRACING_INIT: Once = Once::new();
fn init_tracing() {
TRACING_INIT.call_once(|| {
let filter = tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
tracing_subscriber::EnvFilter::new("info,sqlx=warn,sea_orm=warn,hyper=warn")
});
let _ = tracing_subscriber::fmt()
.with_env_filter(filter)
.with_test_writer()
.try_init();
});
}
#[cfg_attr(not(feature = "knowledge-graph"), allow(unused_variables))]
async fn build_test_client(extraction: Option<LlmConfig>, graph: Option<GraphConfig>) -> Result<TestClient> {
init_tracing();
let database_url =
std::env::var("DATABASE_URL").context("DATABASE_URL env var must be set for integration tests")?;
let qdrant_url =
std::env::var("QDRANT_URL").context("QDRANT_URL env var must be set for integration tests")?;
let suffix = nanoid::nanoid!(8, &TEST_ID_ALPHABET);
let schema = format!("test_{suffix}");
let collection = format!("test_{suffix}");
let db = Database::connect(&database_url)
.await
.context("connect to Postgres (cleanup pool)")?;
let qdrant = Qdrant::from_url(&qdrant_url).build().context("build Qdrant cleanup client")?;
let builder = Client::builder()
.database_url(database_url.clone())
.qdrant(qdrant_url.clone())
.schema(schema.clone())
.collection(collection.clone())
.maybe_extraction_llm(extraction);
#[cfg(feature = "knowledge-graph")]
let builder = builder
.maybe_falkor(graph.as_ref().map(|cfg| cfg.falkor_url.clone()))
.maybe_graph_name(graph.as_ref().map(|_| format!("test_{suffix}")))
.maybe_relational_llm(graph.as_ref().map(|cfg| cfg.llm.clone()));
let client = builder.build().await.context("build memoir Client")?;
client.migrate().await.context("apply memoir migrations")?;
let worker = client
.spawn_worker()
.poll_interval(Duration::from_millis(50))
.lease_duration(Duration::from_secs(60))
.drain_timeout(Duration::from_secs(5))
.start()
.await
.context("spawn worker")?;
Ok(TestClient {
client,
worker: Some(worker),
cleanup_db: Some(db),
cleanup_qdrant: Some(qdrant),
cleanup_scopes: Vec::new(),
schema,
collection,
})
}
pub struct TestClient {
client: Client,
worker: Option<WorkerHandle>,
cleanup_db: Option<DatabaseConnection>,
cleanup_qdrant: Option<Qdrant>,
cleanup_scopes: Vec<Scope>,
pub schema: String,
pub collection: String,
}
impl Deref for TestClient {
type Target = Client;
fn deref(&self) -> &Self::Target {
&self.client
}
}
impl Drop for TestClient {
fn drop(&mut self) {
let schema = self.schema.clone();
let collection = self.collection.clone();
let Some(db) = self.cleanup_db.take() else { return };
let Some(qdrant) = self.cleanup_qdrant.take() else { return };
let worker = self.worker.take();
#[cfg(feature = "knowledge-graph")]
let client = self.client.clone();
#[cfg(feature = "knowledge-graph")]
let scopes = std::mem::take(&mut self.cleanup_scopes);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
if let Some(worker) = worker {
worker.shutdown().await;
}
#[cfg(feature = "knowledge-graph")]
for scope in scopes {
if let Err(err) = client
.forget(memoir_core::memory::ForgetTarget::Scope(scope.clone()))
.await
{
eprintln!("[TestClient::drop] forget graph scope {scope:?} failed: {err}");
}
}
if let Err(err) = qdrant
.delete_collection(DeleteCollectionBuilder::new(&collection))
.await
{
eprintln!("[TestClient::drop] qdrant delete_collection({collection}) failed: {err}");
}
let sql = format!("DROP SCHEMA IF EXISTS \"{schema}\" CASCADE");
if let Err(err) = db.execute_unprepared(&sql).await {
eprintln!("[TestClient::drop] drop schema {schema} failed: {err}");
}
});
});
}));
if let Err(panic) = result {
eprintln!(
"[TestClient::drop] cleanup panicked (schema={schema} collection={collection}): {panic:?}"
);
}
}
}
impl TestClient {
#[cfg(feature = "knowledge-graph")]
pub fn fresh_scope(&mut self) -> Scope {
let scope = fresh_scope();
self.cleanup_scopes.push(scope.clone());
scope
}
#[cfg(feature = "knowledge-graph")]
pub fn fresh_scope_in_org(&mut self, org_id: &str) -> Scope {
let suffix = nanoid::nanoid!(8, &TEST_ID_ALPHABET);
let scope = Scope {
agent_id: format!("agent_{suffix}"),
org_id: org_id.to_string(),
user_id: format!("user_{suffix}"),
};
self.cleanup_scopes.push(scope.clone());
scope
}
pub async fn raw_db(&self) -> Result<DatabaseConnection> {
let database_url =
std::env::var("DATABASE_URL").context("DATABASE_URL env var must be set")?;
let search_path = format!("{},public", self.schema);
let options = sea_orm::ConnectOptions::new(database_url)
.set_schema_search_path(search_path)
.to_owned();
Database::connect(options).await.context("connect raw_db")
}
pub fn raw_qdrant(&self) -> Result<Qdrant> {
let qdrant_url = std::env::var("QDRANT_URL").context("QDRANT_URL env var must be set")?;
Qdrant::from_url(&qdrant_url).build().context("build raw_qdrant client")
}
}
pub fn fresh_scope() -> Scope {
let suffix = nanoid::nanoid!(8, &TEST_ID_ALPHABET);
Scope {
agent_id: format!("agent_{suffix}"),
org_id: format!("org_{suffix}"),
user_id: format!("user_{suffix}"),
}
}
pub async fn wait_until_indexed(
client: &Client,
pid: &str,
scope: &Scope,
query: &str,
timeout: Duration,
) -> Result<()> {
let deadline = Instant::now() + timeout;
let mut delay = Duration::from_millis(50);
while Instant::now() < deadline {
let hits = client
.search(query, scope.clone())
.limit(50)
.await
.context("search probe failed")?;
if hits.list().iter().any(|m| m.pid == pid) {
return Ok(());
}
tokio::time::sleep(delay).await;
delay = (delay * 2).min(Duration::from_millis(500));
}
anyhow::bail!("pid {pid} did not become searchable within {timeout:?}")
}
pub async fn wait_for_first_pid(
client: &Client,
scope: &Scope,
query: &str,
timeout: Duration,
) -> Result<String> {
let deadline = Instant::now() + timeout;
let mut delay = Duration::from_millis(50);
while Instant::now() < deadline {
let hits = client.search(query, scope.clone()).limit(50).await?;
if let Some(first) = hits.list().first() {
return Ok(first.pid.clone());
}
tokio::time::sleep(delay).await;
delay = (delay * 2).min(Duration::from_millis(500));
}
anyhow::bail!("no indexed row appeared in scope within {timeout:?}")
}
pub async fn wait_until_extracted(
client: &Client,
scope: &Scope,
source_pid: &str,
timeout: Duration,
) -> Result<Vec<memoir_core::memory::Memory>> {
use memoir_core::store::MemoryStore;
let deadline = Instant::now() + timeout;
let mut delay = Duration::from_millis(200);
while Instant::now() < deadline {
let pids = client
.store()
.indexed_pids_in_scope(scope)
.await
.context("indexed_pids_in_scope probe failed")?;
let pid_refs: Vec<&str> = pids.iter().map(String::as_str).collect();
let rows = client
.store()
.find_by_pids(&pid_refs)
.await
.context("find_by_pids probe failed")?;
let semantics: Vec<_> = rows
.into_iter()
.filter(|m| m.kind == MemoryKind::Semantic && m.source_pid.as_deref() == Some(source_pid))
.collect();
if !semantics.is_empty() {
return Ok(semantics);
}
tokio::time::sleep(delay).await;
delay = (delay * 2).min(Duration::from_secs(2));
}
anyhow::bail!("no semantic memories observed for source_pid {source_pid} within {timeout:?}")
}
#[cfg(feature = "knowledge-graph")]
pub async fn wait_until_graph_committed(
client: &Client,
scope: &Scope,
timeout: Duration,
ready: impl Fn(&memoir_core::graph::GraphSnapshot) -> bool,
) -> Result<memoir_core::graph::GraphSnapshot> {
let deadline = Instant::now() + timeout;
let mut delay = Duration::from_millis(200);
while Instant::now() < deadline {
let snapshot = client
.inspect_graph()
.agent(scope.agent_id.clone())
.org(scope.org_id.clone())
.user(scope.user_id.clone())
.await
.context("inspect_graph probe failed")?;
if ready(&snapshot) {
return Ok(snapshot);
}
tokio::time::sleep(delay).await;
delay = (delay * 2).min(Duration::from_secs(2));
}
anyhow::bail!("graph did not reach the expected state for scope {scope:?} within {timeout:?}")
}