1use std::net::{IpAddr, ToSocketAddrs};
24use std::str::FromStr;
25
26use anyhow::{Context, Result, anyhow};
27use rusqlite::{Connection, params};
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Subscription {
34 pub id: String,
35 pub url: String,
36 pub events: String,
37 pub namespace_filter: Option<String>,
38 pub agent_filter: Option<String>,
39 pub created_by: Option<String>,
40 pub created_at: String,
41 pub dispatch_count: i64,
42 pub failure_count: i64,
43}
44
45pub struct NewSubscription<'a> {
47 pub url: &'a str,
48 pub events: &'a str,
49 pub secret: Option<&'a str>,
50 pub namespace_filter: Option<&'a str>,
51 pub agent_filter: Option<&'a str>,
52 pub created_by: Option<&'a str>,
53}
54
55pub fn insert(conn: &Connection, req: &NewSubscription<'_>) -> Result<String> {
59 validate_url(req.url)?;
60 let id = uuid::Uuid::new_v4().to_string();
61 let secret_hash = req.secret.map(sha256_hex);
62 let now = chrono::Utc::now().to_rfc3339();
63 conn.execute(
64 "INSERT INTO subscriptions (id, url, events, secret_hash, namespace_filter, agent_filter, created_by, created_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
65 params![id, req.url, req.events, secret_hash, req.namespace_filter, req.agent_filter, req.created_by, now],
66 )?;
67 Ok(id)
68}
69
70pub fn delete(conn: &Connection, id: &str) -> Result<bool> {
72 let n = conn.execute("DELETE FROM subscriptions WHERE id = ?1", params![id])?;
73 Ok(n > 0)
74}
75
76pub fn list(conn: &Connection) -> Result<Vec<Subscription>> {
78 let mut stmt = conn.prepare(
79 "SELECT id, url, events, namespace_filter, agent_filter, created_by, created_at, dispatch_count, failure_count FROM subscriptions ORDER BY created_at DESC",
80 )?;
81 let rows = stmt.query_map([], |row| {
82 Ok(Subscription {
83 id: row.get(0)?,
84 url: row.get(1)?,
85 events: row.get(2)?,
86 namespace_filter: row.get(3)?,
87 agent_filter: row.get(4)?,
88 created_by: row.get(5)?,
89 created_at: row.get(6)?,
90 dispatch_count: row.get(7)?,
91 failure_count: row.get(8)?,
92 })
93 })?;
94 rows.collect::<rusqlite::Result<Vec<_>>>()
95 .context("subscription row decode failed")
96}
97
98fn matches_filters(
100 sub_events: &str,
101 sub_namespace: Option<&str>,
102 sub_agent: Option<&str>,
103 event: &str,
104 namespace: &str,
105 agent: Option<&str>,
106) -> bool {
107 let event_match = sub_events == "*"
109 || sub_events
110 .split(',')
111 .map(str::trim)
112 .any(|e| e == event || e == "*");
113 if !event_match {
114 return false;
115 }
116 if let Some(ns) = sub_namespace
117 && !ns.is_empty()
118 && ns != namespace
119 {
120 return false;
121 }
122 if let Some(filter) = sub_agent
123 && !filter.is_empty()
124 && agent.is_none_or(|a| a != filter)
125 {
126 return false;
127 }
128 true
129}
130
131#[derive(Serialize)]
133struct DispatchPayload<'a> {
134 event: &'a str,
135 memory_id: &'a str,
136 namespace: &'a str,
137 agent_id: Option<&'a str>,
138 delivered_at: String,
139}
140
141pub fn dispatch_event(
149 conn: &Connection,
150 event: &str,
151 memory_id: &str,
152 namespace: &str,
153 agent_id: Option<&str>,
154 db_path: &std::path::Path,
155) {
156 let subs = match list(conn) {
157 Ok(s) => s,
158 Err(e) => {
159 tracing::warn!("subscription list failed during dispatch: {e}");
160 return;
161 }
162 };
163 let matching: Vec<Subscription> = subs
164 .into_iter()
165 .filter(|s| {
166 matches_filters(
167 &s.events,
168 s.namespace_filter.as_deref(),
169 s.agent_filter.as_deref(),
170 event,
171 namespace,
172 agent_id,
173 )
174 })
175 .collect();
176 if matching.is_empty() {
177 return;
178 }
179 let payload = DispatchPayload {
180 event,
181 memory_id,
182 namespace,
183 agent_id,
184 delivered_at: chrono::Utc::now().to_rfc3339(),
185 };
186 let body = match serde_json::to_string(&payload) {
187 Ok(s) => s,
188 Err(e) => {
189 tracing::warn!("dispatch payload serialize failed: {e}");
190 return;
191 }
192 };
193 let timestamp = chrono::Utc::now().timestamp().to_string();
198 for sub in matching {
199 let url = sub.url.clone();
200 let sub_id = sub.id.clone();
201 let body = body.clone();
202 let ts = timestamp.clone();
203 let db_path = db_path.to_path_buf();
204 std::thread::spawn(move || {
205 let secret_hash = match load_secret_hash(&db_path, &sub_id) {
206 Ok(s) => s,
207 Err(e) => {
208 tracing::warn!("subscription secret lookup failed: {e}");
209 return;
210 }
211 };
212 let canonical = format!("{ts}.{body}");
217 let signature = secret_hash
218 .as_deref()
219 .map(|h| hmac_sha256_hex(h, &canonical));
220 let ok = send(&url, &body, &ts, signature.as_deref());
221 record_dispatch(&db_path, &sub_id, ok);
222 });
223 }
224}
225
226fn send(url: &str, body: &str, timestamp: &str, signature: Option<&str>) -> bool {
229 if let Err(e) = validate_url(url) {
230 tracing::warn!("SSRF guard rejected webhook URL {url}: {e}");
231 return false;
232 }
233 if let Err(e) = validate_url_dns(url) {
239 tracing::warn!("DNS SSRF guard rejected webhook URL {url}: {e}");
240 return false;
241 }
242 let client = match reqwest::blocking::Client::builder()
243 .timeout(std::time::Duration::from_secs(10))
244 .build()
245 {
246 Ok(c) => c,
247 Err(e) => {
248 tracing::warn!("webhook client build failed: {e}");
249 return false;
250 }
251 };
252 let mut req = client
253 .post(url)
254 .header("content-type", "application/json")
255 .header("user-agent", "ai-memory/0.6.0.0")
256 .header("x-ai-memory-timestamp", timestamp);
257 if let Some(sig) = signature {
258 req = req.header("x-ai-memory-signature", format!("sha256={sig}"));
259 }
260 match req.body(body.to_string()).send() {
261 Ok(resp) => resp.status().is_success(),
262 Err(e) => {
263 tracing::warn!("webhook POST to {url} failed: {e}");
264 false
265 }
266 }
267}
268
269fn sha256_hex(s: &str) -> String {
271 let mut hasher = Sha256::new();
272 hasher.update(s.as_bytes());
273 format!("{:x}", hasher.finalize())
274}
275
276fn hmac_sha256_hex(key_hex: &str, body: &str) -> String {
281 const BLOCK: usize = 64;
282 let mut key = hex_decode(key_hex).unwrap_or_else(|| key_hex.as_bytes().to_vec());
287 if key.len() > BLOCK {
288 let mut h = Sha256::new();
289 h.update(&key);
290 key = h.finalize().to_vec();
291 }
292 key.resize(BLOCK, 0);
293 let mut opad = [0x5cu8; BLOCK];
294 let mut ipad = [0x36u8; BLOCK];
295 for i in 0..BLOCK {
296 opad[i] ^= key[i];
297 ipad[i] ^= key[i];
298 }
299 let mut inner = Sha256::new();
300 inner.update(ipad);
301 inner.update(body.as_bytes());
302 let inner_digest = inner.finalize();
303 let mut outer = Sha256::new();
304 outer.update(opad);
305 outer.update(inner_digest);
306 format!("{:x}", outer.finalize())
307}
308
309fn hex_decode(s: &str) -> Option<Vec<u8>> {
310 if !s.len().is_multiple_of(2) {
311 return None;
312 }
313 (0..s.len())
314 .step_by(2)
315 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).ok())
316 .collect()
317}
318
319pub fn validate_url_dns(url: &str) -> Result<()> {
329 let lower = url.to_ascii_lowercase();
330 let (_scheme, rest) = lower
331 .split_once("://")
332 .ok_or_else(|| anyhow!("webhook URL missing scheme: {url}"))?;
333 let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
334 let host_port = &rest[..host_end];
335 let resolv_target =
343 if let Some(close_idx) = host_port.strip_prefix('[').and(host_port.find(']')) {
344 let after_bracket = &host_port[close_idx + 1..];
345 if after_bracket.starts_with(':') {
346 host_port.to_string()
348 } else {
349 format!("{host_port}:80")
351 }
352 } else if host_port.contains(':') {
353 host_port.to_string()
355 } else {
356 format!("{host_port}:80")
357 };
358 let addrs: Vec<std::net::SocketAddr> = match resolv_target.to_socket_addrs() {
359 Ok(iter) => iter.collect(),
360 Err(_) => return Ok(()), };
362 for addr in &addrs {
363 let ip = addr.ip();
364 if is_private(ip) && !ip.is_loopback() {
365 return Err(anyhow!(
366 "host resolves to private/link-local IP {ip}: {url}"
367 ));
368 }
369 }
370 Ok(())
371}
372
373pub fn validate_url(url: &str) -> Result<()> {
377 let lower = url.to_ascii_lowercase();
379 let (scheme, rest) = lower
380 .split_once("://")
381 .ok_or_else(|| anyhow!("webhook URL missing scheme: {url}"))?;
382 if scheme != "https" && scheme != "http" {
383 return Err(anyhow!("webhook URL scheme must be http(s): {url}"));
384 }
385 let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
389 let host_port = &rest[..host_end];
390 let host: String = if let Some(stripped) = host_port.strip_prefix('[') {
391 match stripped.find(']') {
393 Some(i) => stripped[..i].to_string(),
394 None => return Err(anyhow!("malformed IPv6 URL host: {url}")),
395 }
396 } else {
397 host_port
399 .rsplit_once(':')
400 .map_or(host_port.to_string(), |(h, _)| h.to_string())
401 };
402 let host = host.as_str();
403 let is_loopback_hostname = matches!(host, "localhost" | "localhost.localdomain" | "");
405 if scheme == "http" && !is_loopback_hostname {
406 if let Ok(ip) = IpAddr::from_str(host) {
409 if !ip.is_loopback() {
410 return Err(anyhow!(
411 "webhook URL must be https for non-loopback host: {url}"
412 ));
413 }
414 } else {
415 return Err(anyhow!(
416 "webhook URL must be https for non-loopback host: {url}"
417 ));
418 }
419 }
420 if let Ok(ip) = IpAddr::from_str(host)
426 && is_private(ip)
427 && !ip.is_loopback()
428 {
429 return Err(anyhow!(
430 "webhook URL targets private / link-local address: {url}"
431 ));
432 }
433 Ok(())
434}
435
436fn is_private(ip: IpAddr) -> bool {
437 match ip {
438 IpAddr::V4(v4) => {
439 v4.is_private()
444 || v4.is_link_local()
445 || v4.is_multicast()
446 || v4.is_broadcast()
447 || v4.is_unspecified()
448 }
449 IpAddr::V6(v6) => {
450 let segs = v6.segments();
455 v6.is_multicast()
456 || v6.is_unspecified()
457 || (segs[0] & 0xfe00) == 0xfc00 || (segs[0] & 0xffc0) == 0xfe80 }
460 }
461}
462
463fn load_secret_hash(db_path: &std::path::Path, sub_id: &str) -> Result<Option<String>> {
464 let conn = Connection::open(db_path).context("load_secret_hash open")?;
465 let row = conn
466 .query_row(
467 "SELECT secret_hash FROM subscriptions WHERE id = ?1",
468 params![sub_id],
469 |r| r.get::<_, Option<String>>(0),
470 )
471 .context("load_secret_hash query")?;
472 Ok(row)
473}
474
475fn record_dispatch(db_path: &std::path::Path, sub_id: &str, ok: bool) {
476 let Ok(conn) = Connection::open(db_path) else {
477 return;
478 };
479 let now = chrono::Utc::now().to_rfc3339();
480 let sql = if ok {
481 "UPDATE subscriptions SET dispatch_count = dispatch_count + 1, last_dispatched_at = ?1 WHERE id = ?2"
482 } else {
483 "UPDATE subscriptions SET dispatch_count = dispatch_count + 1, failure_count = failure_count + 1, last_dispatched_at = ?1 WHERE id = ?2"
484 };
485 let _ = conn.execute(sql, params![now, sub_id]);
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn https_allowed() {
494 assert!(validate_url("https://example.com/hook").is_ok());
495 assert!(validate_url("https://api.example.com:8443/hook?x=1").is_ok());
496 }
497
498 #[test]
499 fn http_only_to_loopback() {
500 assert!(validate_url("http://localhost/hook").is_ok());
501 assert!(validate_url("http://127.0.0.1:8080/hook").is_ok());
502 assert!(validate_url("http://[::1]/hook").is_ok());
504 assert!(validate_url("http://example.com/hook").is_err());
505 assert!(validate_url("http://8.8.8.8/hook").is_err());
506 }
507
508 #[test]
509 fn private_ranges_blocked() {
510 assert!(validate_url("https://10.0.0.1/hook").is_err());
511 assert!(validate_url("https://192.168.1.1/hook").is_err());
512 assert!(validate_url("https://172.16.0.1/hook").is_err());
513 assert!(validate_url("https://169.254.1.1/hook").is_err());
514 assert!(validate_url("https://[fc00::1]/hook").is_err());
515 assert!(validate_url("https://[fe80::1]/hook").is_err());
516 }
517
518 #[test]
519 fn nonsense_rejected() {
520 assert!(validate_url("ftp://example.com").is_err());
521 assert!(validate_url("notaurl").is_err());
522 assert!(validate_url("").is_err());
523 }
524
525 #[test]
526 fn hmac_sha256_stable() {
527 let key = hex::encode_fallback("key".as_bytes());
530 let got = hmac_sha256_hex(&key, "The quick brown fox jumps over the lazy dog");
531 assert_eq!(
532 got,
533 "f7bc83f430538424b13298e6aa6fb143ef4d59a14946175997479dbc2d1a3cd8"
534 );
535 }
536
537 #[test]
538 fn filter_wildcards() {
539 assert!(matches_filters("*", None, None, "memory_store", "ns", None));
540 assert!(matches_filters(
541 "memory_store,memory_delete",
542 None,
543 None,
544 "memory_store",
545 "ns",
546 None
547 ));
548 assert!(!matches_filters(
549 "memory_delete",
550 None,
551 None,
552 "memory_store",
553 "ns",
554 None
555 ));
556 assert!(matches_filters(
557 "*",
558 Some("foo"),
559 None,
560 "memory_store",
561 "foo",
562 None
563 ));
564 assert!(!matches_filters(
565 "*",
566 Some("foo"),
567 None,
568 "memory_store",
569 "bar",
570 None
571 ));
572 assert!(matches_filters(
573 "*",
574 None,
575 Some("alice"),
576 "memory_store",
577 "ns",
578 Some("alice")
579 ));
580 assert!(!matches_filters(
581 "*",
582 None,
583 Some("alice"),
584 "memory_store",
585 "ns",
586 Some("bob")
587 ));
588 }
589
590 #[test]
613 fn test_validate_url_dns_accepts_loopback_v4() {
614 assert!(
619 validate_url_dns("http://127.0.0.1/foo").is_ok(),
620 "127.0.0.1 should be accepted by validate_url_dns (dev/CI)"
621 );
622 assert!(
623 validate_url_dns("http://127.0.0.1:8080/").is_ok(),
624 "127.0.0.1:8080 should be accepted by validate_url_dns"
625 );
626 assert!(
627 validate_url_dns("http://localhost/").is_ok(),
628 "localhost should be accepted by validate_url_dns"
629 );
630 }
631
632 #[test]
633 fn test_validate_url_dns_accepts_loopback_v6() {
634 assert!(
636 validate_url_dns("http://[::1]/").is_ok(),
637 "[::1] should be accepted by validate_url_dns"
638 );
639 assert!(
640 validate_url_dns("http://[0:0:0:0:0:0:0:1]/").is_ok(),
641 "[::1] expanded form should be accepted"
642 );
643 }
644
645 #[test]
646 fn test_validate_url_dns_rejects_link_local_ipv6() {
647 let res = validate_url_dns("http://[fe80::1]/");
653 assert!(
654 res.is_err(),
655 "fe80::1 must be rejected as link-local IPv6, got {res:?}"
656 );
657 }
658
659 #[test]
660 fn test_validate_url_dns_rejects_aws_metadata() {
661 let res = validate_url_dns("http://169.254.169.254/latest/meta-data/");
665 assert!(
666 res.is_err(),
667 "AWS metadata IP must be rejected, got {res:?}"
668 );
669 }
670
671 #[test]
672 fn test_validate_url_dns_rejects_rfc1918_private_ranges() {
673 for url in [
677 "http://10.0.0.1/",
678 "http://172.16.0.1/",
679 "http://172.31.255.255/",
680 "http://192.168.1.1/",
681 ] {
682 let res = validate_url_dns(url);
683 assert!(
684 res.is_err(),
685 "{url} must be rejected as RFC1918, got {res:?}"
686 );
687 }
688 }
689
690 #[test]
691 fn test_validate_url_dns_accepts_public_ip_or_dns() {
692 assert!(
697 validate_url_dns("https://1.1.1.1/").is_ok(),
698 "public IP literal must be accepted"
699 );
700 assert!(
704 validate_url_dns("https://example.com/").is_ok(),
705 "public hostname must be accepted (or DNS-skip path returns Ok)"
706 );
707 }
708
709 #[test]
710 fn test_validate_url_dns_rejects_unspecified_addresses() {
711 let v4 = validate_url_dns("http://0.0.0.0/");
717 let v6 = validate_url_dns("http://[::]/");
718 assert!(
719 v4.is_err(),
720 "0.0.0.0 should be rejected as unspecified, got {v4:?}"
721 );
722 assert!(
723 v6.is_err(),
724 "[::] should be rejected as unspecified, got {v6:?}"
725 );
726 }
727
728 #[test]
729 fn test_validate_url_dns_missing_scheme() {
730 let res = validate_url_dns("not-a-url");
732 assert!(res.is_err(), "missing scheme must Err, got {res:?}");
733 }
734
735 use tempfile::NamedTempFile;
755
756 fn fresh_db() -> (NamedTempFile, std::path::PathBuf) {
760 let f = NamedTempFile::new().expect("tempfile");
761 let p = f.path().to_path_buf();
762 let _ = crate::db::open(&p).expect("db::open");
764 (f, p)
765 }
766
767 #[test]
770 fn insert_persists_and_list_returns_row() {
771 let (_keep, path) = fresh_db();
772 let conn = Connection::open(&path).unwrap();
773 let id = insert(
774 &conn,
775 &NewSubscription {
776 url: "https://example.com/hook",
777 events: "memory_store",
778 secret: Some("s3cret"),
779 namespace_filter: Some("ns1"),
780 agent_filter: Some("alice"),
781 created_by: Some("op"),
782 },
783 )
784 .unwrap();
785 assert!(!id.is_empty());
786
787 let subs = list(&conn).unwrap();
788 assert_eq!(subs.len(), 1);
789 let s = &subs[0];
790 assert_eq!(s.id, id);
791 assert_eq!(s.url, "https://example.com/hook");
792 assert_eq!(s.events, "memory_store");
793 assert_eq!(s.namespace_filter.as_deref(), Some("ns1"));
794 assert_eq!(s.agent_filter.as_deref(), Some("alice"));
795 assert_eq!(s.created_by.as_deref(), Some("op"));
796 assert_eq!(s.dispatch_count, 0);
797 assert_eq!(s.failure_count, 0);
798 }
799
800 #[test]
801 fn insert_rejects_invalid_url() {
802 let (_keep, path) = fresh_db();
803 let conn = Connection::open(&path).unwrap();
804 let res = insert(
805 &conn,
806 &NewSubscription {
807 url: "not-a-url",
808 events: "*",
809 secret: None,
810 namespace_filter: None,
811 agent_filter: None,
812 created_by: None,
813 },
814 );
815 assert!(res.is_err(), "insert must reject invalid URL");
816 }
817
818 #[test]
819 fn insert_hashes_secret_before_persisting() {
820 let (_keep, path) = fresh_db();
821 let conn = Connection::open(&path).unwrap();
822 let plaintext = "super-shared-secret";
823 let id = insert(
824 &conn,
825 &NewSubscription {
826 url: "https://example.com/h",
827 events: "*",
828 secret: Some(plaintext),
829 namespace_filter: None,
830 agent_filter: None,
831 created_by: None,
832 },
833 )
834 .unwrap();
835 let stored: Option<String> = conn
836 .query_row(
837 "SELECT secret_hash FROM subscriptions WHERE id = ?1",
838 params![id],
839 |r| r.get(0),
840 )
841 .unwrap();
842 let hash = stored.expect("secret_hash should be set");
843 assert_ne!(hash, plaintext, "plaintext secret must not be stored");
844 assert_eq!(hash, sha256_hex(plaintext));
845 }
846
847 #[test]
848 fn insert_no_secret_stores_null() {
849 let (_keep, path) = fresh_db();
850 let conn = Connection::open(&path).unwrap();
851 let id = insert(
852 &conn,
853 &NewSubscription {
854 url: "https://example.com/h",
855 events: "*",
856 secret: None,
857 namespace_filter: None,
858 agent_filter: None,
859 created_by: None,
860 },
861 )
862 .unwrap();
863 let stored: Option<String> = conn
864 .query_row(
865 "SELECT secret_hash FROM subscriptions WHERE id = ?1",
866 params![id],
867 |r| r.get(0),
868 )
869 .unwrap();
870 assert!(stored.is_none(), "missing secret must persist as NULL");
871 }
872
873 #[test]
874 fn delete_returns_true_when_row_removed() {
875 let (_keep, path) = fresh_db();
876 let conn = Connection::open(&path).unwrap();
877 let id = insert(
878 &conn,
879 &NewSubscription {
880 url: "https://example.com/h",
881 events: "*",
882 secret: None,
883 namespace_filter: None,
884 agent_filter: None,
885 created_by: None,
886 },
887 )
888 .unwrap();
889 assert!(delete(&conn, &id).unwrap());
890 assert!(list(&conn).unwrap().is_empty());
891 }
892
893 #[test]
894 fn delete_returns_false_when_row_missing() {
895 let (_keep, path) = fresh_db();
896 let conn = Connection::open(&path).unwrap();
897 assert!(!delete(&conn, "nope").unwrap());
898 }
899
900 #[test]
901 fn list_orders_by_created_at_desc() {
902 let (_keep, path) = fresh_db();
903 let conn = Connection::open(&path).unwrap();
904 let id1 = insert(
907 &conn,
908 &NewSubscription {
909 url: "https://a.example.com/",
910 events: "*",
911 secret: None,
912 namespace_filter: None,
913 agent_filter: None,
914 created_by: None,
915 },
916 )
917 .unwrap();
918 std::thread::sleep(std::time::Duration::from_millis(1100));
919 let id2 = insert(
920 &conn,
921 &NewSubscription {
922 url: "https://b.example.com/",
923 events: "*",
924 secret: None,
925 namespace_filter: None,
926 agent_filter: None,
927 created_by: None,
928 },
929 )
930 .unwrap();
931 let subs = list(&conn).unwrap();
932 assert_eq!(subs.len(), 2);
933 assert_eq!(subs[0].id, id2);
935 assert_eq!(subs[1].id, id1);
936 }
937
938 #[test]
941 fn sha256_hex_known_vector() {
942 assert_eq!(
944 sha256_hex(""),
945 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
946 );
947 assert_eq!(
949 sha256_hex("abc"),
950 "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
951 );
952 }
953
954 #[test]
955 fn hex_decode_round_trip_and_invalid() {
956 let s = "deadbeef";
958 let bytes = hex_decode(s).expect("valid hex");
959 assert_eq!(bytes, vec![0xde, 0xad, 0xbe, 0xef]);
960 assert!(hex_decode("abc").is_none());
962 assert!(hex_decode("zz").is_none());
964 }
965
966 #[test]
967 fn hmac_long_key_is_hashed_to_fit_block() {
968 let long_key: String = std::iter::repeat_n('a', 200).collect();
973 let sig = hmac_sha256_hex(&long_key, "hello");
974 assert_eq!(sig.len(), 64); }
976
977 #[test]
978 fn hmac_invalid_hex_key_falls_back_to_raw_bytes() {
979 let sig = hmac_sha256_hex("not-a-hex-key!!", "hello");
983 assert_eq!(sig.len(), 64);
984 assert!(sig.chars().all(|c| c.is_ascii_hexdigit()));
985 }
986
987 #[test]
990 fn matches_filters_event_with_whitespace_and_star() {
991 assert!(matches_filters(
993 "memory_store, *",
994 None,
995 None,
996 "anything",
997 "ns",
998 None,
999 ));
1000 assert!(matches_filters(
1002 " memory_delete , memory_store ",
1003 None,
1004 None,
1005 "memory_store",
1006 "ns",
1007 None,
1008 ));
1009 }
1010
1011 #[test]
1012 fn matches_filters_agent_filter_requires_some() {
1013 assert!(!matches_filters(
1015 "*",
1016 None,
1017 Some("alice"),
1018 "memory_store",
1019 "ns",
1020 None,
1021 ));
1022 }
1023
1024 #[test]
1027 fn record_dispatch_increments_counts_on_success() {
1028 let (_keep, path) = fresh_db();
1029 let id = {
1030 let conn = Connection::open(&path).unwrap();
1031 insert(
1032 &conn,
1033 &NewSubscription {
1034 url: "https://example.com/h",
1035 events: "*",
1036 secret: None,
1037 namespace_filter: None,
1038 agent_filter: None,
1039 created_by: None,
1040 },
1041 )
1042 .unwrap()
1043 };
1044 record_dispatch(&path, &id, true);
1045 record_dispatch(&path, &id, true);
1046 let conn = Connection::open(&path).unwrap();
1047 let (dc, fc): (i64, i64) = conn
1048 .query_row(
1049 "SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
1050 params![id],
1051 |r| Ok((r.get(0)?, r.get(1)?)),
1052 )
1053 .unwrap();
1054 assert_eq!(dc, 2, "two successful dispatches must bump dispatch_count");
1055 assert_eq!(fc, 0, "successes must not bump failure_count");
1056 }
1057
1058 #[test]
1059 fn record_dispatch_increments_failure_on_err() {
1060 let (_keep, path) = fresh_db();
1061 let id = {
1062 let conn = Connection::open(&path).unwrap();
1063 insert(
1064 &conn,
1065 &NewSubscription {
1066 url: "https://example.com/h",
1067 events: "*",
1068 secret: None,
1069 namespace_filter: None,
1070 agent_filter: None,
1071 created_by: None,
1072 },
1073 )
1074 .unwrap()
1075 };
1076 record_dispatch(&path, &id, false);
1077 let conn = Connection::open(&path).unwrap();
1078 let (dc, fc): (i64, i64) = conn
1079 .query_row(
1080 "SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
1081 params![id],
1082 |r| Ok((r.get(0)?, r.get(1)?)),
1083 )
1084 .unwrap();
1085 assert_eq!(dc, 1, "failed dispatch still bumps dispatch_count");
1086 assert_eq!(fc, 1, "failure must bump failure_count");
1087 }
1088
1089 #[test]
1090 fn record_dispatch_nonexistent_id_does_not_panic() {
1091 let (_keep, path) = fresh_db();
1092 record_dispatch(&path, "no-such-id", true);
1095 record_dispatch(&path, "no-such-id", false);
1096 let conn = Connection::open(&path).unwrap();
1098 let n: i64 = conn
1099 .query_row("SELECT COUNT(*) FROM subscriptions", [], |r| r.get(0))
1100 .unwrap();
1101 assert_eq!(n, 0);
1102 }
1103
1104 #[test]
1105 fn record_dispatch_unopenable_db_path_is_noop() {
1106 let bad = std::path::PathBuf::from("/nonexistent-dir-w12c/does-not-exist.db");
1110 record_dispatch(&bad, "x", true);
1111 }
1112
1113 #[test]
1114 fn load_secret_hash_returns_stored_hash() {
1115 let (_keep, path) = fresh_db();
1116 let id = {
1117 let conn = Connection::open(&path).unwrap();
1118 insert(
1119 &conn,
1120 &NewSubscription {
1121 url: "https://example.com/h",
1122 events: "*",
1123 secret: Some("topsecret"),
1124 namespace_filter: None,
1125 agent_filter: None,
1126 created_by: None,
1127 },
1128 )
1129 .unwrap()
1130 };
1131 let got = load_secret_hash(&path, &id).unwrap();
1132 assert_eq!(got, Some(sha256_hex("topsecret")));
1133 }
1134
1135 #[test]
1136 fn load_secret_hash_missing_id_errs() {
1137 let (_keep, path) = fresh_db();
1138 let res = load_secret_hash(&path, "missing-id");
1141 assert!(res.is_err(), "missing subscription id must surface as Err");
1142 }
1143
1144 #[test]
1147 fn dispatch_event_no_subs_is_noop() {
1148 let (_keep, path) = fresh_db();
1149 let conn = Connection::open(&path).unwrap();
1150 dispatch_event(&conn, "memory_store", "m1", "ns", None, &path);
1153 }
1154
1155 #[test]
1156 fn dispatch_event_filter_mismatch_skips_send() {
1157 let (_keep, path) = fresh_db();
1164 let conn = Connection::open(&path).unwrap();
1165 insert(
1166 &conn,
1167 &NewSubscription {
1168 url: "https://example.com/h",
1169 events: "memory_delete",
1170 secret: None,
1171 namespace_filter: None,
1172 agent_filter: None,
1173 created_by: None,
1174 },
1175 )
1176 .unwrap();
1177 dispatch_event(&conn, "memory_store", "m1", "ns", None, &path);
1178 let (dc, fc): (i64, i64) = conn
1180 .query_row(
1181 "SELECT dispatch_count, failure_count FROM subscriptions",
1182 [],
1183 |r| Ok((r.get(0)?, r.get(1)?)),
1184 )
1185 .unwrap();
1186 assert_eq!(dc, 0);
1187 assert_eq!(fc, 0);
1188 }
1189
1190 #[test]
1191 fn dispatch_event_namespace_filter_mismatch_skips() {
1192 let (_keep, path) = fresh_db();
1193 let conn = Connection::open(&path).unwrap();
1194 insert(
1195 &conn,
1196 &NewSubscription {
1197 url: "https://example.com/h",
1198 events: "*",
1199 secret: None,
1200 namespace_filter: Some("only-this-ns"),
1201 agent_filter: None,
1202 created_by: None,
1203 },
1204 )
1205 .unwrap();
1206 dispatch_event(&conn, "memory_store", "m1", "other-ns", None, &path);
1208 let (dc, fc): (i64, i64) = conn
1209 .query_row(
1210 "SELECT dispatch_count, failure_count FROM subscriptions",
1211 [],
1212 |r| Ok((r.get(0)?, r.get(1)?)),
1213 )
1214 .unwrap();
1215 assert_eq!(dc, 0);
1216 assert_eq!(fc, 0);
1217 }
1218
1219 #[tokio::test(flavor = "multi_thread")]
1222 async fn send_returns_true_on_2xx() {
1223 use wiremock::matchers::{method, path};
1224 use wiremock::{Mock, MockServer, ResponseTemplate};
1225 let server = MockServer::start().await;
1226 Mock::given(method("POST"))
1227 .and(path("/hook"))
1228 .respond_with(ResponseTemplate::new(200))
1229 .expect(1)
1230 .mount(&server)
1231 .await;
1232 let url = format!("{}/hook", server.uri());
1233 let ok = tokio::task::spawn_blocking(move || {
1234 send(&url, "{\"event\":\"x\"}", "1700000000", Some("deadbeef"))
1235 })
1236 .await
1237 .unwrap();
1238 assert!(ok, "2xx must return true");
1239 }
1240
1241 #[tokio::test(flavor = "multi_thread")]
1242 async fn send_returns_false_on_5xx() {
1243 use wiremock::matchers::{method, path};
1244 use wiremock::{Mock, MockServer, ResponseTemplate};
1245 let server = MockServer::start().await;
1246 Mock::given(method("POST"))
1247 .and(path("/hook"))
1248 .respond_with(ResponseTemplate::new(500))
1249 .mount(&server)
1250 .await;
1251 let url = format!("{}/hook", server.uri());
1252 let ok = tokio::task::spawn_blocking(move || {
1253 send(&url, "{\"event\":\"x\"}", "1700000000", None)
1254 })
1255 .await
1256 .unwrap();
1257 assert!(!ok, "5xx must return false (no retry inside send)");
1258 }
1259
1260 #[tokio::test(flavor = "multi_thread")]
1261 async fn send_returns_false_on_4xx() {
1262 use wiremock::matchers::{method, path};
1263 use wiremock::{Mock, MockServer, ResponseTemplate};
1264 let server = MockServer::start().await;
1265 Mock::given(method("POST"))
1266 .and(path("/hook"))
1267 .respond_with(ResponseTemplate::new(404))
1268 .mount(&server)
1269 .await;
1270 let url = format!("{}/hook", server.uri());
1271 let ok = tokio::task::spawn_blocking(move || send(&url, "{}", "1700000000", None))
1272 .await
1273 .unwrap();
1274 assert!(!ok, "4xx must return false");
1275 }
1276
1277 #[tokio::test(flavor = "multi_thread")]
1278 async fn send_signature_header_set_when_provided() {
1279 use wiremock::matchers::{header, header_exists, method, path};
1280 use wiremock::{Mock, MockServer, ResponseTemplate};
1281 let server = MockServer::start().await;
1282 Mock::given(method("POST"))
1285 .and(path("/hook"))
1286 .and(header("x-ai-memory-signature", "sha256=abc123"))
1287 .and(header_exists("x-ai-memory-timestamp"))
1288 .and(header("content-type", "application/json"))
1289 .respond_with(ResponseTemplate::new(204))
1290 .expect(1)
1291 .mount(&server)
1292 .await;
1293 let url = format!("{}/hook", server.uri());
1294 let ok =
1295 tokio::task::spawn_blocking(move || send(&url, "{}", "1700000000", Some("abc123")))
1296 .await
1297 .unwrap();
1298 assert!(ok, "2xx with matched signature header must succeed");
1299 }
1300
1301 #[tokio::test(flavor = "multi_thread")]
1302 async fn send_no_signature_header_when_secret_absent() {
1303 use wiremock::matchers::{method, path};
1304 use wiremock::{Mock, MockServer, Request, ResponseTemplate};
1305 let server = MockServer::start().await;
1306 Mock::given(method("POST"))
1307 .and(path("/hook"))
1308 .respond_with(ResponseTemplate::new(202))
1309 .mount(&server)
1310 .await;
1311 let url = format!("{}/hook", server.uri());
1312 let ok = tokio::task::spawn_blocking({
1313 let url = url.clone();
1314 move || send(&url, "{}", "1700000000", None)
1315 })
1316 .await
1317 .unwrap();
1318 assert!(ok);
1319 let received: Vec<Request> = server.received_requests().await.unwrap_or_default();
1321 assert_eq!(received.len(), 1);
1322 let req = &received[0];
1323 assert!(
1325 req.headers.get("x-ai-memory-signature").is_none(),
1326 "no signature should be sent when secret absent"
1327 );
1328 assert!(
1329 req.headers.get("x-ai-memory-timestamp").is_some(),
1330 "timestamp header must always be set"
1331 );
1332 }
1333
1334 #[test]
1335 fn send_rejects_ssrf_url_without_network() {
1336 let ok = send("https://10.0.0.1/hook", "{}", "1700000000", None);
1340 assert!(!ok, "send must reject SSRF URL via validate_url guard");
1341 }
1342
1343 #[test]
1344 fn send_rejects_invalid_scheme_without_network() {
1345 let ok = send("ftp://example.com/hook", "{}", "1700000000", None);
1347 assert!(!ok, "send must reject non-http(s) URL");
1348 }
1349
1350 #[tokio::test(flavor = "multi_thread")]
1353 async fn dispatch_event_e2e_increments_dispatch_count_on_2xx() {
1354 use wiremock::matchers::{method, path};
1355 use wiremock::{Mock, MockServer, ResponseTemplate};
1356 let server = MockServer::start().await;
1357 Mock::given(method("POST"))
1358 .and(path("/hook"))
1359 .respond_with(ResponseTemplate::new(200))
1360 .mount(&server)
1361 .await;
1362
1363 let (_keep, db_path) = fresh_db();
1364 let id = {
1366 let conn = Connection::open(&db_path).unwrap();
1367 let url = format!("{}/hook", server.uri());
1368 insert(
1369 &conn,
1370 &NewSubscription {
1371 url: &url,
1372 events: "*",
1373 secret: Some("mysecret"),
1374 namespace_filter: None,
1375 agent_filter: None,
1376 created_by: None,
1377 },
1378 )
1379 .unwrap()
1380 };
1381
1382 {
1386 let conn = Connection::open(&db_path).unwrap();
1387 dispatch_event(&conn, "memory_store", "m1", "ns", None, &db_path);
1388 }
1389
1390 let path_for_poll = db_path.clone();
1391 let id_for_poll = id.clone();
1392 let dc = tokio::task::spawn_blocking(move || {
1393 for _ in 0..50 {
1394 let conn = Connection::open(&path_for_poll).unwrap();
1395 let dc: i64 = conn
1396 .query_row(
1397 "SELECT dispatch_count FROM subscriptions WHERE id = ?1",
1398 params![id_for_poll],
1399 |r| r.get(0),
1400 )
1401 .unwrap();
1402 if dc > 0 {
1403 return dc;
1404 }
1405 std::thread::sleep(std::time::Duration::from_millis(100));
1406 }
1407 0
1408 })
1409 .await
1410 .unwrap();
1411 assert_eq!(dc, 1, "successful dispatch must increment dispatch_count");
1412 }
1413
1414 #[tokio::test(flavor = "multi_thread")]
1415 async fn dispatch_event_e2e_increments_failure_count_on_5xx() {
1416 use wiremock::matchers::{method, path};
1417 use wiremock::{Mock, MockServer, ResponseTemplate};
1418 let server = MockServer::start().await;
1419 Mock::given(method("POST"))
1420 .and(path("/hook"))
1421 .respond_with(ResponseTemplate::new(500))
1422 .mount(&server)
1423 .await;
1424
1425 let (_keep, db_path) = fresh_db();
1426 let id = {
1427 let conn = Connection::open(&db_path).unwrap();
1428 let url = format!("{}/hook", server.uri());
1429 insert(
1430 &conn,
1431 &NewSubscription {
1432 url: &url,
1433 events: "*",
1434 secret: None,
1435 namespace_filter: None,
1436 agent_filter: None,
1437 created_by: None,
1438 },
1439 )
1440 .unwrap()
1441 };
1442
1443 {
1444 let conn = Connection::open(&db_path).unwrap();
1445 dispatch_event(&conn, "memory_store", "m2", "ns", None, &db_path);
1446 }
1447
1448 let path_for_poll = db_path.clone();
1449 let id_for_poll = id.clone();
1450 let (dc, fc) = tokio::task::spawn_blocking(move || {
1451 for _ in 0..50 {
1452 let conn = Connection::open(&path_for_poll).unwrap();
1453 let row: (i64, i64) = conn
1454 .query_row(
1455 "SELECT dispatch_count, failure_count FROM subscriptions WHERE id = ?1",
1456 params![id_for_poll],
1457 |r| Ok((r.get(0)?, r.get(1)?)),
1458 )
1459 .unwrap();
1460 if row.0 > 0 {
1461 return row;
1462 }
1463 std::thread::sleep(std::time::Duration::from_millis(100));
1464 }
1465 (0, 0)
1466 })
1467 .await
1468 .unwrap();
1469 assert_eq!(dc, 1, "5xx still increments dispatch_count");
1470 assert_eq!(fc, 1, "5xx must increment failure_count");
1471 }
1472
1473 #[tokio::test(flavor = "multi_thread")]
1474 async fn dispatch_event_e2e_signature_present_when_secret_set() {
1475 use wiremock::matchers::{header_exists, method, path};
1476 use wiremock::{Mock, MockServer, ResponseTemplate};
1477 let server = MockServer::start().await;
1478 Mock::given(method("POST"))
1479 .and(path("/hook"))
1480 .and(header_exists("x-ai-memory-signature"))
1481 .and(header_exists("x-ai-memory-timestamp"))
1482 .respond_with(ResponseTemplate::new(200))
1483 .expect(1)
1484 .mount(&server)
1485 .await;
1486
1487 let (_keep, db_path) = fresh_db();
1488 let _id = {
1489 let conn = Connection::open(&db_path).unwrap();
1490 let url = format!("{}/hook", server.uri());
1491 insert(
1492 &conn,
1493 &NewSubscription {
1494 url: &url,
1495 events: "*",
1496 secret: Some("the-secret"),
1497 namespace_filter: None,
1498 agent_filter: None,
1499 created_by: None,
1500 },
1501 )
1502 .unwrap()
1503 };
1504
1505 {
1506 let conn = Connection::open(&db_path).unwrap();
1507 dispatch_event(&conn, "memory_store", "m3", "ns", None, &db_path);
1508 }
1509
1510 let server_ref = &server;
1514 for _ in 0..50 {
1515 let received = server_ref.received_requests().await.unwrap_or_default();
1516 if !received.is_empty() {
1517 let req = &received[0];
1518 assert!(
1519 req.headers.get("x-ai-memory-signature").is_some(),
1520 "signature header must be present when secret set"
1521 );
1522 return;
1523 }
1524 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1525 }
1526 panic!("dispatch thread never reached the mock server");
1527 }
1528}
1529
1530#[cfg(test)]
1533mod hex {
1534 pub fn encode_fallback(bytes: &[u8]) -> String {
1535 bytes.iter().map(|b| format!("{b:02x}")).collect()
1536 }
1537}
1538
1539#[test]
1540fn webhook_signing_with_unicode_payload() {
1541 let payload = serde_json::json!({
1543 "event": "memory_store",
1544 "memory_id": "m1",
1545 "namespace": "café",
1546 "agent_id": null,
1547 "delivered_at": "2026-01-01T00:00:00Z"
1548 });
1549 let body = serde_json::to_string(&payload).unwrap();
1550 let key_hex = sha256_hex("secret-with-café");
1551 let sig = hmac_sha256_hex(&key_hex, &body);
1552 assert!(!sig.is_empty());
1554 assert_eq!(sig.len(), 64); }
1556
1557#[test]
1558fn webhook_retries_on_5xx_response() {
1559 let status_2xx = true; let status_5xx = false; assert_ne!(status_2xx, status_5xx);
1565}
1566
1567#[test]
1568fn webhook_does_not_retry_on_4xx_response() {
1569 let status_4xx = false;
1573 let status_success = true;
1574 assert_ne!(status_4xx, status_success);
1575}
1576
1577#[test]
1578fn namespace_pattern_matches_glob_correctly() {
1579 assert!(matches_filters(
1581 "*",
1582 Some("app"),
1583 None,
1584 "memory_store",
1585 "app",
1586 None
1587 ));
1588 assert!(!matches_filters(
1589 "*",
1590 Some("app"),
1591 None,
1592 "memory_store",
1593 "other",
1594 None
1595 ));
1596 assert!(matches_filters(
1598 "*",
1599 Some(""),
1600 None,
1601 "memory_store",
1602 "any_ns",
1603 None
1604 ));
1605}