use axum::{
Json,
extract::{FromRef, FromRequest, Request, State, rejection::JsonRejection},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use serde::de::DeserializeOwned;
use serde_json::json;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use crate::config::{ResolvedTtl, TierConfig};
use crate::db;
use crate::embeddings::{Embed, Embedder};
use crate::hnsw::VectorIndex;
use crate::profile::Family;
pub type Db = Arc<Mutex<(rusqlite::Connection, std::path::PathBuf, ResolvedTtl, bool)>>;
pub async fn db_op<T, F>(db: Db, op: F) -> T
where
T: Send + 'static,
F: FnOnce(&mut (rusqlite::Connection, std::path::PathBuf, ResolvedTtl, bool)) -> T
+ Send
+ 'static,
{
tokio::task::spawn_blocking(move || {
let mut guard = db.blocking_lock();
op(&mut guard)
})
.await
.expect("PERF-1: db_op spawn_blocking worker panicked or runtime shut down")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend {
Sqlite,
Postgres,
}
impl StorageBackend {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Sqlite => "sqlite",
Self::Postgres => "postgres",
}
}
}
#[derive(Clone)]
pub struct AppState {
pub db: Db,
pub embedder: Arc<Option<Embedder>>,
pub vector_index: Arc<Mutex<Option<VectorIndex>>>,
pub federation: Arc<Option<crate::federation::FederationConfig>>,
pub tier_config: Arc<TierConfig>,
pub scoring: Arc<crate::config::ResolvedScoring>,
pub profile: Arc<crate::profile::Profile>,
pub mcp_config: Arc<Option<crate::config::McpConfig>>,
pub active_keypair: Arc<Option<crate::identity::keypair::AgentKeypair>>,
pub family_embeddings: Arc<RwLock<Option<Vec<(Family, Vec<f32>)>>>>,
pub storage_backend: StorageBackend,
#[cfg(feature = "sal")]
pub store: Arc<dyn crate::store::MemoryStore>,
pub llm: Arc<Option<crate::llm::OllamaClient>>,
pub auto_tag_model: Arc<Option<String>>,
pub llm_call_timeout: std::time::Duration,
pub replay_cache: Arc<crate::identity::replay::ReplayCache>,
pub verify_require_nonce: bool,
pub federation_nonce_cache: Arc<crate::identity::replay::FederationNonceCache>,
pub autonomous_hooks: bool,
pub recall_scope: Arc<Option<crate::config::RecallScope>>,
pub deferred_audit_queue: Arc<Option<crate::governance::deferred_audit::DeferredAuditQueue>>,
pub admin_agent_ids: Arc<Vec<String>>,
pub rule_cache: Arc<crate::governance::rule_cache::RuleCache>,
pub resolved_models: Arc<crate::config::ResolvedModels>,
pub runtime: Arc<crate::runtime_context::RuntimeContext>,
pub max_page_size: usize,
}
#[must_use]
pub fn family_descriptors() -> &'static [(Family, &'static str)] {
&[
(
Family::Core,
"Store, recall, list, get, and search memories. The basic \
read and write operations for saving facts and looking \
them up later.",
),
(
Family::Lifecycle,
"Update, delete, forget, garbage-collect, and promote \
memories. Operations that change a memory's state, tier, \
or visibility over time.",
),
(
Family::Graph,
"Knowledge-graph queries, timelines, links between \
memories, entity registration, taxonomy lookup, and \
replay or verification of stored relationships.",
),
(
Family::Governance,
"Approval workflows, namespace standards, and \
subscriptions. Operations that gate or shape what other \
agents are allowed to write or see.",
),
(
Family::Power,
"Advanced reasoning helpers: consolidate duplicates, \
detect contradictions, check for duplicates, auto-tag, \
expand a query, and inspect the inbox.",
),
(
Family::Meta,
"Server capabilities, agent registration and listing, \
session bootstrap, and aggregate stats. Operations that \
describe the memory system itself rather than its \
contents.",
),
(
Family::Archive,
"List, restore, purge, and report stats on archived \
memories. The cold-storage tier where forgotten or aged-out \
memories live until they are pruned.",
),
(
Family::Other,
"Subscription listing and out-of-band notifications. \
Auxiliary operations that don't fit the other families.",
),
]
}
impl AppState {
#[must_use]
pub fn precompute_family_embeddings(embedder: Option<&dyn Embed>) -> Vec<(Family, Vec<f32>)> {
let Some(embedder) = embedder else {
return Vec::new();
};
let descriptors = family_descriptors();
let mut out: Vec<(Family, Vec<f32>)> = Vec::with_capacity(descriptors.len());
for (family, descriptor) in descriptors {
match embedder.embed(descriptor) {
Ok(v) => out.push((*family, v)),
Err(e) => {
tracing::warn!(
family = family.name(),
error = %e,
"B3: failed to embed family descriptor; \
family_embeddings will be empty",
);
return Vec::new();
}
}
}
out
}
#[must_use]
pub fn best_family_match(&self, intent: &str) -> Option<(Family, f32)> {
let guard = self.family_embeddings.try_read().ok()?;
let cache = guard.as_ref()?;
if cache.is_empty() {
return None;
}
let embedder = self.embedder.as_ref().as_ref()?;
let intent_vec = embedder.embed_query(intent).ok()?;
let mut best: Option<(Family, f32)> = None;
for (family, descriptor_vec) in cache.iter() {
let score = Embedder::cosine_similarity(&intent_vec, descriptor_vec);
match best {
Some((_, prev)) if prev >= score => {}
_ => best = Some((*family, score)),
}
}
best
}
}
impl FromRef<AppState> for Db {
fn from_ref(app: &AppState) -> Self {
app.db.clone()
}
}
pub const MAX_BULK_SIZE: usize = 1000;
pub struct JsonOrBadRequest<T>(pub T);
impl<S, T> FromRequest<S> for JsonOrBadRequest<T>
where
S: Send + Sync,
T: DeserializeOwned,
Json<T>: FromRequest<S, Rejection = JsonRejection>,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
match Json::<T>::from_request(req, state).await {
Ok(Json(value)) => Ok(Self(value)),
Err(rej) => Err(json_rejection_to_400(&rej)),
}
}
}
fn json_rejection_to_400(rej: &JsonRejection) -> Response {
let raw_msg = rej.body_text();
let fields = extract_missing_fields(&raw_msg);
let error_msg = if let Some(first) = fields.first() {
format!("missing required field: {first}")
} else {
match rej {
JsonRejection::JsonSyntaxError(_) => "malformed JSON body".to_string(),
JsonRejection::MissingJsonContentType(_) => {
"expected Content-Type: application/json".to_string()
}
_ => "invalid request body".to_string(),
}
};
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": error_msg,
"fields": fields,
})),
)
.into_response()
}
fn extract_missing_fields(msg: &str) -> Vec<String> {
const MAX_MISSING_FIELDS: usize = 16;
let needle = "missing field `";
let mut out: Vec<String> = Vec::new();
let mut rest = msg;
while let Some(idx) = rest.find(needle) {
if out.len() >= MAX_MISSING_FIELDS {
break;
}
let after = &rest[idx + needle.len()..];
if let Some(end) = after.find('`') {
let name = &after[..end];
if !name.is_empty()
&& name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
&& !out.iter().any(|existing| existing == name)
{
out.push(name.to_string());
}
rest = &after[end + 1..];
} else {
break;
}
}
out
}
pub(crate) const BULK_FANOUT_CONCURRENCY: usize = 8;
#[derive(Clone, Default)]
pub struct ApiKeyState {
pub key: Option<String>,
pub mtls_enforced: bool,
}
#[inline]
pub(crate) fn percent_decode_lossy(input: &str) -> String {
let bytes = input.as_bytes();
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
let h = (bytes[i + 1] as char).to_digit(16);
let l = (bytes[i + 2] as char).to_digit(16);
if let (Some(h), Some(l)) = (h, l) {
out.push(u8::try_from(h * 16 + l).unwrap_or(0));
i += 3;
continue;
}
}
out.push(bytes[i]);
i += 1;
}
String::from_utf8_lossy(&out).into_owned()
}
#[inline]
pub(crate) fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let len_a = a.len();
let len_b = b.len();
let max_len = len_a.max(len_b);
let mut diff: u8 = 0;
diff |= u8::from(len_a != len_b);
for i in 0..max_len {
let x = a.get(i).copied().unwrap_or(0);
let y = b.get(i).copied().unwrap_or(0);
diff |= x ^ y;
}
diff == 0
}
pub async fn api_key_auth(
State(auth): State<ApiKeyState>,
req: Request,
next: Next,
) -> impl IntoResponse {
let Some(ref expected) = auth.key else {
return next.run(req).await.into_response();
};
if req.uri().path() == super::routes::HEALTH {
return next.run(req).await.into_response();
}
let path = req.uri().path();
if auth.mtls_enforced && path.starts_with("/api/v1/sync/") {
return next.run(req).await.into_response();
}
if let Some(header_val) = req.headers().get(crate::HEADER_API_KEY)
&& let Ok(val) = header_val.to_str()
&& constant_time_eq(val.as_bytes(), expected.as_bytes())
{
return next.run(req).await.into_response();
}
if let Some(query) = req.uri().query() {
for pair in query.split('&') {
if let Some(val) = pair.strip_prefix("api_key=") {
let decoded = percent_decode_lossy(val);
if constant_time_eq(decoded.as_bytes(), expected.as_bytes()) {
static QUERY_KEY_WARN_ONCE: std::sync::Once = std::sync::Once::new();
QUERY_KEY_WARN_ONCE.call_once(|| {
tracing::warn!(
target: "http::auth",
"a request authenticated via the `?api_key=` query \
parameter; URL-embedded credentials leak into access \
logs, Referer headers, and proxy logs. Migrate callers \
to the `x-api-key` request header — the `?api_key=` \
query form is DEPRECATED and will be removed in a \
future release (still accepted for the v0.7.0 \
back-compat contract)."
);
});
return next.run(req).await.into_response();
}
}
}
}
(
StatusCode::UNAUTHORIZED,
Json(json!({"error": "missing or invalid API key"})),
)
.into_response()
}
pub async fn health(State(app): State<AppState>) -> impl IntoResponse {
#[cfg(feature = "sal-postgres")]
let ok = if matches!(app.storage_backend, StorageBackend::Postgres) {
app.store.health_check().await.unwrap_or(false)
} else {
db_op(app.db.clone(), |guard| {
db::health_check(&guard.0).unwrap_or(false)
})
.await
};
#[cfg(not(feature = "sal-postgres"))]
let ok = db_op(app.db.clone(), |guard| {
db::health_check(&guard.0).unwrap_or(false)
})
.await;
let embedder_ready = app.embedder.as_ref().is_some();
let federation_enabled = app.federation.as_ref().is_some();
let code = if ok {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
};
(
code,
Json(json!({
"status": if ok { "ok" } else { "error" },
"service": "ai-memory",
"version": crate::PKG_VERSION,
"embedder_ready": embedder_ready,
"federation_enabled": federation_enabled,
})),
)
.into_response()
}
pub async fn prometheus_metrics(State(state): State<Db>) -> impl IntoResponse {
db_op(state, |guard| {
if let Ok(stats) = db::stats(&guard.0, &guard.1) {
crate::metrics::registry()
.memories_gauge
.set(stats.total.try_into().unwrap_or(i64::MAX));
}
})
.await;
let body = crate::metrics::render();
(
StatusCode::OK,
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
body,
)
.into_response()
}
#[cfg(test)]
mod transport_helpers_tests {
use super::*;
#[test]
fn percent_decode_handles_typical_keys() {
assert_eq!(percent_decode_lossy("abc"), "abc");
assert_eq!(percent_decode_lossy("a%2Bb"), "a+b");
assert_eq!(percent_decode_lossy("hello%20world"), "hello world");
assert_eq!(percent_decode_lossy("%2F%3D%3F"), "/=?");
}
#[test]
fn percent_decode_passes_through_invalid_escapes() {
assert_eq!(percent_decode_lossy("a%ZZb"), "a%ZZb");
assert_eq!(percent_decode_lossy("a%2"), "a%2");
}
#[test]
fn constant_time_eq_handles_equal_and_diff_inputs() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn constant_time_eq_no_length_short_circuit_1060() {
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(!constant_time_eq(b"abcd", b"abc"));
assert!(constant_time_eq(b"abc", b"abc"));
assert!(constant_time_eq(b"", b""));
assert!(!constant_time_eq(b"", b"x"));
assert!(!constant_time_eq(b"xxxx", b"yy"));
assert!(!constant_time_eq(b"aa", b"aaaa"));
}
#[test]
fn storage_backend_as_str_round_trip() {
assert_eq!(StorageBackend::Sqlite.as_str(), "sqlite");
assert_eq!(StorageBackend::Postgres.as_str(), "postgres");
}
#[test]
fn family_descriptors_returns_eight_entries() {
let d = family_descriptors();
assert_eq!(d.len(), 8, "expected 8 family descriptors, got {}", d.len());
for (family, text) in d {
assert!(!text.is_empty(), "descriptor for {family:?} is empty");
assert!(
text.len() > 20,
"descriptor for {family:?} too short: {text}"
);
}
}
#[test]
fn precompute_family_embeddings_no_embedder_returns_empty() {
let out = AppState::precompute_family_embeddings(None);
assert!(out.is_empty());
}
#[test]
fn extract_missing_fields_finds_single_field() {
let msg =
"Failed to deserialize the JSON body: missing field `content` at line 1 column 14";
let fields = extract_missing_fields(msg);
assert_eq!(fields, vec!["content".to_string()]);
}
#[test]
fn extract_missing_fields_finds_multiple_fields() {
let msg = "missing field `title` and missing field `content`";
let fields = extract_missing_fields(msg);
assert_eq!(fields, vec!["title".to_string(), "content".to_string()]);
}
#[test]
fn extract_missing_fields_dedups_repeats() {
let msg = "missing field `name` ... missing field `name` again";
let fields = extract_missing_fields(msg);
assert_eq!(fields, vec!["name".to_string()]);
}
#[test]
fn extract_missing_fields_returns_empty_for_clean_message() {
assert!(extract_missing_fields("no missing fields here").is_empty());
}
#[test]
fn extract_missing_fields_rejects_non_identifier_content() {
let msg = "missing field `<script>` injection attempt";
let fields = extract_missing_fields(msg);
assert!(fields.is_empty(), "non-ident content must be rejected");
}
#[test]
fn extract_missing_fields_accepts_underscores_and_dashes() {
let msg = "missing field `agent_id-x` here";
let fields = extract_missing_fields(msg);
assert_eq!(fields, vec!["agent_id-x".to_string()]);
}
#[test]
fn extract_missing_fields_handles_unterminated_backtick() {
let msg = "missing field `unterminated";
let fields = extract_missing_fields(msg);
assert!(fields.is_empty());
}
}