use std::convert::Infallible;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::{FromRequestParts, Path, Query, State};
use axum::http::request::Parts;
use axum::http::{HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use futures::Stream;
use serde::{Deserialize, Serialize};
use solo_core::{
Confidence, DocumentId, EncodingContext, Episode, InvalidateEvent, MemoryId, TenantId,
Tier,
};
use solo_storage::{TenantHandle, TenantRegistry};
use tokio::sync::broadcast;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::trace::TraceLayer;
use crate::auth::{AuthConfig, AuthenticatedPrincipal, middleware::AuthValidator};
#[derive(Clone)]
pub struct SoloHttpState {
pub registry: Arc<TenantRegistry>,
pub default_tenant: TenantId,
pub user_aliases: Arc<Vec<String>>,
pub mcp_sessions: crate::mcp_session::SessionStore,
}
pub const TENANT_HEADER: &str = "x-solo-tenant";
pub struct TenantExtractor(pub Arc<TenantHandle>);
impl<S> FromRequestParts<S> for TenantExtractor
where
SoloHttpState: FromRef<S>,
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let state = SoloHttpState::from_ref(state);
let resolved = if let Some(principal) = parts.extensions.get::<AuthenticatedPrincipal>()
&& let Some(claim) = principal.tenant_claim.clone()
{
claim
} else {
match parts.headers.get(TENANT_HEADER) {
None => state.default_tenant.clone(),
Some(raw) => {
let s = raw.to_str().map_err(|e| {
ApiError::bad_request(format!(
"{TENANT_HEADER}: header value must be ASCII ({e})"
))
})?;
TenantId::new(s.to_string()).map_err(|e| {
ApiError::bad_request(format!("{TENANT_HEADER}: invalid tenant id: {e}"))
})?
}
}
};
let handle = state.registry.get_or_open(&resolved).await.map_err(|e| {
use solo_core::Error;
match &e {
Error::NotFound(_) => ApiError::not_found(e.to_string()),
Error::InvalidInput(_) => ApiError::bad_request(e.to_string()),
_ => ApiError::internal(e.to_string()),
}
})?;
Ok(TenantExtractor(handle))
}
}
use axum::extract::FromRef;
pub struct AuditPrincipal(pub Option<String>);
impl<S> FromRequestParts<S> for AuditPrincipal
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(AuditPrincipal(
parts
.extensions
.get::<AuthenticatedPrincipal>()
.map(|p| p.subject.clone()),
))
}
}
pub struct MaybePrincipal(pub Option<AuthenticatedPrincipal>);
impl<S> FromRequestParts<S> for MaybePrincipal
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(MaybePrincipal(
parts
.extensions
.get::<AuthenticatedPrincipal>()
.cloned(),
))
}
}
pub fn router_with_auth(state: SoloHttpState, bearer_token: Option<String>) -> Router {
let auth = bearer_token.map(|token| AuthConfig::Bearer { token });
router_with_auth_config(state, auth)
}
pub fn router_with_auth_config(state: SoloHttpState, auth: Option<AuthConfig>) -> Router {
let cors = build_cors_layer();
let public = Router::new()
.route("/health", get(|| async { "ok" }))
.route("/openapi.json", get(openapi_handler));
let authed = Router::new()
.route("/memory", post(remember_handler))
.route("/memory/search", post(recall_handler))
.route("/memory/consolidate", post(consolidate_handler))
.route("/memory/{id}", get(inspect_handler).delete(forget_handler))
.route("/backup", post(backup_handler))
.route("/memory/themes", get(themes_handler))
.route("/memory/facts_about", get(facts_about_handler))
.route("/memory/contradictions", get(contradictions_handler))
.route(
"/memory/clusters/{cluster_id}",
get(inspect_cluster_handler),
)
.route(
"/memory/documents/search",
post(search_docs_handler),
)
.route(
"/memory/documents",
post(ingest_document_handler).get(list_documents_handler),
)
.route(
"/memory/documents/{id}",
get(inspect_document_handler).delete(forget_document_handler),
)
.route("/v1/graph/expand", get(graph_expand_handler))
.route("/v1/graph/nodes", get(graph_nodes_handler))
.route("/v1/graph/edges", get(graph_edges_handler))
.route("/v1/graph/inspect/{id}", get(graph_inspect_handler))
.route("/v1/graph/neighbors/{id}", get(graph_neighbors_handler))
.route("/v1/graph/stream", get(graph_stream_handler))
.route("/v1/tenants", get(tenants_list_handler))
.with_state(state.clone());
let mcp_router: Router<SoloHttpState> = Router::new()
.route(
"/mcp",
post(mcp_http_post_handler).get(mcp_http_get_handler),
)
.layer(axum::middleware::from_fn_with_state(
state.mcp_sessions.clone(),
crate::mcp_session::mcp_session_middleware,
));
let authed = authed.merge(mcp_router.with_state(state.clone()));
let authed = if let Some(cfg) = auth {
let validator = Arc::new(AuthValidator::from_config(
&cfg,
state.default_tenant.clone(),
));
authed.layer(axum::middleware::from_fn_with_state(
validator,
crate::auth::middleware::auth_middleware,
))
} else {
authed
};
public
.merge(authed)
.layer(cors)
.layer(TraceLayer::new_for_http())
}
pub fn router(state: SoloHttpState) -> Router {
router_with_auth_config(state, None)
}
fn build_cors_layer() -> CorsLayer {
CorsLayer::new()
.allow_origin(AllowOrigin::predicate(|origin: &HeaderValue, _req| {
origin
.to_str()
.map(is_localhost_origin)
.unwrap_or(false)
}))
.allow_methods([Method::GET, Method::POST, Method::DELETE, Method::OPTIONS])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
axum::http::HeaderName::from_static("x-solo-tenant"),
axum::http::HeaderName::from_static("mcp-session-id"),
axum::http::HeaderName::from_static(crate::mcp_session::MCP_LAST_EVENT_ID_HEADER),
])
}
fn is_localhost_origin(origin: &str) -> bool {
let rest = origin
.strip_prefix("http://")
.or_else(|| origin.strip_prefix("https://"));
let host = match rest {
Some(r) => r,
None => return false,
};
let host = host.split('/').next().unwrap_or(host);
let host = if let Some(idx) = host.rfind(':') {
if host.starts_with('[') {
host.find(']')
.map(|i| &host[..=i])
.unwrap_or(host)
} else {
&host[..idx]
}
} else {
host
};
matches!(host, "localhost" | "127.0.0.1" | "[::1]")
}
pub async fn serve_http(
addr: SocketAddr,
state: SoloHttpState,
bearer_token: Option<String>,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> std::io::Result<()> {
let auth = bearer_token.map(|token| AuthConfig::Bearer { token });
serve_http_with_auth_config(addr, state, auth, shutdown).await
}
pub async fn serve_http_with_auth_config(
addr: SocketAddr,
state: SoloHttpState,
auth: Option<AuthConfig>,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> std::io::Result<()> {
let auth_kind = match &auth {
Some(AuthConfig::Bearer { .. }) => "bearer",
Some(AuthConfig::Oidc { .. }) => "oidc",
None => "none",
};
let app = router_with_auth_config(state, auth);
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!(%addr, auth = auth_kind, "solo http: listening");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown)
.await
}
async fn openapi_handler() -> Json<serde_json::Value> {
Json(openapi_spec())
}
pub fn openapi_spec() -> serde_json::Value {
serde_json::json!({
"openapi": "3.1.0",
"info": {
"title": "Solo HTTP API",
"description":
"Local-first personal memory daemon. The HTTP transport \
mirrors the four MCP tools (memory_remember / recall / \
inspect / forget). Default deployment is loopback-only \
(127.0.0.1); LAN-bound deployments require a bearer \
token via `solo http-serve --bind <ip> --bearer-token-file <path>`.",
"version": env!("CARGO_PKG_VERSION"),
"license": { "name": "Apache-2.0" }
},
"servers": [
{ "url": "http://127.0.0.1:7437", "description": "Default loopback (replace port with your --http-port)" }
],
"components": {
"securitySchemes": {
"bearerAuth": {
"type": "http",
"scheme": "bearer",
"description":
"Bearer-token auth. Required only on LAN-bound deployments \
(`solo http-serve --bind <non-loopback> --bearer-token-file <path>`); \
the default `127.0.0.1` deployment is unauthenticated. \
`GET /health` and `GET /openapi.json` are exempt from auth even \
on bearer-protected instances."
}
},
"schemas": {
"RememberRequest": {
"type": "object",
"required": ["content"],
"properties": {
"content": { "type": "string", "minLength": 1, "description": "Episode content to embed + store." },
"source_type": { "type": "string", "description": "Free-form source tag (e.g. `user_message`, `tool_output`). Defaults to `user_message`." },
"source_id": { "type": "string", "description": "Optional upstream ID for traceability." }
},
"additionalProperties": false
},
"RememberResponse": {
"type": "object",
"required": ["memory_id"],
"properties": {
"memory_id": { "type": "string", "format": "uuid", "description": "UUID v7 assigned to the new episode." }
}
},
"RecallRequest": {
"type": "object",
"required": ["query"],
"properties": {
"query": { "type": "string", "minLength": 1, "description": "Natural-language query; embedded by the same model as stored episodes." },
"limit": { "type": "integer", "minimum": 1, "maximum": 50, "default": 5, "description": "Max number of hits to return." }
},
"additionalProperties": false
},
"RecallResult": {
"type": "object",
"description":
"Recall response. Fields are stable across v0.1 but not exhaustively documented here — \
see `solo_query::RecallResult` in the source for the canonical shape. \
Treat as a forward-compatible JSON object.",
"additionalProperties": true
},
"ConsolidationScope": {
"type": "object",
"description": "Filter + flags for consolidation. All fields optional; empty body = unbounded defaults.",
"properties": {
"window_days": { "type": "integer", "nullable": true, "description": "Restrict to memories with ts_ms >= now - window_days * 86400000. Null/omitted = unbounded." },
"force_merge": { "type": "boolean", "default": false, "description": "Run the existing-vs-existing merge + abstraction-regen passes even with zero unclustered candidates. Drift catch-up on quiet corpora. Added in 0.3.1." }
},
"additionalProperties": false
},
"ConsolidationReport": {
"type": "object",
"required": [
"episodes_seen", "clusters_built", "clusters_merged",
"clusters_absorbed", "existing_clusters_merged",
"episodes_clustered", "abstractions_built",
"abstractions_regenerated", "triples_built",
"contradictions_found"
],
"properties": {
"episodes_seen": { "type": "integer", "minimum": 0 },
"clusters_built": { "type": "integer", "minimum": 0, "description": "Brand-new clusters that survived to be persisted (post in-run-merge, post cross-run-absorb)." },
"clusters_merged": { "type": "integer", "minimum": 0, "description": "In-run merge: clusters absorbed into a sibling within this consolidate run (cross-UTC-bucket case). Counts losers." },
"clusters_absorbed": { "type": "integer", "minimum": 0, "description": "Cross-run absorb: freshly-built clusters folded into a pre-existing DB cluster with a similar centroid. Counts new-side clusters." },
"existing_clusters_merged": { "type": "integer", "minimum": 0, "description": "Existing-vs-existing merge: pre-existing DB clusters that drifted toward each other and now coalesce. Counts losers." },
"episodes_clustered": { "type": "integer", "minimum": 0 },
"abstractions_built": { "type": "integer", "minimum": 0, "description": "Fresh abstractions persisted for newly-built clusters. 0 when no LlmClient is wired." },
"abstractions_regenerated": { "type": "integer", "minimum": 0, "description": "Existing clusters whose stale abstractions were dropped and rebuilt because absorb or existing-merge changed their episode set. 0 without an LlmClient." },
"triples_built": { "type": "integer", "minimum": 0 },
"contradictions_found": { "type": "integer", "minimum": 0 }
}
},
"EpisodeRecord": {
"type": "object",
"description":
"Inspect response: full episode record. Fields are stable across v0.1 but not \
exhaustively documented here — see `solo_query::EpisodeRecord` in the source. \
Treat as a forward-compatible JSON object.",
"additionalProperties": true
},
"ThemeHit": {
"type": "object",
"description":
"One cluster + its (optional) abstraction. Returned by GET /memory/themes. \
See `solo_query::ThemeHit` for the canonical shape: cluster_id, \
abstraction_id?, abstraction_text?, episode_count, coherence, created_at_ms.",
"additionalProperties": true
},
"FactHit": {
"type": "object",
"description":
"One Steward-extracted SPO triple. Returned by GET /memory/facts_about. \
See `solo_query::FactHit` for fields: triple_id, subject_id, predicate, \
object_id, object_kind, valid_from_ms, valid_to_ms?, confidence, cluster_id?.",
"additionalProperties": true
},
"ContradictionHit": {
"type": "object",
"description":
"One Steward-flagged contradiction with each side's triple LEFT JOIN'd in. \
Returned by GET /memory/contradictions. See `solo_query::ContradictionHit`: \
a_id, b_id, kind, explanation, detected_at_ms, a_triple?, b_triple?.",
"additionalProperties": true
},
"ClusterRecord": {
"type": "object",
"description":
"Snapshot of one cluster — its row, optional abstraction, and source episodes \
(content truncated to 200 chars unless ?full_content=true). Returned by \
GET /memory/clusters/{cluster_id}. See `solo_query::ClusterRecord`.",
"additionalProperties": true
},
"IngestDocumentRequest": {
"type": "object",
"required": ["path"],
"properties": {
"path": {
"type": "string",
"minLength": 1,
"description":
"Server-side absolute path to the file to ingest. The file must be \
readable by the Solo process. Supported formats: plaintext / \
markdown / code, HTML, PDF."
}
},
"additionalProperties": false
},
"IngestReport": {
"type": "object",
"description":
"Returned by POST /memory/documents. Reports the document id assigned, \
the number of chunks persisted + embedded, the total byte size, and a \
`deduped` flag (true when the same content_hash was already present and \
the existing doc_id was returned unchanged). See `solo_storage::IngestReport`.",
"required": ["doc_id", "chunks_persisted", "bytes_ingested", "deduped"],
"properties": {
"doc_id": { "type": "string", "format": "uuid" },
"chunks_persisted": { "type": "integer", "minimum": 0 },
"bytes_ingested": { "type": "integer", "minimum": 0, "format": "int64" },
"deduped": { "type": "boolean" }
},
"additionalProperties": false
},
"ForgetDocumentReport": {
"type": "object",
"description":
"Returned by DELETE /memory/documents/{id}. Reports the doc_id soft-deleted \
and how many chunk rowids were tombstoned in the HNSW index. The chunk rows \
themselves survive in SQL for forensic value. See `solo_storage::ForgetDocumentReport`.",
"required": ["doc_id", "chunks_tombstoned"],
"properties": {
"doc_id": { "type": "string", "format": "uuid" },
"chunks_tombstoned": { "type": "integer", "minimum": 0 }
},
"additionalProperties": false
},
"SearchDocsRequest": {
"type": "object",
"required": ["query"],
"properties": {
"query": { "type": "string", "minLength": 1 },
"limit": { "type": "integer", "minimum": 1, "maximum": 100, "default": 5 }
},
"additionalProperties": false
},
"DocSearchHit": {
"type": "object",
"description":
"One chunk hit + parent-doc context. Fields per `solo_query::DocSearchHit`: \
chunk_id, doc_id, doc_title?, doc_source?, doc_mime_type?, chunk_index, \
content, cos_distance, start_offset, end_offset.",
"additionalProperties": true
},
"DocumentInspectResult": {
"type": "object",
"description":
"Returned by GET /memory/documents/{id}. A `document` record (full metadata) \
plus an ordered list of chunk summaries (each preview truncated to 200 \
chars). See `solo_query::DocumentInspectResult`.",
"additionalProperties": true
},
"DocumentSummary": {
"type": "object",
"description":
"One row from GET /memory/documents. Fields per `solo_query::DocumentSummary`: \
doc_id, title?, source?, mime_type?, ingested_at_ms, chunk_count, status.",
"additionalProperties": true
},
"ApiError": {
"type": "object",
"required": ["error", "status"],
"properties": {
"error": { "type": "string" },
"status": { "type": "integer", "minimum": 400, "maximum": 599 }
}
}
}
},
"paths": {
"/health": {
"get": {
"summary": "Liveness probe",
"description": "Returns plain text `ok`. Always unauthenticated.",
"responses": {
"200": {
"description": "Server is up.",
"content": { "text/plain": { "schema": { "type": "string", "example": "ok" } } }
}
}
}
},
"/openapi.json": {
"get": {
"summary": "Self-describing OpenAPI 3.1 spec",
"description": "Returns this document. Always unauthenticated.",
"responses": {
"200": {
"description": "OpenAPI 3.1 document.",
"content": { "application/json": { "schema": { "type": "object" } } }
}
}
}
},
"/memory": {
"post": {
"summary": "Remember (store an episode)",
"description": "Equivalent to MCP tool `memory_remember`.",
"security": [{ "bearerAuth": [] }, {}],
"requestBody": {
"required": true,
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/RememberRequest" } } }
},
"responses": {
"200": {
"description": "Memory stored; returns the new MemoryId.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/RememberResponse" } } }
},
"400": { "description": "Bad request (e.g. empty content).", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/search": {
"post": {
"summary": "Recall (vector search)",
"description": "Equivalent to MCP tool `memory_recall`. Embeds the query, runs HNSW search, returns the top-K hits in cosine-distance order.",
"security": [{ "bearerAuth": [] }, {}],
"requestBody": {
"required": true,
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/RecallRequest" } } }
},
"responses": {
"200": {
"description": "Search results.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/RecallResult" } } }
},
"400": { "description": "Bad request (e.g. empty query).", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/consolidate": {
"post": {
"summary": "Run a consolidation pass (clustering + abstraction)",
"description":
"Idempotent. Triggers the SWS-equivalent clustering pass; if a `Steward` LLM is wired \
on the server, also runs the REM-equivalent abstraction pass that populates \
`semantic_abstractions` and `triples`. Empty request body = default scope (unbounded \
window). Equivalent to the `solo consolidate` CLI.",
"security": [{ "bearerAuth": [] }, {}],
"requestBody": {
"required": false,
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/ConsolidationScope" } } }
},
"responses": {
"200": {
"description": "Consolidation complete; report counts the work done.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/ConsolidationReport" } } }
},
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/backup": {
"post": {
"summary": "Online encrypted backup",
"description":
"Run an online SQLCipher backup of the live data dir to a server-side path. \
The destination file is encrypted with the same Argon2id-derived raw key as \
the source, so it restores under the same passphrase + a copy of the source's \
`solo.config.toml`. Hot — the backup runs against the writer's existing \
connection without taking the lockfile, so the daemon keeps serving reads + \
writes during the operation. v0.3.2+.",
"security": [{ "bearerAuth": [] }, {}],
"requestBody": {
"required": true,
"content": { "application/json": { "schema": {
"type": "object",
"properties": {
"to": { "type": "string", "description": "Server-side absolute path for the backup file." },
"force": { "type": "boolean", "description": "Overwrite an existing destination file. Default false.", "default": false }
},
"required": ["to"]
} } }
},
"responses": {
"200": {
"description": "Backup complete; reports the destination path + elapsed milliseconds.",
"content": { "application/json": { "schema": {
"type": "object",
"properties": {
"path": { "type": "string" },
"elapsed_ms": { "type": "integer", "format": "int64" }
}
} } }
},
"400": { "description": "Destination invalid, exists without force, or its parent doesn't exist." },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." },
"500": { "description": "Backup failed (disk full, permission denied, etc.)." }
}
}
},
"/memory/{id}": {
"get": {
"summary": "Inspect a memory by ID",
"description": "Equivalent to MCP tool `memory_inspect`.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [{
"name": "id",
"in": "path",
"required": true,
"schema": { "type": "string", "format": "uuid" },
"description": "MemoryId (UUID v7)."
}],
"responses": {
"200": {
"description": "Episode record.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/EpisodeRecord" } } }
},
"400": { "description": "Malformed ID.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"404": { "description": "No such memory.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
},
"delete": {
"summary": "Forget (soft-delete) a memory by ID",
"description":
"Equivalent to MCP tool `memory_forget`. Soft-delete: flips `episodes.status = 'forgotten'` \
and tombstones the HNSW vector. The row + embedding are preserved for forensics; \
re-running `solo reembed` after this does NOT restore visibility.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "id", "in": "path", "required": true, "schema": { "type": "string", "format": "uuid" } },
{ "name": "reason", "in": "query", "required": false, "schema": { "type": "string" }, "description": "Free-form reason logged via tracing (not yet persisted to the DB)." }
],
"responses": {
"204": { "description": "Forgotten (or already forgotten — idempotent)." },
"400": { "description": "Malformed ID.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"404": { "description": "No such memory.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/themes": {
"get": {
"summary": "List recent cluster themes",
"description":
"Equivalent to MCP tool `memory_themes`. List cluster abstractions ordered by \
most-recent first. Use to surface 'what has the user been thinking about lately' \
without paging through individual episodes. v0.4.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "window_days", "in": "query", "required": false, "schema": { "type": "integer", "minimum": 1 }, "description": "Optional time window. Omit for unfiltered (all-time, most-recent first)." },
{ "name": "limit", "in": "query", "required": false, "schema": { "type": "integer", "minimum": 1, "maximum": 100, "default": 5 } }
],
"responses": {
"200": {
"description": "Array of ThemeHits (possibly empty).",
"content": { "application/json": { "schema": { "type": "array", "items": { "$ref": "#/components/schemas/ThemeHit" } } } }
},
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/facts_about": {
"get": {
"summary": "Query the SPO knowledge graph by subject",
"description":
"Equivalent to MCP tool `memory_facts_about`. Query Steward-extracted triples by \
subject + optional predicate + optional time window. Subject is required \
(predicate-only scans not supported). Pass `include_as_object=true` (v0.5.1+) \
to also surface rows where `subject` appears as the object. v0.4.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "subject", "in": "query", "required": true, "schema": { "type": "string", "minLength": 1 }, "description": "Subject id to query (e.g. `Sam`)." },
{ "name": "predicate", "in": "query", "required": false, "schema": { "type": "string" }, "description": "Optional predicate filter (e.g. `works_at`)." },
{ "name": "since_ms", "in": "query", "required": false, "schema": { "type": "integer" }, "description": "Optional valid_from_ms lower bound (epoch ms)." },
{ "name": "until_ms", "in": "query", "required": false, "schema": { "type": "integer" }, "description": "Optional valid_to_ms upper bound (epoch ms). NULL upper bounds (still-valid facts) pass through." },
{ "name": "include_as_object", "in": "query", "required": false, "schema": { "type": "boolean", "default": false }, "description": "If true, also match rows where `subject` appears as the object (e.g. surface 'Sam pushes back on PRs about Maya' under subject='Maya'). Default false. v0.5.1+." },
{ "name": "limit", "in": "query", "required": false, "schema": { "type": "integer", "minimum": 1, "maximum": 100, "default": 5 } }
],
"responses": {
"200": {
"description": "Array of FactHits (possibly empty).",
"content": { "application/json": { "schema": { "type": "array", "items": { "$ref": "#/components/schemas/FactHit" } } } }
},
"400": { "description": "Bad request (e.g. empty subject).", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/contradictions": {
"get": {
"summary": "List Steward-flagged contradictions",
"description":
"Equivalent to MCP tool `memory_contradictions`. Each result includes both \
sides' triple SPO via LEFT JOIN for context. v0.4.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "limit", "in": "query", "required": false, "schema": { "type": "integer", "minimum": 1, "maximum": 100, "default": 5 } }
],
"responses": {
"200": {
"description": "Array of ContradictionHits (possibly empty).",
"content": { "application/json": { "schema": { "type": "array", "items": { "$ref": "#/components/schemas/ContradictionHit" } } } }
},
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/clusters/{cluster_id}": {
"get": {
"summary": "Inspect a single cluster",
"description":
"Equivalent to MCP tool `memory_inspect_cluster`. Returns the cluster row, \
its (optional) abstraction, and its source episodes. By default each \
episode's `content` is truncated to 200 chars with a trailing `…`. Pass \
`?full_content=true` to get verbatim episode content. v0.5.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "cluster_id", "in": "path", "required": true, "schema": { "type": "string", "minLength": 1 }, "description": "Cluster id (from a previous GET /memory/themes response)." },
{ "name": "full_content", "in": "query", "required": false, "schema": { "type": "boolean", "default": false }, "description": "If true, return episode content verbatim. Default false (truncate to 200 chars + ellipsis)." }
],
"responses": {
"200": {
"description": "Cluster snapshot.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/ClusterRecord" } } }
},
"400": { "description": "Bad request (e.g. empty cluster_id).", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"404": { "description": "No such cluster.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/documents": {
"post": {
"summary": "Ingest a document",
"description":
"Equivalent to MCP tool `memory_ingest_document`. Reads the file at the \
supplied server-side path, parses + chunks + embeds, and persists under \
`documents` + `document_chunks`. Returns the new doc_id, chunk count, and \
a `deduped` flag (true when an existing document with the same content_hash \
was returned without re-embedding). v0.7.0+.",
"security": [{ "bearerAuth": [] }, {}],
"requestBody": {
"required": true,
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/IngestDocumentRequest" } } }
},
"responses": {
"200": {
"description": "Document ingested (or deduplicated).",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/IngestReport" } } }
},
"400": { "description": "Bad request (e.g. empty path, file unreadable, parse error).", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
},
"get": {
"summary": "List ingested documents (paginated)",
"description":
"Equivalent to MCP tool `memory_list_documents`. Returns a paginated index, \
newest first. Forgotten documents are hidden by default; pass \
`?include_forgotten=true` to see them too. v0.7.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "limit", "in": "query", "required": false, "schema": { "type": "integer", "minimum": 1, "maximum": 100, "default": 20 } },
{ "name": "offset", "in": "query", "required": false, "schema": { "type": "integer", "minimum": 0, "default": 0 } },
{ "name": "include_forgotten", "in": "query", "required": false, "schema": { "type": "boolean", "default": false } }
],
"responses": {
"200": {
"description": "Array of DocumentSummary (possibly empty).",
"content": { "application/json": { "schema": { "type": "array", "items": { "$ref": "#/components/schemas/DocumentSummary" } } } }
},
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/documents/search": {
"post": {
"summary": "Vector search across document chunks",
"description":
"Equivalent to MCP tool `memory_search_docs`. Embeds the query and returns \
up to `limit` matching chunks, best match first, each annotated with the \
parent document's title + source path. Forgotten documents are excluded. \
v0.7.0+.",
"security": [{ "bearerAuth": [] }, {}],
"requestBody": {
"required": true,
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/SearchDocsRequest" } } }
},
"responses": {
"200": {
"description": "Array of DocSearchHits (possibly empty).",
"content": { "application/json": { "schema": { "type": "array", "items": { "$ref": "#/components/schemas/DocSearchHit" } } } }
},
"400": { "description": "Bad request (e.g. empty query).", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
},
"/memory/documents/{id}": {
"get": {
"summary": "Inspect one document",
"description":
"Equivalent to MCP tool `memory_inspect_document`. Returns the document's \
metadata plus a preview of every chunk (truncated to 200 chars). v0.7.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "id", "in": "path", "required": true, "schema": { "type": "string", "format": "uuid" }, "description": "DocumentId (UUID v7)." }
],
"responses": {
"200": {
"description": "Document inspection result.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/DocumentInspectResult" } } }
},
"400": { "description": "Malformed id.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"404": { "description": "No such document.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
},
"delete": {
"summary": "Forget (soft-delete) one document",
"description":
"Equivalent to MCP tool `memory_forget_document`. Flips `documents.status` \
to `forgotten` and tombstones every chunk's HNSW rowid. The chunk rows \
survive in SQL for forensic value. v0.7.0+.",
"security": [{ "bearerAuth": [] }, {}],
"parameters": [
{ "name": "id", "in": "path", "required": true, "schema": { "type": "string", "format": "uuid" } }
],
"responses": {
"200": {
"description": "Document soft-deleted; report counts chunks tombstoned.",
"content": { "application/json": { "schema": { "$ref": "#/components/schemas/ForgetDocumentReport" } } }
},
"400": { "description": "Malformed id.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"404": { "description": "No such document.", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/ApiError" } } } },
"401": { "description": "Missing or invalid bearer token (LAN-bound deployments only)." }
}
}
}
}
})
}
#[derive(Debug, Deserialize)]
struct RememberBody {
content: String,
#[serde(default)]
source_type: Option<String>,
#[serde(default)]
source_id: Option<String>,
}
#[derive(Debug, Serialize)]
struct RememberResponse {
memory_id: String,
}
async fn remember_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Json(body): Json<RememberBody>,
) -> Result<Json<RememberResponse>, ApiError> {
let content = body.content.trim_end().to_string();
if content.is_empty() {
return Err(ApiError::bad_request("content must not be empty"));
}
let embedding = tenant.embedder().embed(&content).await.map_err(ApiError::from)?;
let episode = Episode {
memory_id: MemoryId::new(),
ts_ms: chrono::Utc::now().timestamp_millis(),
source_type: body.source_type.unwrap_or_else(|| "user_message".into()),
source_id: body.source_id,
content,
encoding_context: EncodingContext::default(),
provenance: None,
confidence: Confidence::new(0.9).unwrap(),
strength: 0.5,
salience: 0.5,
tier: Tier::Hot,
};
let mid = tenant
.write()
.remember_as(principal, episode, embedding)
.await
.map_err(ApiError::from)?;
Ok(Json(RememberResponse {
memory_id: mid.to_string(),
}))
}
#[derive(Debug, Deserialize)]
struct RecallBody {
query: String,
#[serde(default = "default_limit")]
limit: usize,
}
fn default_limit() -> usize {
5
}
async fn recall_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Json(body): Json<RecallBody>,
) -> Result<Json<solo_query::RecallResult>, ApiError> {
let result = solo_query::run_recall(tenant.as_ref(), principal, &body.query, body.limit)
.await
.map_err(ApiError::from)?;
Ok(Json(result))
}
async fn inspect_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Path(id): Path<String>,
) -> Result<Json<solo_query::EpisodeRecord>, ApiError> {
let mid = MemoryId::from_str(&id)
.map_err(|e| ApiError::bad_request(format!("invalid id: {e}")))?;
let row = solo_query::inspect_one(tenant.read(), tenant.audit(), principal, mid)
.await
.map_err(ApiError::from)?;
Ok(Json(row))
}
#[derive(Debug, Deserialize)]
struct ThemesQuery {
#[serde(default)]
window_days: Option<i64>,
#[serde(default = "default_limit")]
limit: usize,
}
async fn themes_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Query(q): Query<ThemesQuery>,
) -> Result<Json<Vec<solo_query::ThemeHit>>, ApiError> {
let hits = solo_query::themes(
tenant.read(),
tenant.audit(),
principal,
q.window_days,
q.limit,
)
.await
.map_err(ApiError::from)?;
Ok(Json(hits))
}
#[derive(Debug, Deserialize)]
struct FactsAboutQuery {
subject: String,
#[serde(default)]
predicate: Option<String>,
#[serde(default)]
since_ms: Option<i64>,
#[serde(default)]
until_ms: Option<i64>,
#[serde(default)]
include_as_object: bool,
#[serde(default = "default_limit")]
limit: usize,
}
async fn facts_about_handler(
State(s): State<SoloHttpState>,
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Query(q): Query<FactsAboutQuery>,
) -> Result<Json<Vec<solo_query::FactHit>>, ApiError> {
if q.subject.trim().is_empty() {
return Err(ApiError::bad_request("subject must not be empty"));
}
let hits = solo_query::facts_about(
tenant.read(),
tenant.audit(),
principal,
&q.subject,
&s.user_aliases,
q.include_as_object,
q.predicate.as_deref(),
q.since_ms,
q.until_ms,
q.limit,
)
.await
.map_err(ApiError::from)?;
Ok(Json(hits))
}
#[derive(Debug, Deserialize)]
struct ContradictionsQuery {
#[serde(default = "default_limit")]
limit: usize,
}
async fn contradictions_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Query(q): Query<ContradictionsQuery>,
) -> Result<Json<Vec<solo_query::ContradictionHit>>, ApiError> {
let hits = solo_query::contradictions(tenant.read(), tenant.audit(), principal, q.limit)
.await
.map_err(ApiError::from)?;
Ok(Json(hits))
}
#[derive(Debug, Deserialize, Default)]
struct InspectClusterQuery {
#[serde(default)]
full_content: bool,
}
async fn inspect_cluster_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Path(cluster_id): Path<String>,
Query(q): Query<InspectClusterQuery>,
) -> Result<Json<solo_query::ClusterRecord>, ApiError> {
if cluster_id.trim().is_empty() {
return Err(ApiError::bad_request("cluster_id must not be empty"));
}
let record = solo_query::inspect_cluster(
tenant.read(),
tenant.audit(),
principal,
&cluster_id,
q.full_content,
)
.await
.map_err(ApiError::from)?;
Ok(Json(record))
}
#[derive(Debug, Deserialize)]
struct IngestDocumentBody {
path: String,
}
async fn ingest_document_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Json(body): Json<IngestDocumentBody>,
) -> Result<Json<solo_storage::IngestReport>, ApiError> {
if body.path.trim().is_empty() {
return Err(ApiError::bad_request("path must not be empty"));
}
let path = std::path::PathBuf::from(body.path);
let chunk_config = solo_storage::document::ChunkConfig::default();
let report = tenant
.write()
.ingest_document_as(principal, path, chunk_config)
.await
.map_err(ApiError::from)?;
Ok(Json(report))
}
#[derive(Debug, Deserialize)]
struct SearchDocsBody {
query: String,
#[serde(default = "default_limit")]
limit: usize,
}
async fn search_docs_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Json(body): Json<SearchDocsBody>,
) -> Result<Json<Vec<solo_query::DocSearchHit>>, ApiError> {
let hits = solo_query::run_doc_search(tenant.as_ref(), principal, &body.query, body.limit)
.await
.map_err(ApiError::from)?;
Ok(Json(hits))
}
async fn inspect_document_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Path(id): Path<String>,
) -> Result<Json<solo_query::DocumentInspectResult>, ApiError> {
let doc_id = DocumentId::from_str(&id)
.map_err(|e| ApiError::bad_request(format!("invalid id: {e}")))?;
let result_opt =
solo_query::inspect_document(tenant.read(), tenant.audit(), principal, &doc_id)
.await
.map_err(ApiError::from)?;
match result_opt {
Some(record) => Ok(Json(record)),
None => Err(ApiError::not_found(format!("document {doc_id} not found"))),
}
}
#[derive(Debug, Deserialize)]
struct ListDocumentsQuery {
#[serde(default = "default_list_documents_limit")]
limit: usize,
#[serde(default)]
offset: usize,
#[serde(default)]
include_forgotten: bool,
}
fn default_list_documents_limit() -> usize {
20
}
async fn list_documents_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Query(q): Query<ListDocumentsQuery>,
) -> Result<Json<Vec<solo_query::DocumentSummary>>, ApiError> {
let rows = solo_query::list_documents(
tenant.read(),
tenant.audit(),
principal,
q.limit,
q.offset,
q.include_forgotten,
)
.await
.map_err(ApiError::from)?;
Ok(Json(rows))
}
async fn forget_document_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Path(id): Path<String>,
) -> Result<Json<solo_storage::ForgetDocumentReport>, ApiError> {
let doc_id = DocumentId::from_str(&id)
.map_err(|e| ApiError::bad_request(format!("invalid id: {e}")))?;
let report = tenant
.write()
.forget_document_as(principal, doc_id)
.await
.map_err(ApiError::from)?;
Ok(Json(report))
}
#[derive(Debug, Deserialize)]
struct ForgetQuery {
#[serde(default)]
reason: Option<String>,
}
async fn forget_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
Path(id): Path<String>,
Query(q): Query<ForgetQuery>,
) -> Result<StatusCode, ApiError> {
let mid = MemoryId::from_str(&id).map_err(|e| ApiError::bad_request(format!("invalid id: {e}")))?;
let reason = q.reason.unwrap_or_else(|| "http".into());
tenant
.write()
.forget_as(principal, mid, reason)
.await
.map_err(ApiError::from)?;
Ok(StatusCode::NO_CONTENT)
}
async fn consolidate_handler(
TenantExtractor(tenant): TenantExtractor,
AuditPrincipal(principal): AuditPrincipal,
body: axum::body::Bytes,
) -> Result<Json<solo_storage::ConsolidationReport>, ApiError> {
let scope = if body.is_empty() {
solo_storage::ConsolidationScope::default()
} else {
serde_json::from_slice(&body)
.map_err(|e| ApiError::bad_request(format!("invalid JSON: {e}")))?
};
let report = tenant
.write()
.consolidate_as(principal, scope)
.await
.map_err(ApiError::from)?;
Ok(Json(report))
}
#[derive(Debug, Deserialize)]
struct BackupBody {
to: String,
#[serde(default)]
force: bool,
}
#[derive(Debug, Serialize)]
struct BackupResponse {
path: String,
elapsed_ms: u64,
}
async fn backup_handler(
TenantExtractor(tenant): TenantExtractor,
Json(body): Json<BackupBody>,
) -> Result<Json<BackupResponse>, ApiError> {
use std::path::PathBuf;
let dest = PathBuf::from(&body.to);
if dest.as_os_str().is_empty() {
return Err(ApiError::bad_request("`to` must not be empty"));
}
if solo_storage::paths_refer_to_same_file(tenant.db_path(), &dest) {
return Err(ApiError::bad_request(format!(
"destination {} is the same file as the source database; \
refusing to run (would corrupt the live database)",
dest.display()
)));
}
if dest.exists() {
if !body.force {
return Err(ApiError::bad_request(format!(
"destination {} exists; pass force=true to overwrite",
dest.display()
)));
}
std::fs::remove_file(&dest).map_err(|e| {
ApiError::internal(format!(
"remove existing destination {}: {e}",
dest.display()
))
})?;
}
if let Some(parent) = dest.parent() {
if !parent.as_os_str().is_empty() && !parent.is_dir() {
return Err(ApiError::bad_request(format!(
"destination parent directory {} does not exist",
parent.display()
)));
}
}
let started = std::time::Instant::now();
tenant.write().backup(dest.clone()).await.map_err(ApiError::from)?;
let elapsed_ms = started.elapsed().as_millis() as u64;
Ok(Json(BackupResponse {
path: dest.display().to_string(),
elapsed_ms,
}))
}
const GRAPH_EXPAND_DEFAULT_LIMIT: u32 = 25;
const GRAPH_EXPAND_MAX_LIMIT: u32 = 100;
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "snake_case")]
enum GraphExpandKind {
ClusterMember,
DocumentChunk,
Triple,
Semantic,
}
#[derive(Debug, Deserialize)]
struct GraphExpandQuery {
node_id: String,
kind: GraphExpandKind,
#[serde(default)]
limit: Option<u32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NodeKind {
Episode,
Document,
Chunk,
Cluster,
Entity,
}
impl NodeKind {
fn as_wire_str(self) -> &'static str {
match self {
Self::Episode => "episode",
Self::Document => "document",
Self::Chunk => "chunk",
Self::Cluster => "cluster",
Self::Entity => "entity",
}
}
}
fn parse_node_id(raw: &str) -> Result<(NodeKind, &str), ApiError> {
let (prefix, value) = raw.split_once(':').ok_or_else(|| {
ApiError::bad_request(format!(
"node_id must be `<prefix>:<value>` (one of ep:/doc:/chunk:/cl:/ent:); got {raw:?}"
))
})?;
if value.is_empty() {
return Err(ApiError::bad_request(format!(
"node_id value is empty after prefix: {raw:?}"
)));
}
let kind = match prefix {
"ep" => NodeKind::Episode,
"doc" => NodeKind::Document,
"chunk" => NodeKind::Chunk,
"cl" => NodeKind::Cluster,
"ent" => NodeKind::Entity,
other => {
return Err(ApiError::bad_request(format!(
"unknown node_id prefix {other:?}; expected one of ep:/doc:/chunk:/cl:/ent:"
)));
}
};
Ok((kind, value))
}
#[derive(Debug, Serialize)]
struct GraphNode {
id: String,
kind: &'static str,
label: String,
#[serde(skip_serializing_if = "Option::is_none")]
ts_ms: Option<i64>,
tenant_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
preview: Option<String>,
}
#[derive(Debug, Serialize)]
struct GraphEdge {
id: String,
source: String,
target: String,
kind: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
predicate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
weight: Option<f32>,
}
#[derive(Debug, Serialize)]
struct GraphExpandResponse {
nodes: Vec<GraphNode>,
edges: Vec<GraphEdge>,
}
fn edge_id(source: &str, kind: &str, target: &str) -> String {
format!("{source}--{kind}--{target}")
}
#[derive(Debug)]
struct ExpandedEpisode {
memory_id: String,
ts_ms: i64,
content: String,
}
#[derive(Debug)]
struct ExpandedDocument {
doc_id: String,
title: Option<String>,
source: Option<String>,
ingested_at_ms: i64,
}
#[derive(Debug)]
struct ExpandedChunk {
chunk_id: String,
chunk_index: i64,
content: String,
}
fn truncate_preview(s: &str, max: usize) -> String {
if s.chars().count() <= max {
return s.to_string();
}
let mut out: String = s.chars().take(max - 1).collect();
out.push('…');
out
}
const GRAPH_LABEL_CHARS: usize = 80;
const GRAPH_PREVIEW_CHARS: usize = 200;
fn episode_label(content: &str) -> String {
let first_line = content.lines().next().unwrap_or(content);
truncate_preview(first_line, GRAPH_LABEL_CHARS)
}
fn graph_node_for_episode(tenant_id: &str, ep: &ExpandedEpisode) -> GraphNode {
GraphNode {
id: format!("ep:{}", ep.memory_id),
kind: NodeKind::Episode.as_wire_str(),
label: episode_label(&ep.content),
ts_ms: Some(ep.ts_ms),
tenant_id: tenant_id.to_string(),
preview: Some(truncate_preview(&ep.content, GRAPH_PREVIEW_CHARS)),
}
}
fn graph_node_for_document(tenant_id: &str, d: &ExpandedDocument) -> GraphNode {
let label = d
.title
.clone()
.or_else(|| d.source.clone())
.unwrap_or_else(|| d.doc_id.clone());
GraphNode {
id: format!("doc:{}", d.doc_id),
kind: NodeKind::Document.as_wire_str(),
label: truncate_preview(&label, GRAPH_LABEL_CHARS),
ts_ms: Some(d.ingested_at_ms),
tenant_id: tenant_id.to_string(),
preview: d.source.clone(),
}
}
fn graph_node_for_chunk(tenant_id: &str, c: &ExpandedChunk) -> GraphNode {
GraphNode {
id: format!("chunk:{}", c.chunk_id),
kind: NodeKind::Chunk.as_wire_str(),
label: format!("chunk #{}: {}", c.chunk_index, episode_label(&c.content)),
ts_ms: None,
tenant_id: tenant_id.to_string(),
preview: Some(truncate_preview(&c.content, GRAPH_PREVIEW_CHARS)),
}
}
fn graph_node_for_cluster(
tenant_id: &str,
cluster_id: &str,
abstraction: Option<&str>,
created_at_ms: i64,
) -> GraphNode {
let label = abstraction
.map(|a| truncate_preview(a, GRAPH_LABEL_CHARS))
.unwrap_or_else(|| format!("cluster {cluster_id}"));
GraphNode {
id: format!("cl:{cluster_id}"),
kind: NodeKind::Cluster.as_wire_str(),
label,
ts_ms: Some(created_at_ms),
tenant_id: tenant_id.to_string(),
preview: abstraction.map(|a| truncate_preview(a, GRAPH_PREVIEW_CHARS)),
}
}
fn graph_node_for_entity(tenant_id: &str, value: &str) -> GraphNode {
GraphNode {
id: format!("ent:{value}"),
kind: NodeKind::Entity.as_wire_str(),
label: truncate_preview(value, GRAPH_LABEL_CHARS),
ts_ms: None,
tenant_id: tenant_id.to_string(),
preview: None,
}
}
async fn graph_expand_handler(
TenantExtractor(tenant): TenantExtractor,
Query(q): Query<GraphExpandQuery>,
) -> Result<Json<GraphExpandResponse>, ApiError> {
let limit = q.limit.unwrap_or(GRAPH_EXPAND_DEFAULT_LIMIT);
let limit = limit.clamp(1, GRAPH_EXPAND_MAX_LIMIT) as i64;
let (node_kind, value) = parse_node_id(&q.node_id)?;
let value = value.to_string();
let node_id_full = q.node_id.clone();
let tenant_id_str = tenant.tenant_id().to_string();
match q.kind {
GraphExpandKind::ClusterMember => {
expand_cluster_member(&tenant, &tenant_id_str, node_kind, &value, &node_id_full, limit)
.await
}
GraphExpandKind::DocumentChunk => {
expand_document_chunk(&tenant, &tenant_id_str, node_kind, &value, &node_id_full, limit)
.await
}
GraphExpandKind::Triple => {
expand_triple(&tenant, &tenant_id_str, node_kind, &value, &node_id_full, limit).await
}
GraphExpandKind::Semantic => {
expand_semantic(&tenant, &tenant_id_str, node_kind, &value, &node_id_full, limit).await
}
}
.map(Json)
}
async fn expand_cluster_member(
tenant: &TenantHandle,
tenant_id: &str,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
match node_kind {
NodeKind::Episode => expand_cluster_member_from_episode(
tenant,
tenant_id,
value.to_string(),
node_id_full.to_string(),
limit,
)
.await,
NodeKind::Cluster => expand_cluster_member_from_cluster(
tenant,
tenant_id,
value.to_string(),
node_id_full.to_string(),
limit,
)
.await,
_ => Err(ApiError::bad_request(format!(
"kind=cluster_member only valid for episode or cluster source nodes; got {}",
node_kind.as_wire_str()
))),
}
}
async fn expand_cluster_member_from_episode(
tenant: &TenantHandle,
tenant_id: &str,
memory_id: String,
node_id_full: String,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
let memory_id_for_err = memory_id.clone();
let rows: Vec<(String, Option<String>, i64)> = tenant
.read()
.interact(move |conn| {
let exists: i64 = conn.query_row(
"SELECT COUNT(*) FROM episodes WHERE memory_id = ?1",
rusqlite::params![&memory_id],
|r| r.get(0),
)?;
if exists == 0 {
return Ok(Vec::new());
}
let mut stmt = conn.prepare(
"SELECT c.cluster_id, sa.content, c.created_at_ms
FROM cluster_episodes ce
JOIN clusters c ON c.cluster_id = ce.cluster_id
LEFT JOIN semantic_abstractions sa ON sa.cluster_id = c.cluster_id
WHERE ce.memory_id = ?1
ORDER BY c.created_at_ms DESC
LIMIT ?2",
)?;
let mapped = stmt
.query_map(rusqlite::params![&memory_id, limit], |r| {
Ok((
r.get::<_, String>(0)?,
r.get::<_, Option<String>>(1)?,
r.get::<_, i64>(2)?,
))
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok::<_, rusqlite::Error>(mapped)
})
.await
.map_err(ApiError::from)?;
if rows.is_empty() {
ensure_episode_exists(tenant, &memory_id_for_err, &node_id_full).await?;
return Ok(GraphExpandResponse {
nodes: Vec::new(),
edges: Vec::new(),
});
}
let mut nodes = Vec::with_capacity(rows.len());
let mut edges = Vec::with_capacity(rows.len());
for (cluster_id, abstraction, created_at_ms) in rows {
let target_id = format!("cl:{cluster_id}");
edges.push(GraphEdge {
id: edge_id(&node_id_full, "cluster_member", &target_id),
source: node_id_full.clone(),
target: target_id,
kind: "cluster_member",
predicate: None,
weight: None,
});
nodes.push(graph_node_for_cluster(
tenant_id,
&cluster_id,
abstraction.as_deref(),
created_at_ms,
));
}
Ok(GraphExpandResponse { nodes, edges })
}
async fn expand_cluster_member_from_cluster(
tenant: &TenantHandle,
tenant_id: &str,
cluster_id: String,
node_id_full: String,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
let cluster_id_for_err = cluster_id.clone();
let rows: Vec<ExpandedEpisode> = tenant
.read()
.interact(move |conn| {
let exists: i64 = conn.query_row(
"SELECT COUNT(*) FROM clusters WHERE cluster_id = ?1",
rusqlite::params![&cluster_id],
|r| r.get(0),
)?;
if exists == 0 {
return Ok(Vec::new());
}
let mut stmt = conn.prepare(
"SELECT e.memory_id, e.ts_ms, e.content
FROM cluster_episodes ce
JOIN episodes e ON e.memory_id = ce.memory_id
WHERE ce.cluster_id = ?1
AND e.status = 'active'
ORDER BY e.ts_ms DESC
LIMIT ?2",
)?;
let mapped = stmt
.query_map(rusqlite::params![&cluster_id, limit], |r| {
Ok(ExpandedEpisode {
memory_id: r.get(0)?,
ts_ms: r.get(1)?,
content: r.get(2)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok::<_, rusqlite::Error>(mapped)
})
.await
.map_err(ApiError::from)?;
if rows.is_empty() {
ensure_cluster_exists(tenant, &cluster_id_for_err, &node_id_full).await?;
return Ok(GraphExpandResponse {
nodes: Vec::new(),
edges: Vec::new(),
});
}
let mut nodes = Vec::with_capacity(rows.len());
let mut edges = Vec::with_capacity(rows.len());
for ep in rows {
let target_id = format!("ep:{}", ep.memory_id);
edges.push(GraphEdge {
id: edge_id(&node_id_full, "cluster_member", &target_id),
source: node_id_full.clone(),
target: target_id,
kind: "cluster_member",
predicate: None,
weight: None,
});
nodes.push(graph_node_for_episode(tenant_id, &ep));
}
Ok(GraphExpandResponse { nodes, edges })
}
async fn expand_document_chunk(
tenant: &TenantHandle,
tenant_id: &str,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
match node_kind {
NodeKind::Document => expand_document_chunk_from_document(
tenant,
tenant_id,
value.to_string(),
node_id_full.to_string(),
limit,
)
.await,
NodeKind::Chunk => expand_document_chunk_from_chunk(
tenant,
tenant_id,
value.to_string(),
node_id_full.to_string(),
)
.await,
_ => Err(ApiError::bad_request(format!(
"kind=document_chunk only valid for document or chunk source nodes; got {}",
node_kind.as_wire_str()
))),
}
}
async fn expand_document_chunk_from_document(
tenant: &TenantHandle,
tenant_id: &str,
doc_id: String,
node_id_full: String,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
let doc_id_for_err = doc_id.clone();
let rows: Vec<ExpandedChunk> = tenant
.read()
.interact(move |conn| {
let exists: i64 = conn.query_row(
"SELECT COUNT(*) FROM documents WHERE doc_id = ?1",
rusqlite::params![&doc_id],
|r| r.get(0),
)?;
if exists == 0 {
return Ok(Vec::new());
}
let mut stmt = conn.prepare(
"SELECT chunk_id, chunk_index, content
FROM document_chunks
WHERE doc_id = ?1
ORDER BY chunk_index ASC
LIMIT ?2",
)?;
let mapped = stmt
.query_map(rusqlite::params![&doc_id, limit], |r| {
Ok(ExpandedChunk {
chunk_id: r.get(0)?,
chunk_index: r.get(1)?,
content: r.get(2)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok::<_, rusqlite::Error>(mapped)
})
.await
.map_err(ApiError::from)?;
if rows.is_empty() {
ensure_document_exists(tenant, &doc_id_for_err, &node_id_full).await?;
return Ok(GraphExpandResponse {
nodes: Vec::new(),
edges: Vec::new(),
});
}
let mut nodes = Vec::with_capacity(rows.len());
let mut edges = Vec::with_capacity(rows.len());
for c in rows {
let target_id = format!("chunk:{}", c.chunk_id);
edges.push(GraphEdge {
id: edge_id(&node_id_full, "document_chunk", &target_id),
source: node_id_full.clone(),
target: target_id,
kind: "document_chunk",
predicate: None,
weight: None,
});
nodes.push(graph_node_for_chunk(tenant_id, &c));
}
Ok(GraphExpandResponse { nodes, edges })
}
async fn expand_document_chunk_from_chunk(
tenant: &TenantHandle,
tenant_id: &str,
chunk_id: String,
node_id_full: String,
) -> Result<GraphExpandResponse, ApiError> {
let chunk_id_for_err = chunk_id.clone();
let row: Option<ExpandedDocument> = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT d.doc_id, d.title, d.source, d.ingested_at_ms
FROM document_chunks c
JOIN documents d ON d.doc_id = c.doc_id
WHERE c.chunk_id = ?1",
rusqlite::params![&chunk_id],
|r| {
Ok(ExpandedDocument {
doc_id: r.get(0)?,
title: r.get(1)?,
source: r.get(2)?,
ingested_at_ms: r.get(3)?,
})
},
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})
})
.await
.map_err(ApiError::from)?;
let d = row.ok_or_else(|| {
ApiError::not_found(format!(
"node_id {node_id_full:?} (chunk_id {chunk_id_for_err}) not found in current tenant"
))
})?;
let target_id = format!("doc:{}", d.doc_id);
let edge = GraphEdge {
id: edge_id(&node_id_full, "document_chunk", &target_id),
source: node_id_full.clone(),
target: target_id,
kind: "document_chunk",
predicate: None,
weight: None,
};
let node = graph_node_for_document(tenant_id, &d);
Ok(GraphExpandResponse {
nodes: vec![node],
edges: vec![edge],
})
}
async fn expand_triple(
tenant: &TenantHandle,
tenant_id: &str,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
match node_kind {
NodeKind::Episode => expand_triple_from_episode(
tenant,
tenant_id,
value.to_string(),
node_id_full.to_string(),
limit,
)
.await,
NodeKind::Entity => expand_triple_from_entity(
tenant,
tenant_id,
value.to_string(),
node_id_full.to_string(),
limit,
)
.await,
_ => Err(ApiError::bad_request(format!(
"kind=triple only valid for episode or entity source nodes; got {}",
node_kind.as_wire_str()
))),
}
}
#[derive(Debug)]
struct TripleRow {
subject_id: String,
predicate: String,
object_id: String,
confidence: f32,
}
async fn expand_triple_from_episode(
tenant: &TenantHandle,
tenant_id: &str,
memory_id: String,
node_id_full: String,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
let memory_id_for_err = memory_id.clone();
let rows: Vec<TripleRow> = tenant
.read()
.interact(move |conn| {
let rowid_opt: Option<i64> = conn
.query_row(
"SELECT rowid FROM episodes WHERE memory_id = ?1",
rusqlite::params![&memory_id],
|r| r.get(0),
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})?;
let Some(rowid) = rowid_opt else {
return Ok(Vec::new());
};
let mut stmt = conn.prepare(
"SELECT subject_id, predicate, object_id, confidence
FROM triples
WHERE source_episode_id = ?1
AND status = 'active'
ORDER BY valid_from_ms DESC
LIMIT ?2",
)?;
let mapped = stmt
.query_map(rusqlite::params![rowid, limit], |r| {
Ok(TripleRow {
subject_id: r.get(0)?,
predicate: r.get(1)?,
object_id: r.get(2)?,
confidence: r.get(3)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok::<_, rusqlite::Error>(mapped)
})
.await
.map_err(ApiError::from)?;
if rows.is_empty() {
ensure_episode_exists(tenant, &memory_id_for_err, &node_id_full).await?;
return Ok(GraphExpandResponse {
nodes: Vec::new(),
edges: Vec::new(),
});
}
let mut nodes = Vec::new();
let mut edges = Vec::new();
let mut seen_entities: std::collections::HashSet<String> = Default::default();
for t in rows {
let subj_id = format!("ent:{}", t.subject_id);
let obj_id = format!("ent:{}", t.object_id);
if seen_entities.insert(t.subject_id.clone()) {
nodes.push(graph_node_for_entity(tenant_id, &t.subject_id));
}
if seen_entities.insert(t.object_id.clone()) {
nodes.push(graph_node_for_entity(tenant_id, &t.object_id));
}
edges.push(GraphEdge {
id: edge_id(&subj_id, "triple", &obj_id),
source: subj_id,
target: obj_id,
kind: "triple",
predicate: Some(t.predicate),
weight: Some(t.confidence),
});
}
Ok(GraphExpandResponse { nodes, edges })
}
async fn expand_triple_from_entity(
tenant: &TenantHandle,
tenant_id: &str,
entity_value: String,
node_id_full: String,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
let entity_q = entity_value.clone();
let rows: Vec<ExpandedEpisode> = tenant
.read()
.interact(move |conn| {
let mut stmt = conn.prepare(
"SELECT DISTINCT e.memory_id, e.ts_ms, e.content
FROM triples t
JOIN episodes e ON e.rowid = t.source_episode_id
WHERE (t.subject_id = ?1 OR t.object_id = ?1)
AND t.status = 'active'
AND t.source_episode_id IS NOT NULL
AND e.status = 'active'
ORDER BY e.ts_ms DESC
LIMIT ?2",
)?;
let mapped = stmt
.query_map(rusqlite::params![&entity_q, limit], |r| {
Ok(ExpandedEpisode {
memory_id: r.get(0)?,
ts_ms: r.get(1)?,
content: r.get(2)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok::<_, rusqlite::Error>(mapped)
})
.await
.map_err(ApiError::from)?;
let mut nodes = Vec::with_capacity(rows.len());
let mut edges = Vec::with_capacity(rows.len());
for ep in rows {
let target_id = format!("ep:{}", ep.memory_id);
edges.push(GraphEdge {
id: edge_id(&node_id_full, "triple", &target_id),
source: node_id_full.clone(),
target: target_id,
kind: "triple",
predicate: None,
weight: None,
});
nodes.push(graph_node_for_episode(tenant_id, &ep));
}
let _ = entity_value;
Ok(GraphExpandResponse { nodes, edges })
}
async fn expand_semantic(
tenant: &TenantHandle,
tenant_id: &str,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
limit: i64,
) -> Result<GraphExpandResponse, ApiError> {
if node_kind != NodeKind::Episode {
return Err(ApiError::bad_request(format!(
"kind=semantic only valid for episode source nodes; got {}",
node_kind.as_wire_str()
)));
}
let memory_id = value.to_string();
let memory_id_q = memory_id.clone();
let content: Option<String> = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT content FROM episodes WHERE memory_id = ?1 AND status = 'active'",
rusqlite::params![&memory_id_q],
|r| r.get::<_, String>(0),
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})
})
.await
.map_err(ApiError::from)?;
let content = content.ok_or_else(|| {
ApiError::not_found(format!(
"node_id {node_id_full:?} (memory_id {memory_id}) not found in current tenant"
))
})?;
let widened = (limit as usize).saturating_add(1).min(100);
let result = solo_query::recall::run_recall_inner(
tenant.embedder(),
tenant.hnsw(),
tenant.read(),
&content,
widened,
)
.await
.map_err(ApiError::from)?;
let mut nodes = Vec::new();
let mut edges = Vec::new();
for hit in result.hits.into_iter() {
if hit.memory_id == memory_id {
continue;
}
if nodes.len() as i64 >= limit {
break;
}
let weight = (1.0 - hit.cos_distance).max(0.0);
let target_id = format!("ep:{}", hit.memory_id);
edges.push(GraphEdge {
id: edge_id(node_id_full, "semantic", &target_id),
source: node_id_full.to_string(),
target: target_id,
kind: "semantic",
predicate: None,
weight: Some(weight),
});
nodes.push(GraphNode {
id: format!("ep:{}", hit.memory_id),
kind: NodeKind::Episode.as_wire_str(),
label: episode_label(&hit.content),
ts_ms: None,
tenant_id: tenant_id.to_string(),
preview: Some(truncate_preview(&hit.content, GRAPH_PREVIEW_CHARS)),
});
}
Ok(GraphExpandResponse { nodes, edges })
}
async fn ensure_episode_exists(
tenant: &TenantHandle,
memory_id: &str,
node_id_full: &str,
) -> Result<(), ApiError> {
let memory_id_q = memory_id.to_string();
let exists: i64 = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*) FROM episodes WHERE memory_id = ?1",
rusqlite::params![&memory_id_q],
|r| r.get(0),
)
})
.await
.map_err(ApiError::from)?;
if exists == 0 {
return Err(ApiError::not_found(format!(
"node_id {node_id_full:?} not found in current tenant"
)));
}
Ok(())
}
async fn ensure_cluster_exists(
tenant: &TenantHandle,
cluster_id: &str,
node_id_full: &str,
) -> Result<(), ApiError> {
let cluster_id_q = cluster_id.to_string();
let exists: i64 = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*) FROM clusters WHERE cluster_id = ?1",
rusqlite::params![&cluster_id_q],
|r| r.get(0),
)
})
.await
.map_err(ApiError::from)?;
if exists == 0 {
return Err(ApiError::not_found(format!(
"node_id {node_id_full:?} not found in current tenant"
)));
}
Ok(())
}
async fn ensure_document_exists(
tenant: &TenantHandle,
doc_id: &str,
node_id_full: &str,
) -> Result<(), ApiError> {
let doc_id_q = doc_id.to_string();
let exists: i64 = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*) FROM documents WHERE doc_id = ?1",
rusqlite::params![&doc_id_q],
|r| r.get(0),
)
})
.await
.map_err(ApiError::from)?;
if exists == 0 {
return Err(ApiError::not_found(format!(
"node_id {node_id_full:?} not found in current tenant"
)));
}
Ok(())
}
const GRAPH_NODES_DEFAULT_LIMIT: u32 = 100;
const GRAPH_NODES_MAX_LIMIT: u32 = 1000;
const GRAPH_EDGES_DEFAULT_LIMIT: u32 = 200;
const GRAPH_EDGES_MAX_LIMIT: u32 = 2000;
const GRAPH_ENTITY_CAP: usize = 200;
const ENTITY_CAP_HEADER: &str = "x-solo-entity-cap-reached";
#[derive(Debug, Deserialize)]
struct GraphNodesQuery {
#[serde(default)]
kind: Option<String>,
#[serde(default)]
since_ms: Option<i64>,
#[serde(default)]
until_ms: Option<i64>,
#[serde(default)]
limit: Option<u32>,
#[serde(default)]
cursor: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GraphEdgesQuery {
#[serde(default)]
node_id: Option<String>,
#[serde(default)]
r#type: Option<String>,
#[serde(default)]
limit: Option<u32>,
#[serde(default)]
cursor: Option<String>,
}
#[derive(Debug, Serialize)]
struct GraphNodesResponse {
nodes: Vec<GraphNode>,
#[serde(skip_serializing_if = "Option::is_none")]
next_cursor: Option<String>,
}
#[derive(Debug, Serialize)]
struct GraphEdgesResponse {
edges: Vec<GraphEdge>,
#[serde(skip_serializing_if = "Option::is_none")]
next_cursor: Option<String>,
}
fn parse_node_kind_filter(raw: Option<&str>) -> Result<Vec<NodeKind>, ApiError> {
let raw = raw.unwrap_or("").trim();
if raw.is_empty() {
return Ok(vec![
NodeKind::Episode,
NodeKind::Document,
NodeKind::Chunk,
NodeKind::Cluster,
NodeKind::Entity,
]);
}
let mut out = Vec::new();
for token in raw.split(',') {
let token = token.trim();
if token.is_empty() {
continue;
}
let kind = match token {
"episode" => NodeKind::Episode,
"document" => NodeKind::Document,
"chunk" => NodeKind::Chunk,
"cluster" => NodeKind::Cluster,
"entity" => NodeKind::Entity,
other => {
return Err(ApiError::bad_request(format!(
"unknown node kind {other:?}; expected one of episode/document/chunk/cluster/entity"
)));
}
};
if !out.contains(&kind) {
out.push(kind);
}
}
if out.is_empty() {
return Err(ApiError::bad_request(
"kind filter is empty after parsing; either omit or list at least one kind",
));
}
Ok(out)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum EdgeKind {
Triple,
DocumentChunk,
ClusterMember,
}
impl EdgeKind {
fn order_idx(self) -> u8 {
match self {
Self::Triple => 0,
Self::DocumentChunk => 1,
Self::ClusterMember => 2,
}
}
}
fn parse_edge_kind_filter(raw: Option<&str>) -> Result<Vec<EdgeKind>, ApiError> {
let raw = raw.unwrap_or("").trim();
if raw.is_empty() {
return Ok(vec![
EdgeKind::Triple,
EdgeKind::DocumentChunk,
EdgeKind::ClusterMember,
]);
}
let mut out = Vec::new();
for token in raw.split(',') {
let token = token.trim();
if token.is_empty() {
continue;
}
let kind = match token {
"triple" => EdgeKind::Triple,
"document_chunk" => EdgeKind::DocumentChunk,
"cluster_member" => EdgeKind::ClusterMember,
"semantic" => {
return Err(ApiError::bad_request(
"semantic edges are available via /v1/graph/neighbors/:id?kind=semantic, not /v1/graph/edges (semantic edges aren't precomputed; they're query-time HNSW lookups)",
));
}
other => {
return Err(ApiError::bad_request(format!(
"unknown edge type {other:?}; expected one of triple/document_chunk/cluster_member"
)));
}
};
if !out.contains(&kind) {
out.push(kind);
}
}
if out.is_empty() {
return Err(ApiError::bad_request(
"type filter is empty after parsing; either omit or list at least one type",
));
}
Ok(out)
}
#[derive(Debug, Serialize, Deserialize)]
struct NodesCursor {
ts_ms: i64,
id: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct EdgesCursor {
kind_idx: u8,
sub_id: String,
}
fn encode_cursor<T: Serialize>(value: &T) -> Result<String, ApiError> {
use base64::Engine;
let json = serde_json::to_vec(value).map_err(|e| {
ApiError::internal(format!("cursor serialize: {e}"))
})?;
Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json))
}
fn decode_cursor<T: for<'de> Deserialize<'de>>(raw: &str) -> Result<T, ApiError> {
use base64::Engine;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(raw.as_bytes())
.map_err(|e| ApiError::bad_request(format!("cursor: bad base64: {e}")))?;
serde_json::from_slice::<T>(&bytes)
.map_err(|e| ApiError::bad_request(format!("cursor: bad JSON payload: {e}")))
}
#[derive(Debug)]
struct StagingNode {
node: GraphNode,
sort_ts_ms: i64,
sort_id: String,
}
fn cmp_node_sort_keys(a: (i64, &str), b: (i64, &str)) -> std::cmp::Ordering {
match b.0.cmp(&a.0) {
std::cmp::Ordering::Equal => a.1.cmp(b.1), other => other,
}
}
fn node_passes_cursor(ts_ms: i64, id: &str, cursor: &NodesCursor) -> bool {
cmp_node_sort_keys((ts_ms, id), (cursor.ts_ms, cursor.id.as_str()))
== std::cmp::Ordering::Greater
}
#[derive(Debug)]
struct NodeRowEp {
memory_id: String,
ts_ms: i64,
content: String,
}
fn fetch_episodes_for_nodes(
conn: &rusqlite::Connection,
since_ms: Option<i64>,
until_ms: Option<i64>,
cursor: Option<&NodesCursor>,
limit: i64,
) -> rusqlite::Result<Vec<NodeRowEp>> {
let mut sql = String::from(
"SELECT memory_id, ts_ms, content
FROM episodes
WHERE status = 'active'",
);
let mut params: Vec<rusqlite::types::Value> = Vec::new();
if let Some(s) = since_ms {
sql.push_str(" AND ts_ms >= ?");
params.push(s.into());
}
if let Some(u) = until_ms {
sql.push_str(" AND ts_ms <= ?");
params.push(u.into());
}
if let Some(cur) = cursor {
sql.push_str(" AND ts_ms <= ?");
params.push(cur.ts_ms.into());
}
sql.push_str(" ORDER BY ts_ms DESC, memory_id ASC LIMIT ?");
params.push(limit.into());
let mut stmt = conn.prepare(&sql)?;
let rows: Vec<NodeRowEp> = stmt
.query_map(rusqlite::params_from_iter(params), |r| {
Ok(NodeRowEp {
memory_id: r.get(0)?,
ts_ms: r.get(1)?,
content: r.get(2)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
#[derive(Debug)]
struct NodeRowDoc {
doc_id: String,
title: Option<String>,
source: Option<String>,
ingested_at_ms: i64,
}
fn fetch_documents_for_nodes(
conn: &rusqlite::Connection,
since_ms: Option<i64>,
until_ms: Option<i64>,
cursor: Option<&NodesCursor>,
limit: i64,
) -> rusqlite::Result<Vec<NodeRowDoc>> {
let mut sql = String::from(
"SELECT doc_id, title, source, ingested_at_ms
FROM documents
WHERE status = 'active'",
);
let mut params: Vec<rusqlite::types::Value> = Vec::new();
if let Some(s) = since_ms {
sql.push_str(" AND ingested_at_ms >= ?");
params.push(s.into());
}
if let Some(u) = until_ms {
sql.push_str(" AND ingested_at_ms <= ?");
params.push(u.into());
}
if let Some(cur) = cursor {
sql.push_str(" AND ingested_at_ms <= ?");
params.push(cur.ts_ms.into());
}
sql.push_str(" ORDER BY ingested_at_ms DESC, doc_id ASC LIMIT ?");
params.push(limit.into());
let mut stmt = conn.prepare(&sql)?;
let rows: Vec<NodeRowDoc> = stmt
.query_map(rusqlite::params_from_iter(params), |r| {
Ok(NodeRowDoc {
doc_id: r.get(0)?,
title: r.get(1)?,
source: r.get(2)?,
ingested_at_ms: r.get(3)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
#[derive(Debug)]
struct NodeRowChunk {
chunk_id: String,
chunk_index: i64,
content: String,
created_at_ms: i64,
}
fn fetch_chunks_for_nodes(
conn: &rusqlite::Connection,
since_ms: Option<i64>,
until_ms: Option<i64>,
cursor: Option<&NodesCursor>,
limit: i64,
) -> rusqlite::Result<Vec<NodeRowChunk>> {
let mut sql = String::from(
"SELECT c.chunk_id, c.chunk_index, c.content, c.created_at_ms
FROM document_chunks c
JOIN documents d ON d.doc_id = c.doc_id
WHERE d.status = 'active'",
);
let mut params: Vec<rusqlite::types::Value> = Vec::new();
if let Some(s) = since_ms {
sql.push_str(" AND c.created_at_ms >= ?");
params.push(s.into());
}
if let Some(u) = until_ms {
sql.push_str(" AND c.created_at_ms <= ?");
params.push(u.into());
}
if let Some(cur) = cursor {
sql.push_str(" AND c.created_at_ms <= ?");
params.push(cur.ts_ms.into());
}
sql.push_str(" ORDER BY c.created_at_ms DESC, c.chunk_id ASC LIMIT ?");
params.push(limit.into());
let mut stmt = conn.prepare(&sql)?;
let rows: Vec<NodeRowChunk> = stmt
.query_map(rusqlite::params_from_iter(params), |r| {
Ok(NodeRowChunk {
chunk_id: r.get(0)?,
chunk_index: r.get(1)?,
content: r.get(2)?,
created_at_ms: r.get(3)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
#[derive(Debug)]
struct NodeRowCluster {
cluster_id: String,
abstraction: Option<String>,
created_at_ms: i64,
}
fn fetch_clusters_for_nodes(
conn: &rusqlite::Connection,
since_ms: Option<i64>,
until_ms: Option<i64>,
cursor: Option<&NodesCursor>,
limit: i64,
) -> rusqlite::Result<Vec<NodeRowCluster>> {
let mut sql = String::from(
"SELECT c.cluster_id, sa.content, c.created_at_ms
FROM clusters c
LEFT JOIN semantic_abstractions sa ON sa.cluster_id = c.cluster_id
WHERE 1=1",
);
let mut params: Vec<rusqlite::types::Value> = Vec::new();
if let Some(s) = since_ms {
sql.push_str(" AND c.created_at_ms >= ?");
params.push(s.into());
}
if let Some(u) = until_ms {
sql.push_str(" AND c.created_at_ms <= ?");
params.push(u.into());
}
if let Some(cur) = cursor {
sql.push_str(" AND c.created_at_ms <= ?");
params.push(cur.ts_ms.into());
}
sql.push_str(" ORDER BY c.created_at_ms DESC, c.cluster_id ASC LIMIT ?");
params.push(limit.into());
let mut stmt = conn.prepare(&sql)?;
let rows: Vec<NodeRowCluster> = stmt
.query_map(rusqlite::params_from_iter(params), |r| {
Ok(NodeRowCluster {
cluster_id: r.get(0)?,
abstraction: r.get(1)?,
created_at_ms: r.get(2)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
#[derive(Debug)]
struct NodeRowEntity {
value: String,
ref_count: i64,
first_seen_ms: i64,
}
fn fetch_entities_for_nodes(
conn: &rusqlite::Connection,
since_ms: Option<i64>,
until_ms: Option<i64>,
cursor: Option<&NodesCursor>,
) -> rusqlite::Result<(Vec<NodeRowEntity>, bool)> {
let mut sql = String::from(
"WITH all_refs AS (
SELECT subject_id AS value, valid_from_ms AS ts_ms FROM triples WHERE status = 'active'
UNION ALL
SELECT object_id AS value, valid_from_ms AS ts_ms FROM triples WHERE status = 'active'
)
SELECT value, COUNT(*) AS ref_count, MIN(ts_ms) AS first_seen_ms
FROM all_refs
WHERE 1=1",
);
let mut params: Vec<rusqlite::types::Value> = Vec::new();
if let Some(s) = since_ms {
sql.push_str(" AND ts_ms >= ?");
params.push(s.into());
}
if let Some(u) = until_ms {
sql.push_str(" AND ts_ms <= ?");
params.push(u.into());
}
sql.push_str(" GROUP BY value");
if let Some(ts) = cursor.map(|c| c.ts_ms) {
sql.push_str(" HAVING MIN(ts_ms) <= ?");
params.push(ts.into());
}
let want = GRAPH_ENTITY_CAP as i64 + 1;
sql.push_str(" ORDER BY ref_count DESC, value ASC LIMIT ?");
params.push(want.into());
let mut stmt = conn.prepare(&sql)?;
let rows: Vec<NodeRowEntity> = stmt
.query_map(rusqlite::params_from_iter(params), |r| {
Ok(NodeRowEntity {
value: r.get(0)?,
ref_count: r.get(1)?,
first_seen_ms: r.get(2)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
let cap_reached = rows.len() > GRAPH_ENTITY_CAP;
let mut trimmed = rows;
if cap_reached {
trimmed.truncate(GRAPH_ENTITY_CAP);
}
Ok((trimmed, cap_reached))
}
async fn graph_nodes_handler(
TenantExtractor(tenant): TenantExtractor,
Query(q): Query<GraphNodesQuery>,
) -> Result<Response, ApiError> {
let limit = q.limit.unwrap_or(GRAPH_NODES_DEFAULT_LIMIT);
let limit = limit.clamp(1, GRAPH_NODES_MAX_LIMIT);
let kinds = parse_node_kind_filter(q.kind.as_deref())?;
let since_ms = q.since_ms;
let until_ms = q.until_ms;
if let (Some(s), Some(u)) = (since_ms, until_ms) {
if s > u {
return Err(ApiError::bad_request(format!(
"since_ms ({s}) must be <= until_ms ({u})"
)));
}
}
let cursor = match q.cursor.as_deref() {
None => None,
Some("") => None,
Some(raw) => Some(decode_cursor::<NodesCursor>(raw)?),
};
let want_episode = kinds.contains(&NodeKind::Episode);
let want_document = kinds.contains(&NodeKind::Document);
let want_chunk = kinds.contains(&NodeKind::Chunk);
let want_cluster = kinds.contains(&NodeKind::Cluster);
let want_entity = kinds.contains(&NodeKind::Entity);
let per_kind_limit = (limit as i64).saturating_add(2);
let tenant_id_for_blocking = tenant.tenant_id().to_string();
let cursor_clone = cursor.as_ref().map(|c| NodesCursor {
ts_ms: c.ts_ms,
id: c.id.clone(),
});
let (mut staged, cap_reached) = tenant
.read()
.interact(move |conn| {
let mut staged: Vec<StagingNode> = Vec::new();
let mut cap_reached = false;
let cursor_ref = cursor_clone.as_ref();
if want_episode {
let eps = fetch_episodes_for_nodes(conn, since_ms, until_ms, cursor_ref, per_kind_limit)?;
for ep in eps {
let id = format!("ep:{}", ep.memory_id);
let exp = ExpandedEpisode {
memory_id: ep.memory_id,
ts_ms: ep.ts_ms,
content: ep.content,
};
let node = graph_node_for_episode(&tenant_id_for_blocking, &exp);
staged.push(StagingNode {
sort_ts_ms: ep.ts_ms,
sort_id: id.clone(),
node,
});
}
}
if want_document {
let docs = fetch_documents_for_nodes(conn, since_ms, until_ms, cursor_ref, per_kind_limit)?;
for d in docs {
let id = format!("doc:{}", d.doc_id);
let exp = ExpandedDocument {
doc_id: d.doc_id,
title: d.title,
source: d.source,
ingested_at_ms: d.ingested_at_ms,
};
let node = graph_node_for_document(&tenant_id_for_blocking, &exp);
staged.push(StagingNode {
sort_ts_ms: d.ingested_at_ms,
sort_id: id.clone(),
node,
});
}
}
if want_chunk {
let chunks = fetch_chunks_for_nodes(conn, since_ms, until_ms, cursor_ref, per_kind_limit)?;
for c in chunks {
let id = format!("chunk:{}", c.chunk_id);
let exp = ExpandedChunk {
chunk_id: c.chunk_id,
chunk_index: c.chunk_index,
content: c.content,
};
let mut node = graph_node_for_chunk(&tenant_id_for_blocking, &exp);
node.ts_ms = Some(c.created_at_ms);
staged.push(StagingNode {
sort_ts_ms: c.created_at_ms,
sort_id: id.clone(),
node,
});
}
}
if want_cluster {
let cls = fetch_clusters_for_nodes(conn, since_ms, until_ms, cursor_ref, per_kind_limit)?;
for c in cls {
let id = format!("cl:{}", c.cluster_id);
let node = graph_node_for_cluster(
&tenant_id_for_blocking,
&c.cluster_id,
c.abstraction.as_deref(),
c.created_at_ms,
);
staged.push(StagingNode {
sort_ts_ms: c.created_at_ms,
sort_id: id.clone(),
node,
});
}
}
if want_entity {
let (ents, was_cap_reached) =
fetch_entities_for_nodes(conn, since_ms, until_ms, cursor_ref)?;
cap_reached = was_cap_reached;
for e in ents {
let id = format!("ent:{}", e.value);
let mut node = graph_node_for_entity(&tenant_id_for_blocking, &e.value);
node.ts_ms = Some(e.first_seen_ms);
node.preview =
Some(format!("Referenced in {} triples", e.ref_count));
staged.push(StagingNode {
sort_ts_ms: e.first_seen_ms,
sort_id: id.clone(),
node,
});
}
}
Ok::<_, rusqlite::Error>((staged, cap_reached))
})
.await
.map_err(ApiError::from)?;
if let Some(cur) = &cursor {
staged.retain(|s| node_passes_cursor(s.sort_ts_ms, &s.sort_id, cur));
}
staged.sort_by(|a, b| {
cmp_node_sort_keys((a.sort_ts_ms, &a.sort_id), (b.sort_ts_ms, &b.sort_id))
});
let limit_us = limit as usize;
let next_cursor = if staged.len() > limit_us {
let last = &staged[limit_us - 1];
Some(NodesCursor {
ts_ms: last.sort_ts_ms,
id: last.sort_id.clone(),
})
} else {
None
};
staged.truncate(limit_us);
let next_cursor_str = match next_cursor {
Some(c) => Some(encode_cursor(&c)?),
None => None,
};
let nodes: Vec<GraphNode> = staged.into_iter().map(|s| s.node).collect();
let payload = GraphNodesResponse {
nodes,
next_cursor: next_cursor_str,
};
let mut response = Json(payload).into_response();
if cap_reached {
response
.headers_mut()
.insert(ENTITY_CAP_HEADER, HeaderValue::from_static("true"));
}
Ok(response)
}
#[derive(Debug)]
struct StagingEdge {
edge: GraphEdge,
kind_idx: u8,
sub_id: String,
}
fn cmp_edge_sort_keys(a: (u8, &str), b: (u8, &str)) -> std::cmp::Ordering {
match a.0.cmp(&b.0) {
std::cmp::Ordering::Equal => a.1.cmp(b.1),
other => other,
}
}
fn edge_passes_cursor(kind_idx: u8, sub_id: &str, cursor: &EdgesCursor) -> bool {
cmp_edge_sort_keys((kind_idx, sub_id), (cursor.kind_idx, cursor.sub_id.as_str()))
== std::cmp::Ordering::Greater
}
fn edge_touches_focus(
kind: EdgeKind,
focus_kind: NodeKind,
focus_value: &str,
src_value: &str,
tgt_value: &str,
extra_value: Option<&str>,
) -> bool {
match kind {
EdgeKind::Triple => match focus_kind {
NodeKind::Episode => src_value == focus_value,
NodeKind::Entity => {
tgt_value == focus_value
|| extra_value.map(|x| x == focus_value).unwrap_or(false)
|| src_value == focus_value
}
_ => false,
},
EdgeKind::DocumentChunk => match focus_kind {
NodeKind::Document => src_value == focus_value,
NodeKind::Chunk => tgt_value == focus_value,
_ => false,
},
EdgeKind::ClusterMember => match focus_kind {
NodeKind::Cluster => src_value == focus_value,
NodeKind::Episode => tgt_value == focus_value,
_ => false,
},
}
}
#[derive(Debug)]
struct EdgeRowTriple {
triple_id: String,
source_memory_id: Option<String>,
object_id: String,
predicate: String,
confidence: f32,
}
fn fetch_triple_edges(conn: &rusqlite::Connection) -> rusqlite::Result<Vec<EdgeRowTriple>> {
let safety_cap = (GRAPH_EDGES_MAX_LIMIT as i64) * 4;
let mut stmt = conn.prepare(
"SELECT t.triple_id, e.memory_id, t.object_id, t.predicate, t.confidence
FROM triples t
LEFT JOIN episodes e ON e.rowid = t.source_episode_id
WHERE t.status = 'active'
ORDER BY t.triple_id ASC
LIMIT ?1",
)?;
let rows: Vec<EdgeRowTriple> = stmt
.query_map(rusqlite::params![safety_cap], |r| {
Ok(EdgeRowTriple {
triple_id: r.get(0)?,
source_memory_id: r.get::<_, Option<String>>(1)?,
object_id: r.get(2)?,
predicate: r.get(3)?,
confidence: r.get(4)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
#[derive(Debug)]
struct EdgeRowDocChunk {
chunk_id: String,
doc_id: String,
}
fn fetch_document_chunk_edges(
conn: &rusqlite::Connection,
) -> rusqlite::Result<Vec<EdgeRowDocChunk>> {
let safety_cap = (GRAPH_EDGES_MAX_LIMIT as i64) * 4;
let mut stmt = conn.prepare(
"SELECT c.chunk_id, c.doc_id
FROM document_chunks c
JOIN documents d ON d.doc_id = c.doc_id
WHERE d.status = 'active'
ORDER BY c.chunk_id ASC
LIMIT ?1",
)?;
let rows: Vec<EdgeRowDocChunk> = stmt
.query_map(rusqlite::params![safety_cap], |r| {
Ok(EdgeRowDocChunk {
chunk_id: r.get(0)?,
doc_id: r.get(1)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
#[derive(Debug)]
struct EdgeRowClusterMember {
cluster_id: String,
memory_id: String,
}
fn fetch_cluster_member_edges(
conn: &rusqlite::Connection,
) -> rusqlite::Result<Vec<EdgeRowClusterMember>> {
let safety_cap = (GRAPH_EDGES_MAX_LIMIT as i64) * 4;
let mut stmt = conn.prepare(
"SELECT ce.cluster_id, ce.memory_id
FROM cluster_episodes ce
JOIN episodes e ON e.memory_id = ce.memory_id
WHERE e.status = 'active'
ORDER BY ce.cluster_id ASC, ce.memory_id ASC
LIMIT ?1",
)?;
let rows: Vec<EdgeRowClusterMember> = stmt
.query_map(rusqlite::params![safety_cap], |r| {
Ok(EdgeRowClusterMember {
cluster_id: r.get(0)?,
memory_id: r.get(1)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
async fn graph_edges_handler(
TenantExtractor(tenant): TenantExtractor,
Query(q): Query<GraphEdgesQuery>,
) -> Result<Json<GraphEdgesResponse>, ApiError> {
let limit = q.limit.unwrap_or(GRAPH_EDGES_DEFAULT_LIMIT);
let limit = limit.clamp(1, GRAPH_EDGES_MAX_LIMIT);
let kinds = parse_edge_kind_filter(q.r#type.as_deref())?;
let cursor = match q.cursor.as_deref() {
None => None,
Some("") => None,
Some(raw) => Some(decode_cursor::<EdgesCursor>(raw)?),
};
let focus = match q.node_id.as_deref() {
None => None,
Some(raw) => {
let (kind, value) = parse_node_id(raw)?;
Some((kind, value.to_string()))
}
};
let want_triple = kinds.contains(&EdgeKind::Triple);
let want_doc_chunk = kinds.contains(&EdgeKind::DocumentChunk);
let want_cluster_member = kinds.contains(&EdgeKind::ClusterMember);
let staged: Vec<StagingEdge> = tenant
.read()
.interact(move |conn| {
let mut staged: Vec<StagingEdge> = Vec::new();
if want_triple {
for t in fetch_triple_edges(conn)? {
let src_id = match &t.source_memory_id {
Some(mid) => format!("ep:{mid}"),
None => continue, };
let tgt_id = format!("ent:{}", t.object_id);
if let Some((fk, fv)) = &focus {
if !edge_touches_focus(
EdgeKind::Triple,
*fk,
fv,
t.source_memory_id
.as_deref()
.unwrap_or(""),
&t.object_id,
None,
) {
continue;
}
}
let edge = GraphEdge {
id: edge_id(&src_id, "triple", &tgt_id),
source: src_id,
target: tgt_id,
kind: "triple",
predicate: Some(t.predicate),
weight: Some(t.confidence),
};
staged.push(StagingEdge {
edge,
kind_idx: EdgeKind::Triple.order_idx(),
sub_id: t.triple_id,
});
}
}
if want_doc_chunk {
for dc in fetch_document_chunk_edges(conn)? {
let src_id = format!("doc:{}", dc.doc_id);
let tgt_id = format!("chunk:{}", dc.chunk_id);
if let Some((fk, fv)) = &focus {
if !edge_touches_focus(
EdgeKind::DocumentChunk,
*fk,
fv,
&dc.doc_id,
&dc.chunk_id,
None,
) {
continue;
}
}
let edge = GraphEdge {
id: edge_id(&src_id, "document_chunk", &tgt_id),
source: src_id,
target: tgt_id,
kind: "document_chunk",
predicate: None,
weight: None,
};
staged.push(StagingEdge {
edge,
kind_idx: EdgeKind::DocumentChunk.order_idx(),
sub_id: dc.chunk_id,
});
}
}
if want_cluster_member {
for cm in fetch_cluster_member_edges(conn)? {
let src_id = format!("cl:{}", cm.cluster_id);
let tgt_id = format!("ep:{}", cm.memory_id);
if let Some((fk, fv)) = &focus {
if !edge_touches_focus(
EdgeKind::ClusterMember,
*fk,
fv,
&cm.cluster_id,
&cm.memory_id,
None,
) {
continue;
}
}
let edge = GraphEdge {
id: edge_id(&src_id, "cluster_member", &tgt_id),
source: src_id,
target: tgt_id,
kind: "cluster_member",
predicate: None,
weight: None,
};
let sub_id = format!("{}\u{1f}{}", cm.cluster_id, cm.memory_id);
staged.push(StagingEdge {
edge,
kind_idx: EdgeKind::ClusterMember.order_idx(),
sub_id,
});
}
}
Ok::<_, rusqlite::Error>(staged)
})
.await
.map_err(ApiError::from)?;
let mut staged = staged;
if let Some(cur) = &cursor {
staged.retain(|s| edge_passes_cursor(s.kind_idx, &s.sub_id, cur));
}
staged.sort_by(|a, b| {
cmp_edge_sort_keys((a.kind_idx, &a.sub_id), (b.kind_idx, &b.sub_id))
});
let limit_us = limit as usize;
let next_cursor = if staged.len() > limit_us {
let last = &staged[limit_us - 1];
Some(EdgesCursor {
kind_idx: last.kind_idx,
sub_id: last.sub_id.clone(),
})
} else {
None
};
staged.truncate(limit_us);
let next_cursor_str = match next_cursor {
Some(c) => Some(encode_cursor(&c)?),
None => None,
};
let edges: Vec<GraphEdge> = staged.into_iter().map(|s| s.edge).collect();
Ok(Json(GraphEdgesResponse {
edges,
next_cursor: next_cursor_str,
}))
}
const GRAPH_INSPECT_ENTITY_TRIPLES_CAP: i64 = 50;
#[derive(Debug, Serialize)]
struct GraphInspectResponse {
node: GraphNode,
#[serde(skip_serializing_if = "Option::is_none")]
full_text: Option<String>,
triples_in: Vec<GraphEdge>,
triples_out: Vec<GraphEdge>,
}
async fn graph_inspect_handler(
TenantExtractor(tenant): TenantExtractor,
Path(id): Path<String>,
) -> Result<Json<GraphInspectResponse>, ApiError> {
let (kind, value) = parse_node_id(&id)?;
let tenant_id_str = tenant.tenant_id().to_string();
let value = value.to_string();
let node_id_full = id;
match kind {
NodeKind::Episode => {
inspect_episode_node(&tenant, &tenant_id_str, value, node_id_full).await
}
NodeKind::Document => {
inspect_document_node(&tenant, &tenant_id_str, value, node_id_full).await
}
NodeKind::Chunk => {
inspect_chunk_node(&tenant, &tenant_id_str, value, node_id_full).await
}
NodeKind::Cluster => {
inspect_cluster_node(&tenant, &tenant_id_str, value, node_id_full).await
}
NodeKind::Entity => {
inspect_entity_node(&tenant, &tenant_id_str, value, node_id_full).await
}
}
.map(Json)
}
async fn inspect_episode_node(
tenant: &TenantHandle,
tenant_id: &str,
memory_id: String,
node_id_full: String,
) -> Result<GraphInspectResponse, ApiError> {
let memory_id_for_err = memory_id.clone();
let memory_id_q = memory_id.clone();
let fetched: Option<(ExpandedEpisode, Vec<TripleRow>)> = tenant
.read()
.interact(move |conn| {
let ep_row: Option<(i64, i64, String)> = conn
.query_row(
"SELECT rowid, ts_ms, content
FROM episodes
WHERE memory_id = ?1
AND status = 'active'",
rusqlite::params![&memory_id_q],
|r| {
Ok((
r.get::<_, i64>(0)?,
r.get::<_, i64>(1)?,
r.get::<_, String>(2)?,
))
},
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})?;
let Some((rowid, ts_ms, content)) = ep_row else {
return Ok(None);
};
let mut stmt = conn.prepare(
"SELECT subject_id, predicate, object_id, confidence
FROM triples
WHERE source_episode_id = ?1
AND status = 'active'
ORDER BY valid_from_ms DESC",
)?;
let triples = stmt
.query_map(rusqlite::params![rowid], |r| {
Ok(TripleRow {
subject_id: r.get(0)?,
predicate: r.get(1)?,
object_id: r.get(2)?,
confidence: r.get(3)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
let ep = ExpandedEpisode {
memory_id: memory_id_q,
ts_ms,
content,
};
Ok::<_, rusqlite::Error>(Some((ep, triples)))
})
.await
.map_err(ApiError::from)?;
let (ep, triples) = fetched.ok_or_else(|| {
ApiError::not_found(format!(
"node_id {node_id_full:?} (memory_id {memory_id_for_err}) not found in current tenant"
))
})?;
let node = graph_node_for_episode(tenant_id, &ep);
let full_text = Some(ep.content.clone());
let mut triples_out = Vec::with_capacity(triples.len());
for t in triples {
let tgt_id = format!("ent:{}", t.object_id);
triples_out.push(GraphEdge {
id: edge_id(&node_id_full, "triple", &tgt_id),
source: node_id_full.clone(),
target: tgt_id,
kind: "triple",
predicate: Some(t.predicate),
weight: Some(t.confidence),
});
}
Ok(GraphInspectResponse {
node,
full_text,
triples_in: Vec::new(),
triples_out,
})
}
async fn inspect_document_node(
tenant: &TenantHandle,
tenant_id: &str,
doc_id: String,
node_id_full: String,
) -> Result<GraphInspectResponse, ApiError> {
let doc_id_for_err = doc_id.clone();
let doc_id_q = doc_id.clone();
let fetched: Option<(ExpandedDocument, Vec<String>)> = tenant
.read()
.interact(move |conn| {
let doc_row: Option<ExpandedDocument> = conn
.query_row(
"SELECT doc_id, title, source, ingested_at_ms
FROM documents
WHERE doc_id = ?1
AND status = 'active'",
rusqlite::params![&doc_id_q],
|r| {
Ok(ExpandedDocument {
doc_id: r.get(0)?,
title: r.get(1)?,
source: r.get(2)?,
ingested_at_ms: r.get(3)?,
})
},
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})?;
let Some(doc) = doc_row else {
return Ok(None);
};
let mut stmt = conn.prepare(
"SELECT content
FROM document_chunks
WHERE doc_id = ?1
ORDER BY chunk_index ASC",
)?;
let chunks = stmt
.query_map(rusqlite::params![&doc_id_q], |r| r.get::<_, String>(0))?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok::<_, rusqlite::Error>(Some((doc, chunks)))
})
.await
.map_err(ApiError::from)?;
let (doc, chunks) = fetched.ok_or_else(|| {
ApiError::not_found(format!(
"node_id {node_id_full:?} (doc_id {doc_id_for_err}) not found in current tenant"
))
})?;
let full_text = if chunks.is_empty() {
None
} else {
Some(chunks.join("\n\n"))
};
Ok(GraphInspectResponse {
node: graph_node_for_document(tenant_id, &doc),
full_text,
triples_in: Vec::new(),
triples_out: Vec::new(),
})
}
async fn inspect_chunk_node(
tenant: &TenantHandle,
tenant_id: &str,
chunk_id: String,
node_id_full: String,
) -> Result<GraphInspectResponse, ApiError> {
let chunk_id_for_err = chunk_id.clone();
let chunk_id_q = chunk_id.clone();
let row: Option<(ExpandedChunk, i64)> = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT c.chunk_id, c.chunk_index, c.content, c.created_at_ms
FROM document_chunks c
JOIN documents d ON d.doc_id = c.doc_id
WHERE c.chunk_id = ?1
AND d.status = 'active'",
rusqlite::params![&chunk_id_q],
|r| {
Ok((
ExpandedChunk {
chunk_id: r.get(0)?,
chunk_index: r.get(1)?,
content: r.get(2)?,
},
r.get::<_, i64>(3)?,
))
},
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})
})
.await
.map_err(ApiError::from)?;
let (chunk, created_at_ms) = row.ok_or_else(|| {
ApiError::not_found(format!(
"node_id {node_id_full:?} (chunk_id {chunk_id_for_err}) not found in current tenant"
))
})?;
let full_text = Some(chunk.content.clone());
let mut node = graph_node_for_chunk(tenant_id, &chunk);
node.ts_ms = Some(created_at_ms);
Ok(GraphInspectResponse {
node,
full_text,
triples_in: Vec::new(),
triples_out: Vec::new(),
})
}
async fn inspect_cluster_node(
tenant: &TenantHandle,
tenant_id: &str,
cluster_id: String,
node_id_full: String,
) -> Result<GraphInspectResponse, ApiError> {
let cluster_id_for_err = cluster_id.clone();
let cluster_id_q = cluster_id.clone();
let row: Option<(Option<String>, i64)> = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT sa.content, c.created_at_ms
FROM clusters c
LEFT JOIN semantic_abstractions sa ON sa.cluster_id = c.cluster_id
WHERE c.cluster_id = ?1",
rusqlite::params![&cluster_id_q],
|r| Ok((r.get::<_, Option<String>>(0)?, r.get::<_, i64>(1)?)),
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})
})
.await
.map_err(ApiError::from)?;
let (abstraction, created_at_ms) = row.ok_or_else(|| {
ApiError::not_found(format!(
"node_id {node_id_full:?} (cluster_id {cluster_id_for_err}) not found in current tenant"
))
})?;
let full_text = match abstraction.as_deref() {
Some(a) => Some(format!("cluster {cluster_id_for_err}\n\n{a}")),
None => Some(format!("cluster {cluster_id_for_err}")),
};
Ok(GraphInspectResponse {
node: graph_node_for_cluster(
tenant_id,
&cluster_id_for_err,
abstraction.as_deref(),
created_at_ms,
),
full_text,
triples_in: Vec::new(),
triples_out: Vec::new(),
})
}
async fn inspect_entity_node(
tenant: &TenantHandle,
tenant_id: &str,
entity_value: String,
node_id_full: String,
) -> Result<GraphInspectResponse, ApiError> {
let entity_q = entity_value.clone();
let rows: Vec<TripleRow> = tenant
.read()
.interact(move |conn| {
let mut stmt = conn.prepare(
"SELECT subject_id, predicate, object_id, confidence
FROM triples
WHERE (subject_id = ?1 OR object_id = ?1)
AND status = 'active'
ORDER BY valid_from_ms DESC
LIMIT ?2",
)?;
stmt.query_map(
rusqlite::params![&entity_q, GRAPH_INSPECT_ENTITY_TRIPLES_CAP],
|r| {
Ok(TripleRow {
subject_id: r.get(0)?,
predicate: r.get(1)?,
object_id: r.get(2)?,
confidence: r.get(3)?,
})
},
)?
.collect::<rusqlite::Result<Vec<_>>>()
})
.await
.map_err(ApiError::from)?;
if rows.is_empty() {
return Err(ApiError::not_found(format!(
"node_id {node_id_full:?} (entity {entity_value:?}) not found in current tenant -- entities must be referenced by at least one triple to be inspectable"
)));
}
let mut triples_out = Vec::with_capacity(rows.len());
for t in rows {
let other = if t.subject_id == entity_value {
t.object_id
} else {
t.subject_id
};
let tgt_id = format!("ent:{other}");
triples_out.push(GraphEdge {
id: edge_id(&node_id_full, "triple", &tgt_id),
source: node_id_full.clone(),
target: tgt_id,
kind: "triple",
predicate: Some(t.predicate),
weight: Some(t.confidence),
});
}
Ok(GraphInspectResponse {
node: graph_node_for_entity(tenant_id, &entity_value),
full_text: None,
triples_in: Vec::new(),
triples_out,
})
}
const GRAPH_NEIGHBORS_DEFAULT_LIMIT: u32 = 25;
const GRAPH_NEIGHBORS_MAX_LIMIT: u32 = 100;
const GRAPH_NEIGHBORS_DEFAULT_THRESHOLD: f32 = 0.75;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
enum GraphNeighborsKind {
Explicit,
Semantic,
#[default]
Both,
}
#[derive(Debug, Deserialize)]
struct GraphNeighborsQuery {
#[serde(default)]
kind: Option<GraphNeighborsKind>,
#[serde(default)]
threshold: Option<f32>,
#[serde(default)]
limit: Option<u32>,
}
async fn graph_neighbors_handler(
TenantExtractor(tenant): TenantExtractor,
Path(id): Path<String>,
Query(q): Query<GraphNeighborsQuery>,
) -> Result<Json<GraphExpandResponse>, ApiError> {
let kind = q.kind.unwrap_or_default();
let threshold = q.threshold.unwrap_or(GRAPH_NEIGHBORS_DEFAULT_THRESHOLD);
if !(0.0..=1.0).contains(&threshold) {
return Err(ApiError::bad_request(format!(
"threshold must be in [0.0, 1.0]; got {threshold}"
)));
}
let limit_raw = q.limit.unwrap_or(GRAPH_NEIGHBORS_DEFAULT_LIMIT);
let limit = limit_raw.clamp(1, GRAPH_NEIGHBORS_MAX_LIMIT);
let (node_kind, value) = parse_node_id(&id)?;
let value_owned = value.to_string();
let tenant_id_str = tenant.tenant_id().to_string();
let node_id_full = id;
ensure_neighbors_focal_exists(&tenant, node_kind, &value_owned, &node_id_full).await?;
let (explicit_nodes, explicit_edges) = if matches!(
kind,
GraphNeighborsKind::Explicit | GraphNeighborsKind::Both
) {
neighbors_explicit(
&tenant,
&tenant_id_str,
node_kind,
&value_owned,
&node_id_full,
limit as i64,
)
.await?
} else {
(Vec::new(), Vec::new())
};
let (semantic_nodes, semantic_edges) = if matches!(
kind,
GraphNeighborsKind::Semantic | GraphNeighborsKind::Both
) {
match neighbors_semantic(
&tenant,
&tenant_id_str,
node_kind,
&value_owned,
&node_id_full,
limit,
threshold,
)
.await
{
Ok(parts) => parts,
Err(e) => {
if matches!(kind, GraphNeighborsKind::Semantic) {
return Err(e);
}
(Vec::new(), Vec::new())
}
}
} else {
(Vec::new(), Vec::new())
};
let mut explicit_endpoints: std::collections::HashSet<(String, String)> =
std::collections::HashSet::with_capacity(explicit_edges.len());
for e in &explicit_edges {
explicit_endpoints.insert((e.source.clone(), e.target.clone()));
}
let mut nodes: Vec<GraphNode> = Vec::with_capacity(explicit_nodes.len() + semantic_nodes.len());
let mut edges: Vec<GraphEdge> =
Vec::with_capacity(explicit_edges.len() + semantic_edges.len());
let mut seen_node_ids: std::collections::HashSet<String> =
std::collections::HashSet::with_capacity(explicit_nodes.len() + semantic_nodes.len());
for n in explicit_nodes {
if seen_node_ids.insert(n.id.clone()) {
nodes.push(n);
}
}
for e in explicit_edges {
edges.push(e);
}
for n in semantic_nodes {
if seen_node_ids.insert(n.id.clone()) {
nodes.push(n);
}
}
for e in semantic_edges {
if explicit_endpoints.contains(&(e.source.clone(), e.target.clone())) {
continue;
}
edges.push(e);
}
Ok(Json(GraphExpandResponse { nodes, edges }))
}
async fn ensure_neighbors_focal_exists(
tenant: &TenantHandle,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
) -> Result<(), ApiError> {
match node_kind {
NodeKind::Episode => ensure_episode_exists(tenant, value, node_id_full).await,
NodeKind::Cluster => ensure_cluster_exists(tenant, value, node_id_full).await,
NodeKind::Document => ensure_document_exists(tenant, value, node_id_full).await,
NodeKind::Chunk => ensure_chunk_exists(tenant, value, node_id_full).await,
NodeKind::Entity => ensure_entity_referenced(tenant, value, node_id_full).await,
}
}
async fn ensure_chunk_exists(
tenant: &TenantHandle,
chunk_id: &str,
node_id_full: &str,
) -> Result<(), ApiError> {
let chunk_id_q = chunk_id.to_string();
let exists: i64 = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*)
FROM document_chunks c
JOIN documents d ON d.doc_id = c.doc_id
WHERE c.chunk_id = ?1
AND d.status = 'active'",
rusqlite::params![&chunk_id_q],
|r| r.get(0),
)
})
.await
.map_err(ApiError::from)?;
if exists == 0 {
return Err(ApiError::not_found(format!(
"node_id {node_id_full:?} not found in current tenant"
)));
}
Ok(())
}
async fn ensure_entity_referenced(
tenant: &TenantHandle,
entity_value: &str,
node_id_full: &str,
) -> Result<(), ApiError> {
let entity_q = entity_value.to_string();
let exists: i64 = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*)
FROM triples
WHERE (subject_id = ?1 OR object_id = ?1)
AND status = 'active'",
rusqlite::params![&entity_q],
|r| r.get(0),
)
})
.await
.map_err(ApiError::from)?;
if exists == 0 {
return Err(ApiError::not_found(format!(
"node_id {node_id_full:?} (entity {entity_value:?}) not found in current tenant -- entities must be referenced by at least one triple to be neighborable"
)));
}
Ok(())
}
async fn neighbors_explicit(
tenant: &TenantHandle,
tenant_id: &str,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
limit: i64,
) -> Result<(Vec<GraphNode>, Vec<GraphEdge>), ApiError> {
let mut nodes: Vec<GraphNode> = Vec::new();
let mut edges: Vec<GraphEdge> = Vec::new();
match node_kind {
NodeKind::Episode => {
let r1 = expand_cluster_member(tenant, tenant_id, node_kind, value, node_id_full, limit)
.await?;
nodes.extend(r1.nodes);
edges.extend(r1.edges);
let r2 =
expand_triple(tenant, tenant_id, node_kind, value, node_id_full, limit).await?;
nodes.extend(r2.nodes);
edges.extend(r2.edges);
}
NodeKind::Document => {
let r = expand_document_chunk(tenant, tenant_id, node_kind, value, node_id_full, limit)
.await?;
nodes.extend(r.nodes);
edges.extend(r.edges);
}
NodeKind::Chunk => {
let r = expand_document_chunk(tenant, tenant_id, node_kind, value, node_id_full, limit)
.await?;
nodes.extend(r.nodes);
edges.extend(r.edges);
}
NodeKind::Cluster => {
let r = expand_cluster_member(tenant, tenant_id, node_kind, value, node_id_full, limit)
.await?;
nodes.extend(r.nodes);
edges.extend(r.edges);
}
NodeKind::Entity => {
let r =
expand_triple(tenant, tenant_id, node_kind, value, node_id_full, limit).await?;
nodes.extend(r.nodes);
edges.extend(r.edges);
}
}
Ok((nodes, edges))
}
async fn neighbors_semantic(
tenant: &TenantHandle,
tenant_id: &str,
node_kind: NodeKind,
value: &str,
node_id_full: &str,
limit: u32,
threshold: f32,
) -> Result<(Vec<GraphNode>, Vec<GraphEdge>), ApiError> {
match node_kind {
NodeKind::Episode => {
neighbors_semantic_from_episode(
tenant,
tenant_id,
value,
node_id_full,
limit,
threshold,
)
.await
}
NodeKind::Chunk => {
neighbors_semantic_from_chunk(
tenant,
tenant_id,
value,
node_id_full,
limit,
threshold,
)
.await
}
_ => Err(ApiError::bad_request(format!(
"semantic neighbors only valid for episode or chunk source; got {}",
node_kind.as_wire_str()
))),
}
}
async fn neighbors_semantic_from_episode(
tenant: &TenantHandle,
tenant_id: &str,
memory_id: &str,
node_id_full: &str,
limit: u32,
threshold: f32,
) -> Result<(Vec<GraphNode>, Vec<GraphEdge>), ApiError> {
let memory_id_q = memory_id.to_string();
let memory_id_for_self_excl = memory_id.to_string();
let content: Option<String> = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT content FROM episodes WHERE memory_id = ?1 AND status = 'active'",
rusqlite::params![&memory_id_q],
|r| r.get::<_, String>(0),
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})
})
.await
.map_err(ApiError::from)?;
let Some(content) = content else {
return Ok((Vec::new(), Vec::new()));
};
let widened = (limit as usize).saturating_add(1).min(100);
let result = solo_query::recall::run_recall_inner(
tenant.embedder(),
tenant.hnsw(),
tenant.read(),
&content,
widened,
)
.await
.map_err(ApiError::from)?;
let mut nodes = Vec::new();
let mut edges = Vec::new();
for hit in result.hits.into_iter() {
if hit.memory_id == memory_id_for_self_excl {
continue;
}
if nodes.len() as u32 >= limit {
break;
}
let weight = (1.0 - hit.cos_distance).max(0.0);
if weight < threshold {
continue;
}
let target_id = format!("ep:{}", hit.memory_id);
edges.push(GraphEdge {
id: edge_id(node_id_full, "semantic", &target_id),
source: node_id_full.to_string(),
target: target_id,
kind: "semantic",
predicate: None,
weight: Some(weight),
});
nodes.push(GraphNode {
id: format!("ep:{}", hit.memory_id),
kind: NodeKind::Episode.as_wire_str(),
label: episode_label(&hit.content),
ts_ms: None,
tenant_id: tenant_id.to_string(),
preview: Some(truncate_preview(&hit.content, GRAPH_PREVIEW_CHARS)),
});
}
Ok((nodes, edges))
}
async fn neighbors_semantic_from_chunk(
tenant: &TenantHandle,
tenant_id: &str,
chunk_id: &str,
node_id_full: &str,
limit: u32,
threshold: f32,
) -> Result<(Vec<GraphNode>, Vec<GraphEdge>), ApiError> {
let chunk_id_q = chunk_id.to_string();
let chunk_id_for_self_excl = chunk_id.to_string();
let content: Option<String> = tenant
.read()
.interact(move |conn| {
conn.query_row(
"SELECT c.content
FROM document_chunks c
JOIN documents d ON d.doc_id = c.doc_id
WHERE c.chunk_id = ?1
AND d.status = 'active'",
rusqlite::params![&chunk_id_q],
|r| r.get::<_, String>(0),
)
.map(Some)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(None),
other => Err(other),
})
})
.await
.map_err(ApiError::from)?;
let Some(content) = content else {
return Ok((Vec::new(), Vec::new()));
};
let widened = (limit as usize).saturating_add(1).min(100);
let hits = solo_query::doc_search::run_doc_search_inner(
tenant.embedder(),
tenant.hnsw(),
tenant.read(),
&content,
widened,
)
.await
.map_err(ApiError::from)?;
let mut nodes = Vec::new();
let mut edges = Vec::new();
for hit in hits.into_iter() {
if hit.chunk_id == chunk_id_for_self_excl {
continue;
}
if nodes.len() as u32 >= limit {
break;
}
let weight = (1.0 - hit.cos_distance).max(0.0);
if weight < threshold {
continue;
}
let target_id = format!("chunk:{}", hit.chunk_id);
edges.push(GraphEdge {
id: edge_id(node_id_full, "semantic", &target_id),
source: node_id_full.to_string(),
target: target_id,
kind: "semantic",
predicate: None,
weight: Some(weight),
});
let exp = ExpandedChunk {
chunk_id: hit.chunk_id.clone(),
chunk_index: hit.chunk_index as i64,
content: hit.content.clone(),
};
nodes.push(graph_node_for_chunk(tenant_id, &exp));
}
Ok((nodes, edges))
}
pub const STREAM_HEARTBEAT_SECS: u64 = 30;
const STREAM_EVENT_INIT: &str = "init";
const STREAM_EVENT_INVALIDATE: &str = "invalidate";
const STREAM_EVENT_HEARTBEAT: &str = "heartbeat";
async fn graph_stream_handler(
TenantExtractor(tenant): TenantExtractor,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let rx = tenant.invalidate_sender().subscribe();
let tenant_id = tenant.tenant_id().to_string();
let stream = build_invalidate_stream(rx, tenant_id, STREAM_HEARTBEAT_SECS);
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(3600)))
}
struct StreamState {
rx: broadcast::Receiver<InvalidateEvent>,
heartbeat: tokio::time::Interval,
tenant_id: String,
needs_init: bool,
}
fn build_invalidate_stream(
rx: broadcast::Receiver<InvalidateEvent>,
tenant_id: String,
heartbeat_secs: u64,
) -> impl Stream<Item = Result<Event, Infallible>> {
let start_at = tokio::time::Instant::now() + Duration::from_secs(heartbeat_secs);
let heartbeat =
tokio::time::interval_at(start_at, Duration::from_secs(heartbeat_secs));
let state = StreamState {
rx,
heartbeat,
tenant_id,
needs_init: true,
};
futures::stream::unfold(state, move |mut state| async move {
if state.needs_init {
state.needs_init = false;
let init_payload = serde_json::json!({
"connected": true,
"tenant_id": state.tenant_id,
"ts_ms": chrono::Utc::now().timestamp_millis(),
});
let ev = Event::default()
.event(STREAM_EVENT_INIT)
.json_data(init_payload)
.unwrap_or_else(|_| Event::default().event(STREAM_EVENT_INIT));
return Some((Ok::<Event, Infallible>(ev), state));
}
loop {
tokio::select! {
event = state.rx.recv() => {
match event {
Ok(ev) => {
let sse_event = Event::default()
.event(STREAM_EVENT_INVALIDATE)
.json_data(&ev)
.unwrap_or_else(|_| Event::default()
.event(STREAM_EVENT_INVALIDATE));
return Some((Ok::<Event, Infallible>(sse_event), state));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(
lagged = n,
"graph stream subscriber lagged; client will \
resync on the next real invalidate"
);
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!(
"graph stream broadcast closed; ending SSE stream"
);
return None;
}
}
}
_ = state.heartbeat.tick() => {
let hb_payload = serde_json::json!({
"ts_ms": chrono::Utc::now().timestamp_millis(),
});
let sse_event = Event::default()
.event(STREAM_EVENT_HEARTBEAT)
.json_data(hb_payload)
.unwrap_or_else(|_| Event::default()
.event(STREAM_EVENT_HEARTBEAT));
return Some((Ok::<Event, Infallible>(sse_event), state));
}
}
}
})
}
#[derive(Debug, Clone, Serialize)]
struct TenantListItem {
id: String,
#[serde(skip_serializing_if = "Option::is_none")]
display_name: Option<String>,
created_at_ms: i64,
#[serde(skip_serializing_if = "Option::is_none")]
last_accessed_ms: Option<i64>,
status: TenantStatusJson,
#[serde(skip_serializing_if = "Option::is_none")]
quota_bytes: Option<u64>,
episode_count: Option<i64>,
size_bytes: Option<u64>,
pct_used: Option<f64>,
}
#[derive(Debug, Clone, Copy, Serialize)]
#[serde(rename_all = "snake_case")]
enum TenantStatusJson {
Active,
}
impl From<&solo_storage::TenantStatus> for TenantStatusJson {
fn from(s: &solo_storage::TenantStatus) -> Self {
match s {
solo_storage::TenantStatus::Active => TenantStatusJson::Active,
solo_storage::TenantStatus::PendingMigration
| solo_storage::TenantStatus::PendingDelete => TenantStatusJson::Active,
}
}
}
#[derive(Debug, Serialize)]
struct TenantsListResponse {
tenants: Vec<TenantListItem>,
}
const TENANTS_COUNT_HYDRATION_CAP: usize = 50;
const X_SOLO_TENANTS_COUNT_CAP_HEADER: &str = "x-solo-tenants-count-cap-reached";
async fn tenants_list_handler(
State(state): State<SoloHttpState>,
MaybePrincipal(maybe_principal): MaybePrincipal,
) -> Result<Response, ApiError> {
let mut records = state.registry.list_active().await.map_err(ApiError::from)?;
records.retain(|r| matches!(r.status, solo_storage::TenantStatus::Active));
let filtered = filter_tenants_for_principal(records, maybe_principal.as_ref());
let cap = TENANTS_COUNT_HYDRATION_CAP;
let costs = state
.registry
.hydrate_tenant_cost_numbers(&filtered, cap)
.await;
let cap_reached = filtered.len() > cap;
let tenants: Vec<TenantListItem> = filtered
.iter()
.zip(costs.iter())
.map(|(r, cost)| {
let pct_used = match (cost.size_bytes, r.quota_bytes) {
(Some(size), Some(quota)) if quota > 0 => {
let raw = (size as f64) * 100.0 / (quota as f64);
Some(raw.min(100.0))
}
_ => None,
};
TenantListItem {
id: r.tenant_id.to_string(),
display_name: r.display_name.clone(),
created_at_ms: r.created_at_ms,
last_accessed_ms: r.last_accessed_ms,
status: TenantStatusJson::from(&r.status),
quota_bytes: r.quota_bytes,
episode_count: cost.episode_count,
size_bytes: cost.size_bytes,
pct_used,
}
})
.collect();
let body = Json(TenantsListResponse { tenants });
if cap_reached {
let mut resp = body.into_response();
resp.headers_mut().insert(
axum::http::HeaderName::from_static(X_SOLO_TENANTS_COUNT_CAP_HEADER),
axum::http::HeaderValue::from_static("true"),
);
Ok(resp)
} else {
Ok(body.into_response())
}
}
fn filter_tenants_for_principal(
records: Vec<solo_storage::TenantRecord>,
principal: Option<&AuthenticatedPrincipal>,
) -> Vec<solo_storage::TenantRecord> {
let Some(p) = principal else {
return records;
};
if is_single_principal_bearer(p) {
return records;
}
let Some(claim) = p.tenant_claim.as_ref() else {
return Vec::new();
};
records
.into_iter()
.filter(|r| r.tenant_id == *claim)
.collect()
}
fn is_single_principal_bearer(principal: &AuthenticatedPrincipal) -> bool {
principal.subject == "bearer"
&& principal.claims.is_null()
&& principal.scopes.is_empty()
}
async fn mcp_http_post_handler(
TenantExtractor(tenant): TenantExtractor,
State(state): State<SoloHttpState>,
AuditPrincipal(principal): AuditPrincipal,
request: axum::extract::Request,
) -> Response {
let existing_session_id: Option<crate::mcp_session::SessionId> = request
.extensions()
.get::<crate::mcp_session::SessionId>()
.cloned();
let principal_full = request
.extensions()
.get::<crate::auth::AuthenticatedPrincipal>()
.cloned();
let body_bytes = match axum::body::to_bytes(
request.into_body(),
8 * 1024 * 1024,
)
.await
{
Ok(b) => b,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!("invalid request body: {e}"),
"status": 400,
})),
)
.into_response();
}
};
let request: crate::mcp_dispatch::JsonRpcRequest = match serde_json::from_slice(&body_bytes) {
Ok(r) => r,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!("invalid JSON-RPC request: {e}"),
"status": 400,
})),
)
.into_response();
}
};
if request.jsonrpc != "2.0" {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!(
"invalid JSON-RPC request: expected jsonrpc=\"2.0\", got {:?}",
request.jsonrpc
),
"status": 400,
})),
)
.into_response();
}
let (session_id, freshly_assigned) = match existing_session_id {
Some(id) => (id, false),
None => {
let new_state = crate::mcp_session::SessionState::new(
tenant.tenant_id().clone(),
principal_full,
);
let id = state.mcp_sessions.insert(new_state);
(id, true)
}
};
let session_state: Option<std::sync::Arc<crate::mcp_session::SessionState>> =
state.mcp_sessions.get(&session_id);
if freshly_assigned
&& let Some(session_state_for_bridge) = session_state.clone()
{
drop(crate::mcp_notify::spawn_invalidate_bridge(
tenant.clone(),
session_state_for_bridge,
));
}
let dispatcher = crate::mcp_dispatch::McpDispatcher::new(
state.registry.clone(),
tenant,
(*state.user_aliases).clone(),
principal,
);
let mut response = match dispatcher.dispatch(request, session_state).await {
Some(response) => {
(StatusCode::OK, Json(response)).into_response()
}
None => {
StatusCode::ACCEPTED.into_response()
}
};
crate::mcp_session::set_session_id_header(response.headers_mut(), &session_id);
if freshly_assigned {
tracing::debug!(
session_id = %session_id,
"mcp-http: assigned new session id"
);
}
response
}
pub const MCP_STREAM_HEARTBEAT_SECS: u64 = 30;
async fn mcp_http_get_handler(
TenantExtractor(tenant): TenantExtractor,
State(state): State<SoloHttpState>,
AuditPrincipal(principal): AuditPrincipal,
request: axum::extract::Request,
) -> Response {
let _ = principal; let _ = state;
let session_id = match request.extensions().get::<crate::mcp_session::SessionId>() {
Some(id) => id.clone(),
None => {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": crate::mcp_session::MCP_SESSION_EXPIRED_ERROR,
"status": 404,
"message": "GET /mcp requires an `Mcp-Session-Id` header \
from a prior POST /mcp; open one first",
"retry": "re-initialize",
})),
)
.into_response();
}
};
let session_state = match request.extensions().get::<std::sync::Arc<crate::mcp_session::SessionState>>() {
Some(state) => state.clone(),
None => {
tracing::error!(
"mcp_http_get_handler: SessionId extension present but \
SessionState extension missing — middleware bug"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let last_event_id: u64 = request
.headers()
.get(crate::mcp_session::MCP_LAST_EVENT_ID_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.unwrap_or(0);
let tenant_id = tenant.tenant_id().to_string();
let stream = build_mcp_session_stream(
session_state,
session_id.clone(),
tenant_id,
last_event_id,
MCP_STREAM_HEARTBEAT_SECS,
);
let sse = Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(3600)));
let mut response = sse.into_response();
crate::mcp_session::set_session_id_header(response.headers_mut(), &session_id);
response
}
struct McpStreamState {
rx: broadcast::Receiver<crate::mcp_session::McpStreamEvent>,
heartbeat: tokio::time::Interval,
replay_queue: Vec<crate::mcp_session::McpStreamEvent>,
last_emitted_id: Option<u64>,
needs_init: bool,
session_id_str: String,
tenant_id: String,
_session_state: std::sync::Arc<crate::mcp_session::SessionState>,
}
fn build_mcp_session_stream(
session_state: std::sync::Arc<crate::mcp_session::SessionState>,
session_id: crate::mcp_session::SessionId,
tenant_id: String,
last_event_id: u64,
heartbeat_secs: u64,
) -> impl Stream<Item = Result<Event, Infallible>> {
let rx = session_state.subscribe_events();
let snapshot = session_state.snapshot_replay_buffer();
let mut replay_queue: Vec<crate::mcp_session::McpStreamEvent> = Vec::new();
if last_event_id > 0 {
let oldest_in_buffer = snapshot.first().map(|e| e.id);
let newest_in_buffer = snapshot.last().map(|e| e.id);
if let (Some(oldest), Some(newest)) = (oldest_in_buffer, newest_in_buffer) {
if last_event_id + 1 < oldest {
let dropped = oldest.saturating_sub(last_event_id + 1);
replay_queue.push(crate::mcp_session::McpStreamEvent {
id: 0,
event: crate::mcp_session::McpEventKind::Lagged,
data: serde_json::json!({
"dropped": dropped,
"last_event_id": last_event_id,
"oldest_available": oldest,
}),
});
replay_queue.extend(snapshot);
} else if last_event_id >= newest {
} else {
replay_queue.extend(
snapshot
.into_iter()
.filter(|e| e.id > last_event_id),
);
}
}
}
let start_at = tokio::time::Instant::now() + Duration::from_secs(heartbeat_secs);
let heartbeat =
tokio::time::interval_at(start_at, Duration::from_secs(heartbeat_secs));
let stream_state = McpStreamState {
rx,
heartbeat,
replay_queue,
last_emitted_id: None,
needs_init: true,
session_id_str: session_id.to_string(),
tenant_id,
_session_state: session_state,
};
futures::stream::unfold(stream_state, move |mut state| async move {
if state.needs_init {
state.needs_init = false;
let init_payload = serde_json::json!({
"connected": true,
"session_id": state.session_id_str,
"tenant_id": state.tenant_id,
"ts_ms": chrono::Utc::now().timestamp_millis(),
});
let ev = build_mcp_sse_event(
0,
crate::mcp_session::McpEventKind::Init,
&init_payload,
);
return Some((Ok::<Event, Infallible>(ev), state));
}
if !state.replay_queue.is_empty() {
let entry = state.replay_queue.remove(0);
if entry.event != crate::mcp_session::McpEventKind::Lagged {
state.last_emitted_id = Some(entry.id);
}
let ev = build_mcp_sse_event(entry.id, entry.event, &entry.data);
return Some((Ok::<Event, Infallible>(ev), state));
}
loop {
tokio::select! {
event = state.rx.recv() => {
match event {
Ok(ev) => {
if let Some(last) = state.last_emitted_id
&& ev.id <= last
{
continue;
}
state.last_emitted_id = Some(ev.id);
let sse = build_mcp_sse_event(ev.id, ev.event, &ev.data);
return Some((Ok::<Event, Infallible>(sse), state));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(
lagged = n,
session_id = %state.session_id_str,
"mcp GET stream subscriber lagged"
);
let lagged_payload = serde_json::json!({
"dropped": n,
});
let sse = build_mcp_sse_event(
0,
crate::mcp_session::McpEventKind::Lagged,
&lagged_payload,
);
return Some((Ok::<Event, Infallible>(sse), state));
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!(
session_id = %state.session_id_str,
"mcp GET stream broadcast closed; ending SSE stream"
);
return None;
}
}
}
_ = state.heartbeat.tick() => {
let hb_payload = serde_json::json!({
"ts_ms": chrono::Utc::now().timestamp_millis(),
});
let sse = build_mcp_sse_event(
0,
crate::mcp_session::McpEventKind::Heartbeat,
&hb_payload,
);
return Some((Ok::<Event, Infallible>(sse), state));
}
}
}
})
}
fn build_mcp_sse_event(
id: u64,
kind: crate::mcp_session::McpEventKind,
data: &serde_json::Value,
) -> Event {
Event::default()
.id(id.to_string())
.event(kind.as_str())
.json_data(data)
.unwrap_or_else(|_| Event::default().id(id.to_string()).event(kind.as_str()))
}
#[derive(Debug)]
pub struct ApiError {
status: StatusCode,
message: String,
}
impl ApiError {
fn bad_request(msg: impl Into<String>) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: msg.into(),
}
}
fn not_found(msg: impl Into<String>) -> Self {
Self {
status: StatusCode::NOT_FOUND,
message: msg.into(),
}
}
fn internal(msg: impl Into<String>) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: msg.into(),
}
}
}
impl From<solo_core::Error> for ApiError {
fn from(e: solo_core::Error) -> Self {
use solo_core::Error;
match e {
Error::NotFound(msg) => ApiError::not_found(msg),
Error::InvalidInput(msg) => ApiError::bad_request(msg),
Error::Conflict(msg) => Self {
status: StatusCode::CONFLICT,
message: msg,
},
other => ApiError::internal(other.to_string()),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let body = serde_json::json!({
"error": self.message,
"status": self.status.as_u16(),
});
(self.status, Json(body)).into_response()
}
}
#[cfg(test)]
mod handler_tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use http_body_util::BodyExt;
use serde_json::{Value, json};
use solo_storage::test_support::StubVectorIndex;
use solo_storage::{
EmbedderConfig, IdentityConfig, KeyMaterial, ReaderPool, SoloConfig,
StubEmbedder, TenantHandle, TenantRegistry, WriterActor, WriterSpawn,
};
use solo_core::VectorIndex;
use std::sync::Arc as StdArc;
use tower::ServiceExt;
fn fake_config(dim: u32) -> SoloConfig {
SoloConfig {
schema_version: 1,
salt_hex: "00000000000000000000000000000000".to_string(),
embedder: EmbedderConfig {
name: "stub".to_string(),
version: "v1".to_string(),
dim,
dtype: "f32".to_string(),
},
identity: IdentityConfig::default(),
documents: solo_storage::DocumentConfig::default(),
auth: None,
audit: solo_storage::AuditSettings::default(),
redaction: solo_storage::RedactionConfig::default(),
llm: None,
triples: solo_storage::TriplesConfig::default(),
sampling: solo_storage::SamplingConfig::default(),
}
}
struct Harness {
router: axum::Router,
_tmp: tempfile::TempDir,
db_path: std::path::PathBuf,
write_handle_extra: Option<solo_storage::WriteHandle>,
join: Option<std::thread::JoinHandle<()>>,
tenant_handle: StdArc<TenantHandle>,
registry: StdArc<TenantRegistry>,
mcp_sessions: crate::mcp_session::SessionStore,
}
impl Harness {
fn invalidate_sender(&self) -> tokio::sync::broadcast::Sender<InvalidateEvent> {
self.tenant_handle.invalidate_sender().clone()
}
}
impl Harness {
fn new(runtime: &tokio::runtime::Runtime) -> Self {
Self::new_with_auth(runtime, None)
}
fn open_db(&self) -> rusqlite::Connection {
solo_storage::test_support::open_test_db_at(&self.db_path)
}
fn new_with_auth(
runtime: &tokio::runtime::Runtime,
bearer_token: Option<String>,
) -> Self {
Self::new_with_auth_config(
runtime,
bearer_token.map(|token| crate::auth::AuthConfig::Bearer { token }),
)
}
fn new_with_auth_config(
runtime: &tokio::runtime::Runtime,
auth: Option<crate::auth::AuthConfig>,
) -> Self {
use solo_storage::embedder_registry::{EmbedderIdentity, get_or_insert_embedder_id};
let tmp = tempfile::TempDir::new().unwrap();
let dim = 16usize;
let hnsw: StdArc<dyn VectorIndex + Send + Sync> = StdArc::new(StubVectorIndex::new(dim));
let embedder: StdArc<dyn solo_core::Embedder> =
StdArc::new(StubEmbedder::new("stub", "v1", dim));
let path = tmp.path().join("test.db");
let embedder_id = {
let conn = solo_storage::test_support::open_test_db_at(&path);
get_or_insert_embedder_id(
&conn,
&EmbedderIdentity {
name: "stub".into(),
version: "v1".into(),
dim: dim as u32,
dtype: "f32".into(),
},
)
.unwrap()
};
let conn = solo_storage::test_support::open_test_db_at(&path);
let WriterSpawn { handle, join } = WriterActor::spawn_full(
conn,
hnsw.clone(),
tmp.path().to_path_buf(),
embedder_id,
);
let pool: ReaderPool =
runtime.block_on(async { ReaderPool::new(&path, None, hnsw.clone()).unwrap() });
let tenant_id = solo_core::TenantId::default_tenant();
let tenant_handle = StdArc::new(
TenantHandle::from_parts_for_tests(
tenant_id.clone(),
fake_config(dim as u32),
path.clone(),
tmp.path().to_path_buf(),
embedder_id,
hnsw,
embedder.clone(),
handle.clone(),
std::thread::spawn(|| {}),
pool,
),
);
let tenant_handle_clone = tenant_handle.clone();
let key = KeyMaterial::from_bytes_for_tests([0u8; 32]);
let registry = StdArc::new(TenantRegistry::for_tests_with_single_tenant(
tmp.path().to_path_buf(),
key,
embedder,
tenant_handle,
));
let registry_clone = registry.clone();
let mcp_sessions = runtime
.block_on(async { crate::mcp_session::SessionStore::new() });
let mcp_sessions_clone = mcp_sessions.clone();
let state = SoloHttpState {
registry,
default_tenant: tenant_id,
user_aliases: Arc::new(Vec::new()),
mcp_sessions,
};
let router = router_with_auth_config(state, auth);
Harness {
router,
_tmp: tmp,
db_path: path,
write_handle_extra: Some(handle),
join: Some(join),
tenant_handle: tenant_handle_clone,
registry: registry_clone,
mcp_sessions: mcp_sessions_clone,
}
}
fn shutdown(mut self, runtime: &tokio::runtime::Runtime) {
let join = self.join.take();
let extra = self.write_handle_extra.take();
let tenant_handle = self.tenant_handle;
let registry = self.registry;
runtime.block_on(async move {
drop(extra);
drop(tenant_handle); drop(registry); drop(self.router); drop(self._tmp);
if let Some(join) = join {
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let _ = tx.send(join.join());
});
tokio::task::spawn_blocking(move || {
rx.recv_timeout(std::time::Duration::from_secs(5))
})
.await
.expect("blocking task")
.expect("writer thread did not exit within 5s")
.expect("writer thread panicked");
}
});
}
}
fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.unwrap()
}
async fn call(
router: axum::Router,
method: &str,
uri: &str,
body: Option<Value>,
) -> (StatusCode, Value) {
call_with_auth(router, method, uri, body, None).await
}
async fn call_with_auth(
router: axum::Router,
method: &str,
uri: &str,
body: Option<Value>,
auth: Option<&str>,
) -> (StatusCode, Value) {
let mut req_builder = Request::builder()
.method(method)
.uri(uri)
.header("content-type", "application/json");
if let Some(a) = auth {
req_builder = req_builder.header("authorization", a);
}
let req = if let Some(b) = body {
let bytes = serde_json::to_vec(&b).unwrap();
req_builder.body(Body::from(bytes)).unwrap()
} else {
req_builder = req_builder.header("content-length", "0");
req_builder.body(Body::empty()).unwrap()
};
let resp = router.oneshot(req).await.expect("oneshot");
let status = resp.status();
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = if body_bytes.is_empty() {
Value::Null
} else {
serde_json::from_slice(&body_bytes).unwrap_or(Value::Null)
};
(status, v)
}
#[test]
fn health_returns_ok() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(r, "GET", "/health", None));
assert_eq!(status, StatusCode::OK);
h.shutdown(&runtime);
}
#[test]
fn openapi_json_describes_all_endpoints() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, spec) = runtime.block_on(call(r, "GET", "/openapi.json", None));
assert_eq!(status, StatusCode::OK);
assert!(spec.is_object(), "openapi.json must be a JSON object");
assert!(
spec.get("openapi")
.and_then(|v| v.as_str())
.is_some_and(|s| s.starts_with("3.")),
"missing or wrong openapi version: {spec}"
);
assert!(spec.pointer("/info/title").is_some());
assert!(spec.pointer("/info/version").is_some());
let paths = spec
.get("paths")
.and_then(|v| v.as_object())
.expect("paths must be an object");
for expected in [
"/health",
"/openapi.json",
"/memory",
"/memory/search",
"/memory/consolidate",
"/memory/{id}",
"/memory/themes",
"/memory/facts_about",
"/memory/contradictions",
"/memory/clusters/{cluster_id}",
"/memory/documents",
"/memory/documents/search",
"/memory/documents/{id}",
] {
assert!(
paths.contains_key(expected),
"openapi paths missing {expected}: {paths:?}"
);
}
let docs = paths.get("/memory/documents").expect("/memory/documents");
assert!(docs.get("post").is_some(), "POST /memory/documents undocumented");
assert!(docs.get("get").is_some(), "GET /memory/documents undocumented");
let docid = paths
.get("/memory/documents/{id}")
.expect("/memory/documents/{id}");
assert!(
docid.get("get").is_some(),
"GET /memory/documents/{{id}} undocumented"
);
assert!(
docid.get("delete").is_some(),
"DELETE /memory/documents/{{id}} undocumented"
);
let memid = paths.get("/memory/{id}").expect("memory/{id}");
assert!(memid.get("get").is_some(), "GET /memory/{{id}} undocumented");
assert!(
memid.get("delete").is_some(),
"DELETE /memory/{{id}} undocumented"
);
for schema_name in [
"RememberRequest",
"RememberResponse",
"RecallRequest",
"RecallResult",
"EpisodeRecord",
"ApiError",
"ConsolidationScope",
"ConsolidationReport",
"ThemeHit",
"FactHit",
"ContradictionHit",
"ClusterRecord",
"IngestDocumentRequest",
"IngestReport",
"ForgetDocumentReport",
"SearchDocsRequest",
"DocSearchHit",
"DocumentInspectResult",
"DocumentSummary",
] {
let ptr = format!("/components/schemas/{schema_name}");
assert!(
spec.pointer(&ptr).is_some(),
"component schema {schema_name} missing"
);
}
assert!(
spec.pointer("/components/securitySchemes/bearerAuth")
.is_some(),
"bearerAuth security scheme missing"
);
h.shutdown(&runtime);
}
#[test]
fn openapi_json_is_exempt_from_bearer_auth() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("super-secret".into()));
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(r, "GET", "/openapi.json", None));
assert_eq!(status, StatusCode::OK);
h.shutdown(&runtime);
}
#[test]
fn remember_returns_memory_id() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"POST",
"/memory",
Some(json!({ "content": "http harness test" })),
));
assert_eq!(status, StatusCode::OK);
let mid = body.get("memory_id").and_then(|v| v.as_str()).unwrap();
assert_eq!(mid.len(), 36, "uuid length");
h.shutdown(&runtime);
}
#[test]
fn empty_content_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) =
runtime.block_on(call(r, "POST", "/memory", Some(json!({ "content": "" }))));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body.get("error")
.and_then(|e| e.as_str())
.map(|s| s.contains("must not be empty"))
.unwrap_or(false),
"got: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn empty_query_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"POST",
"/memory/search",
Some(json!({ "query": "" })),
));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body.get("error")
.and_then(|e| e.as_str())
.map(|s| s.contains("must not be empty"))
.unwrap_or(false),
"got: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn inspect_unknown_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/00000000-0000-7000-8000-000000000000",
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
assert!(body.get("error").is_some(), "got: {body}");
h.shutdown(&runtime);
}
#[test]
fn inspect_invalid_id_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(r, "GET", "/memory/not-a-uuid", None));
assert_eq!(status, StatusCode::BAD_REQUEST);
h.shutdown(&runtime);
}
#[test]
fn forget_unknown_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(
r,
"DELETE",
"/memory/00000000-0000-7000-8000-000000000000",
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn consolidate_endpoint_returns_report() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let (status, body) = call(r.clone(), "POST", "/memory/consolidate", None).await;
assert_eq!(status, StatusCode::OK);
for field in [
"episodes_seen",
"clusters_built",
"episodes_clustered",
"abstractions_built",
"triples_built",
"contradictions_found",
] {
assert!(
body.get(field).and_then(|v| v.as_u64()).is_some(),
"missing field {field}: {body}"
);
}
assert_eq!(body["episodes_seen"], 0);
assert_eq!(body["clusters_built"], 0);
let (status2, _body2) = call(
r,
"POST",
"/memory/consolidate",
Some(json!({ "window_days": 7 })),
)
.await;
assert_eq!(status2, StatusCode::OK);
});
h.shutdown(&runtime);
}
#[test]
fn auth_required_routes_reject_missing_token() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("secret-xyz".into()));
let r = h.router.clone();
runtime.block_on(async move {
let (status, _body) = call(
r.clone(),
"POST",
"/memory",
Some(json!({ "content": "x" })),
)
.await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, _body) = call_with_auth(
r.clone(),
"POST",
"/memory",
Some(json!({ "content": "x" })),
Some("Bearer wrong-token"),
)
.await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, body) = call_with_auth(
r.clone(),
"POST",
"/memory",
Some(json!({ "content": "authed" })),
Some("Bearer secret-xyz"),
)
.await;
assert_eq!(status, StatusCode::OK);
assert!(body.get("memory_id").is_some());
});
h.shutdown(&runtime);
}
#[test]
fn health_endpoint_does_not_require_auth() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("secret".into()));
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(r, "GET", "/health", None));
assert_eq!(status, StatusCode::OK);
h.shutdown(&runtime);
}
#[test]
fn auth_response_includes_www_authenticate_header() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("secret".into()));
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("POST")
.uri("/memory")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&json!({ "content": "x" })).unwrap()))
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let www = resp
.headers()
.get("www-authenticate")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
www.starts_with("Bearer"),
"expected WWW-Authenticate: Bearer..., got: {www}"
);
});
h.shutdown(&runtime);
}
fn base64_url_for_test(bytes: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
async fn spin_fake_idp() -> (wiremock::MockServer, String, Vec<u8>, &'static str) {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
let secret = b"http-test-secret-for-hmac-fixture".to_vec();
let kid = "http-test-kid";
let discovery = serde_json::json!({
"issuer": server.uri(),
"jwks_uri": format!("{}/jwks", server.uri()),
});
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(ResponseTemplate::new(200).set_body_json(discovery))
.mount(&server)
.await;
let jwks = serde_json::json!({
"keys": [
{
"kty": "oct",
"kid": kid,
"alg": "HS256",
"k": base64_url_for_test(&secret),
}
]
});
Mock::given(method("GET"))
.and(path("/jwks"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks))
.mount(&server)
.await;
let discovery_url = format!("{}/.well-known/openid-configuration", server.uri());
(server, discovery_url, secret, kid)
}
fn mint_idp_token(
server_uri: &str,
kid: &str,
secret: &[u8],
tenant_claim: &str,
audience: &str,
) -> String {
use jsonwebtoken::{Algorithm, EncodingKey, Header};
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(kid.to_string());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = serde_json::json!({
"iss": server_uri,
"sub": "test-user-1",
"aud": audience,
"exp": now + 600,
"iat": now,
"solo_tenant": tenant_claim,
});
jsonwebtoken::encode(&header, &claims, &EncodingKey::from_secret(secret))
.expect("mint token")
}
#[test]
fn http_oidc_accept_resolves_to_tenant_from_claim() {
let runtime = rt();
let (fake_server, discovery_url, secret, kid) =
runtime.block_on(async { spin_fake_idp().await });
let server_uri = fake_server.uri();
let _server_guard = fake_server;
let auth = crate::auth::AuthConfig::Oidc {
discovery_url,
audience: "test-audience".to_string(),
tenant_claim_name: "solo_tenant".to_string(),
};
let h = Harness::new_with_auth_config(&runtime, Some(auth));
let r = h.router.clone();
let token = mint_idp_token(
&server_uri,
kid,
&secret,
"default",
"test-audience",
);
runtime.block_on(async move {
let (status, body) = call_with_auth(
r.clone(),
"POST",
"/memory",
Some(json!({ "content": "oidc-routed content" })),
Some(&format!("Bearer {token}")),
)
.await;
assert_eq!(status, StatusCode::OK, "got body: {body}");
assert!(body.get("memory_id").is_some(), "no memory_id in {body}");
});
h.shutdown(&runtime);
}
#[test]
fn http_oidc_reject_missing_token_returns_401() {
let runtime = rt();
let (fake_server, discovery_url, _secret, _kid) =
runtime.block_on(async { spin_fake_idp().await });
let _server_guard = fake_server;
let auth = crate::auth::AuthConfig::Oidc {
discovery_url,
audience: "test-audience".to_string(),
tenant_claim_name: "solo_tenant".to_string(),
};
let h = Harness::new_with_auth_config(&runtime, Some(auth));
let r = h.router.clone();
runtime.block_on(async move {
let (status, _body) =
call(r.clone(), "POST", "/memory", Some(json!({ "content": "x" }))).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, _body) = call_with_auth(
r.clone(),
"POST",
"/memory",
Some(json!({ "content": "x" })),
Some("Bearer not-a-real-jwt"),
)
.await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
});
h.shutdown(&runtime);
}
#[test]
fn full_remember_recall_inspect_forget_round_trip() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let (status, body) = call(
r.clone(),
"POST",
"/memory",
Some(json!({ "content": "round-trip content" })),
)
.await;
assert_eq!(status, StatusCode::OK);
let mid = body
.get("memory_id")
.and_then(|v| v.as_str())
.unwrap()
.to_string();
let (status, body) = call(
r.clone(),
"POST",
"/memory/search",
Some(json!({ "query": "round-trip content", "limit": 5 })),
)
.await;
assert_eq!(status, StatusCode::OK);
let hits = body.get("hits").and_then(|v| v.as_array()).unwrap();
assert!(
hits.iter()
.any(|h| h.get("content").and_then(|c| c.as_str())
== Some("round-trip content")),
"expected hit with content; got: {body}"
);
let (status, body) = call(r.clone(), "GET", &format!("/memory/{mid}"), None).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body.get("status").and_then(|v| v.as_str()), Some("active"));
let (status, _body) =
call(r.clone(), "DELETE", &format!("/memory/{mid}"), None).await;
assert_eq!(status, StatusCode::NO_CONTENT);
let (status, body) = call(r.clone(), "GET", &format!("/memory/{mid}"), None).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(
body.get("status").and_then(|v| v.as_str()),
Some("forgotten")
);
let (status, body) = call(
r.clone(),
"POST",
"/memory/search",
Some(json!({ "query": "round-trip content", "limit": 5 })),
)
.await;
assert_eq!(status, StatusCode::OK);
let hits = body.get("hits").and_then(|v| v.as_array()).unwrap();
assert!(
hits.iter().all(|h| h.get("memory_id").and_then(|m| m.as_str())
!= Some(mid.as_str())),
"forgotten row should be excluded from recall: {body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn themes_endpoint_returns_empty_array_on_empty_db() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) =
runtime.block_on(call(r, "GET", "/memory/themes", None));
assert_eq!(status, StatusCode::OK);
assert!(body.is_array(), "expected array, got {body}");
assert_eq!(body.as_array().unwrap().len(), 0);
h.shutdown(&runtime);
}
#[test]
fn themes_endpoint_passes_through_query_params() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/themes?window_days=7&limit=20",
None,
));
assert_eq!(status, StatusCode::OK);
assert!(body.is_array(), "expected array, got {body}");
h.shutdown(&runtime);
}
#[test]
fn facts_about_endpoint_requires_subject() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) =
runtime.block_on(call(r, "GET", "/memory/facts_about", None));
assert!(
status == StatusCode::BAD_REQUEST
|| status == StatusCode::UNPROCESSABLE_ENTITY,
"expected 400 or 422 for missing subject, got {status}"
);
h.shutdown(&runtime);
}
#[test]
fn facts_about_endpoint_rejects_blank_subject() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/facts_about?subject=%20%20",
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body.get("error")
.and_then(|v| v.as_str())
.is_some_and(|s| s.contains("subject")),
"expected error mentioning subject, got {body}"
);
h.shutdown(&runtime);
}
#[test]
fn facts_about_endpoint_returns_empty_array_for_unknown_subject() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/facts_about?subject=NobodyKnows",
None,
));
assert_eq!(status, StatusCode::OK);
assert_eq!(body.as_array().unwrap().len(), 0);
h.shutdown(&runtime);
}
#[test]
fn facts_about_endpoint_parses_include_as_object_query_param() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/facts_about?subject=Maya&include_as_object=true",
None,
));
assert_eq!(
status,
StatusCode::OK,
"expected 200 with include_as_object query param, got {status}"
);
assert!(body.is_array());
h.shutdown(&runtime);
}
#[test]
fn inspect_cluster_endpoint_unknown_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/clusters/no-such-cluster",
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
assert!(
body.get("error")
.and_then(|v| v.as_str())
.is_some_and(|s| s.contains("no-such-cluster")),
"expected error mentioning cluster id, got {body}"
);
h.shutdown(&runtime);
}
#[test]
fn inspect_cluster_endpoint_passes_full_content_query_param() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(
r,
"GET",
"/memory/clusters/missing?full_content=true",
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn contradictions_endpoint_returns_empty_array_on_empty_db() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/contradictions",
None,
));
assert_eq!(status, StatusCode::OK);
assert!(body.is_array());
assert_eq!(body.as_array().unwrap().len(), 0);
h.shutdown(&runtime);
}
#[test]
fn derived_endpoints_require_bearer_when_auth_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("secret-token".to_string()));
for path in [
"/memory/themes",
"/memory/facts_about?subject=Sam",
"/memory/contradictions",
"/memory/clusters/any-id",
] {
let (status, _) = runtime.block_on(call(h.router.clone(), "GET", path, None));
assert_eq!(
status,
StatusCode::UNAUTHORIZED,
"{path} should 401 without token"
);
}
h.shutdown(&runtime);
}
#[test]
fn list_documents_endpoint_returns_empty_array_on_empty_db() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(r, "GET", "/memory/documents", None));
assert_eq!(status, StatusCode::OK);
assert!(body.is_array(), "expected array, got {body}");
assert_eq!(body.as_array().unwrap().len(), 0);
h.shutdown(&runtime);
}
#[test]
fn list_documents_endpoint_parses_query_params() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/documents?limit=5&offset=0&include_forgotten=true",
None,
));
assert_eq!(status, StatusCode::OK);
assert!(body.is_array());
h.shutdown(&runtime);
}
#[test]
fn ingest_document_endpoint_rejects_empty_path() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"POST",
"/memory/documents",
Some(json!({ "path": "" })),
));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body.get("error")
.and_then(|v| v.as_str())
.is_some_and(|s| s.contains("path")),
"expected error mentioning path, got {body}"
);
h.shutdown(&runtime);
}
#[test]
fn search_docs_endpoint_rejects_empty_query() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"POST",
"/memory/documents/search",
Some(json!({ "query": " " })),
));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body.get("error")
.and_then(|v| v.as_str())
.is_some_and(|s| s.contains("must not be empty")
|| s.contains("doc_search")),
"expected error mentioning empty query, got {body}"
);
h.shutdown(&runtime);
}
#[test]
fn inspect_document_endpoint_unknown_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(call(
r,
"GET",
"/memory/documents/00000000-0000-7000-8000-000000000000",
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
assert!(body.get("error").is_some(), "got: {body}");
h.shutdown(&runtime);
}
#[test]
fn inspect_document_endpoint_rejects_malformed_id() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) =
runtime.block_on(call(r, "GET", "/memory/documents/not-a-uuid", None));
assert_eq!(status, StatusCode::BAD_REQUEST);
h.shutdown(&runtime);
}
#[test]
fn forget_document_endpoint_unknown_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(call(
r,
"DELETE",
"/memory/documents/00000000-0000-7000-8000-000000000000",
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn forget_document_endpoint_rejects_malformed_id() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) =
runtime.block_on(call(r, "DELETE", "/memory/documents/not-a-uuid", None));
assert_eq!(status, StatusCode::BAD_REQUEST);
h.shutdown(&runtime);
}
#[test]
fn document_endpoints_require_bearer_when_auth_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("doc-secret".to_string()));
let cases: &[(&str, &str, Option<Value>)] = &[
("POST", "/memory/documents", Some(json!({ "path": "/x" }))),
("GET", "/memory/documents", None),
(
"POST",
"/memory/documents/search",
Some(json!({ "query": "x" })),
),
(
"GET",
"/memory/documents/00000000-0000-7000-8000-000000000000",
None,
),
(
"DELETE",
"/memory/documents/00000000-0000-7000-8000-000000000000",
None,
),
];
for (method, path, body) in cases {
let (status, _) =
runtime.block_on(call(h.router.clone(), method, path, body.clone()));
assert_eq!(
status,
StatusCode::UNAUTHORIZED,
"{method} {path} should 401 without token"
);
}
h.shutdown(&runtime);
}
#[test]
fn document_endpoints_accept_correct_bearer_token() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("doc-secret".to_string()));
runtime.block_on(async {
let (status, _) = call_with_auth(
h.router.clone(),
"GET",
"/memory/documents",
None,
Some("Bearer doc-secret"),
)
.await;
assert_eq!(status, StatusCode::OK);
let (status, _) = call_with_auth(
h.router.clone(),
"GET",
"/memory/documents/00000000-0000-7000-8000-000000000000",
None,
Some("Bearer doc-secret"),
)
.await;
assert_eq!(status, StatusCode::NOT_FOUND);
});
h.shutdown(&runtime);
}
#[test]
fn tenant_header_default_resolves() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri("/memory/00000000-0000-7000-8000-000000000000")
.header("x-solo-tenant", "default")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn tenant_header_invalid_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, body) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri("/memory/00000000-0000-7000-8000-000000000000")
.header("x-solo-tenant", "UPPER")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&bytes).unwrap_or(Value::Null);
(s, v)
});
assert_eq!(status, StatusCode::BAD_REQUEST);
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("");
assert!(
msg.to_lowercase().contains("tenant") || msg.to_lowercase().contains("invalid"),
"error must mention tenant/invalid: {msg}"
);
h.shutdown(&runtime);
}
#[test]
fn tenant_header_unknown_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri("/memory/00000000-0000-7000-8000-000000000000")
.header("x-solo-tenant", "never-registered")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn tenant_header_missing_defaults_to_state_default_tenant() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let (status, _body) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri("/memory/00000000-0000-7000-8000-000000000000")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
fn seed_episode(
conn: &rusqlite::Connection,
memory_id: &str,
ts_ms: i64,
content: &str,
) -> i64 {
conn.execute(
"INSERT INTO episodes
(memory_id, ts_ms, source_type, content,
encoding_context_json, tier, status,
confidence, strength, salience,
created_at_ms, updated_at_ms)
VALUES (?1, ?2, 'user_message', ?3,
'{}', 'hot', 'active',
1.0, 0.5, 0.5, ?2, ?2)",
rusqlite::params![memory_id, ts_ms, content],
)
.expect("seed episode");
conn.last_insert_rowid()
}
fn seed_cluster_row(conn: &rusqlite::Connection, cluster_id: &str, created_at_ms: i64) {
conn.execute(
"INSERT INTO clusters (cluster_id, coherence, created_at_ms)
VALUES (?1, 0.5, ?2)",
rusqlite::params![cluster_id, created_at_ms],
)
.expect("seed cluster");
}
fn seed_cluster_member(conn: &rusqlite::Connection, cluster_id: &str, memory_id: &str) {
conn.execute(
"INSERT INTO cluster_episodes (cluster_id, memory_id) VALUES (?1, ?2)",
rusqlite::params![cluster_id, memory_id],
)
.expect("seed cluster_episodes");
}
fn seed_document_row(conn: &rusqlite::Connection, doc_id: &str, title: &str) {
conn.execute(
"INSERT INTO documents
(doc_id, source, title, mime_type, ingested_at_ms,
modified_at_ms, status, chunk_count, content_hash, byte_size)
VALUES (?1, ?2, ?3, 'text/plain', 0, NULL,
'active', 0, ?1, NULL)",
rusqlite::params![doc_id, format!("/tmp/{title}.txt"), title],
)
.expect("seed doc");
}
fn seed_chunk_row(
conn: &rusqlite::Connection,
chunk_id: &str,
doc_id: &str,
chunk_index: i64,
content: &str,
) {
conn.execute(
"INSERT INTO document_chunks
(chunk_id, doc_id, chunk_index, content,
token_count, start_offset, end_offset, created_at_ms)
VALUES (?1, ?2, ?3, ?4, 1, 0, ?5, 0)",
rusqlite::params![chunk_id, doc_id, chunk_index, content, content.len() as i64],
)
.expect("seed chunk");
}
fn seed_triple_row(
conn: &rusqlite::Connection,
triple_id: &str,
subject: &str,
predicate: &str,
object: &str,
source_episode_rowid: Option<i64>,
) {
conn.execute(
"INSERT INTO triples
(triple_id, subject_id, predicate, object_id, object_kind,
valid_from_ms, valid_to_ms, confidence, provenance_json,
status, created_at_ms, updated_at_ms, source_episode_id)
VALUES (?1, ?2, ?3, ?4, 'literal', 0, NULL, 0.9, '{}',
'active', 0, 0, ?5)",
rusqlite::params![triple_id, subject, predicate, object, source_episode_rowid],
)
.expect("seed triple");
}
fn seed_abstraction_row(
conn: &rusqlite::Connection,
abstraction_id: &str,
cluster_id: &str,
content: &str,
) {
conn.execute(
"INSERT INTO semantic_abstractions
(abstraction_id, cluster_id, content, provenance_json,
confidence, created_at_ms)
VALUES (?1, ?2, ?3, '{}', 0.9, 0)",
rusqlite::params![abstraction_id, cluster_id, content],
)
.expect("seed abstraction");
}
fn percent_encode_node_id(node_id: &str) -> String {
let mut out = String::with_capacity(node_id.len());
for c in node_id.chars() {
match c {
':' => out.push_str("%3A"),
' ' => out.push_str("%20"),
'&' => out.push_str("%26"),
'+' => out.push_str("%2B"),
'?' => out.push_str("%3F"),
'#' => out.push_str("%23"),
_ => out.push(c),
}
}
out
}
fn graph_uri(node_id: &str, kind: &str) -> String {
let encoded = percent_encode_node_id(node_id);
format!("/v1/graph/expand?node_id={encoded}&kind={kind}")
}
fn graph_uri_with_limit(node_id: &str, kind: &str, limit: u32) -> String {
let encoded = percent_encode_node_id(node_id);
format!("/v1/graph/expand?node_id={encoded}&kind={kind}&limit={limit}")
}
#[test]
fn expand_cluster_member_from_episode_returns_clusters() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "11111111-1111-7000-8000-000000000001";
{
let conn = h.open_db();
seed_episode(&conn, memory_id, 100, "ep content");
seed_cluster_row(&conn, "cl-a", 200);
seed_cluster_member(&conn, "cl-a", memory_id);
}
let node_id = format!("ep:{memory_id}");
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri(&node_id, "cluster_member"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body.get("nodes").and_then(|v| v.as_array()).expect("nodes array");
let edges = body.get("edges").and_then(|v| v.as_array()).expect("edges array");
assert_eq!(nodes.len(), 1, "{body}");
assert_eq!(nodes[0]["id"], "cl:cl-a");
assert_eq!(nodes[0]["kind"], "cluster");
assert_eq!(edges.len(), 1);
assert_eq!(edges[0]["source"], node_id);
assert_eq!(edges[0]["target"], "cl:cl-a");
assert_eq!(edges[0]["kind"], "cluster_member");
h.shutdown(&runtime);
}
#[test]
fn expand_cluster_member_from_cluster_returns_episodes() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_cluster_row(&conn, "cl-multi", 500);
for i in 0..5 {
let mid = format!("2222{i}222-2222-7000-8000-000000000001");
seed_episode(&conn, &mid, 100 + i as i64, &format!("content {i}"));
seed_cluster_member(&conn, "cl-multi", &mid);
}
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri_with_limit("cl:cl-multi", "cluster_member", 3),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let edges = body["edges"].as_array().unwrap();
assert_eq!(nodes.len(), 3, "limit honored: {body}");
assert_eq!(edges.len(), 3);
for n in nodes {
assert_eq!(n["kind"], "episode");
}
h.shutdown(&runtime);
}
#[test]
fn expand_document_chunk_from_document_returns_chunks() {
let runtime = rt();
let h = Harness::new(&runtime);
let doc_id = "33333333-3333-7000-8000-000000000001";
{
let conn = h.open_db();
seed_document_row(&conn, doc_id, "doc A");
seed_chunk_row(&conn, "c2", doc_id, 2, "chunk 2 text");
seed_chunk_row(&conn, "c0", doc_id, 0, "chunk 0 text");
seed_chunk_row(&conn, "c1", doc_id, 1, "chunk 1 text");
seed_chunk_row(&conn, "c3", doc_id, 3, "chunk 3 text");
}
let node_id = format!("doc:{doc_id}");
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri(&node_id, "document_chunk"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let edges = body["edges"].as_array().unwrap();
assert_eq!(nodes.len(), 4);
assert_eq!(edges.len(), 4);
assert_eq!(nodes[0]["id"], "chunk:c0");
assert_eq!(nodes[1]["id"], "chunk:c1");
assert_eq!(nodes[2]["id"], "chunk:c2");
assert_eq!(nodes[3]["id"], "chunk:c3");
for e in edges {
assert_eq!(e["kind"], "document_chunk");
}
h.shutdown(&runtime);
}
#[test]
fn expand_document_chunk_from_chunk_returns_parent_document() {
let runtime = rt();
let h = Harness::new(&runtime);
let doc_id = "44444444-4444-7000-8000-000000000001";
{
let conn = h.open_db();
seed_document_row(&conn, doc_id, "parent doc");
seed_chunk_row(&conn, "c-orphan", doc_id, 0, "chunk content");
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri("chunk:c-orphan", "document_chunk"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let edges = body["edges"].as_array().unwrap();
assert_eq!(nodes.len(), 1);
assert_eq!(edges.len(), 1);
assert_eq!(nodes[0]["id"], format!("doc:{doc_id}"));
assert_eq!(edges[0]["source"], "chunk:c-orphan");
assert_eq!(edges[0]["target"], format!("doc:{doc_id}"));
h.shutdown(&runtime);
}
#[test]
fn expand_triple_from_episode_returns_entities() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "55555555-5555-7000-8000-000000000001";
let rowid;
{
let conn = h.open_db();
rowid = seed_episode(&conn, memory_id, 100, "alice works at anthropic");
seed_triple_row(&conn, "t1", "Alice", "works_at", "Anthropic", Some(rowid));
seed_triple_row(&conn, "t2", "Bob", "lives_in", "NYC", Some(rowid));
}
let node_id = format!("ep:{memory_id}");
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri(&node_id, "triple"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let edges = body["edges"].as_array().unwrap();
assert_eq!(nodes.len(), 4, "expected 4 unique entity nodes: {body}");
assert_eq!(edges.len(), 2);
let ids: std::collections::HashSet<String> = nodes
.iter()
.map(|n| n["id"].as_str().unwrap().to_string())
.collect();
for expected in ["ent:Alice", "ent:Anthropic", "ent:Bob", "ent:NYC"] {
assert!(ids.contains(expected), "missing {expected} in {body}");
}
for e in edges {
assert_eq!(e["kind"], "triple");
assert!(e["predicate"].is_string(), "predicate set: {body}");
}
h.shutdown(&runtime);
}
#[test]
fn expand_triple_from_entity_returns_episodes() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let r1 = seed_episode(
&conn,
"66666666-6666-7000-8000-000000000001",
100,
"alice ep one",
);
let r2 = seed_episode(
&conn,
"66666666-6666-7000-8000-000000000002",
200,
"alice ep two",
);
let r3 = seed_episode(
&conn,
"66666666-6666-7000-8000-000000000003",
300,
"alice ep three",
);
seed_triple_row(&conn, "t1", "Alice", "p", "Bob", Some(r1));
seed_triple_row(&conn, "t2", "Carol", "p", "Alice", Some(r2));
seed_triple_row(&conn, "t3", "Alice", "q", "Dave", Some(r3));
seed_triple_row(&conn, "t-orphan", "Alice", "p", "Eve", None);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri("ent:Alice", "triple"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let edges = body["edges"].as_array().unwrap();
assert_eq!(nodes.len(), 3, "expected 3 episodes: {body}");
assert_eq!(edges.len(), 3);
for n in nodes {
assert_eq!(n["kind"], "episode");
}
for e in edges {
assert_eq!(e["source"], "ent:Alice");
assert_eq!(e["kind"], "triple");
}
h.shutdown(&runtime);
}
#[test]
fn expand_semantic_from_episode_returns_similar() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let mid1 = post_remember(h.router.clone(), "alpha alpha alpha").await;
let _mid2 = post_remember(h.router.clone(), "beta beta beta").await;
let _mid3 = post_remember(h.router.clone(), "gamma gamma gamma").await;
let (status, body) = call(
h.router.clone(),
"GET",
&graph_uri_with_limit(&format!("ep:{mid1}"), "semantic", 5),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let edges = body["edges"].as_array().unwrap();
for n in nodes {
assert_ne!(
n["id"].as_str().unwrap(),
format!("ep:{mid1}"),
"self must be excluded: {body}"
);
}
for e in edges {
assert_eq!(e["kind"], "semantic");
assert!(e["weight"].is_number(), "weight set: {body}");
}
});
h.shutdown(&runtime);
}
async fn post_remember(router: axum::Router, content: &str) -> String {
let (status, body) = call(
router,
"POST",
"/memory",
Some(json!({ "content": content })),
)
.await;
assert_eq!(status, StatusCode::OK, "post failed: {body}");
body["memory_id"].as_str().unwrap().to_string()
}
#[test]
fn expand_400_on_invalid_kind() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, _body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/expand?node_id=ep:any&kind=banana",
None,
));
assert!(
status == StatusCode::BAD_REQUEST || status == StatusCode::UNPROCESSABLE_ENTITY,
"expected 400/422 for bad kind, got {status}"
);
h.shutdown(&runtime);
}
#[test]
fn expand_400_on_invalid_node_for_kind() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri("cl:doesnt-matter", "semantic"),
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body["error"]
.as_str()
.is_some_and(|s| s.contains("semantic only valid for episode")),
"got: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn expand_404_on_missing_node_id() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri("ep:99999999-9999-7000-8000-000000000999", "cluster_member"),
None,
));
assert_eq!(status, StatusCode::NOT_FOUND, "{body}");
h.shutdown(&runtime);
}
#[test]
fn expand_limit_clamped_at_100() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_cluster_row(&conn, "cl-huge", 1_000);
for i in 0..150 {
let mid = format!("77777777-7777-7000-8000-{:012}", i);
seed_episode(&conn, &mid, 100 + i as i64, &format!("content {i}"));
seed_cluster_member(&conn, "cl-huge", &mid);
}
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri_with_limit("cl:cl-huge", "cluster_member", 999),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
assert_eq!(
nodes.len(),
100,
"limit must be silently clamped to 100, got {}",
nodes.len()
);
h.shutdown(&runtime);
}
#[test]
fn expand_bad_node_id_prefix_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/expand?node_id=garbage&kind=cluster_member",
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(
body["error"]
.as_str()
.is_some_and(|s| s.contains("node_id must be")),
"got: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn expand_respects_tenant_scoping_via_unknown_tenant_header() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "88888888-8888-7000-8000-000000000001";
{
let conn = h.open_db();
seed_episode(&conn, memory_id, 100, "scoped");
seed_cluster_row(&conn, "cl-scoped", 200);
seed_cluster_member(&conn, "cl-scoped", memory_id);
}
let node_id = format!("ep:{memory_id}");
let r = h.router.clone();
let (status, _body) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri(graph_uri(&node_id, "cluster_member"))
.header("x-solo-tenant", "never-registered-tenant")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn expand_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("graph-secret".into()));
let (status, _) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri("ep:any", "cluster_member"),
None,
));
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, _) = runtime.block_on(call_with_auth(
h.router.clone(),
"GET",
&graph_uri("ep:99999999-9999-7000-8000-000000000999", "cluster_member"),
None,
Some("Bearer graph-secret"),
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn expand_works_when_auth_none() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, _) = runtime.block_on(call(
h.router.clone(),
"GET",
&graph_uri("ep:99999999-9999-7000-8000-000000000999", "cluster_member"),
None,
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
async fn call_with_headers(
router: axum::Router,
method: &str,
uri: &str,
) -> (StatusCode, axum::http::HeaderMap, Value) {
let req = Request::builder()
.method(method)
.uri(uri)
.header("content-length", "0")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.expect("oneshot");
let status = resp.status();
let headers = resp.headers().clone();
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = if body_bytes.is_empty() {
Value::Null
} else {
serde_json::from_slice(&body_bytes).unwrap_or(Value::Null)
};
(status, headers, v)
}
#[test]
fn nodes_returns_all_kinds_when_no_filter() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"aaaaaaaa-0000-7000-8000-000000000001",
100,
"episode one",
);
seed_document_row(&conn, "doc-1", "doc one");
seed_chunk_row(&conn, "chunk-1", "doc-1", 0, "chunk one body");
seed_cluster_row(&conn, "cl-one", 200);
seed_triple_row(
&conn,
"t-one",
"Alice",
"knows",
"Bob",
Some(rowid),
);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/nodes",
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let kinds: std::collections::HashSet<&str> = nodes
.iter()
.map(|n| n["kind"].as_str().unwrap())
.collect();
for expected in ["episode", "document", "chunk", "cluster", "entity"] {
assert!(
kinds.contains(expected),
"expected {expected} kind in response: {body}"
);
}
h.shutdown(&runtime);
}
#[test]
fn nodes_filter_by_single_kind() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_episode(&conn, "bbbbbbbb-0000-7000-8000-000000000001", 100, "ep");
seed_document_row(&conn, "doc-only", "d");
seed_cluster_row(&conn, "cl-only", 300);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/nodes?kind=episode",
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
assert!(!nodes.is_empty(), "{body}");
for n in nodes {
assert_eq!(n["kind"], "episode", "kind filter must be exclusive: {body}");
}
h.shutdown(&runtime);
}
#[test]
fn nodes_filter_by_multiple_kinds() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_episode(&conn, "cccccccc-0000-7000-8000-000000000001", 100, "ep");
seed_document_row(&conn, "doc-multi", "d");
seed_cluster_row(&conn, "cl-multi", 300);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/nodes?kind=episode,document",
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
let kinds: std::collections::HashSet<&str> = nodes
.iter()
.map(|n| n["kind"].as_str().unwrap())
.collect();
assert!(kinds.contains("episode"), "{body}");
assert!(kinds.contains("document"), "{body}");
assert!(
!kinds.contains("cluster"),
"cluster must be filtered out: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn nodes_entity_synthesis_caps_at_200() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"dddddddd-0000-7000-8000-000000000001",
100,
"ep",
);
for i in 0..250 {
let triple_id = format!("t-cap-{i:03}");
let obj = format!("Entity{i:03}");
seed_triple_row(&conn, &triple_id, "Alice", "knows", &obj, Some(rowid));
}
}
let (status, headers, body) = runtime.block_on(call_with_headers(
h.router.clone(),
"GET",
"/v1/graph/nodes?kind=entity&limit=500",
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
assert_eq!(
nodes.len(),
200,
"entity cap must be enforced at 200, got {}",
nodes.len()
);
assert_eq!(
headers
.get("x-solo-entity-cap-reached")
.and_then(|v| v.to_str().ok()),
Some("true"),
"cap-reached header missing: headers={headers:?}"
);
for n in nodes {
assert_eq!(n["kind"], "entity");
}
h.shutdown(&runtime);
}
#[test]
fn nodes_since_until_filter_works() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_episode(
&conn,
"eeeeeeee-0000-7000-8000-000000000001",
100,
"early",
);
seed_episode(
&conn,
"eeeeeeee-0000-7000-8000-000000000002",
500,
"middle",
);
seed_episode(
&conn,
"eeeeeeee-0000-7000-8000-000000000003",
1000,
"late",
);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/nodes?kind=episode&since_ms=400&until_ms=600",
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let nodes = body["nodes"].as_array().unwrap();
assert_eq!(nodes.len(), 1, "{body}");
assert_eq!(
nodes[0]["id"],
"ep:eeeeeeee-0000-7000-8000-000000000002"
);
h.shutdown(&runtime);
}
#[test]
fn nodes_pagination_round_trip() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
for i in 0..150 {
let mid = format!("f0000000-0000-7000-8000-{i:012}");
seed_episode(&conn, &mid, 1_000 + i as i64, "page");
}
}
let limit = 50u32;
let mut seen: std::collections::HashSet<String> = Default::default();
let mut next_cursor: Option<String> = None;
for page_idx in 0..4 {
let cursor_param = next_cursor
.as_deref()
.map(|c| format!("&cursor={c}"))
.unwrap_or_default();
let uri = format!(
"/v1/graph/nodes?kind=episode&limit={limit}{cursor_param}"
);
let (status, body) =
runtime.block_on(call(h.router.clone(), "GET", &uri, None));
assert_eq!(status, StatusCode::OK, "page {page_idx}: {body}");
let nodes = body["nodes"].as_array().unwrap();
assert!(
nodes.len() <= limit as usize,
"page {page_idx} over-fetched: {body}"
);
for n in nodes {
let id = n["id"].as_str().unwrap().to_string();
assert!(seen.insert(id.clone()), "duplicate id across pages: {id}");
}
next_cursor = body
.get("next_cursor")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if next_cursor.is_none() {
break;
}
}
assert_eq!(
seen.len(),
150,
"expected 150 distinct ids across pages, got {}",
seen.len()
);
assert!(
next_cursor.is_none(),
"cursor should be null after last page; got {next_cursor:?}"
);
h.shutdown(&runtime);
}
#[test]
fn nodes_respects_tenant_scoping() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_episode(
&conn,
"11110000-0000-7000-8000-000000000001",
100,
"tenant scope",
);
}
let r = h.router.clone();
let (status, _body) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri("/v1/graph/nodes")
.header("x-solo-tenant", "never-registered-tenant")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn nodes_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("nodes-secret".into()));
let (status, _) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/nodes",
None,
));
assert_eq!(
status,
StatusCode::UNAUTHORIZED,
"must reject unauthenticated request"
);
let (status, _) = runtime.block_on(call_with_auth(
h.router.clone(),
"GET",
"/v1/graph/nodes",
None,
Some("Bearer nodes-secret"),
));
assert_eq!(status, StatusCode::OK, "must pass through with bearer");
h.shutdown(&runtime);
}
#[test]
fn nodes_works_with_auth_none() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/nodes",
None,
));
assert_eq!(status, StatusCode::OK, "{body}");
assert!(body.get("nodes").is_some());
h.shutdown(&runtime);
}
#[test]
fn edges_returns_all_default_kinds() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"22220000-0000-7000-8000-000000000001",
100,
"ep src",
);
seed_triple_row(&conn, "t-def", "Alice", "knows", "Bob", Some(rowid));
seed_document_row(&conn, "doc-e", "doc");
seed_chunk_row(&conn, "c-e", "doc-e", 0, "chunk");
seed_cluster_row(&conn, "cl-e", 200);
seed_cluster_member(
&conn,
"cl-e",
"22220000-0000-7000-8000-000000000001",
);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/edges",
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
let kinds: std::collections::HashSet<&str> = edges
.iter()
.map(|e| e["kind"].as_str().unwrap())
.collect();
assert!(kinds.contains("triple"), "{body}");
assert!(kinds.contains("document_chunk"), "{body}");
assert!(kinds.contains("cluster_member"), "{body}");
assert!(
!kinds.contains("semantic"),
"semantic is NOT in default response: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn edges_filter_by_node_id_finds_incident_edges() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "33330000-0000-7000-8000-000000000001";
{
let conn = h.open_db();
let rowid = seed_episode(&conn, memory_id, 100, "ep multi-triple");
seed_triple_row(&conn, "t-a", "Alice", "p", "Bob", Some(rowid));
seed_triple_row(&conn, "t-b", "Alice", "p", "Carol", Some(rowid));
seed_triple_row(&conn, "t-c", "Alice", "p", "Dave", Some(rowid));
let decoy_rowid = seed_episode(
&conn,
"33330000-0000-7000-8000-000000000999",
200,
"decoy",
);
seed_triple_row(
&conn,
"t-decoy",
"Alice",
"p",
"Eve",
Some(decoy_rowid),
);
}
let uri = format!(
"/v1/graph/edges?type=triple&node_id={}",
percent_encode_node_id(&format!("ep:{memory_id}"))
);
let (status, body) =
runtime.block_on(call(h.router.clone(), "GET", &uri, None));
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
assert_eq!(edges.len(), 3, "expected 3 incident edges: {body}");
for e in edges {
assert_eq!(e["source"], format!("ep:{memory_id}"));
assert_eq!(e["kind"], "triple");
}
h.shutdown(&runtime);
}
#[test]
fn edges_filter_by_type_works() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"44440000-0000-7000-8000-000000000001",
100,
"ep",
);
seed_triple_row(&conn, "t-only", "Alice", "p", "Bob", Some(rowid));
seed_document_row(&conn, "doc-skip", "doc");
seed_chunk_row(&conn, "c-skip", "doc-skip", 0, "chunk");
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/edges?type=triple",
None,
));
assert_eq!(status, StatusCode::OK, "{body}");
let edges = body["edges"].as_array().unwrap();
assert!(!edges.is_empty(), "{body}");
for e in edges {
assert_eq!(e["kind"], "triple", "{body}");
}
h.shutdown(&runtime);
}
#[test]
fn edges_rejects_semantic_type_with_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/edges?type=semantic",
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST, "body: {body}");
let err = body["error"].as_str().unwrap_or_default();
assert!(
err.contains("/v1/graph/neighbors"),
"error must point to /v1/graph/neighbors: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn edges_pagination_round_trip() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"55550000-0000-7000-8000-000000000001",
100,
"ep big",
);
for i in 0..60 {
let tid = format!("t-page-{i:03}");
let obj = format!("Obj{i:03}");
seed_triple_row(&conn, &tid, "Alice", "p", &obj, Some(rowid));
}
}
let limit = 25u32;
let mut seen: std::collections::HashSet<String> = Default::default();
let mut next_cursor: Option<String> = None;
for page_idx in 0..5 {
let cursor_param = next_cursor
.as_deref()
.map(|c| format!("&cursor={c}"))
.unwrap_or_default();
let uri = format!(
"/v1/graph/edges?type=triple&limit={limit}{cursor_param}"
);
let (status, body) =
runtime.block_on(call(h.router.clone(), "GET", &uri, None));
assert_eq!(status, StatusCode::OK, "page {page_idx}: {body}");
let edges = body["edges"].as_array().unwrap();
for e in edges {
let id = e["id"].as_str().unwrap().to_string();
assert!(seen.insert(id.clone()), "duplicate edge id: {id}");
}
next_cursor = body
.get("next_cursor")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if next_cursor.is_none() {
break;
}
}
assert_eq!(
seen.len(),
60,
"expected 60 distinct edges, got {}",
seen.len()
);
assert!(next_cursor.is_none(), "expected exhausted cursor");
h.shutdown(&runtime);
}
#[test]
fn edges_respects_tenant_scoping() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"66660000-0000-7000-8000-000000000001",
100,
"ep",
);
seed_triple_row(&conn, "t-tenant", "Alice", "p", "Bob", Some(rowid));
}
let r = h.router.clone();
let (status, _) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri("/v1/graph/edges")
.header("x-solo-tenant", "never-registered-tenant")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[test]
fn edges_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("edges-secret".into()));
let (status, _) = runtime.block_on(call(
h.router.clone(),
"GET",
"/v1/graph/edges",
None,
));
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, _) = runtime.block_on(call_with_auth(
h.router.clone(),
"GET",
"/v1/graph/edges",
None,
Some("Bearer edges-secret"),
));
assert_eq!(status, StatusCode::OK);
h.shutdown(&runtime);
}
fn inspect_uri(node_id: &str) -> String {
format!("/v1/graph/inspect/{}", percent_encode_node_id(node_id))
}
#[test]
fn inspect_episode_returns_full_text_plus_triples_out() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "a1110000-0000-7000-8000-000000000001";
let full_text = "Met Alice for coffee at the new place. She mentioned the project is on track but they're hitting issues with the deploy pipeline.";
{
let conn = h.open_db();
let rowid = seed_episode(&conn, memory_id, 1_715_625_600_000, full_text);
seed_triple_row(&conn, "t-ep-1", "user", "met_with", "Alice", Some(rowid));
seed_triple_row(&conn, "t-ep-2", "user", "discussed", "deploy_pipeline", Some(rowid));
seed_triple_row(&conn, "t-ep-3", "Alice", "works_on", "project", Some(rowid));
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri(&format!("ep:{memory_id}")),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
assert_eq!(body["node"]["kind"], "episode");
assert_eq!(body["node"]["id"], format!("ep:{memory_id}"));
assert_eq!(
body["full_text"].as_str().unwrap(),
full_text,
"full_text must match episodes.content verbatim, untruncated"
);
let triples_out = body["triples_out"].as_array().unwrap();
assert_eq!(triples_out.len(), 3, "{body}");
let triples_in = body["triples_in"].as_array().unwrap();
assert!(triples_in.is_empty(), "episodes have no triples_in: {body}");
for e in triples_out {
assert_eq!(e["kind"], "triple");
assert_eq!(e["source"], format!("ep:{memory_id}"));
assert!(e["target"].as_str().unwrap().starts_with("ent:"));
assert!(e["predicate"].as_str().is_some());
assert!(e["weight"].as_f64().is_some());
}
h.shutdown(&runtime);
}
#[test]
fn inspect_episode_triples_in_is_empty_for_v10p1() {
let runtime = rt();
let h = Harness::new(&runtime);
let focal = "a2220000-0000-7000-8000-000000000001";
let other = "a2220000-0000-7000-8000-000000000002";
{
let conn = h.open_db();
seed_episode(&conn, focal, 100, "focal episode body");
let other_rowid = seed_episode(&conn, other, 200, "another episode");
for i in 0..5 {
let tid = format!("t-other-{i}");
seed_triple_row(&conn, &tid, "user", "did", "thing", Some(other_rowid));
}
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri(&format!("ep:{focal}")),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let triples_in = body["triples_in"].as_array().unwrap();
assert!(
triples_in.is_empty(),
"episode triples_in must be empty regardless of cross-episode entity references: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn inspect_document_returns_full_text_concatenated_from_chunks() {
let runtime = rt();
let h = Harness::new(&runtime);
let doc_id = "d3330000-0000-7000-8000-000000000001";
{
let conn = h.open_db();
seed_document_row(&conn, doc_id, "doc-title");
seed_chunk_row(&conn, "ch-doc-1", doc_id, 0, "First chunk body.");
seed_chunk_row(&conn, "ch-doc-2", doc_id, 1, "Second chunk body.");
seed_chunk_row(&conn, "ch-doc-3", doc_id, 2, "Third chunk body.");
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri(&format!("doc:{doc_id}")),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
assert_eq!(body["node"]["kind"], "document");
let full_text = body["full_text"].as_str().unwrap();
assert_eq!(
full_text,
"First chunk body.\n\nSecond chunk body.\n\nThird chunk body."
);
assert!(body["triples_in"].as_array().unwrap().is_empty());
assert!(body["triples_out"].as_array().unwrap().is_empty());
h.shutdown(&runtime);
}
#[test]
fn inspect_chunk_returns_text() {
let runtime = rt();
let h = Harness::new(&runtime);
let chunk_body = "This is the body of the chunk being inspected.";
{
let conn = h.open_db();
seed_document_row(&conn, "doc-chunk-host", "host");
seed_chunk_row(&conn, "chunk-inspect-target", "doc-chunk-host", 0, chunk_body);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri("chunk:chunk-inspect-target"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
assert_eq!(body["node"]["kind"], "chunk");
assert_eq!(body["full_text"].as_str().unwrap(), chunk_body);
assert!(body["triples_in"].as_array().unwrap().is_empty());
assert!(body["triples_out"].as_array().unwrap().is_empty());
h.shutdown(&runtime);
}
#[test]
fn inspect_cluster_returns_label_and_abstraction() {
let runtime = rt();
let h = Harness::new(&runtime);
let cluster_id = "cl-inspect-target";
let abstraction_text = "Discussions about the deploy pipeline and on-call rotation.";
{
let conn = h.open_db();
seed_cluster_row(&conn, cluster_id, 12345);
seed_abstraction_row(&conn, "abs-1", cluster_id, abstraction_text);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri(&format!("cl:{cluster_id}")),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
assert_eq!(body["node"]["kind"], "cluster");
let full_text = body["full_text"].as_str().unwrap();
assert!(
full_text.contains(cluster_id),
"full_text must include cluster label: {full_text}"
);
assert!(
full_text.contains(abstraction_text),
"full_text must include abstraction text: {full_text}"
);
assert!(full_text.contains("\n\n"), "label and abstraction must be separated: {full_text}");
h.shutdown(&runtime);
}
#[test]
fn inspect_entity_returns_triples_only() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"e5550000-0000-7000-8000-000000000001",
100,
"host episode",
);
seed_triple_row(&conn, "t-ent-1", "Alice", "knows", "Bob", Some(rowid));
seed_triple_row(&conn, "t-ent-2", "Alice", "works_at", "Anthropic", Some(rowid));
seed_triple_row(&conn, "t-ent-3", "user", "met", "Alice", Some(rowid));
seed_triple_row(&conn, "t-ent-4", "Alice", "owns", "laptop", Some(rowid));
seed_triple_row(&conn, "t-ent-5", "Carol", "mentors", "Alice", Some(rowid));
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri("ent:Alice"),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
assert_eq!(body["node"]["kind"], "entity");
assert_eq!(body["node"]["id"], "ent:Alice");
assert!(
body["full_text"].is_null(),
"entity full_text must be null (entities have no body): {body}"
);
let triples_out = body["triples_out"].as_array().unwrap();
assert_eq!(triples_out.len(), 5, "{body}");
assert!(body["triples_in"].as_array().unwrap().is_empty());
for e in triples_out {
assert_eq!(e["kind"], "triple");
assert_eq!(e["source"], "ent:Alice");
assert!(e["target"].as_str().unwrap().starts_with("ent:"));
assert_ne!(e["target"], "ent:Alice");
}
h.shutdown(&runtime);
}
#[test]
fn inspect_entity_with_zero_triples_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
let rowid = seed_episode(
&conn,
"e6660000-0000-7000-8000-000000000001",
100,
"ep",
);
seed_triple_row(&conn, "t-other", "Bob", "knows", "Carol", Some(rowid));
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri("ent:Nonexistent"),
None,
));
assert_eq!(status, StatusCode::NOT_FOUND, "body: {body}");
let err = body["error"].as_str().unwrap_or_default();
assert!(
err.contains("Nonexistent") || err.contains("entity"),
"error must mention entity: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn inspect_404_on_missing_node() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri("ep:99999999-9999-7000-8000-000000000999"),
None,
));
assert_eq!(status, StatusCode::NOT_FOUND, "body: {body}");
h.shutdown(&runtime);
}
#[test]
fn inspect_400_on_invalid_prefix() {
let runtime = rt();
let h = Harness::new(&runtime);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri("xyz:foo"),
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST, "body: {body}");
let err = body["error"].as_str().unwrap_or_default();
assert!(
err.contains("xyz") || err.contains("prefix"),
"error must mention bad prefix: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn inspect_respects_tenant_scoping() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "a7770000-0000-7000-8000-000000000001";
{
let conn = h.open_db();
seed_episode(&conn, memory_id, 100, "tenant scope");
}
let r = h.router.clone();
let (status, _) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri(inspect_uri(&format!("ep:{memory_id}")))
.header("x-solo-tenant", "never-registered-tenant")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri(&format!("ep:{memory_id}")),
None,
));
assert_eq!(status, StatusCode::OK, "default tenant must resolve: {body}");
h.shutdown(&runtime);
}
#[test]
fn inspect_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("inspect-secret".into()));
let (status, _) = runtime.block_on(call(
h.router.clone(),
"GET",
&inspect_uri("ep:99999999-9999-7000-8000-000000000999"),
None,
));
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, _) = runtime.block_on(call_with_auth(
h.router.clone(),
"GET",
&inspect_uri("ep:99999999-9999-7000-8000-000000000999"),
None,
Some("Bearer inspect-secret"),
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
fn neighbors_uri(
node_id: &str,
kind: Option<&str>,
threshold: Option<f32>,
limit: Option<u32>,
) -> String {
let mut qs: Vec<String> = Vec::new();
if let Some(k) = kind {
qs.push(format!("kind={k}"));
}
if let Some(t) = threshold {
qs.push(format!("threshold={t}"));
}
if let Some(l) = limit {
qs.push(format!("limit={l}"));
}
let encoded = percent_encode_node_id(node_id);
if qs.is_empty() {
format!("/v1/graph/neighbors/{encoded}")
} else {
format!("/v1/graph/neighbors/{encoded}?{}", qs.join("&"))
}
}
#[test]
fn neighbors_explicit_only_returns_no_semantic_edges() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let focal = post_remember(h.router.clone(), "alpha alpha alpha").await;
let _other1 = post_remember(h.router.clone(), "beta beta beta").await;
let _other2 = post_remember(h.router.clone(), "gamma gamma gamma").await;
{
let conn = h.open_db();
let rowid: i64 = conn
.query_row(
"SELECT rowid FROM episodes WHERE memory_id = ?1",
rusqlite::params![&focal],
|r| r.get(0),
)
.unwrap();
seed_triple_row(&conn, "t-exp-1", "Alice", "knows", "Bob", Some(rowid));
seed_triple_row(&conn, "t-exp-2", "Alice", "owns", "laptop", Some(rowid));
}
let (status, body) = call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{focal}"), Some("explicit"), None, None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
assert!(!edges.is_empty(), "expected explicit edges: {body}");
for e in edges {
assert_ne!(
e["kind"], "semantic",
"kind=explicit must drop semantic edges: {body}"
);
}
});
h.shutdown(&runtime);
}
#[test]
fn neighbors_semantic_only_returns_no_explicit_edges() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let focal = post_remember(h.router.clone(), "alpha alpha alpha").await;
let _other1 = post_remember(h.router.clone(), "beta beta beta").await;
let _other2 = post_remember(h.router.clone(), "gamma gamma gamma").await;
{
let conn = h.open_db();
let rowid: i64 = conn
.query_row(
"SELECT rowid FROM episodes WHERE memory_id = ?1",
rusqlite::params![&focal],
|r| r.get(0),
)
.unwrap();
seed_triple_row(&conn, "t-exp-1", "Alice", "knows", "Bob", Some(rowid));
}
let (status, body) = call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{focal}"), Some("semantic"), Some(0.0), None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
for e in edges {
assert_eq!(
e["kind"], "semantic",
"kind=semantic must drop explicit edges: {body}"
);
assert!(e["weight"].is_number(), "semantic edges carry weight: {body}");
}
});
h.shutdown(&runtime);
}
#[test]
fn neighbors_both_default_returns_combined() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let focal = post_remember(h.router.clone(), "alpha alpha alpha").await;
let _other1 = post_remember(h.router.clone(), "beta beta beta").await;
{
let conn = h.open_db();
let rowid: i64 = conn
.query_row(
"SELECT rowid FROM episodes WHERE memory_id = ?1",
rusqlite::params![&focal],
|r| r.get(0),
)
.unwrap();
seed_triple_row(&conn, "t-both-1", "Alice", "met", "Bob", Some(rowid));
}
let (status, body) = call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{focal}"), None, Some(0.0), None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
let kinds: std::collections::HashSet<&str> = edges
.iter()
.map(|e| e["kind"].as_str().unwrap())
.collect();
assert!(
kinds.contains("triple"),
"expected at least one triple edge: {body}"
);
assert!(
kinds.contains("semantic"),
"expected at least one semantic edge: {body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn neighbors_dedupes_semantic_when_explicit_exists() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let focal = post_remember(h.router.clone(), "alpha alpha alpha").await;
let _other = post_remember(h.router.clone(), "beta beta beta").await;
{
let conn = h.open_db();
let rowid: i64 = conn
.query_row(
"SELECT rowid FROM episodes WHERE memory_id = ?1",
rusqlite::params![&focal],
|r| r.get(0),
)
.unwrap();
seed_triple_row(
&conn,
"t-dedupe-1",
"Alice",
"knows",
"Bob",
Some(rowid),
);
}
let (status, body) = call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{focal}"), Some("both"), Some(0.0), None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
let mut seen: std::collections::HashMap<(String, String), i32> =
std::collections::HashMap::new();
for e in edges {
let key = (
e["source"].as_str().unwrap().to_string(),
e["target"].as_str().unwrap().to_string(),
);
*seen.entry(key).or_insert(0) += 1;
}
for (pair, count) in &seen {
assert_eq!(
*count, 1,
"edge pair {pair:?} appears {count} times -- dedupe rule violated: {body}"
);
}
});
h.shutdown(&runtime);
}
#[test]
fn neighbors_threshold_filters_low_similarity() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let focal = post_remember(h.router.clone(), "alpha alpha alpha").await;
let _o1 = post_remember(h.router.clone(), "beta one").await;
let _o2 = post_remember(h.router.clone(), "beta two").await;
let _o3 = post_remember(h.router.clone(), "beta three").await;
let (status, low_body) = call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{focal}"), Some("semantic"), Some(0.0), None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {low_body}");
let low_edge_count = low_body["edges"].as_array().unwrap().len();
let (status, high_body) = call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{focal}"), Some("semantic"), Some(0.99), None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {high_body}");
let high_edge_count = high_body["edges"].as_array().unwrap().len();
assert!(
high_edge_count <= low_edge_count,
"high-threshold ({high_edge_count}) must not exceed low-threshold ({low_edge_count}): low={low_body}, high={high_body}"
);
for e in high_body["edges"].as_array().unwrap() {
if let Some(w) = e["weight"].as_f64() {
assert!(
w >= 0.99,
"edge with weight {w} survived threshold=0.99: {e}"
);
}
}
});
h.shutdown(&runtime);
}
#[test]
fn neighbors_limit_clamped_at_100() {
let runtime = rt();
let h = Harness::new(&runtime);
{
let conn = h.open_db();
seed_cluster_row(&conn, "cl-huge-n", 1000);
for i in 0..150 {
let mid = format!("99119911-1111-7000-8000-{:012}", i);
seed_episode(&conn, &mid, 100 + i as i64, &format!("content {i}"));
seed_cluster_member(&conn, "cl-huge-n", &mid);
}
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&neighbors_uri("cl:cl-huge-n", Some("explicit"), None, Some(999)),
None,
));
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
assert_eq!(
edges.len(),
100,
"limit must be silently clamped to 100, got {}",
edges.len()
);
h.shutdown(&runtime);
}
#[test]
fn neighbors_semantic_rejects_document_source() {
let runtime = rt();
let h = Harness::new(&runtime);
let doc_id = "d-semrej-0000-7000-8000-000000000001";
{
let conn = h.open_db();
seed_document_row(&conn, doc_id, "host");
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&neighbors_uri(
&format!("doc:{doc_id}"),
Some("semantic"),
None,
None,
),
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST, "body: {body}");
let err = body["error"].as_str().unwrap_or_default();
assert!(
err.contains("episode") && err.contains("chunk"),
"error must list supported kinds: {body}"
);
h.shutdown(&runtime);
}
#[test]
fn neighbors_semantic_rejects_cluster_source() {
let runtime = rt();
let h = Harness::new(&runtime);
let cluster_id = "cl-semrej-target";
{
let conn = h.open_db();
seed_cluster_row(&conn, cluster_id, 12345);
}
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&neighbors_uri(
&format!("cl:{cluster_id}"),
Some("semantic"),
None,
None,
),
None,
));
assert_eq!(status, StatusCode::BAD_REQUEST, "body: {body}");
h.shutdown(&runtime);
}
#[test]
fn neighbors_entity_returns_triples_only() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let host_mid = post_remember(h.router.clone(), "Alice and Bob talked").await;
{
let conn = h.open_db();
let rowid: i64 = conn
.query_row(
"SELECT rowid FROM episodes WHERE memory_id = ?1",
rusqlite::params![&host_mid],
|r| r.get(0),
)
.unwrap();
seed_triple_row(&conn, "t-ent-n-1", "Alice", "knows", "Bob", Some(rowid));
seed_triple_row(&conn, "t-ent-n-2", "Alice", "works_at", "Acme", Some(rowid));
}
let (status, body) = call(
h.router.clone(),
"GET",
&neighbors_uri("ent:Alice", None, Some(0.0), None),
None,
)
.await;
assert_eq!(status, StatusCode::OK, "body: {body}");
let edges = body["edges"].as_array().unwrap();
assert!(!edges.is_empty(), "expected explicit triples: {body}");
for e in edges {
assert_eq!(
e["kind"], "triple",
"entity focal must produce only triple edges: {body}"
);
}
});
h.shutdown(&runtime);
}
#[test]
fn neighbors_respects_tenant_scoping() {
let runtime = rt();
let h = Harness::new(&runtime);
let memory_id = "a8880000-0000-7000-8000-000000000001";
{
let conn = h.open_db();
seed_episode(&conn, memory_id, 100, "tenant scope");
}
let r = h.router.clone();
let (status, _) = runtime.block_on(async {
let req = Request::builder()
.method("GET")
.uri(neighbors_uri(
&format!("ep:{memory_id}"),
Some("explicit"),
None,
None,
))
.header("x-solo-tenant", "never-registered-tenant-n")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.expect("oneshot");
let s = resp.status();
let _b = resp.into_body().collect().await.unwrap().to_bytes();
(s, _b)
});
assert_eq!(status, StatusCode::NOT_FOUND);
let (status, body) = runtime.block_on(call(
h.router.clone(),
"GET",
&neighbors_uri(&format!("ep:{memory_id}"), Some("explicit"), None, None),
None,
));
assert_eq!(status, StatusCode::OK, "default tenant must resolve: {body}");
h.shutdown(&runtime);
}
#[test]
fn neighbors_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("neighbors-secret".into()));
let (status, _) = runtime.block_on(call(
h.router.clone(),
"GET",
&neighbors_uri(
"ep:99999999-9999-7000-8000-000000000999",
Some("explicit"),
None,
None,
),
None,
));
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, _) = runtime.block_on(call_with_auth(
h.router.clone(),
"GET",
&neighbors_uri(
"ep:99999999-9999-7000-8000-000000000999",
Some("explicit"),
None,
None,
),
None,
Some("Bearer neighbors-secret"),
));
assert_eq!(status, StatusCode::NOT_FOUND);
h.shutdown(&runtime);
}
#[derive(Debug, Clone)]
struct ParsedSseEvent {
event: String,
data: Value,
id: Option<String>,
}
async fn read_one_sse_event(
body: &mut axum::body::Body,
timeout: std::time::Duration,
) -> Option<ParsedSseEvent> {
use http_body_util::BodyExt;
let mut buf = String::new();
let start = std::time::Instant::now();
loop {
if start.elapsed() >= timeout {
return None;
}
let remaining = timeout.saturating_sub(start.elapsed());
let frame_res =
tokio::time::timeout(remaining, body.frame()).await;
let frame = match frame_res {
Ok(Some(Ok(f))) => f,
Ok(Some(Err(_))) | Ok(None) => return None,
Err(_) => return None,
};
if let Ok(data) = frame.into_data() {
buf.push_str(&String::from_utf8_lossy(&data));
while let Some(idx) = buf.find("\n\n") {
let block: String = buf.drain(..idx + 2).collect();
if let Some(parsed) = parse_sse_block(&block) {
return Some(parsed);
}
}
}
}
}
fn parse_sse_block(block: &str) -> Option<ParsedSseEvent> {
let mut event: Option<String> = None;
let mut data: Option<String> = None;
let mut id: Option<String> = None;
for line in block.lines() {
if let Some(rest) = line.strip_prefix("event:") {
event = Some(rest.trim().to_string());
} else if let Some(rest) = line.strip_prefix("data:") {
data = Some(rest.trim().to_string());
} else if let Some(rest) = line.strip_prefix("id:") {
id = Some(rest.trim().to_string());
}
}
let event = event?;
let data_str = data?;
let data_json = serde_json::from_str(&data_str).ok()?;
Some(ParsedSseEvent {
event,
data: data_json,
id,
})
}
async fn open_sse_stream_inner(
router: axum::Router,
auth: Option<&str>,
tenant: Option<&str>,
) -> (StatusCode, axum::body::Body) {
let mut builder = Request::builder()
.method("GET")
.uri("/v1/graph/stream");
if let Some(a) = auth {
builder = builder.header("authorization", a);
}
if let Some(t) = tenant {
builder = builder.header("x-solo-tenant", t);
}
let req = builder
.header("content-length", "0")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.expect("oneshot");
let status = resp.status();
let body = resp.into_body();
(status, body)
}
#[test]
fn stream_emits_init_event_on_connect() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let (status, mut body) = open_sse_stream_inner(r, None, None).await;
assert_eq!(status, StatusCode::OK);
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("must receive init event within 2s");
assert_eq!(ev.event, "init");
assert_eq!(ev.data["connected"].as_bool(), Some(true));
assert_eq!(ev.data["tenant_id"].as_str(), Some("default"));
assert!(ev.data["ts_ms"].is_number());
});
h.shutdown(&runtime);
}
#[test]
fn stream_emits_invalidate_after_writer_event() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let sender = h.invalidate_sender();
runtime.block_on(async {
let (status, mut body) = open_sse_stream_inner(r, None, None).await;
assert_eq!(status, StatusCode::OK);
let init = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.unwrap();
assert_eq!(init.event, "init");
sender
.send(InvalidateEvent {
reason: "memory.remember".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: "episode".to_string(),
})
.expect("must have at least one subscriber");
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("invalidate event must arrive within 2s");
assert_eq!(ev.event, "invalidate");
assert_eq!(ev.data["reason"].as_str(), Some("memory.remember"));
assert_eq!(ev.data["tenant_id"].as_str(), Some("default"));
assert_eq!(ev.data["kind"].as_str(), Some("episode"));
});
h.shutdown(&runtime);
}
#[test]
fn stream_emits_invalidate_for_each_writer_command() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let sender = h.invalidate_sender();
let cases = [
("memory.remember", "episode"),
("memory.forget", "episode"),
("memory.consolidate", "cluster"),
("memory.ingest_document", "document"),
("memory.forget_document", "document"),
("memory.triples_extract", "cluster"),
("memory.reembed", "episode"),
("gdpr.forget_user", "tenant"),
];
runtime.block_on(async {
let (status, mut body) = open_sse_stream_inner(r, None, None).await;
assert_eq!(status, StatusCode::OK);
let _ = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.unwrap();
for (reason, kind) in cases {
sender
.send(InvalidateEvent {
reason: reason.to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: kind.to_string(),
})
.unwrap();
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.unwrap_or_else(|| panic!("must receive event for {reason}"));
assert_eq!(ev.event, "invalidate");
assert_eq!(
ev.data["reason"].as_str(),
Some(reason),
"reason mismatch"
);
assert_eq!(ev.data["kind"].as_str(), Some(kind), "kind mismatch");
}
});
h.shutdown(&runtime);
}
#[test]
fn stream_emits_heartbeat_when_no_events() {
let runtime = rt();
let h = Harness::new(&runtime);
let sender = h.invalidate_sender();
runtime.block_on(async {
let rx = sender.subscribe();
let stream = build_invalidate_stream(rx, "default".to_string(), 1);
let sse: Sse<_> = Sse::new(stream);
let resp = sse.into_response();
let mut body = resp.into_body();
let first =
read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("init event must arrive");
assert_eq!(first.event, "init");
let second =
read_one_sse_event(&mut body, std::time::Duration::from_secs(3))
.await
.expect("heartbeat event must arrive within 3s");
assert_eq!(second.event, "heartbeat");
assert!(second.data["ts_ms"].is_number());
});
h.shutdown(&runtime);
}
#[test]
fn stream_concurrent_subscribers_same_tenant() {
let runtime = rt();
let h = Harness::new(&runtime);
let r1 = h.router.clone();
let r2 = h.router.clone();
let r3 = h.router.clone();
let sender = h.invalidate_sender();
runtime.block_on(async {
let (s1, mut body1) = open_sse_stream_inner(r1, None, None).await;
let (s2, mut body2) = open_sse_stream_inner(r2, None, None).await;
let (s3, mut body3) = open_sse_stream_inner(r3, None, None).await;
assert_eq!(s1, StatusCode::OK);
assert_eq!(s2, StatusCode::OK);
assert_eq!(s3, StatusCode::OK);
for body in [&mut body1, &mut body2, &mut body3] {
let ev = read_one_sse_event(body, std::time::Duration::from_secs(2))
.await
.unwrap();
assert_eq!(ev.event, "init");
}
assert!(
sender.receiver_count() >= 3,
"expected ≥3 subscribers, got {}",
sender.receiver_count()
);
sender
.send(InvalidateEvent {
reason: "memory.remember".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: "episode".to_string(),
})
.expect("send must succeed");
for body in [&mut body1, &mut body2, &mut body3] {
let ev = read_one_sse_event(body, std::time::Duration::from_secs(2))
.await
.unwrap();
assert_eq!(ev.event, "invalidate");
assert_eq!(ev.data["reason"].as_str(), Some("memory.remember"));
}
});
h.shutdown(&runtime);
}
#[test]
fn stream_handles_client_disconnect_gracefully() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let sender = h.invalidate_sender();
let before = sender.receiver_count();
runtime.block_on(async {
let (status, mut body) = open_sse_stream_inner(r, None, None).await;
assert_eq!(status, StatusCode::OK);
let _ = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.unwrap();
let during = sender.receiver_count();
assert!(
during > before,
"subscriber count must increase while stream is live (before={before}, during={during})"
);
drop(body);
});
runtime.block_on(async {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
});
let after = sender.receiver_count();
assert!(
after <= before,
"subscriber count must drop back after disconnect (before={before}, after={after})"
);
h.shutdown(&runtime);
}
#[test]
fn stream_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("stream-secret".into()));
let r = h.router.clone();
runtime.block_on(async {
let (status, _body) = open_sse_stream_inner(r, None, None).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
});
h.shutdown(&runtime);
}
#[test]
fn stream_works_with_auth_none() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let (status, mut body) = open_sse_stream_inner(r, None, None).await;
assert_eq!(status, StatusCode::OK);
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("must receive init event");
assert_eq!(ev.event, "init");
});
h.shutdown(&runtime);
}
#[test]
fn stream_respects_auth_accepts_valid_token() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("stream-secret".into()));
let r = h.router.clone();
runtime.block_on(async {
let (status, mut body) =
open_sse_stream_inner(r, Some("Bearer stream-secret"), None).await;
assert_eq!(status, StatusCode::OK);
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("must receive init event with valid bearer");
assert_eq!(ev.event, "init");
assert_eq!(ev.data["tenant_id"].as_str(), Some("default"));
});
h.shutdown(&runtime);
}
#[test]
fn stream_respects_tenant_scoping() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let (status, _body) =
open_sse_stream_inner(r, None, Some("never-registered-tenant-x")).await;
assert_eq!(status, StatusCode::NOT_FOUND);
});
h.shutdown(&runtime);
}
async fn seed_three_tenants(registry: &TenantRegistry) -> Vec<String> {
use solo_core::TenantId as TenantIdT;
let ids = ["alice", "bob", "default"];
for id in ids {
let tid = TenantIdT::new(id).unwrap();
registry
.with_index(|idx| {
idx.register(&tid, &format!("{id}.db"), Some(&format!("{id} tenant")))
.unwrap();
})
.await;
tokio::time::sleep(std::time::Duration::from_millis(2)).await;
}
vec!["alice".into(), "bob".into(), "default".into()]
}
#[test]
fn tenants_returns_all_when_auth_none() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let _expected = seed_three_tenants(&h.registry).await;
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let arr = body
.get("tenants")
.and_then(|v| v.as_array())
.expect("tenants array");
assert_eq!(arr.len(), 3, "got body: {body}");
let ids: Vec<&str> =
arr.iter().filter_map(|t| t["id"].as_str()).collect();
assert_eq!(ids, vec!["alice", "bob", "default"]);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_returns_all_when_bearer_auth() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("tlist-secret".into()));
let r = h.router.clone();
runtime.block_on(async {
seed_three_tenants(&h.registry).await;
let (status, body) = call_with_auth(
r,
"GET",
"/v1/tenants",
None,
Some("Bearer tlist-secret"),
)
.await;
assert_eq!(status, StatusCode::OK, "got body: {body}");
let arr = body["tenants"].as_array().expect("tenants array");
assert_eq!(arr.len(), 3, "bearer must see all tenants");
});
h.shutdown(&runtime);
}
#[test]
fn tenants_filters_to_principal_claim_when_oidc() {
let runtime = rt();
let (fake_server, discovery_url, secret, kid) =
runtime.block_on(async { spin_fake_idp().await });
let server_uri = fake_server.uri();
let _server_guard = fake_server;
let auth = crate::auth::AuthConfig::Oidc {
discovery_url,
audience: "tlist-audience".to_string(),
tenant_claim_name: "solo_tenant".to_string(),
};
let h = Harness::new_with_auth_config(&runtime, Some(auth));
let r = h.router.clone();
runtime.block_on(async {
seed_three_tenants(&h.registry).await;
let token = mint_idp_token(
&server_uri,
kid,
&secret,
"alice",
"tlist-audience",
);
let (status, body) = call_with_auth(
r,
"GET",
"/v1/tenants",
None,
Some(&format!("Bearer {token}")),
)
.await;
assert_eq!(status, StatusCode::OK, "got body: {body}");
let arr = body["tenants"].as_array().expect("tenants array");
assert_eq!(arr.len(), 1, "OIDC alice must see exactly one tenant");
assert_eq!(arr[0]["id"].as_str(), Some("alice"));
});
h.shutdown(&runtime);
}
#[test]
fn tenants_returns_empty_when_oidc_claim_unmatched() {
let runtime = rt();
let (fake_server, discovery_url, secret, kid) =
runtime.block_on(async { spin_fake_idp().await });
let server_uri = fake_server.uri();
let _server_guard = fake_server;
let auth = crate::auth::AuthConfig::Oidc {
discovery_url,
audience: "tlist-audience".to_string(),
tenant_claim_name: "solo_tenant".to_string(),
};
let h = Harness::new_with_auth_config(&runtime, Some(auth));
let r = h.router.clone();
runtime.block_on(async {
seed_three_tenants(&h.registry).await;
let token = mint_idp_token(
&server_uri,
kid,
&secret,
"nonexistent",
"tlist-audience",
);
let (status, body) = call_with_auth(
r,
"GET",
"/v1/tenants",
None,
Some(&format!("Bearer {token}")),
)
.await;
assert_eq!(
status,
StatusCode::OK,
"must be 200 OK, not 404 — don't leak tenant existence: {body}"
);
let arr = body["tenants"].as_array().expect("tenants array");
assert_eq!(
arr.len(),
0,
"unmatched OIDC claim must produce empty list, got: {body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_response_shape_matches_solo_web_types() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let tid = solo_core::TenantId::new("shaped").unwrap();
h.registry
.with_index(|idx| {
idx.register_with_quota(
&tid,
"shaped.db",
Some("Shaped tenant"),
Some(1_048_576),
)
.unwrap();
})
.await;
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let item = &body["tenants"][0];
assert_eq!(item["id"].as_str(), Some("shaped"));
assert_eq!(item["display_name"].as_str(), Some("Shaped tenant"));
assert!(
item["created_at_ms"].is_i64(),
"created_at_ms must be an i64, got {item}"
);
assert_eq!(item["status"].as_str(), Some("active"));
assert_eq!(item["quota_bytes"].as_u64(), Some(1_048_576));
assert!(
item["episode_count"].is_null(),
"episode_count must be JSON null when tenant DB is missing, got {item}"
);
assert!(
item["size_bytes"].is_null(),
"size_bytes must be JSON null when tenant DB is missing, got {item}"
);
assert!(
item["pct_used"].is_null(),
"pct_used must be JSON null when size_bytes is null, got {item}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_respects_auth_when_enabled() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("must-auth".into()));
let r = h.router.clone();
runtime.block_on(async {
seed_three_tenants(&h.registry).await;
let (status, _body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_status_filter_excludes_non_active() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let keeper = solo_core::TenantId::new("keeper").unwrap();
let migrating = solo_core::TenantId::new("migrating").unwrap();
let deleting = solo_core::TenantId::new("deleting").unwrap();
h.registry
.with_index(|idx| {
idx.register(&keeper, "keeper.db", None).unwrap();
idx.register_with_status(
&migrating,
"migrating.db",
None,
solo_storage::TenantStatus::PendingMigration,
)
.unwrap();
idx.register_with_status(
&deleting,
"deleting.db",
None,
solo_storage::TenantStatus::PendingDelete,
)
.unwrap();
})
.await;
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let arr = body["tenants"].as_array().expect("tenants array");
let ids: Vec<&str> =
arr.iter().filter_map(|t| t["id"].as_str()).collect();
assert_eq!(
ids,
vec!["keeper"],
"only Active tenants visible; got: {body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_returns_empty_array_when_no_tenants_registered() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let arr = body["tenants"].as_array().expect("tenants array");
assert_eq!(arr.len(), 0, "expected empty array, got: {body}");
});
h.shutdown(&runtime);
}
fn seed_per_tenant_db_with_episodes(
data_dir: &std::path::Path,
db_filename: &str,
n_active: i64,
n_forgotten: i64,
) -> std::path::PathBuf {
let tenants_dir = data_dir.join(solo_storage::TENANTS_SUBDIR);
std::fs::create_dir_all(&tenants_dir).unwrap();
let db_path = tenants_dir.join(db_filename);
let mut conn = rusqlite::Connection::open(&db_path).unwrap();
solo_storage::run_migrations(&mut conn).unwrap();
for i in 0..n_active {
conn.execute(
"INSERT INTO episodes (memory_id, ts_ms, source_type, content, confidence, strength, salience, tier, status, created_at_ms, updated_at_ms)
VALUES (?, 0, 'user_message', 'x', 0.5, 0.5, 0.5, 'hot', 'active', 0, 0)",
rusqlite::params![format!("a-{i}")],
)
.unwrap();
}
for i in 0..n_forgotten {
conn.execute(
"INSERT INTO episodes (memory_id, ts_ms, source_type, content, confidence, strength, salience, tier, status, created_at_ms, updated_at_ms)
VALUES (?, 0, 'user_message', 'x', 0.5, 0.5, 0.5, 'hot', 'forgotten', 0, 0)",
rusqlite::params![format!("f-{i}")],
)
.unwrap();
}
drop(conn);
db_path
}
#[test]
fn tenants_response_hydrates_episode_count_when_tenant_has_data() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let data_dir = h._tmp.path().to_path_buf();
runtime.block_on(async {
let tid = solo_core::TenantId::new("counted").unwrap();
seed_per_tenant_db_with_episodes(&data_dir, "counted.db", 3, 2);
h.registry
.with_index(|idx| {
idx.register(&tid, "counted.db", Some("Counted tenant"))
.unwrap();
})
.await;
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let item = &body["tenants"][0];
assert_eq!(item["id"].as_str(), Some("counted"));
assert_eq!(
item["episode_count"].as_i64(),
Some(3),
"episode_count must be 3 (active rows only, 2 forgotten excluded); got {item}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_response_hydrates_size_bytes_from_db_file() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let data_dir = h._tmp.path().to_path_buf();
runtime.block_on(async {
let tid = solo_core::TenantId::new("sized").unwrap();
let db_path =
seed_per_tenant_db_with_episodes(&data_dir, "sized.db", 1, 0);
h.registry
.with_index(|idx| {
idx.register(&tid, "sized.db", None).unwrap();
})
.await;
let on_disk = std::fs::metadata(&db_path).unwrap().len();
assert!(on_disk > 0, "test setup: db file should be non-empty");
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let item = &body["tenants"][0];
assert_eq!(item["id"].as_str(), Some("sized"));
assert_eq!(
item["size_bytes"].as_u64(),
Some(on_disk),
"size_bytes must match fs::metadata; got {item}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_response_computes_pct_used_when_quota_set() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let data_dir = h._tmp.path().to_path_buf();
runtime.block_on(async {
let tid = solo_core::TenantId::new("quoted").unwrap();
let db_path =
seed_per_tenant_db_with_episodes(&data_dir, "quoted.db", 1, 0);
let on_disk = std::fs::metadata(&db_path).unwrap().len();
let quota = on_disk * 4; h.registry
.with_index(|idx| {
idx.register_with_quota(&tid, "quoted.db", None, Some(quota))
.unwrap();
})
.await;
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let item = &body["tenants"][0];
let pct = item["pct_used"].as_f64().expect("pct_used must be a number");
assert!(
(0.0..=100.0).contains(&pct),
"pct_used must be in [0, 100], got {pct}"
);
assert!(
(20.0..=30.0).contains(&pct),
"pct_used must be ~25% for size=quota/4, got {pct}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_response_pct_used_null_when_quota_null() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let data_dir = h._tmp.path().to_path_buf();
runtime.block_on(async {
let tid = solo_core::TenantId::new("unlimited").unwrap();
seed_per_tenant_db_with_episodes(&data_dir, "unlimited.db", 1, 0);
h.registry
.with_index(|idx| {
idx.register(&tid, "unlimited.db", None).unwrap();
})
.await;
let (status, body) = call(r, "GET", "/v1/tenants", None).await;
assert_eq!(status, StatusCode::OK);
let item = &body["tenants"][0];
assert_eq!(item["id"].as_str(), Some("unlimited"));
assert!(
item["quota_bytes"].is_null(),
"test setup: quota_bytes must be null, got {item}"
);
assert!(
item["pct_used"].is_null(),
"pct_used must be JSON null when quota_bytes is null, got {item}"
);
assert!(
item["size_bytes"].is_u64(),
"size_bytes must still be present when quota_bytes is null, got {item}"
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_response_sets_cap_reached_header_when_over_cap() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
h.registry
.with_index(|idx| {
for i in 0..51 {
let id = format!("t{i:02}");
let tid = solo_core::TenantId::new(&id).unwrap();
idx.register(&tid, &format!("{id}.db"), None).unwrap();
}
})
.await;
use axum::body::Body;
use axum::http::Request;
use http_body_util::BodyExt;
let req = Request::builder()
.method("GET")
.uri("/v1/tenants")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let cap_header = resp
.headers()
.get(X_SOLO_TENANTS_COUNT_CAP_HEADER)
.expect("cap-reached header must be present");
assert_eq!(
cap_header.to_str().unwrap(),
"true",
"cap-reached header value must be 'true' when over cap"
);
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
let arr = body["tenants"].as_array().expect("tenants array");
assert_eq!(arr.len(), 51, "got {} tenants", arr.len());
assert!(
arr[50]["episode_count"].is_null(),
"the 51st tenant (beyond cap) must have null episode_count, got {}",
arr[50]
);
});
h.shutdown(&runtime);
}
#[test]
fn tenants_response_omits_cap_header_when_under_cap() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async {
seed_three_tenants(&h.registry).await;
use axum::body::Body;
use axum::http::Request;
let req = Request::builder()
.method("GET")
.uri("/v1/tenants")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(
resp.headers().get(X_SOLO_TENANTS_COUNT_CAP_HEADER).is_none(),
"cap-reached header must be absent under the cap"
);
});
h.shutdown(&runtime);
}
fn make_record(id: &str) -> solo_storage::TenantRecord {
solo_storage::TenantRecord {
tenant_id: solo_core::TenantId::new(id).unwrap(),
db_filename: format!("{id}.db"),
display_name: None,
created_at_ms: 0,
status: solo_storage::TenantStatus::Active,
quota_bytes: None,
last_accessed_ms: None,
}
}
#[test]
fn filter_no_principal_returns_all() {
let records = vec![make_record("a"), make_record("b")];
let out = filter_tenants_for_principal(records.clone(), None);
assert_eq!(out.len(), 2);
assert_eq!(out[0].tenant_id.as_str(), "a");
assert_eq!(out[1].tenant_id.as_str(), "b");
}
#[test]
fn filter_bearer_principal_returns_all() {
let records = vec![make_record("a"), make_record("b")];
let p = AuthenticatedPrincipal::bearer(
solo_core::TenantId::new("a").unwrap(),
);
let out = filter_tenants_for_principal(records, Some(&p));
assert_eq!(out.len(), 2);
}
#[test]
fn filter_oidc_principal_keeps_only_claim() {
let records = vec![make_record("a"), make_record("b"), make_record("c")];
let p = AuthenticatedPrincipal {
subject: "alice@example.com".to_string(),
tenant_claim: Some(solo_core::TenantId::new("b").unwrap()),
scopes: vec!["read".to_string()],
claims: serde_json::json!({ "sub": "alice@example.com" }),
};
let out = filter_tenants_for_principal(records, Some(&p));
assert_eq!(out.len(), 1);
assert_eq!(out[0].tenant_id.as_str(), "b");
}
#[test]
fn filter_oidc_principal_with_no_claim_returns_empty() {
let records = vec![make_record("a")];
let p = AuthenticatedPrincipal {
subject: "alice@example.com".to_string(),
tenant_claim: None,
scopes: vec![],
claims: serde_json::json!({ "sub": "alice@example.com" }),
};
let out = filter_tenants_for_principal(records, Some(&p));
assert!(out.is_empty());
}
#[test]
fn is_single_principal_bearer_discriminator() {
let bearer = AuthenticatedPrincipal::bearer(
solo_core::TenantId::new("default").unwrap(),
);
assert!(is_single_principal_bearer(&bearer));
let oidc = AuthenticatedPrincipal {
subject: "alice".to_string(),
tenant_claim: Some(solo_core::TenantId::new("alice").unwrap()),
scopes: vec![],
claims: serde_json::json!({ "x": 1 }),
};
assert!(!is_single_principal_bearer(&oidc));
let weird = AuthenticatedPrincipal {
subject: "bearer".to_string(),
tenant_claim: Some(solo_core::TenantId::default_tenant()),
scopes: vec![],
claims: serde_json::json!({ "leak": 1 }),
};
assert!(!is_single_principal_bearer(&weird));
}
#[test]
fn mcp_http_tools_list_returns_fourteen_canonical_tools() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
});
let (status, body) = call(r, "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body.get("jsonrpc").and_then(|v| v.as_str()), Some("2.0"));
assert_eq!(body.get("id").and_then(|v| v.as_i64()), Some(1));
let tools = body
.pointer("/result/tools")
.and_then(|v| v.as_array())
.unwrap_or_else(|| panic!("missing /result/tools: {body}"));
let mut names: Vec<String> = tools
.iter()
.filter_map(|t| t.get("name").and_then(|n| n.as_str()).map(String::from))
.collect();
names.sort();
assert_eq!(
names,
vec![
"memory_contradictions".to_string(),
"memory_facts_about".to_string(),
"memory_forget".to_string(),
"memory_forget_document".to_string(),
"memory_ingest_document".to_string(),
"memory_inspect".to_string(),
"memory_inspect_cluster".to_string(),
"memory_inspect_document".to_string(),
"memory_list_documents".to_string(),
"memory_recall".to_string(),
"memory_remember".to_string(),
"memory_remember_batch".to_string(),
"memory_search_docs".to_string(),
"memory_themes".to_string(),
],
"mcp_http: tools/list returned unexpected name set"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_remember_writes_episode_visible_via_graph_nodes() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "memory_remember",
"arguments": { "content": "mcp-http-cross-surface-smoke" },
},
});
let (status, body) = call(r.clone(), "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::OK);
let result_text = body
.pointer("/result/content/0/text")
.and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("missing /result/content/0/text: {body}"));
assert!(
result_text.starts_with("remembered "),
"expected `remembered <id>`, got: {result_text}"
);
let (status2, nodes_body) =
call(r, "GET", "/v1/graph/nodes?kind=episode&limit=10", None).await;
assert_eq!(status2, StatusCode::OK);
let nodes = nodes_body
.get("nodes")
.and_then(|v| v.as_array())
.unwrap_or_else(|| panic!("missing nodes: {nodes_body}"));
assert!(
nodes.iter().any(|n| {
let label_hit = n
.get("label")
.and_then(|c| c.as_str())
.is_some_and(|s| s.contains("mcp-http-cross-surface-smoke"));
let preview_hit = n
.get("preview")
.and_then(|c| c.as_str())
.is_some_and(|s| s.contains("mcp-http-cross-surface-smoke"));
label_hit || preview_hit
}),
"graph/nodes didn't surface the MCP-written episode: {nodes_body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_recall_returns_just_remembered_episode() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let needle = "mcp-http-recall-needle-deadbeef";
let req = json!({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "memory_remember",
"arguments": { "content": needle },
},
});
let (status, _body) = call(r.clone(), "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::OK);
let req = json!({
"jsonrpc": "2.0",
"id": 4,
"method": "tools/call",
"params": {
"name": "memory_recall",
"arguments": { "query": needle, "limit": 5 },
},
});
let (status, body) = call(r, "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::OK);
let recall_text = body
.pointer("/result/content/0/text")
.and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("missing /result/content/0/text: {body}"));
assert!(
recall_text.contains(needle),
"recall didn't surface needle `{needle}`: {recall_text}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_malformed_body_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from("not-json-at-all".as_bytes()))
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body_bytes =
resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(
v.get("error")
.and_then(|e| e.as_str())
.map(|s| s.contains("invalid JSON-RPC request"))
.unwrap_or(false),
"got: {v}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_wrong_jsonrpc_version_returns_400() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "1.0",
"id": 1,
"method": "tools/list",
});
let (status, _body) = call(r, "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::BAD_REQUEST);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_unknown_method_returns_in_body_method_not_found() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "2.0",
"id": 5,
"method": "definitely/not/a/method",
});
let (status, body) = call(r, "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(
body.pointer("/error/code").and_then(|v| v.as_i64()),
Some(-32601),
"expected JSON-RPC METHOD_NOT_FOUND (-32601), got: {body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_post_respects_bearer_auth() {
let runtime = rt();
let h = Harness::new_with_auth(&runtime, Some("secret-mcp-token".into()));
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "2.0",
"id": 6,
"method": "tools/list",
});
let (status, _body) = call(r.clone(), "POST", "/mcp", Some(req.clone())).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
let (status, body) = call_with_auth(
r,
"POST",
"/mcp",
Some(req),
Some("Bearer secret-mcp-token"),
)
.await;
assert_eq!(status, StatusCode::OK);
assert_eq!(
body.pointer("/result/tools").and_then(|v| v.as_array()).map(|a| a.len()),
Some(14),
"authed tools/list should still return 14 tools: {body}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_cors_preflight_allows_mcp_session_id_header() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("OPTIONS")
.uri("/mcp")
.header("origin", "http://localhost:5173")
.header("access-control-request-method", "POST")
.header(
"access-control-request-headers",
"content-type, mcp-session-id, x-solo-tenant, authorization",
)
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let allow_headers = resp
.headers()
.get("access-control-allow-headers")
.and_then(|h| h.to_str().ok())
.unwrap_or("")
.to_lowercase();
assert!(
allow_headers.contains("mcp-session-id"),
"preflight allow-headers must include mcp-session-id; got: {allow_headers}"
);
assert!(
allow_headers.contains("x-solo-tenant"),
"preflight allow-headers must still include x-solo-tenant; got: {allow_headers}"
);
let allow_origin = resp
.headers()
.get("access-control-allow-origin")
.and_then(|h| h.to_str().ok())
.unwrap_or("");
assert_eq!(allow_origin, "http://localhost:5173");
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_notification_returns_202_accepted() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
});
let (status, body) = call(r, "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::ACCEPTED);
assert_eq!(body, Value::Null);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_post_without_session_id_creates_new_session() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 100,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let session_id = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
panic!(
"mcp-session-id response header missing on session-init POST: {:?}",
resp.headers()
)
});
assert!(
!session_id.is_empty(),
"session id must be a non-empty string"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_post_with_valid_session_id_continues_session() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 101,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp1 = r.clone().oneshot(req).await.unwrap();
assert_eq!(resp1.status(), StatusCode::OK);
let assigned_id = resp1
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.expect("first response must carry mcp-session-id");
let req2 = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.header("mcp-session-id", &assigned_id)
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 102,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp2 = r.oneshot(req2).await.unwrap();
assert_eq!(resp2.status(), StatusCode::OK);
let echoed = resp2
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.expect("continuation response must echo mcp-session-id");
assert_eq!(
echoed, assigned_id,
"second response must echo the same session id"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_post_with_unknown_session_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.header("mcp-session-id", "11111111-2222-3333-4444-555555555555")
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 103,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
v.get("error").and_then(|e| e.as_str()),
Some(crate::mcp_session::MCP_SESSION_EXPIRED_ERROR),
"404 body must carry the session_expired discriminator: {v}"
);
assert!(
v.get("retry")
.and_then(|e| e.as_str())
.map(|s| s == "re-initialize")
.unwrap_or(false),
"404 body must instruct re-initialize: {v}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_post_with_expired_session_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
runtime.block_on(async move {
let req1 = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 104,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp1 = r.clone().oneshot(req1).await.unwrap();
let assigned_id_str = resp1
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.expect("first response must carry mcp-session-id");
let parsed = crate::mcp_session::SessionId::parse(&assigned_id_str)
.expect("just-assigned id must parse");
assert!(
store.delete(&parsed),
"stored session must be deletable"
);
let req2 = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.header("mcp-session-id", &assigned_id_str)
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 105,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp2 = r.oneshot(req2).await.unwrap();
assert_eq!(resp2.status(), StatusCode::NOT_FOUND);
let body_bytes = resp2.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
v.get("error").and_then(|e| e.as_str()),
Some(crate::mcp_session::MCP_SESSION_EXPIRED_ERROR),
"expired-session 404 body must carry session_expired: {v}"
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_get_without_session_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("GET")
.uri("/mcp")
.header("accept", "text/event-stream")
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
v.get("error").and_then(|e| e.as_str()),
Some(crate::mcp_session::MCP_SESSION_EXPIRED_ERROR),
"GET /mcp without session id must carry session_expired: {v}"
);
assert_eq!(
v.get("retry").and_then(|e| e.as_str()),
Some("re-initialize"),
);
});
h.shutdown(&runtime);
}
async fn open_mcp_get_stream(
router: axum::Router,
session_id: &str,
last_event_id: Option<&str>,
) -> (StatusCode, axum::body::Body, axum::http::HeaderMap) {
let mut builder = Request::builder()
.method("GET")
.uri("/mcp")
.header("accept", "text/event-stream")
.header(crate::mcp_session::MCP_SESSION_ID_HEADER, session_id);
if let Some(leid) = last_event_id {
builder = builder.header(crate::mcp_session::MCP_LAST_EVENT_ID_HEADER, leid);
}
let req = builder
.header("content-length", "0")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.expect("oneshot");
let status = resp.status();
let headers = resp.headers().clone();
let body = resp.into_body();
(status, body, headers)
}
async fn allocate_mcp_session(router: axum::Router) -> String {
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_vec(&json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
}))
.unwrap(),
))
.unwrap();
let resp = router.oneshot(req).await.expect("oneshot");
assert_eq!(resp.status(), StatusCode::OK, "POST must allocate session");
resp.headers()
.get(crate::mcp_session::MCP_SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.expect("POST must echo Mcp-Session-Id")
}
fn session_state_for_test(
store: &crate::mcp_session::SessionStore,
session_id: &str,
) -> std::sync::Arc<crate::mcp_session::SessionState> {
let parsed = crate::mcp_session::SessionId::parse(session_id)
.expect("test session id must parse");
store
.get(&parsed)
.expect("session must still be in store")
}
#[test]
fn mcp_get_with_expired_session_id_returns_404() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let parsed = crate::mcp_session::SessionId::parse(&session_id).unwrap();
assert!(store.delete(&parsed));
let req = Request::builder()
.method("GET")
.uri("/mcp")
.header("accept", "text/event-stream")
.header(crate::mcp_session::MCP_SESSION_ID_HEADER, &session_id)
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
v.get("error").and_then(|e| e.as_str()),
Some(crate::mcp_session::MCP_SESSION_EXPIRED_ERROR),
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_get_with_valid_session_id_subscribes() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let (status, mut body, headers) =
open_mcp_get_stream(r, &session_id, None).await;
assert_eq!(status, StatusCode::OK);
let echoed = headers
.get(crate::mcp_session::MCP_SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.unwrap();
assert_eq!(echoed, session_id);
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("init event must arrive within 2s");
assert_eq!(ev.event, crate::mcp_session::MCP_STREAM_EVENT_INIT_NAME);
assert_eq!(ev.data["connected"].as_bool(), Some(true));
assert_eq!(ev.data["session_id"].as_str(), Some(session_id.as_str()));
assert_eq!(ev.id.as_deref(), Some("0"));
});
h.shutdown(&runtime);
}
#[test]
fn mcp_get_resumes_from_last_event_id() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let state = session_state_for_test(&store, &session_id);
for i in 1..=5 {
state.publish_event(
crate::mcp_session::McpEventKind::Message,
json!({"n": i}),
);
}
let (status, mut body, _) =
open_mcp_get_stream(r, &session_id, Some("2")).await;
assert_eq!(status, StatusCode::OK);
let init = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.unwrap();
assert_eq!(init.event, crate::mcp_session::MCP_STREAM_EVENT_INIT_NAME);
for expected_id in 3..=5 {
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("replay event must arrive within 2s");
assert_eq!(
ev.event,
crate::mcp_session::MCP_STREAM_EVENT_MESSAGE_NAME,
"expected replay of message event id {expected_id}, got {ev:?}",
);
assert_eq!(ev.id.as_deref(), Some(expected_id.to_string().as_str()));
assert_eq!(ev.data["n"].as_u64(), Some(expected_id));
}
});
h.shutdown(&runtime);
}
#[test]
fn mcp_get_emits_lagged_when_last_event_id_too_old() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let state = session_state_for_test(&store, &session_id);
for _ in 0..300 {
state.publish_event(
crate::mcp_session::McpEventKind::Message,
json!({}),
);
}
let (status, mut body, _) =
open_mcp_get_stream(r, &session_id, Some("1")).await;
assert_eq!(status, StatusCode::OK);
let init = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.unwrap();
assert_eq!(init.event, crate::mcp_session::MCP_STREAM_EVENT_INIT_NAME);
let lagged = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("lagged event must arrive within 2s");
assert_eq!(
lagged.event,
crate::mcp_session::MCP_STREAM_EVENT_LAGGED_NAME,
"expected `event: lagged` after Last-Event-ID before buffer",
);
assert_eq!(lagged.id.as_deref(), Some("0"));
assert!(
lagged.data["dropped"].as_u64().unwrap_or(0) > 0,
"lagged event must carry a non-zero `dropped` count: {:?}",
lagged.data,
);
});
h.shutdown(&runtime);
}
#[test]
fn cors_preflight_allows_last_event_id_header() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = Request::builder()
.method("OPTIONS")
.uri("/mcp")
.header("origin", "http://localhost:5173")
.header("access-control-request-method", "GET")
.header(
"access-control-request-headers",
"last-event-id,mcp-session-id",
)
.body(Body::empty())
.unwrap();
let resp = r.oneshot(req).await.unwrap();
assert!(
resp.status().is_success() || resp.status() == StatusCode::NO_CONTENT,
"preflight must succeed, got: {}",
resp.status(),
);
let allow = resp
.headers()
.get("access-control-allow-headers")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_ascii_lowercase())
.unwrap_or_default();
assert!(
allow.contains("last-event-id"),
"preflight must allow `last-event-id`; allow-headers = {allow:?}",
);
assert!(
allow.contains("mcp-session-id"),
"preflight must allow `mcp-session-id` too; allow-headers = {allow:?}",
);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_get_heartbeats_after_init() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async move {
let state = std::sync::Arc::new(crate::mcp_session::SessionState::new(
solo_core::TenantId::default_tenant(),
None,
));
let session_id = crate::mcp_session::SessionId::new();
let stream = build_mcp_session_stream(
state,
session_id.clone(),
"default".to_string(),
0,
1, );
use futures::StreamExt;
let mut stream = std::pin::pin!(stream);
let init_ev = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next())
.await
.expect("init must arrive within 2s")
.expect("stream must yield init");
drop(init_ev);
let hb = tokio::time::timeout(std::time::Duration::from_secs(3), stream.next())
.await
.expect("heartbeat must arrive within ~3s")
.expect("stream must yield heartbeat");
drop(hb);
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_ingest_document_emits_parsed_and_chunked_progress_events() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let state = session_state_for_test(&store, &session_id);
let mut rx = state.subscribe_events();
let tmpdir = tempfile::TempDir::new().unwrap();
let tmpfile = tmpdir.path().join("ingest-progress.txt");
std::fs::write(&tmpfile, b"hello world progress test").unwrap();
let body = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "memory_ingest_document",
"arguments": { "path": tmpfile.to_string_lossy() },
"_meta": { "progressToken": "ingest-tok" },
},
});
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.header(crate::mcp_session::MCP_SESSION_ID_HEADER, &session_id)
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = r.clone().oneshot(req).await.expect("oneshot");
assert_eq!(resp.status(), StatusCode::OK);
let _ = resp.into_body().collect().await.unwrap().to_bytes();
let mut events = Vec::new();
while let Ok(ev) = rx.try_recv() {
events.push(ev);
}
assert!(
events.len() >= 2,
"expected at least 2 progress events (parsed + chunked), got {}: {events:?}",
events.len()
);
assert_eq!(events[0].data["params"]["progress"], json!(1));
assert_eq!(events[0].data["params"]["message"], json!("parsed"));
assert_eq!(events[1].data["params"]["progress"], json!(2));
assert_eq!(events[1].data["params"]["message"], json!("chunked"));
for ev in &events {
assert_eq!(
ev.event,
crate::mcp_session::McpEventKind::Progress,
);
assert_eq!(
ev.data["method"],
json!(crate::mcp_progress::MCP_NOTIFICATION_PROGRESS_METHOD)
);
assert_eq!(ev.data["params"]["progressToken"], json!("ingest-tok"));
assert_eq!(ev.data["params"]["total"], json!(4));
}
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_progress_event_subscribers_receive_via_get_mcp_stream() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let state = session_state_for_test(&store, &session_id);
state.publish_event(
crate::mcp_session::McpEventKind::Message,
json!({"seed": true}),
);
let body = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "memory_search_docs",
"arguments": { "query": "anything", "limit": 150 },
"_meta": { "progressToken": "progress-roundtrip" },
},
});
let req = Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.header(crate::mcp_session::MCP_SESSION_ID_HEADER, &session_id)
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = r.clone().oneshot(req).await.expect("oneshot");
assert_eq!(resp.status(), StatusCode::OK);
let _ = resp.into_body().collect().await.unwrap().to_bytes();
let (status, mut stream_body, _) =
open_mcp_get_stream(r, &session_id, Some("1")).await;
assert_eq!(status, StatusCode::OK);
let init = read_one_sse_event(&mut stream_body, std::time::Duration::from_secs(2))
.await
.expect("init must arrive within 2s");
assert_eq!(init.event, crate::mcp_session::MCP_STREAM_EVENT_INIT_NAME);
for expected_progress in 1u64..=3u64 {
let ev = read_one_sse_event(&mut stream_body, std::time::Duration::from_secs(2))
.await
.expect("progress event must arrive within 2s");
assert_eq!(
ev.event,
crate::mcp_session::MCP_STREAM_EVENT_PROGRESS_NAME,
"expected progress event #{expected_progress}, got {ev:?}",
);
assert_eq!(ev.data["jsonrpc"], json!("2.0"));
assert_eq!(
ev.data["method"],
json!(crate::mcp_progress::MCP_NOTIFICATION_PROGRESS_METHOD)
);
assert_eq!(
ev.data["params"]["progressToken"],
json!("progress-roundtrip")
);
assert_eq!(
ev.data["params"]["progress"],
json!(expected_progress)
);
assert_eq!(ev.data["params"]["total"], json!(3));
}
});
h.shutdown(&runtime);
}
#[test]
fn mcp_http_initialize_returns_solo_server_info() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
runtime.block_on(async move {
let req = json!({
"jsonrpc": "2.0",
"id": 7,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": { "name": "solo-http-test", "version": "0.0.0" },
},
});
let (status, body) = call(r, "POST", "/mcp", Some(req)).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(
body.pointer("/result/serverInfo/name").and_then(|v| v.as_str()),
Some("solo"),
"serverInfo.name must be `solo`, not `solo-api` or `rmcp`; got: {body}"
);
assert_eq!(
body.pointer("/result/protocolVersion").and_then(|v| v.as_str()),
Some("2024-11-05"),
);
});
h.shutdown(&runtime);
}
#[test]
fn session_subscribes_to_tenant_invalidate_on_creation() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
let sender = h.invalidate_sender();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r).await;
let state = session_state_for_test(&store, &session_id);
let mut rx = state.subscribe_events();
sender
.send(InvalidateEvent {
reason: "memory.remember".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: "episode".to_string(),
})
.expect("at least one subscriber (the bridge)");
let received = tokio::time::timeout(
std::time::Duration::from_secs(2),
rx.recv(),
)
.await
.expect("bridge must forward invalidate within 2s")
.expect("session receiver must observe published event");
assert_eq!(received.event, crate::mcp_session::McpEventKind::Message);
assert_eq!(
received.data["method"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_MESSAGE_METHOD),
);
});
h.shutdown(&runtime);
}
#[test]
fn invalidate_event_translates_to_mcp_notifications_message() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
let sender = h.invalidate_sender();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r).await;
let state = session_state_for_test(&store, &session_id);
let mut rx = state.subscribe_events();
sender
.send(InvalidateEvent {
reason: "memory.ingest_document".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_999_999,
kind: "document".to_string(),
})
.expect("at least one subscriber");
let received = tokio::time::timeout(
std::time::Duration::from_secs(2),
rx.recv(),
)
.await
.expect("forward within 2s")
.expect("session must receive event");
assert_eq!(received.data["jsonrpc"].as_str(), Some("2.0"));
assert_eq!(
received.data["method"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_MESSAGE_METHOD),
);
let params = &received.data["params"];
assert_eq!(
params["level"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_MESSAGE_LEVEL),
);
assert_eq!(
params["logger"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_MESSAGE_LOGGER),
);
assert_eq!(
params["data"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_DATA_DOCUMENTS_UPDATED),
);
assert_eq!(
params["details"]["reason"].as_str(),
Some("memory.ingest_document"),
);
assert_eq!(
params["details"]["kind"].as_str(),
Some("document"),
);
assert_eq!(
params["details"]["ts_ms"].as_i64(),
Some(1_715_625_999_999),
);
});
h.shutdown(&runtime);
}
#[test]
fn invalidate_event_published_to_correct_session_only() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let store = h.mcp_sessions.clone();
let sender = h.invalidate_sender();
runtime.block_on(async move {
let session_id_a = allocate_mcp_session(r.clone()).await;
let session_id_b = allocate_mcp_session(r).await;
assert_ne!(session_id_a, session_id_b);
let state_a = session_state_for_test(&store, &session_id_a);
let state_b = session_state_for_test(&store, &session_id_b);
let mut rx_a = state_a.subscribe_events();
let mut rx_b = state_b.subscribe_events();
sender
.send(InvalidateEvent {
reason: "memory.consolidate".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: "cluster".to_string(),
})
.expect("at least one subscriber");
let a = tokio::time::timeout(std::time::Duration::from_secs(2), rx_a.recv())
.await
.expect("session A receives within 2s")
.expect("session A receiver alive");
let b = tokio::time::timeout(std::time::Duration::from_secs(2), rx_b.recv())
.await
.expect("session B receives within 2s")
.expect("session B receiver alive");
for evt in [&a, &b] {
assert_eq!(evt.event, crate::mcp_session::McpEventKind::Message);
assert_eq!(
evt.data["params"]["data"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_DATA_CONSOLIDATION_UPDATED),
);
}
});
h.shutdown(&runtime);
}
#[test]
fn mcp_get_subscriber_receives_notifications_message_event() {
let runtime = rt();
let h = Harness::new(&runtime);
let r = h.router.clone();
let sender = h.invalidate_sender();
runtime.block_on(async move {
let session_id = allocate_mcp_session(r.clone()).await;
let (status, mut body, _) =
open_mcp_get_stream(r, &session_id, None).await;
assert_eq!(status, StatusCode::OK);
let init = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("init event must arrive within 2s");
assert_eq!(
init.event,
crate::mcp_session::MCP_STREAM_EVENT_INIT_NAME,
);
sender
.send(InvalidateEvent {
reason: "memory.triples_extract".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: "triple".to_string(),
})
.expect("send must succeed");
let ev = read_one_sse_event(&mut body, std::time::Duration::from_secs(2))
.await
.expect("message event must arrive within 2s");
assert_eq!(
ev.event,
crate::mcp_session::MCP_STREAM_EVENT_MESSAGE_NAME,
);
assert_eq!(ev.data["jsonrpc"].as_str(), Some("2.0"));
assert_eq!(
ev.data["method"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_MESSAGE_METHOD),
);
assert_eq!(
ev.data["params"]["data"].as_str(),
Some(crate::mcp_notify::MCP_NOTIFICATION_DATA_GRAPH_UPDATED),
);
assert_eq!(
ev.data["params"]["details"]["reason"].as_str(),
Some("memory.triples_extract"),
);
});
h.shutdown(&runtime);
}
}
#[cfg(test)]
mod cors_tests {
use super::is_localhost_origin;
#[test]
fn accepts_canonical_localhost_origins() {
assert!(is_localhost_origin("http://localhost"));
assert!(is_localhost_origin("http://localhost:3000"));
assert!(is_localhost_origin("https://localhost:8443"));
assert!(is_localhost_origin("http://127.0.0.1"));
assert!(is_localhost_origin("http://127.0.0.1:5173"));
assert!(is_localhost_origin("http://[::1]"));
assert!(is_localhost_origin("http://[::1]:8080"));
}
#[test]
fn rejects_remote_origins() {
assert!(!is_localhost_origin("http://example.com"));
assert!(!is_localhost_origin("https://malicious.example"));
assert!(!is_localhost_origin("http://192.168.1.5"));
assert!(!is_localhost_origin("http://10.0.0.1"));
}
#[test]
fn rejects_dns_rebinding_tricks() {
assert!(!is_localhost_origin("http://127.0.0.1.nip.io"));
assert!(!is_localhost_origin("http://localhost.evil.com"));
assert!(!is_localhost_origin("http://evil.localhost"));
}
#[test]
fn rejects_non_http_schemes() {
assert!(!is_localhost_origin("file:///"));
assert!(!is_localhost_origin("ws://localhost:3000"));
assert!(!is_localhost_origin("javascript:alert(1)"));
}
#[test]
fn rejects_malformed() {
assert!(!is_localhost_origin(""));
assert!(!is_localhost_origin("localhost"));
assert!(!is_localhost_origin("//localhost"));
}
}