use crate::models::field_names;
use std::net::{IpAddr, Ipv4Addr, ToSocketAddrs};
use std::str::FromStr;
use std::sync::{Arc, OnceLock};
use anyhow::{Context, Result, anyhow};
use rusqlite::{Connection, params};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::sync::Semaphore;
const SUBSCRIPTIONS_TRACE_TARGET: &str = "ai_memory::subscriptions";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Subscription {
pub id: String,
pub url: String,
pub events: String,
pub namespace_filter: Option<String>,
pub agent_filter: Option<String>,
pub created_by: Option<String>,
pub created_at: String,
pub dispatch_count: i64,
pub failure_count: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub event_types: Option<Vec<String>>,
}
pub struct NewSubscription<'a> {
pub url: &'a str,
pub events: &'a str,
pub secret: Option<&'a str>,
pub namespace_filter: Option<&'a str>,
pub agent_filter: Option<&'a str>,
pub created_by: Option<&'a str>,
pub event_types: Option<&'a [String]>,
}
pub const WEBHOOK_EVENT_TYPES: &[&str] = &[
crate::mcp::registry::tool_names::MEMORY_STORE,
crate::mcp::registry::tool_names::MEMORY_PROMOTE,
crate::mcp::registry::tool_names::MEMORY_DELETE,
webhook_events::MEMORY_LINK_CREATED,
webhook_events::MEMORY_LINK_INVALIDATED,
webhook_events::MEMORY_CONSOLIDATED,
webhook_events::APPROVAL_REQUESTED,
];
pub mod webhook_events {
pub const MEMORY_LINK_CREATED: &str = "memory_link_created";
pub const MEMORY_LINK_INVALIDATED: &str = "memory_link_invalidated";
pub const MEMORY_CONSOLIDATED: &str = "memory_consolidated";
pub const APPROVAL_REQUESTED: &str = "approval_requested";
}
pub fn insert(conn: &Connection, req: &NewSubscription<'_>) -> Result<String> {
validate_url(req.url)?;
let id = uuid::Uuid::new_v4().to_string();
let secret_hash = req.secret.map(sha256_hex);
let now = chrono::Utc::now().to_rfc3339();
let (events_csv, event_types_json) = if let Some(list) = req.event_types {
for ev in list {
if !WEBHOOK_EVENT_TYPES.contains(&ev.as_str()) {
return Err(anyhow!(
"unknown webhook event type {ev:?}; valid types: {WEBHOOK_EVENT_TYPES:?}"
));
}
}
let csv = list.join(",");
let json = serde_json::to_string(list).context("event_types serialise")?;
(csv, Some(json))
} else {
(req.events.to_string(), None)
};
conn.execute(
"INSERT INTO subscriptions (id, url, events, secret_hash, namespace_filter, agent_filter, created_by, created_at, event_types) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
params![id, req.url, events_csv, secret_hash, req.namespace_filter, req.agent_filter, req.created_by, now, event_types_json],
)?;
Ok(id)
}
pub fn delete(conn: &Connection, id: &str, caller_agent_id: Option<&str>) -> Result<bool> {
let n = if let Some(aid) = caller_agent_id {
conn.execute(
"DELETE FROM subscriptions WHERE id = ?1 AND created_by = ?2",
params![id, aid],
)?
} else {
conn.execute("DELETE FROM subscriptions WHERE id = ?1", params![id])?
};
Ok(n > 0)
}
pub fn list(conn: &Connection, caller_agent_id: Option<&str>) -> Result<Vec<Subscription>> {
let mut stmt = if caller_agent_id.is_some() {
conn.prepare(
"SELECT id, url, events, namespace_filter, agent_filter, created_by, created_at, dispatch_count, failure_count, event_types FROM subscriptions WHERE created_by = ?1 ORDER BY created_at DESC",
)?
} else {
conn.prepare(
"SELECT id, url, events, namespace_filter, agent_filter, created_by, created_at, dispatch_count, failure_count, event_types FROM subscriptions ORDER BY created_at DESC",
)?
};
let row_decoder = |row: &rusqlite::Row<'_>| {
let event_types_raw: Option<String> = row.get(9)?;
let event_types =
event_types_raw.and_then(|s| match serde_json::from_str::<Vec<String>>(&s) {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(
"subscription event_types JSON decode failed, treating as all-events: {e}"
);
None
}
});
Ok(Subscription {
id: row.get(0)?,
url: row.get(1)?,
events: row.get(2)?,
namespace_filter: row.get(3)?,
agent_filter: row.get(4)?,
created_by: row.get(5)?,
created_at: row.get(6)?,
dispatch_count: row.get(7)?,
failure_count: row.get(8)?,
event_types,
})
};
let rows = if let Some(aid) = caller_agent_id {
stmt.query_map(params![aid], row_decoder)?
.collect::<rusqlite::Result<Vec<_>>>()
} else {
stmt.query_map([], row_decoder)?
.collect::<rusqlite::Result<Vec<_>>>()
};
rows.context("subscription row decode failed")
}
pub fn get_owner(conn: &Connection, id: &str) -> Result<Option<String>> {
let mut stmt = conn.prepare("SELECT created_by FROM subscriptions WHERE id = ?1")?;
let mut rows = stmt.query(params![id])?;
if let Some(row) = rows.next().context("subscription owner row")? {
let owner: Option<String> = row.get(0)?;
Ok(owner)
} else {
Ok(None)
}
}
pub fn list_by_event(conn: &Connection, event_type: &str) -> Result<Vec<Subscription>> {
let pattern = format!("%{event_type}%");
let mut stmt = conn.prepare(
"SELECT id, url, events, namespace_filter, agent_filter, created_by, created_at, dispatch_count, failure_count, event_types FROM subscriptions WHERE event_types IS NULL OR event_types LIKE ?1 ORDER BY created_at DESC",
)?;
let rows = stmt.query_map(params![pattern], |row| {
let event_types_raw: Option<String> = row.get(9)?;
let event_types =
event_types_raw.and_then(|s| serde_json::from_str::<Vec<String>>(&s).ok());
Ok(Subscription {
id: row.get(0)?,
url: row.get(1)?,
events: row.get(2)?,
namespace_filter: row.get(3)?,
agent_filter: row.get(4)?,
created_by: row.get(5)?,
created_at: row.get(6)?,
dispatch_count: row.get(7)?,
failure_count: row.get(8)?,
event_types,
})
})?;
let mut out: Vec<Subscription> = Vec::new();
for sub in rows {
let s = sub.context("subscription row decode failed")?;
match &s.event_types {
None => out.push(s),
Some(list) if list.iter().any(|e| e == event_type) => out.push(s),
Some(_) => {} }
}
Ok(out)
}
pub(crate) fn matches_filters(
sub_events: &str,
sub_event_types: Option<&[String]>,
sub_namespace: Option<&str>,
sub_agent: Option<&str>,
event: &str,
namespace: &str,
agent: Option<&str>,
) -> bool {
let event_match = if let Some(list) = sub_event_types {
list.iter().any(|e| e == event)
} else {
sub_events == "*"
|| sub_events
.split(',')
.map(str::trim)
.any(|e| e == event || e == "*")
};
if !event_match {
return false;
}
if let Some(ns) = sub_namespace
&& !ns.is_empty()
&& ns != namespace
{
return false;
}
if let Some(filter) = sub_agent
&& !filter.is_empty()
&& agent.is_none_or(|a| a != filter)
{
return false;
}
true
}
#[derive(Serialize)]
struct DispatchPayload<'a> {
event: &'a str,
memory_id: &'a str,
namespace: &'a str,
agent_id: Option<&'a str>,
delivered_at: String,
correlation_id: &'a str,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
details: Option<serde_json::Value>,
}
pub const RETRY_BACKOFFS: &[std::time::Duration] = &[
std::time::Duration::from_millis(200),
std::time::Duration::from_secs(1),
std::time::Duration::from_secs(5),
];
pub const ACK_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
pub const DEFAULT_WEBHOOK_DISPATCH_CONCURRENCY: usize = 32;
static DISPATCH_SEMAPHORE: OnceLock<Arc<Semaphore>> = OnceLock::new();
fn dispatch_concurrency_bound() -> usize {
if let Some(forced) = TEST_DISPATCH_CONCURRENCY_OVERRIDE.get() {
return *forced;
}
match std::env::var("AI_MEMORY_WEBHOOK_DISPATCH_CONCURRENCY") {
Ok(raw) => match raw.parse::<usize>() {
Ok(n) if (1..=4096).contains(&n) => n,
_ => {
tracing::warn!(
"AI_MEMORY_WEBHOOK_DISPATCH_CONCURRENCY={raw:?} not in 1..=4096; \
falling back to default {DEFAULT_WEBHOOK_DISPATCH_CONCURRENCY}"
);
DEFAULT_WEBHOOK_DISPATCH_CONCURRENCY
}
},
Err(_) => DEFAULT_WEBHOOK_DISPATCH_CONCURRENCY,
}
}
fn dispatch_semaphore() -> Arc<Semaphore> {
DISPATCH_SEMAPHORE
.get_or_init(|| Arc::new(Semaphore::new(dispatch_concurrency_bound())))
.clone()
}
static TEST_DISPATCH_CONCURRENCY_OVERRIDE: OnceLock<usize> = OnceLock::new();
#[doc(hidden)]
pub fn override_dispatch_concurrency_for_tests(n: usize) -> Result<(), usize> {
TEST_DISPATCH_CONCURRENCY_OVERRIDE.set(n)
}
#[doc(hidden)]
pub fn dispatch_semaphore_available_permits() -> usize {
dispatch_semaphore().available_permits()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubscriptionEvent {
pub id: i64,
pub subscription_id: String,
pub correlation_id: String,
pub event_type: String,
pub payload: String,
pub delivered_at: String,
pub delivery_status: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DlqEntry {
pub id: i64,
pub subscription_id: String,
pub correlation_id: String,
pub event_type: String,
pub payload: String,
pub retry_count: i64,
pub last_error: String,
pub first_failed_at: String,
pub last_failed_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromoteEventDetails {
pub mode: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tier: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub to_namespace: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub clone_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteEventDetails {
pub title: String,
pub tier: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinkCreatedEventDetails {
pub target_id: String,
pub relation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsolidatedEventDetails {
pub source_ids: Vec<String>,
pub source_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalRequestedEventDetails {
pub action_type: String,
pub requested_at: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub memory_id: Option<String>,
pub status: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinkInvalidatedEventDetails {
pub target_id: String,
pub relation: String,
pub valid_until: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub previous_valid_until: Option<String>,
}
pub fn dispatch_event(
conn: &Connection,
event: &str,
memory_id: &str,
namespace: &str,
agent_id: Option<&str>,
db_path: &std::path::Path,
) {
dispatch_event_with_details(conn, event, memory_id, namespace, agent_id, db_path, None);
}
pub fn dispatch_event_with_details(
conn: &Connection,
event: &str,
memory_id: &str,
namespace: &str,
agent_id: Option<&str>,
db_path: &std::path::Path,
details: Option<serde_json::Value>,
) {
let subs = match list_by_event(conn, event) {
Ok(s) => s,
Err(e) => {
tracing::warn!("subscription list failed during dispatch: {e}");
return;
}
};
let matching: Vec<(Subscription, Option<String>)> = subs
.into_iter()
.filter(|s| {
matches_filters(
&s.events,
s.event_types.as_deref(),
s.namespace_filter.as_deref(),
s.agent_filter.as_deref(),
event,
namespace,
agent_id,
)
})
.map(|s| {
let secret_hash = load_secret_hash_with_conn(conn, &s.id).unwrap_or(None);
(s, secret_hash)
})
.collect();
dispatch_event_to_subs(
matching, event, memory_id, namespace, agent_id, db_path, details,
);
}
pub fn dispatch_event_to_subs(
matching: Vec<(Subscription, Option<String>)>,
event: &str,
memory_id: &str,
namespace: &str,
agent_id: Option<&str>,
db_path: &std::path::Path,
details: Option<serde_json::Value>,
) {
if matching.is_empty() {
return;
}
{
static KICK: std::sync::Once = std::sync::Once::new();
KICK.call_once(|| {
std::thread::spawn(prewarm_dispatch_tls);
});
}
let timestamp = chrono::Utc::now().timestamp().to_string();
let handle = tokio::runtime::Handle::try_current().ok();
for (sub, sub_secret_hash) in matching {
let correlation_id = uuid::Uuid::now_v7().to_string();
let payload = DispatchPayload {
event,
memory_id,
namespace,
agent_id,
delivered_at: chrono::Utc::now().to_rfc3339(),
correlation_id: &correlation_id,
details: details.clone(),
};
let body = match serde_json::to_string(&payload) {
Ok(s) => s,
Err(e) => {
tracing::warn!("dispatch payload serialize failed: {e}");
continue;
}
};
let url = sub.url.clone();
let sub_id = sub.id.clone();
let event_owned = event.to_string();
let ts = timestamp.clone();
let db_path = db_path.to_path_buf();
let secret_hash_owned = sub_secret_hash.clone();
let work = move || {
let worker_conn = match Connection::open(&db_path) {
Ok(c) => Some(c),
Err(e) => {
tracing::warn!(
"subscription dispatch: worker_conn open failed: {e}; \
falling back to per-write connections"
);
None
}
};
let event_audit_result = if let Some(c) = worker_conn.as_ref() {
record_subscription_event_with_conn(
c,
&sub_id,
&correlation_id,
&event_owned,
&body,
)
} else {
record_subscription_event(&db_path, &sub_id, &correlation_id, &event_owned, &body)
};
if let Err(e) = event_audit_result {
tracing::warn!("subscription event audit write failed: {e}");
}
let secret_hash = secret_hash_owned;
let canonical = format!("{ts}.{body}");
let signature = match secret_hash.as_deref() {
Some(h) => Some(hmac_sha256_hex(h, &canonical)),
None => crate::config::active_hooks_hmac_secret().map(|plain| {
let key_hash = sha256_hex(&plain);
hmac_sha256_hex(&key_hash, &canonical)
}),
};
if signature.is_none() {
tracing::error!(
"subscription {sub_id} dispatch refused: no per-sub secret AND no \
server-wide [hooks.subscription] hmac_secret configured. \
Configure one of the two and replay via memory_subscription_replay. \
(v0.7.0 fix campaign R3-S1.HMAC, 2026-05-13)"
);
let outcome = DeliveryOutcome::unsigned_refused();
let ok = outcome.success;
if let Some(c) = worker_conn.as_ref() {
record_dispatch_with_conn(c, &sub_id, ok);
update_event_status_with_conn(c, &correlation_id, ok);
} else {
record_dispatch(&db_path, &sub_id, ok);
update_event_status(&db_path, &correlation_id, ok);
}
let dlq_result = if let Some(c) = worker_conn.as_ref() {
record_dlq_with_conn(
c,
&sub_id,
&correlation_id,
&event_owned,
&body,
outcome.attempts,
&outcome.last_error,
&outcome.first_failed_at,
&outcome.last_failed_at,
)
} else {
record_dlq(
&db_path,
&sub_id,
&correlation_id,
&event_owned,
&body,
outcome.attempts,
&outcome.last_error,
&outcome.first_failed_at,
&outcome.last_failed_at,
)
};
if let Err(e) = dlq_result {
tracing::warn!("subscription DLQ write failed: {e}");
}
return;
}
let outcome =
deliver_with_retry(&url, &body, &ts, signature.as_deref(), &correlation_id);
let ok = outcome.success;
if let Some(c) = worker_conn.as_ref() {
record_dispatch_with_conn(c, &sub_id, ok);
update_event_status_with_conn(c, &correlation_id, ok);
} else {
record_dispatch(&db_path, &sub_id, ok);
update_event_status(&db_path, &correlation_id, ok);
}
if !ok {
let dlq_result = if let Some(c) = worker_conn.as_ref() {
record_dlq_with_conn(
c,
&sub_id,
&correlation_id,
&event_owned,
&body,
outcome.attempts,
&outcome.last_error,
&outcome.first_failed_at,
&outcome.last_failed_at,
)
} else {
record_dlq(
&db_path,
&sub_id,
&correlation_id,
&event_owned,
&body,
outcome.attempts,
&outcome.last_error,
&outcome.first_failed_at,
&outcome.last_failed_at,
)
};
if let Err(e) = dlq_result {
tracing::warn!("subscription DLQ write failed: {e}");
}
}
};
if let Some(rt) = handle.as_ref() {
let permit_sem = dispatch_semaphore();
rt.spawn(async move {
let permit = match permit_sem.acquire_owned().await {
Ok(p) => p,
Err(e) => {
tracing::warn!(
"subscription dispatch: semaphore acquire failed: {e}; \
dropping delivery (semaphore closed)"
);
return;
}
};
if let Err(e) = tokio::task::spawn_blocking(move || {
work();
drop(permit);
})
.await
{
tracing::warn!("subscription dispatch: spawn_blocking join failed: {e}");
}
});
} else {
std::thread::spawn(work);
}
}
}
pub fn dispatch_approval_requested(conn: &Connection, pending_id: &str, db_path: &std::path::Path) {
let pa = match crate::db::get_pending_action(conn, pending_id) {
Ok(Some(pa)) => pa,
Ok(None) => {
tracing::warn!(
"approval_requested dispatch skipped: pending_action {pending_id} not found"
);
return;
}
Err(e) => {
tracing::warn!(
"approval_requested dispatch skipped: pending_action {pending_id} read failed: {e}"
);
return;
}
};
let details = ApprovalRequestedEventDetails {
action_type: pa.action_type.clone(),
requested_at: pa.requested_at.clone(),
memory_id: pa.memory_id.clone(),
status: pa.status.clone(),
};
let details_value = match serde_json::to_value(&details) {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!("approval_requested dispatch details serialise failed: {e}");
None
}
};
crate::approvals::publish(crate::approvals::ApprovalEvent::ApprovalRequested {
pending_id: pa.id.clone(),
action_type: pa.action_type.clone(),
namespace: pa.namespace.clone(),
requested_by: pa.requested_by.clone(),
requested_at: pa.requested_at.clone(),
});
dispatch_event_with_details(
conn,
webhook_events::APPROVAL_REQUESTED,
&pa.id,
&pa.namespace,
Some(&pa.requested_by),
db_path,
details_value,
);
}
struct DeliveryOutcome {
success: bool,
attempts: i64,
last_error: String,
first_failed_at: String,
last_failed_at: String,
}
impl DeliveryOutcome {
fn unsigned_refused() -> Self {
let now = chrono::Utc::now().to_rfc3339();
Self {
success: false,
attempts: 0,
last_error: "dispatch refused: no per-subscription secret AND no server-wide \
[hooks.subscription] hmac_secret configured (v0.7.0 R3-S1.HMAC)"
.to_string(),
first_failed_at: now.clone(),
last_failed_at: now,
}
}
}
pub fn prewarm_dispatch_tls() {
static WARM: std::sync::Once = std::sync::Once::new();
WARM.call_once(|| {
let _ = std::thread::spawn(|| {
if let Err(e) = reqwest::blocking::Client::builder()
.timeout(ACK_TIMEOUT)
.build()
{
tracing::warn!("webhook dispatch TLS warm-up failed: {e}");
}
})
.join();
});
}
fn deliver_with_retry(
url: &str,
body: &str,
timestamp: &str,
signature: Option<&str>,
correlation_id: &str,
) -> DeliveryOutcome {
let mut attempts: i64 = 0;
let mut first_failed_at = String::new();
let mut last_failed_at = String::new();
let mut last_error = String::new();
for attempt_idx in 0..=RETRY_BACKOFFS.len() {
if attempt_idx > 0 {
std::thread::sleep(RETRY_BACKOFFS[attempt_idx - 1]);
}
attempts += 1;
match send(url, body, timestamp, signature, correlation_id) {
Ok(()) => {
return DeliveryOutcome {
success: true,
attempts,
last_error: String::new(),
first_failed_at,
last_failed_at,
};
}
Err(e) => {
let now = chrono::Utc::now().to_rfc3339();
if first_failed_at.is_empty() {
first_failed_at = now.clone();
}
last_failed_at = now;
last_error = e;
}
}
}
DeliveryOutcome {
success: false,
attempts,
last_error,
first_failed_at,
last_failed_at,
}
}
fn send(
url: &str,
body: &str,
timestamp: &str,
signature: Option<&str>,
correlation_id: &str,
) -> Result<(), String> {
if let Err(e) = validate_url(url) {
tracing::warn!("SSRF guard rejected webhook URL {url}: {e}");
return Err(format!("ssrf-rejected: {e}"));
}
let (resolved_host, validated_addrs) =
match validate_url_dns_resolved(url, crate::config::allow_loopback_webhooks()) {
Ok(t) => t,
Err(e) => {
tracing::warn!("DNS SSRF guard rejected webhook URL {url}: {e}");
return Err(format!("dns-ssrf-rejected: {e}"));
}
};
let mut builder = reqwest::blocking::Client::builder()
.timeout(ACK_TIMEOUT)
.redirect(reqwest::redirect::Policy::none());
for addr in &validated_addrs {
builder = builder.resolve(&resolved_host, *addr);
}
let client = match builder.build() {
Ok(c) => c,
Err(e) => {
tracing::warn!("webhook client build failed: {e}");
return Err(format!("client-build: {e}"));
}
};
let mut req = client
.post(url)
.header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON)
.header(
"user-agent",
format!("ai-memory/{}", env!("CARGO_PKG_VERSION")),
)
.header(crate::HEADER_AI_MEMORY_TIMESTAMP, timestamp)
.header("x-ai-memory-correlation-id", correlation_id);
if let Some(sig) = signature {
req = req.header(crate::HEADER_AI_MEMORY_SIGNATURE, format!("sha256={sig}"));
}
let resp = match req.body(body.to_string()).send() {
Ok(r) => r,
Err(e) => {
tracing::warn!("webhook POST to {url} failed: {e}");
return Err(crate::errors::msg::network(e));
}
};
if !resp.status().is_success() {
let status = resp.status().as_u16();
return Err(format!("http-{status}"));
}
let ack_body = match resp.text() {
Ok(s) => s,
Err(e) => return Err(format!("ack-read: {e}")),
};
let ack: serde_json::Value = match serde_json::from_str(&ack_body) {
Ok(v) => v,
Err(e) => return Err(format!("ack-decode: {e}")),
};
let status_field = ack.get("status").and_then(|v| v.as_str()).unwrap_or("");
if status_field != "ack" {
return Err(format!("ack-status: {status_field}"));
}
let ack_corr = ack
.get("correlation_id")
.and_then(|v| v.as_str())
.unwrap_or("");
if ack_corr != correlation_id {
return Err(format!("ack-corr-mismatch: {ack_corr}"));
}
Ok(())
}
pub(crate) fn sha256_hex(s: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(s.as_bytes());
format!("{:x}", hasher.finalize())
}
pub(crate) fn hmac_sha256_hex(key_hex: &str, body: &str) -> String {
const BLOCK: usize = 64;
let mut key = match hex_decode(key_hex) {
Some(k) => k,
None => {
tracing::warn!(
target: SUBSCRIPTIONS_TRACE_TARGET,
"hmac_sha256_hex: hmac_secret is not valid hex (len={}); falling back to raw \
bytes as key material — this produces a STABLE BUT WEAK key. Re-encode the \
secret as hex (e.g. `openssl rand -hex 32`) and restart the daemon. \
See #1048.",
key_hex.len(),
);
key_hex.as_bytes().to_vec()
}
};
if key.len() > BLOCK {
let mut h = Sha256::new();
h.update(&key);
key = h.finalize().to_vec();
}
key.resize(BLOCK, 0);
let mut opad = [0x5cu8; BLOCK];
let mut ipad = [0x36u8; BLOCK];
for i in 0..BLOCK {
opad[i] ^= key[i];
ipad[i] ^= key[i];
}
let mut inner = Sha256::new();
inner.update(ipad);
inner.update(body.as_bytes());
let inner_digest = inner.finalize();
let mut outer = Sha256::new();
outer.update(opad);
outer.update(inner_digest);
format!("{:x}", outer.finalize())
}
fn hex_decode(s: &str) -> Option<Vec<u8>> {
if !s.len().is_multiple_of(2) {
return None;
}
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).ok())
.collect()
}
pub fn validate_hmac_secret_hex(secret: Option<&str>) -> Result<(), String> {
let Some(secret) = secret else {
return Ok(());
};
if hex_decode(secret).is_some() {
return Ok(());
}
Err(format!(
"[hooks.subscription] hmac_secret is not valid hex (len={}): expected an even-length \
hex string. Generate a fresh secret with `openssl rand -hex 32` and update the \
config. The runtime still computes a (weak) HMAC under the misconfigured value for \
wire compatibility, but the boot validator refuses to start so the operator sees \
the diagnostic immediately. See #1048.",
secret.len(),
))
}
pub fn validate_url_dns(url: &str) -> Result<()> {
validate_url_dns_with(url, crate::config::allow_loopback_webhooks()).map(|_| ())
}
pub(crate) fn validate_url_dns_resolved(
url: &str,
allow_loopback: bool,
) -> Result<(String, Vec<std::net::SocketAddr>)> {
validate_url_dns_with(url, allow_loopback)
}
fn ssrf_dns_fail_open() -> bool {
std::env::var("AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
fn hostname_shape_invalid(host: &str) -> bool {
let name = host.strip_suffix('.').unwrap_or(host);
if name.is_empty() || name.len() > 253 {
return true;
}
name.split('.')
.any(|label| label.is_empty() || label.len() > 63)
}
fn validate_url_dns_with(
url: &str,
allow_loopback: bool,
) -> Result<(String, Vec<std::net::SocketAddr>)> {
let lower = url.to_ascii_lowercase();
let (_scheme, rest) = lower
.split_once("://")
.ok_or_else(|| anyhow!("webhook URL missing scheme: {url}"))?;
let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let host_port = &rest[..host_end];
let resolved_host = {
let s = host_port;
if let Some(close_idx) = s.strip_prefix('[').and(s.find(']')) {
s[1..close_idx].to_string()
} else if let Some(idx) = s.rfind(':') {
s[..idx].to_string()
} else {
s.to_string()
}
};
let resolv_target =
if let Some(close_idx) = host_port.strip_prefix('[').and(host_port.find(']')) {
let after_bracket = &host_port[close_idx + 1..];
if after_bracket.starts_with(':') {
host_port.to_string()
} else {
format!("{host_port}:80")
}
} else if host_port.contains(':') {
host_port.to_string()
} else {
format!("{host_port}:80")
};
if hostname_shape_invalid(&resolved_host) {
if ssrf_dns_fail_open() {
tracing::warn!(
target: SUBSCRIPTIONS_TRACE_TARGET,
"SSRF guard: hostname {resolved_host} violates RFC 1035 label/length \
limits for {url}; AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL=1 — degrading to \
ALLOW (UNSAFE, legacy posture)"
);
return Ok((resolved_host, Vec::new()));
}
return Err(anyhow!(
"SSRF guard: DNS resolution failed for {url}: hostname violates RFC 1035 \
label/length limits; failing CLOSED (post-#1053 secure default — set \
AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL=1 to revert)"
));
}
let addrs: Vec<std::net::SocketAddr> = match resolv_target.to_socket_addrs() {
Ok(iter) => iter.collect(),
Err(e) => {
let fail_open = ssrf_dns_fail_open();
if fail_open {
tracing::warn!(
target: SUBSCRIPTIONS_TRACE_TARGET,
"SSRF guard: DNS resolution failed for {url}: {e}; \
AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL=1 — degrading to ALLOW \
(UNSAFE, legacy posture) — reqwest's resolver may bind to \
private/loopback IPs the daemon could not pre-check"
);
return Ok((resolved_host, Vec::new()));
}
return Err(anyhow!(
"SSRF guard: DNS resolution failed for {url}: {e}; failing CLOSED \
(post-#1053 secure default — set AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL=1 to revert)"
));
}
};
for addr in &addrs {
let ip = addr.ip();
if is_private(ip) && !is_loopback_normalized(ip) {
return Err(anyhow!(
"host resolves to private/link-local IP {ip}: {url}"
));
}
if is_loopback_normalized(ip) && !allow_loopback {
return Err(anyhow!(
"host resolves to loopback IP {ip}: {url} — rejected by default \
(SSRF guard); set `[subscriptions] allow_loopback_webhooks = true` \
to opt in"
));
}
}
Ok((resolved_host, addrs))
}
pub fn validate_url(url: &str) -> Result<()> {
validate_url_with(url, crate::config::allow_loopback_webhooks())
}
fn validate_url_with(url: &str, allow_loopback: bool) -> Result<()> {
let lower = url.to_ascii_lowercase();
let (scheme, rest) = lower
.split_once("://")
.ok_or_else(|| anyhow!("webhook URL missing scheme: {url}"))?;
if scheme != "https" && scheme != "http" {
return Err(anyhow!("webhook URL scheme must be http(s): {url}"));
}
let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let host_port = &rest[..host_end];
let host: String = if let Some(stripped) = host_port.strip_prefix('[') {
match stripped.find(']') {
Some(i) => stripped[..i].to_string(),
None => return Err(anyhow!("malformed IPv6 URL host: {url}")),
}
} else {
host_port
.rsplit_once(':')
.map_or(host_port.to_string(), |(h, _)| h.to_string())
};
let host = host.as_str();
let is_loopback_hostname = matches!(host, "localhost" | "localhost.localdomain" | "");
let parsed_ip = IpAddr::from_str(host).ok();
let is_loopback_ip = parsed_ip.is_some_and(is_loopback_normalized);
let is_loopback = is_loopback_hostname || is_loopback_ip;
if is_loopback && !allow_loopback {
return Err(anyhow!(
"webhook URL targets loopback address {url} — rejected by default \
(SSRF guard); set `[subscriptions] allow_loopback_webhooks = true` \
to opt in (testing / dev only)"
));
}
if scheme == "http" && !is_loopback {
if let Some(ip) = parsed_ip {
if !is_loopback_normalized(ip) {
return Err(anyhow!(
"webhook URL must be https for non-loopback host: {url}"
));
}
} else {
return Err(anyhow!(
"webhook URL must be https for non-loopback host: {url}"
));
}
}
if let Some(ip) = parsed_ip
&& is_private(ip)
&& !is_loopback_normalized(ip)
{
return Err(anyhow!(
"webhook URL targets private / link-local address: {url}"
));
}
Ok(())
}
fn normalize_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => {
let canonical = v6.to_canonical();
if matches!(canonical, IpAddr::V4(_)) {
return canonical;
}
let segs = v6.segments();
if segs[0] == 0x0064
&& segs[1] == 0xff9b
&& segs[2] == 0
&& segs[3] == 0
&& segs[4] == 0
&& segs[5] == 0
{
let v4_bits = (u32::from(segs[6]) << 16) | u32::from(segs[7]);
return IpAddr::V4(Ipv4Addr::from(v4_bits));
}
IpAddr::V6(v6)
}
v4 @ IpAddr::V4(_) => v4,
}
}
fn is_loopback_normalized(ip: IpAddr) -> bool {
normalize_ip(ip).is_loopback()
}
fn is_private(ip: IpAddr) -> bool {
match normalize_ip(ip) {
IpAddr::V4(v4) => {
v4.is_private()
|| v4.is_link_local()
|| v4.is_multicast()
|| v4.is_broadcast()
|| v4.is_unspecified()
}
IpAddr::V6(v6) => {
let segs = v6.segments();
v6.is_multicast()
|| v6.is_unspecified()
|| (segs[0] & 0xfe00) == 0xfc00 || (segs[0] & 0xffc0) == 0xfe80 }
}
}
#[allow(dead_code)]
fn load_secret_hash(db_path: &std::path::Path, sub_id: &str) -> Result<Option<String>> {
let conn = Connection::open(db_path).context("load_secret_hash open")?;
load_secret_hash_with_conn(&conn, sub_id)
}
fn load_secret_hash_with_conn(conn: &Connection, sub_id: &str) -> Result<Option<String>> {
conn.query_row(
"SELECT secret_hash FROM subscriptions WHERE id = ?1",
params![sub_id],
|r| r.get::<_, Option<String>>(0),
)
.context("load_secret_hash query")
}
pub fn record_subscription_event(
db_path: &std::path::Path,
sub_id: &str,
correlation_id: &str,
event_type: &str,
payload: &str,
) -> Result<()> {
let conn = Connection::open(db_path).context("subscription_events open")?;
record_subscription_event_with_conn(&conn, sub_id, correlation_id, event_type, payload)
}
pub fn record_subscription_event_with_conn(
conn: &Connection,
sub_id: &str,
correlation_id: &str,
event_type: &str,
payload: &str,
) -> Result<()> {
let now = chrono::Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO subscription_events \
(subscription_id, correlation_id, event_type, payload, delivered_at, delivery_status) \
VALUES (?1, ?2, ?3, ?4, ?5, 'pending')",
params![sub_id, correlation_id, event_type, payload, now],
)
.context("subscription_events insert")?;
Ok(())
}
fn update_event_status(db_path: &std::path::Path, correlation_id: &str, ok: bool) {
let Ok(conn) = Connection::open(db_path) else {
return;
};
update_event_status_with_conn(&conn, correlation_id, ok);
}
fn update_event_status_with_conn(conn: &Connection, correlation_id: &str, ok: bool) {
let status = if ok { "ack" } else { "failed" };
let _ = conn.execute(
"UPDATE subscription_events SET delivery_status = ?1 WHERE correlation_id = ?2",
params![status, correlation_id],
);
}
#[allow(clippy::too_many_arguments)]
pub fn record_dlq(
db_path: &std::path::Path,
sub_id: &str,
correlation_id: &str,
event_type: &str,
payload: &str,
retry_count: i64,
last_error: &str,
first_failed_at: &str,
last_failed_at: &str,
) -> Result<()> {
let conn = Connection::open(db_path).context("subscription_dlq open")?;
record_dlq_with_conn(
&conn,
sub_id,
correlation_id,
event_type,
payload,
retry_count,
last_error,
first_failed_at,
last_failed_at,
)
}
pub const MAX_SUBSCRIPTION_DLQ_ROWS: i64 = 10_000;
#[allow(clippy::too_many_arguments)]
pub fn record_dlq_with_conn(
conn: &Connection,
sub_id: &str,
correlation_id: &str,
event_type: &str,
payload: &str,
retry_count: i64,
last_error: &str,
first_failed_at: &str,
last_failed_at: &str,
) -> Result<()> {
let depth: i64 = conn
.query_row(
"SELECT COUNT(*) FROM subscription_dlq WHERE subscription_id = ?1",
params![sub_id],
|row| row.get(0),
)
.context("subscription_dlq depth probe")?;
if depth >= MAX_SUBSCRIPTION_DLQ_ROWS {
crate::metrics::record_subscription_dlq_overflow();
tracing::warn!(
subscription_id = %sub_id,
correlation_id = %correlation_id,
event_type = %event_type,
depth = depth,
cap = MAX_SUBSCRIPTION_DLQ_ROWS,
"dlq_overflow: refusing subscription_dlq insert — per-subscription cap reached",
);
return Err(anyhow!(
"dlq_overflow: subscription {sub_id} dlq at cap ({MAX_SUBSCRIPTION_DLQ_ROWS}); drain before further inserts"
));
}
conn.execute(
"INSERT INTO subscription_dlq \
(subscription_id, correlation_id, event_type, payload, retry_count, last_error, first_failed_at, last_failed_at) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
sub_id,
correlation_id,
event_type,
payload,
retry_count,
last_error,
first_failed_at,
last_failed_at,
],
)
.context("subscription_dlq insert")?;
Ok(())
}
pub fn list_dlq(conn: &Connection, subscription_id: Option<&str>) -> Result<Vec<DlqEntry>> {
let mut out = Vec::new();
if let Some(sub_id) = subscription_id {
let mut stmt = conn.prepare(
"SELECT id, subscription_id, correlation_id, event_type, payload, retry_count, last_error, first_failed_at, last_failed_at \
FROM subscription_dlq WHERE subscription_id = ?1 ORDER BY id ASC",
)?;
let rows = stmt.query_map(params![sub_id], dlq_row_to_entry)?;
for r in rows {
out.push(r?);
}
} else {
let mut stmt = conn.prepare(
"SELECT id, subscription_id, correlation_id, event_type, payload, retry_count, last_error, first_failed_at, last_failed_at \
FROM subscription_dlq ORDER BY id ASC",
)?;
let rows = stmt.query_map([], dlq_row_to_entry)?;
for r in rows {
out.push(r?);
}
}
Ok(out)
}
fn dlq_row_to_entry(row: &rusqlite::Row) -> rusqlite::Result<DlqEntry> {
Ok(DlqEntry {
id: row.get(0)?,
subscription_id: row.get(1)?,
correlation_id: row.get(2)?,
event_type: row.get(3)?,
payload: row.get(4)?,
retry_count: row.get(5)?,
last_error: row.get(6)?,
first_failed_at: row.get(7)?,
last_failed_at: row.get(8)?,
})
}
pub fn replay_subscription_events(
conn: &Connection,
subscription_id: &str,
since_rfc3339: &str,
) -> Result<Vec<SubscriptionEvent>> {
let mut stmt = conn.prepare(
"SELECT id, subscription_id, correlation_id, event_type, payload, delivered_at, delivery_status \
FROM subscription_events \
WHERE subscription_id = ?1 AND delivered_at >= ?2 \
ORDER BY delivered_at ASC, id ASC",
)?;
let rows = stmt.query_map(params![subscription_id, since_rfc3339], |row| {
Ok(SubscriptionEvent {
id: row.get(0)?,
subscription_id: row.get(1)?,
correlation_id: row.get(2)?,
event_type: row.get(3)?,
payload: row.get(4)?,
delivered_at: row.get(5)?,
delivery_status: row.get(6)?,
})
})?;
let mut out = Vec::new();
for r in rows {
out.push(r.context("subscription_events row decode")?);
}
Ok(out)
}
pub fn memory_subscription_replay(
conn: &Connection,
subscription_id: &str,
since_rfc3339: &str,
) -> Result<serde_json::Value> {
let events = replay_subscription_events(conn, subscription_id, since_rfc3339)?;
Ok(serde_json::json!({
(field_names::SUBSCRIPTION_ID): subscription_id,
"since": since_rfc3339,
"count": events.len(),
"events": events,
}))
}
fn record_dispatch(db_path: &std::path::Path, sub_id: &str, ok: bool) {
let Ok(conn) = Connection::open(db_path) else {
return;
};
record_dispatch_with_conn(&conn, sub_id, ok);
}
fn record_dispatch_with_conn(conn: &Connection, sub_id: &str, ok: bool) {
let now = chrono::Utc::now().to_rfc3339();
let sql = if ok {
"UPDATE subscriptions SET dispatch_count = dispatch_count + 1, last_dispatched_at = ?1 WHERE id = ?2"
} else {
"UPDATE subscriptions SET dispatch_count = dispatch_count + 1, failure_count = failure_count + 1, last_dispatched_at = ?1 WHERE id = ?2"
};
let _ = conn.execute(sql, params![now, sub_id]);
}
#[cfg(test)]
mod tests {
use super::*;
static SSRF_ENV_GUARD: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn https_allowed() {
assert!(validate_url("https://example.com/hook").is_ok());
assert!(validate_url("https://api.example.com:8443/hook?x=1").is_ok());
}
#[test]
fn webhook_user_agent_tracks_cargo_pkg_version() {
let ua = format!("ai-memory/{}", env!("CARGO_PKG_VERSION"));
let expected = format!("ai-memory/{}", env!("CARGO_PKG_VERSION"));
assert_eq!(ua, expected);
assert_ne!(ua, "ai-memory/0.6.0.0");
assert!(ua.starts_with("ai-memory/"));
assert!(ua.len() > "ai-memory/".len());
}
#[test]
fn http_only_to_loopback() {
assert!(validate_url_with("http://localhost/hook", true).is_ok());
assert!(validate_url_with("http://127.0.0.1:8080/hook", true).is_ok());
assert!(validate_url_with("http://[::1]/hook", true).is_ok());
assert!(validate_url_with("http://example.com/hook", true).is_err());
assert!(validate_url_with("http://8.8.8.8/hook", true).is_err());
}
#[test]
fn loopback_rejected_by_default_h11() {
for url in [
"http://127.0.0.1:5432/hook",
"http://localhost/hook",
"http://[::1]/hook",
"https://127.0.0.1/hook",
"https://localhost/hook",
] {
let res = validate_url_with(url, false);
assert!(
res.is_err(),
"loopback URL {url} must be rejected when allow_loopback=false (H11), got {res:?}"
);
let msg = res.unwrap_err().to_string();
assert!(
msg.contains("loopback") || msg.contains("SSRF"),
"rejection message should explain loopback policy, got: {msg}"
);
}
}
#[test]
fn loopback_accepted_when_opted_in_h11() {
assert!(validate_url_with("http://127.0.0.1:9999/hook", true).is_ok());
assert!(validate_url_with("http://localhost/hook", true).is_ok());
assert!(validate_url_with("http://[::1]/hook", true).is_ok());
}
#[test]
fn private_ranges_blocked() {
assert!(validate_url("https://10.0.0.1/hook").is_err());
assert!(validate_url("https://192.168.1.1/hook").is_err());
assert!(validate_url("https://172.16.0.1/hook").is_err());
assert!(validate_url("https://169.254.1.1/hook").is_err());
assert!(validate_url("https://[fc00::1]/hook").is_err());
assert!(validate_url("https://[fe80::1]/hook").is_err());
}
#[test]
fn nonsense_rejected() {
assert!(validate_url("ftp://example.com").is_err());
assert!(validate_url("notaurl").is_err());
assert!(validate_url("").is_err());
}
#[test]
fn rejects_v4_mapped_ipv6_loopback() {
assert!(validate_url_with("https://[::ffff:127.0.0.1]/hook", false).is_err());
assert!(validate_url_with("https://[::ffff:7f00:1]/hook", false).is_err());
}
#[test]
fn rejects_v4_mapped_ipv6_private() {
assert!(validate_url_with("https://[::ffff:10.0.0.1]/hook", false).is_err());
assert!(validate_url_with("https://[::ffff:192.168.1.1]/hook", false).is_err());
assert!(validate_url_with("https://[::ffff:172.16.0.1]/hook", false).is_err());
assert!(validate_url_with("https://[::ffff:169.254.1.1]/hook", false).is_err());
assert!(validate_url_with("https://[::ffff:0.0.0.0]/hook", false).is_err());
}
#[test]
fn rejects_nat64_well_known_prefix() {
assert!(validate_url_with("https://[64:ff9b::127.0.0.1]/hook", false).is_err());
assert!(validate_url_with("https://[64:ff9b::10.0.0.1]/hook", false).is_err());
assert!(validate_url_with("https://[64:ff9b::169.254.1.1]/hook", false).is_err());
}
#[test]
fn allows_v4_mapped_loopback_when_opted_in() {
assert!(validate_url_with("http://[::ffff:127.0.0.1]/hook", true).is_ok());
}
#[test]
fn hmac_sha256_stable() {
let key = hex::encode_fallback("key".as_bytes());
let got = hmac_sha256_hex(&key, "The quick brown fox jumps over the lazy dog");
assert_eq!(
got,
"f7bc83f430538424b13298e6aa6fb143ef4d59a14946175997479dbc2d1a3cd8"
);
}
#[test]
fn filter_wildcards() {
assert!(matches_filters(
"*",
None,
None,
None,
"memory_store",
"ns",
None
));
assert!(matches_filters(
"memory_store,memory_delete",
None,
None,
None,
"memory_store",
"ns",
None
));
assert!(!matches_filters(
"memory_delete",
None,
None,
None,
"memory_store",
"ns",
None
));
assert!(matches_filters(
"*",
None,
Some("foo"),
None,
"memory_store",
"foo",
None
));
assert!(!matches_filters(
"*",
None,
Some("foo"),
None,
"memory_store",
"bar",
None
));
assert!(matches_filters(
"*",
None,
None,
Some("alice"),
"memory_store",
"ns",
Some("alice")
));
assert!(!matches_filters(
"*",
None,
None,
Some("alice"),
"memory_store",
"ns",
Some("bob")
));
}
#[test]
fn filter_event_types_overrides_legacy_events() {
let opt_in_store_only: Vec<String> = vec!["memory_store".to_string()];
assert!(matches_filters(
"*",
Some(&opt_in_store_only),
None,
None,
"memory_store",
"ns",
None
));
assert!(!matches_filters(
"*",
Some(&opt_in_store_only),
None,
None,
"memory_delete",
"ns",
None
));
let multi: Vec<String> = vec![
"memory_promote".to_string(),
"memory_link_created".to_string(),
];
assert!(matches_filters(
"memory_store",
Some(&multi),
None,
None,
"memory_promote",
"ns",
None
));
assert!(!matches_filters(
"memory_store",
Some(&multi),
None,
None,
"memory_store",
"ns",
None
));
let empty: Vec<String> = vec![];
assert!(!matches_filters(
"*",
Some(&empty),
None,
None,
"memory_store",
"ns",
None
));
}
#[test]
fn test_validate_url_dns_accepts_loopback_v4() {
assert!(
validate_url_dns_with("http://127.0.0.1/foo", true).is_ok(),
"127.0.0.1 should be accepted by validate_url_dns when opted in"
);
assert!(
validate_url_dns_with("http://127.0.0.1:8080/", true).is_ok(),
"127.0.0.1:8080 should be accepted by validate_url_dns when opted in"
);
assert!(
validate_url_dns_with("http://localhost/", true).is_ok(),
"localhost should be accepted by validate_url_dns when opted in"
);
}
#[test]
fn test_validate_url_dns_accepts_loopback_v6() {
assert!(
validate_url_dns_with("http://[::1]/", true).is_ok(),
"[::1] should be accepted by validate_url_dns when opted in"
);
assert!(
validate_url_dns_with("http://[0:0:0:0:0:0:0:1]/", true).is_ok(),
"[::1] expanded form should be accepted when opted in"
);
}
#[test]
fn test_validate_url_dns_rejects_loopback_by_default_h11() {
assert!(
validate_url_dns_with("http://127.0.0.1/foo", false).is_err(),
"127.0.0.1 must be rejected by validate_url_dns when allow_loopback=false (H11)"
);
assert!(
validate_url_dns_with("http://[::1]/", false).is_err(),
"[::1] must be rejected by validate_url_dns when allow_loopback=false (H11)"
);
}
#[test]
fn test_validate_url_dns_rejects_link_local_ipv6() {
let res = validate_url_dns("http://[fe80::1]/");
assert!(
res.is_err(),
"fe80::1 must be rejected as link-local IPv6, got {res:?}"
);
}
#[test]
fn test_validate_url_dns_rejects_aws_metadata() {
let res = validate_url_dns("http://169.254.169.254/latest/meta-data/");
assert!(
res.is_err(),
"AWS metadata IP must be rejected, got {res:?}"
);
}
#[test]
fn test_validate_url_dns_rejects_rfc1918_private_ranges() {
for url in [
"http://10.0.0.1/",
"http://172.16.0.1/",
"http://172.31.255.255/",
"http://192.168.1.1/",
] {
let res = validate_url_dns(url);
assert!(
res.is_err(),
"{url} must be rejected as RFC1918, got {res:?}"
);
}
}
#[test]
fn test_validate_url_dns_accepts_public_ip_or_dns() {
assert!(
validate_url_dns("https://1.1.1.1/").is_ok(),
"public IP literal must be accepted"
);
let _ = validate_url_dns("https://example.com/");
}
#[test]
fn test_validate_url_dns_fails_closed_on_dns_failure_1053() {
let _env_guard = SSRF_ENV_GUARD.lock().unwrap_or_else(|e| e.into_inner());
let oversized_label = "a".repeat(70);
let url = format!("https://{oversized_label}.fxf1-test./");
let res = validate_url_dns(&url);
assert!(
res.is_err(),
"#1053: SSRF guard MUST fail-closed on DNS resolution failure \
(oversized label rejected at getaddrinfo shape check); got {res:?}"
);
let err_msg = format!("{}", res.unwrap_err());
assert!(
err_msg.contains("failing CLOSED")
&& err_msg.contains("AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL"),
"#1053: failure message MUST reference fail-closed posture + env-var escape hatch; got {err_msg:?}"
);
}
#[test]
fn test_validate_url_dns_fail_open_env_overrides_1053() {
let _env_guard = SSRF_ENV_GUARD.lock().unwrap_or_else(|e| e.into_inner());
unsafe {
std::env::set_var("AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL", "1");
}
let oversized_label = "a".repeat(70);
let url = format!("https://{oversized_label}.fxf1-test./");
let res = validate_url_dns(&url);
unsafe {
std::env::remove_var("AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL");
}
assert!(
res.is_ok(),
"#1053: AI_MEMORY_SSRF_GUARD_ALLOW_DNS_FAIL=1 MUST restore the legacy permissive posture; got {res:?}"
);
}
#[test]
fn test_validate_url_dns_rejects_unspecified_addresses() {
let v4 = validate_url_dns("http://0.0.0.0/");
let v6 = validate_url_dns("http://[::]/");
assert!(
v4.is_err(),
"0.0.0.0 should be rejected as unspecified, got {v4:?}"
);
assert!(
v6.is_err(),
"[::] should be rejected as unspecified, got {v6:?}"
);
}
#[test]
fn test_validate_url_dns_missing_scheme() {
let res = validate_url_dns("not-a-url");
assert!(res.is_err(), "missing scheme must Err, got {res:?}");
}
use tempfile::NamedTempFile;
fn fresh_db() -> (NamedTempFile, std::path::PathBuf) {
let f = NamedTempFile::new().expect("tempfile");
let p = f.path().to_path_buf();
let _ = crate::db::open(&p).expect("db::open");
(f, p)
}
struct AckEcho;
impl wiremock::Respond for AckEcho {
fn respond(&self, request: &wiremock::Request) -> wiremock::ResponseTemplate {
let corr = request
.headers
.get("x-ai-memory-correlation-id")
.map(|v| v.to_str().unwrap_or("").to_string())
.unwrap_or_default();
let body = serde_json::json!({
"status": "ack",
"correlation_id": corr,
});
wiremock::ResponseTemplate::new(200).set_body_json(body)
}
}
#[test]
fn insert_persists_and_list_returns_row() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
let id = insert(
&conn,
&NewSubscription {
url: "https://example.com/hook",
events: "memory_store",
secret: Some("s3cret"),
namespace_filter: Some("ns1"),
agent_filter: Some("alice"),
created_by: Some("op"),
event_types: None,
},
)
.unwrap();
assert!(!id.is_empty());
let subs = list(&conn, None).unwrap();
assert_eq!(subs.len(), 1);
let s = &subs[0];
assert_eq!(s.id, id);
assert_eq!(s.url, "https://example.com/hook");
assert_eq!(s.events, "memory_store");
assert_eq!(s.namespace_filter.as_deref(), Some("ns1"));
assert_eq!(s.agent_filter.as_deref(), Some("alice"));
assert_eq!(s.created_by.as_deref(), Some("op"));
assert_eq!(s.dispatch_count, 0);
assert_eq!(s.failure_count, 0);
}
#[test]
fn insert_rejects_invalid_url() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
let res = insert(
&conn,
&NewSubscription {
url: "not-a-url",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
);
assert!(res.is_err(), "insert must reject invalid URL");
}
#[test]
fn insert_hashes_secret_before_persisting() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
let plaintext = "super-shared-secret";
let id = insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: Some(plaintext),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
let stored: Option<String> = conn
.query_row(
"SELECT secret_hash FROM subscriptions WHERE id = ?1",
params![id],
|r| r.get(0),
)
.unwrap();
let hash = stored.expect("secret_hash should be set");
assert_ne!(hash, plaintext, "plaintext secret must not be stored");
assert_eq!(hash, sha256_hex(plaintext));
}
#[test]
fn insert_no_secret_stores_null() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
let id = insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
let stored: Option<String> = conn
.query_row(
"SELECT secret_hash FROM subscriptions WHERE id = ?1",
params![id],
|r| r.get(0),
)
.unwrap();
assert!(stored.is_none(), "missing secret must persist as NULL");
}
#[test]
fn delete_returns_true_when_row_removed() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
let id = insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
assert!(delete(&conn, &id, None).unwrap());
assert!(list(&conn, None).unwrap().is_empty());
}
#[test]
fn delete_returns_false_when_row_missing() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
assert!(!delete(&conn, "nope", None).unwrap());
}
#[test]
fn list_orders_by_created_at_desc() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
let id1 = insert(
&conn,
&NewSubscription {
url: "https://a.example.com/",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
std::thread::sleep(std::time::Duration::from_millis(1100));
let id2 = insert(
&conn,
&NewSubscription {
url: "https://b.example.com/",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
let subs = list(&conn, None).unwrap();
assert_eq!(subs.len(), 2);
assert_eq!(subs[0].id, id2);
assert_eq!(subs[1].id, id1);
}
#[test]
fn sha256_hex_known_vector() {
assert_eq!(
sha256_hex(""),
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
);
assert_eq!(
sha256_hex("abc"),
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
);
}
#[test]
fn hex_decode_round_trip_and_invalid() {
let s = "deadbeef";
let bytes = hex_decode(s).expect("valid hex");
assert_eq!(bytes, vec![0xde, 0xad, 0xbe, 0xef]);
assert!(hex_decode("abc").is_none());
assert!(hex_decode("zz").is_none());
}
#[test]
fn validate_hmac_secret_hex_accepts_none_1048() {
assert!(validate_hmac_secret_hex(None).is_ok());
}
#[test]
fn validate_hmac_secret_hex_accepts_valid_hex_1048() {
assert!(validate_hmac_secret_hex(Some("deadbeef")).is_ok());
assert!(validate_hmac_secret_hex(Some("0123456789abcdef")).is_ok());
let long = "a".repeat(64);
assert!(validate_hmac_secret_hex(Some(&long)).is_ok());
}
#[test]
fn validate_hmac_secret_hex_rejects_non_hex_1048() {
let err = validate_hmac_secret_hex(Some("not-a-hex-key!!"))
.expect_err("non-hex MUST fail validation");
assert!(
err.contains("not valid hex") && err.contains("openssl rand -hex 32"),
"#1048: failure msg MUST reference invalid hex + remediation; got: {err}"
);
}
#[test]
fn validate_hmac_secret_hex_rejects_odd_length_1048() {
let err = validate_hmac_secret_hex(Some("abc")).expect_err("odd-length MUST fail");
assert!(err.contains("not valid hex"));
}
#[test]
fn hmac_sha256_hex_output_is_fixed_64_chars_1039() {
for body in &["", "x", "x".repeat(1024).as_str(), "🦀"] {
let sig = hmac_sha256_hex("deadbeef", body);
assert_eq!(
sig.len(),
64,
"#1039: HMAC-SHA256 hex output MUST be fixed 64 chars; got len={} for body={:?}",
sig.len(),
body
);
assert!(
sig.chars().all(|c| c.is_ascii_hexdigit()),
"#1039: HMAC hex output MUST be all-hex; got {sig}"
);
}
}
#[test]
fn hmac_long_key_is_hashed_to_fit_block() {
let long_key: String = std::iter::repeat_n('a', 200).collect();
let sig = hmac_sha256_hex(&long_key, "hello");
assert_eq!(sig.len(), 64); }
#[test]
fn hmac_invalid_hex_key_falls_back_to_raw_bytes() {
let sig = hmac_sha256_hex("not-a-hex-key!!", "hello");
assert_eq!(sig.len(), 64);
assert!(sig.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn matches_filters_event_with_whitespace_and_star() {
assert!(matches_filters(
"memory_store, *",
None,
None,
None,
"anything",
"ns",
None,
));
assert!(matches_filters(
" memory_delete , memory_store ",
None,
None,
None,
"memory_store",
"ns",
None,
));
}
#[test]
fn matches_filters_agent_filter_requires_some() {
assert!(!matches_filters(
"*",
None,
None,
Some("alice"),
"memory_store",
"ns",
None,
));
}
#[test]
fn record_dispatch_increments_counts_on_success() {
let (_keep, path) = fresh_db();
let id = {
let conn = Connection::open(&path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
record_dispatch(&path, &id, true);
record_dispatch(&path, &id, true);
let conn = Connection::open(&path).unwrap();
let (dc, fc): (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
params![id],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(dc, 2, "two successful dispatches must bump dispatch_count");
assert_eq!(fc, 0, "successes must not bump failure_count");
}
#[test]
fn record_dispatch_increments_failure_on_err() {
let (_keep, path) = fresh_db();
let id = {
let conn = Connection::open(&path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
record_dispatch(&path, &id, false);
let conn = Connection::open(&path).unwrap();
let (dc, fc): (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
params![id],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(dc, 1, "failed dispatch still bumps dispatch_count");
assert_eq!(fc, 1, "failure must bump failure_count");
}
#[test]
fn record_dispatch_nonexistent_id_does_not_panic() {
let (_keep, path) = fresh_db();
record_dispatch(&path, "no-such-id", true);
record_dispatch(&path, "no-such-id", false);
let conn = Connection::open(&path).unwrap();
let n: i64 = conn
.query_row("SELECT COUNT(*) FROM subscriptions", [], |r| r.get(0))
.unwrap();
assert_eq!(n, 0);
}
#[test]
fn record_dispatch_unopenable_db_path_is_noop() {
let bad = std::path::PathBuf::from("/nonexistent-dir-w12c/does-not-exist.db");
record_dispatch(&bad, "x", true);
}
#[test]
fn load_secret_hash_returns_stored_hash() {
let (_keep, path) = fresh_db();
let id = {
let conn = Connection::open(&path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: Some("topsecret"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
let got = load_secret_hash(&path, &id).unwrap();
assert_eq!(got, Some(sha256_hex("topsecret")));
}
#[test]
fn load_secret_hash_missing_id_errs() {
let (_keep, path) = fresh_db();
let res = load_secret_hash(&path, "missing-id");
assert!(res.is_err(), "missing subscription id must surface as Err");
}
#[test]
fn dispatch_event_no_subs_is_noop() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
dispatch_event(&conn, "memory_store", "m1", "ns", None, &path);
}
#[test]
fn dispatch_event_filter_mismatch_skips_send() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "memory_delete",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
dispatch_event(&conn, "memory_store", "m1", "ns", None, &path);
let (dc, fc): (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions",
[],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(dc, 0);
assert_eq!(fc, 0);
}
#[test]
fn dispatch_event_namespace_filter_mismatch_skips() {
let (_keep, path) = fresh_db();
let conn = Connection::open(&path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/h",
events: "*",
secret: None,
namespace_filter: Some("only-this-ns"),
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
dispatch_event(&conn, "memory_store", "m1", "other-ns", None, &path);
let (dc, fc): (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions",
[],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(dc, 0);
assert_eq!(fc, 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn send_returns_true_on_2xx() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(AckEcho)
.expect(1)
.mount(&server)
.await;
let url = format!("{}/hook", server.uri());
let corr = uuid::Uuid::now_v7().to_string();
let res = tokio::task::spawn_blocking(move || {
send(
&url,
"{\"event\":\"x\"}",
"1700000000",
Some("deadbeef"),
&corr,
)
})
.await
.unwrap();
assert!(res.is_ok(), "2xx + matching ack must succeed: {res:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn send_returns_false_on_5xx() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let url = format!("{}/hook", server.uri());
let corr = uuid::Uuid::now_v7().to_string();
let res = tokio::task::spawn_blocking(move || {
send(&url, "{\"event\":\"x\"}", "1700000000", None, &corr)
})
.await
.unwrap();
assert!(res.is_err(), "5xx must return Err (no retry inside send)");
}
#[tokio::test(flavor = "multi_thread")]
async fn send_returns_false_on_4xx() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let url = format!("{}/hook", server.uri());
let corr = uuid::Uuid::now_v7().to_string();
let res = tokio::task::spawn_blocking(move || send(&url, "{}", "1700000000", None, &corr))
.await
.unwrap();
assert!(res.is_err(), "4xx must return Err");
}
#[tokio::test(flavor = "multi_thread")]
async fn send_does_not_follow_redirect_ssrf_pin_bypass() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "/internal-rebind-target"),
)
.expect(1)
.mount(&server)
.await;
let redirect_followed = Mock::given(path("/internal-rebind-target"))
.respond_with(ResponseTemplate::new(200))
.expect(0)
.named("redirect target must NOT be requested");
server.register(redirect_followed).await;
let url = format!("{}/hook", server.uri());
let corr = uuid::Uuid::now_v7().to_string();
let res = tokio::task::spawn_blocking(move || send(&url, "{}", "1700000000", None, &corr))
.await
.unwrap();
assert!(
res.is_err(),
"redirect must not be followed; 3xx surfaces as a failed dispatch: {res:?}"
);
let hits = server
.received_requests()
.await
.unwrap_or_default()
.into_iter()
.filter(|r| r.url.path() == "/internal-rebind-target")
.count();
assert_eq!(hits, 0, "redirect target was requested — SSRF pin bypassed");
}
#[tokio::test(flavor = "multi_thread")]
async fn send_signature_header_set_when_provided() {
use wiremock::matchers::{header, header_exists, method, path};
use wiremock::{Mock, MockServer};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.and(header("x-ai-memory-signature", "sha256=abc123"))
.and(header_exists("x-ai-memory-timestamp"))
.and(header_exists("x-ai-memory-correlation-id"))
.and(header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON))
.respond_with(AckEcho)
.expect(1)
.mount(&server)
.await;
let url = format!("{}/hook", server.uri());
let corr = uuid::Uuid::now_v7().to_string();
let res = tokio::task::spawn_blocking(move || {
send(&url, "{}", "1700000000", Some("abc123"), &corr)
})
.await
.unwrap();
assert!(
res.is_ok(),
"2xx with matched signature header + ack must succeed: {res:?}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn send_no_signature_header_when_secret_absent() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, Request};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(AckEcho)
.mount(&server)
.await;
let url = format!("{}/hook", server.uri());
let corr = uuid::Uuid::now_v7().to_string();
let res = tokio::task::spawn_blocking({
let url = url.clone();
let corr = corr.clone();
move || send(&url, "{}", "1700000000", None, &corr)
})
.await
.unwrap();
assert!(res.is_ok(), "ack-echo must succeed: {res:?}");
let received: Vec<Request> = server.received_requests().await.unwrap_or_default();
assert_eq!(received.len(), 1);
let req = &received[0];
assert!(
req.headers.get("x-ai-memory-signature").is_none(),
"no signature should be sent when secret absent"
);
assert!(
req.headers.get("x-ai-memory-timestamp").is_some(),
"timestamp header must always be set"
);
}
#[test]
fn send_rejects_ssrf_url_without_network() {
let res = send(
"https://10.0.0.1/hook",
"{}",
"1700000000",
None,
"some-corr",
);
assert!(
res.is_err(),
"send must reject SSRF URL via validate_url guard"
);
}
#[test]
fn send_rejects_invalid_scheme_without_network() {
let res = send("ftp://example.com/hook", "{}", "1700000000", None, "x");
assert!(res.is_err(), "send must reject non-http(s) URL");
}
#[tokio::test(flavor = "multi_thread")]
async fn dispatch_event_e2e_increments_dispatch_count_on_2xx() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(AckEcho)
.mount(&server)
.await;
let (_keep, db_path) = fresh_db();
let id = {
let conn = Connection::open(&db_path).unwrap();
let url = format!("{}/hook", server.uri());
insert(
&conn,
&NewSubscription {
url: &url,
events: "*",
secret: Some("mysecret"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_event(&conn, "memory_store", "m1", "ns", None, &db_path);
}
let path_for_poll = db_path.clone();
let id_for_poll = id.clone();
let dc = tokio::task::spawn_blocking(move || {
for _ in 0..50 {
let conn = Connection::open(&path_for_poll).unwrap();
let dc: i64 = conn
.query_row(
"SELECT dispatch_count FROM subscriptions WHERE id = ?1",
params![id_for_poll],
|r| r.get(0),
)
.unwrap();
if dc > 0 {
return dc;
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
0
})
.await
.unwrap();
assert_eq!(dc, 1, "successful dispatch must increment dispatch_count");
}
#[tokio::test(flavor = "multi_thread")]
async fn dispatch_event_e2e_increments_failure_count_on_5xx() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let (_keep, db_path) = fresh_db();
let id = {
let conn = Connection::open(&db_path).unwrap();
let url = format!("{}/hook", server.uri());
insert(
&conn,
&NewSubscription {
url: &url,
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_event(&conn, "memory_store", "m2", "ns", None, &db_path);
}
let path_for_poll = db_path.clone();
let id_for_poll = id.clone();
let (dc, fc) = tokio::task::spawn_blocking(move || {
for _ in 0..120 {
let conn = Connection::open(&path_for_poll).unwrap();
let row: (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
params![id_for_poll],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
if row.0 > 0 {
return row;
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
(0, 0)
})
.await
.unwrap();
assert_eq!(dc, 1, "5xx still increments dispatch_count");
assert_eq!(fc, 1, "5xx must increment failure_count");
}
#[tokio::test(flavor = "multi_thread")]
async fn dispatch_event_e2e_signature_present_when_secret_set() {
use wiremock::matchers::{header_exists, method, path};
use wiremock::{Mock, MockServer};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.and(header_exists("x-ai-memory-signature"))
.and(header_exists("x-ai-memory-timestamp"))
.respond_with(AckEcho)
.expect(1)
.mount(&server)
.await;
let (_keep, db_path) = fresh_db();
let _id = {
let conn = Connection::open(&db_path).unwrap();
let url = format!("{}/hook", server.uri());
insert(
&conn,
&NewSubscription {
url: &url,
events: "*",
secret: Some("the-secret"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_event(&conn, "memory_store", "m3", "ns", None, &db_path);
}
let server_ref = &server;
for _ in 0..50 {
let received = server_ref.received_requests().await.unwrap_or_default();
if !received.is_empty() {
let req = &received[0];
assert!(
req.headers.get("x-ai-memory-signature").is_some(),
"signature header must be present when secret set"
);
return;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
panic!("dispatch thread never reached the mock server");
}
#[test]
fn approval_requested_event_in_canonical_list() {
assert!(
WEBHOOK_EVENT_TYPES.contains(&"approval_requested"),
"K4: WEBHOOK_EVENT_TYPES must include approval_requested"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn approval_requested_dispatches_to_opt_in_subscriber() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(AckEcho)
.mount(&server)
.await;
let (_keep, db_path) = fresh_db();
let url = format!("{}/hook", server.uri());
let opt_in: Vec<String> = vec!["approval_requested".to_string()];
let sub_id = {
let conn = Connection::open(&db_path).unwrap();
insert(
&conn,
&NewSubscription {
url: &url,
events: "approval_requested",
secret: Some("test-sub-secret"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: Some(&opt_in),
},
)
.unwrap()
};
let pending_id = {
let conn = Connection::open(&db_path).unwrap();
crate::db::queue_pending_action(
&conn,
crate::models::GovernedAction::Store,
"k4-ns",
None,
"agent-requestor",
&serde_json::json!({"title": "k4 approval routing"}),
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_approval_requested(&conn, &pending_id, &db_path);
}
let mut received = Vec::new();
for _ in 0..50 {
received = server.received_requests().await.unwrap_or_default();
if !received.is_empty() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
assert_eq!(
received.len(),
1,
"K4: opt-in subscriber must receive exactly one approval_requested POST"
);
let body: serde_json::Value =
serde_json::from_slice(&received[0].body).expect("dispatch body must be JSON");
assert_eq!(body["event"], "approval_requested");
assert_eq!(body["memory_id"], pending_id);
assert_eq!(body["namespace"], "k4-ns");
assert_eq!(body["agent_id"], "agent-requestor");
assert_eq!(body["action_type"], "store");
assert_eq!(body["status"], "pending");
assert!(
body["requested_at"].is_string(),
"requested_at must round-trip from the row"
);
let conn = Connection::open(&db_path).unwrap();
let mut dc: i64 = 0;
for _ in 0..40 {
dc = conn
.query_row(
"SELECT dispatch_count FROM subscriptions WHERE id = ?1",
params![sub_id],
|r| r.get(0),
)
.unwrap();
if dc == 1 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
assert_eq!(dc, 1, "dispatch_count must be 1 after successful dispatch");
}
#[test]
fn approval_requested_skipped_for_filtered_subscriber() {
let (_keep, db_path) = fresh_db();
let opt_in_other: Vec<String> = vec!["memory_store".to_string()];
let sub_id = {
let conn = Connection::open(&db_path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/hook",
events: "memory_store",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: Some(&opt_in_other),
},
)
.unwrap()
};
let pending_id = {
let conn = Connection::open(&db_path).unwrap();
crate::db::queue_pending_action(
&conn,
crate::models::GovernedAction::Delete,
"k4-ns-2",
Some("memory-xyz"),
"agent-requestor",
&serde_json::json!({"id": "memory-xyz"}),
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_approval_requested(&conn, &pending_id, &db_path);
}
std::thread::sleep(std::time::Duration::from_millis(200));
let conn = Connection::open(&db_path).unwrap();
let (dc, fc): (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
params![sub_id],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(dc, 0, "filter mismatch must skip dispatch");
assert_eq!(fc, 0);
}
#[test]
fn approval_requested_missing_pending_row_is_noop() {
let (_keep, db_path) = fresh_db();
let sub_id = {
let conn = Connection::open(&db_path).unwrap();
insert(
&conn,
&NewSubscription {
url: "https://example.com/hook",
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
let conn = Connection::open(&db_path).unwrap();
dispatch_approval_requested(&conn, "nonexistent-id", &db_path);
let (dc, fc): (i64, i64) = conn
.query_row(
"SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
params![sub_id],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(dc, 0, "missing pending row must not dispatch");
assert_eq!(fc, 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn k6_dispatch_persists_uuidv7_correlation_id() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(AckEcho)
.mount(&server)
.await;
let (_keep, db_path) = fresh_db();
let url = format!("{}/hook", server.uri());
let sub_id = {
let conn = Connection::open(&db_path).unwrap();
insert(
&conn,
&NewSubscription {
url: &url,
events: "*",
secret: Some("test-sub-secret"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_event(&conn, "memory_store", "k6-mem", "k6-ns", None, &db_path);
}
let path_for_poll = db_path.clone();
let sub_for_poll = sub_id.clone();
let row = tokio::task::spawn_blocking(move || {
for _ in 0..50 {
let conn = Connection::open(&path_for_poll).unwrap();
let r: Option<(String, String, String)> = conn
.query_row(
"SELECT correlation_id, payload, delivery_status \
FROM subscription_events WHERE subscription_id = ?1",
params![sub_for_poll],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
)
.ok();
if let Some(r) = r
&& r.2 == "ack"
{
return Some(r);
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
None
})
.await
.unwrap();
let (corr, body, status) = row.expect("audit row must reach ack status");
assert_eq!(status, "ack");
let parsed = uuid::Uuid::parse_str(&corr).expect("UUIDv7 string");
assert_eq!(parsed.get_version_num(), 7, "correlation_id must be UUIDv7");
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(
json["correlation_id"].as_str(),
Some(corr.as_str()),
"payload correlation_id must match audit row"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn k6_500_after_retries_lands_in_dlq() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/hook"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let (_keep, db_path) = fresh_db();
let url = format!("{}/hook", server.uri());
let sub_id = {
let conn = Connection::open(&db_path).unwrap();
insert(
&conn,
&NewSubscription {
url: &url,
events: "*",
secret: Some("test-sub-secret"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
{
let conn = Connection::open(&db_path).unwrap();
dispatch_event(&conn, "memory_store", "k6-fail", "k6-ns", None, &db_path);
}
let path_for_poll = db_path.clone();
let sub_for_poll = sub_id.clone();
let dlq_row = tokio::task::spawn_blocking(move || {
for _ in 0..120 {
let conn = Connection::open(&path_for_poll).unwrap();
let entries = list_dlq(&conn, Some(&sub_for_poll)).unwrap();
if !entries.is_empty() {
return Some(entries);
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
None
})
.await
.unwrap()
.expect("DLQ row must appear after retry ladder exhaustion");
assert_eq!(dlq_row.len(), 1, "exactly one DLQ row per failed delivery");
let row = &dlq_row[0];
assert_eq!(row.subscription_id, sub_id);
assert_eq!(row.event_type, "memory_store");
assert_eq!(
row.retry_count,
(RETRY_BACKOFFS.len() as i64) + 1,
"retry_count = initial attempt + RETRY_BACKOFFS.len() retries"
);
assert!(
row.last_error.starts_with("http-5"),
"last_error must record the 5xx status: {}",
row.last_error
);
assert!(!row.first_failed_at.is_empty());
assert!(!row.last_failed_at.is_empty());
let conn = Connection::open(&db_path).unwrap();
let status: String = conn
.query_row(
"SELECT delivery_status FROM subscription_events WHERE correlation_id = ?1",
params![row.correlation_id],
|r| r.get(0),
)
.unwrap();
assert_eq!(status, "failed");
}
#[tokio::test(flavor = "multi_thread")]
async fn k6_replay_subscription_events_returns_rows_since_cursor() {
let (_keep, db_path) = fresh_db();
let url = "https://example.com/hook";
let sub_id = {
let conn = Connection::open(&db_path).unwrap();
insert(
&conn,
&NewSubscription {
url,
events: "*",
secret: None,
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap()
};
let conn = Connection::open(&db_path).unwrap();
conn.execute(
"INSERT INTO subscription_events \
(subscription_id, correlation_id, event_type, payload, delivered_at, delivery_status) \
VALUES (?1, ?2, 'memory_store', '{}', '2026-01-01T00:00:00Z', 'ack')",
params![sub_id, "c-old"],
)
.unwrap();
conn.execute(
"INSERT INTO subscription_events \
(subscription_id, correlation_id, event_type, payload, delivered_at, delivery_status) \
VALUES (?1, ?2, 'memory_store', '{}', '2026-05-05T00:00:00Z', 'ack')",
params![sub_id, "c-new"],
)
.unwrap();
let after = replay_subscription_events(&conn, &sub_id, "2026-03-01T00:00:00Z")
.expect("replay query");
assert_eq!(after.len(), 1, "cursor must filter to the newer row");
assert_eq!(after[0].correlation_id, "c-new");
let envelope = memory_subscription_replay(&conn, &sub_id, "2026-03-01T00:00:00Z").unwrap();
assert_eq!(envelope["count"], 1);
assert_eq!(envelope["events"][0]["correlation_id"], "c-new");
}
#[test]
fn issue_1253_dlq_overflow_cap_refuses_past_max() {
let (_keep, db_path) = fresh_db();
let conn = Connection::open(&db_path).unwrap();
let sub_id = insert(
&conn,
&NewSubscription {
url: "https://example.com/hook",
events: "*",
secret: Some("s"),
namespace_filter: None,
agent_filter: None,
created_by: None,
event_types: None,
},
)
.unwrap();
for i in 0..MAX_SUBSCRIPTION_DLQ_ROWS {
let corr = format!("c-{i}");
record_dlq_with_conn(
&conn,
&sub_id,
&corr,
"memory_store",
"{}",
4,
"http-500",
"2026-01-01T00:00:00Z",
"2026-01-01T00:00:00Z",
)
.expect("inserts below cap must succeed");
}
let depth: i64 = conn
.query_row(
"SELECT COUNT(*) FROM subscription_dlq WHERE subscription_id = ?1",
params![&sub_id],
|r| r.get(0),
)
.unwrap();
assert_eq!(depth, MAX_SUBSCRIPTION_DLQ_ROWS, "DLQ should fill to cap");
let before = crate::metrics::subscription_dlq_overflow_count();
let res = record_dlq_with_conn(
&conn,
&sub_id,
"c-overflow",
"memory_store",
"{}",
4,
"http-500",
"2026-01-01T00:00:00Z",
"2026-01-01T00:00:00Z",
);
let err = res.expect_err("over-cap insert must be refused");
let msg = format!("{err}");
assert!(
msg.contains("dlq_overflow"),
"error must carry the dlq_overflow tag for operators: {msg}"
);
let depth_after: i64 = conn
.query_row(
"SELECT COUNT(*) FROM subscription_dlq WHERE subscription_id = ?1",
params![&sub_id],
|r| r.get(0),
)
.unwrap();
assert_eq!(
depth_after, MAX_SUBSCRIPTION_DLQ_ROWS,
"refused insert must not leak a row into subscription_dlq"
);
let after = crate::metrics::subscription_dlq_overflow_count();
assert!(
after >= before + 1,
"subscription_dlq_overflow_total did not advance (before={before}, after={after})"
);
}
}
#[cfg(test)]
mod hex {
pub fn encode_fallback(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
}
#[test]
fn webhook_signing_with_unicode_payload() {
let payload = serde_json::json!({
"event": "memory_store",
"memory_id": "m1",
"namespace": "café",
"agent_id": null,
"delivered_at": "2026-01-01T00:00:00Z"
});
let body = serde_json::to_string(&payload).unwrap();
let key_hex = sha256_hex("secret-with-café");
let sig = hmac_sha256_hex(&key_hex, &body);
assert!(!sig.is_empty());
assert_eq!(sig.len(), 64); }
#[test]
fn webhook_retries_on_5xx_response() {
let status_2xx = true; let status_5xx = false; assert_ne!(status_2xx, status_5xx);
}
#[test]
fn webhook_does_not_retry_on_4xx_response() {
let status_4xx = false;
let status_success = true;
assert_ne!(status_4xx, status_success);
}
#[test]
fn namespace_pattern_matches_glob_correctly() {
assert!(matches_filters(
"*",
None,
Some("app"),
None,
"memory_store",
"app",
None
));
assert!(!matches_filters(
"*",
None,
Some("app"),
None,
"memory_store",
"other",
None
));
assert!(matches_filters(
"*",
None,
Some(""),
None,
"memory_store",
"any_ns",
None
));
}