1use std::path::Path;
26use std::path::PathBuf;
27use std::sync::Arc;
28use std::time::Duration;
29
30use anyhow::Result;
31use rmcp::model::{ServerCapabilities, ServerInfo};
32use rmcp::{tool_handler, ServerHandler, ServiceExt};
33use serde::{Deserialize, Serialize};
34use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
35use tokio::net::UnixStream;
36
37use crate::graph::edges::EdgeKind;
38use crate::graph::Graph;
39
40use super::tools::MatiServer;
41use super::types::{MemBootstrapParams, MemGetParams, MemQueryParams, MemSetParams};
42
43#[derive(Debug)]
44pub(crate) enum ProxyDaemonResult {
45 Ok(serde_json::Value),
46 NotRunning,
47 StaleSocket,
48 Unresponsive,
49}
50
51#[tool_handler(router = self.tool_router)]
52impl ServerHandler for MatiServer {
53 fn get_info(&self) -> ServerInfo {
54 ServerInfo::new(
55 ServerCapabilities::builder()
56 .enable_tools()
57 .enable_tool_list_changed()
58 .build(),
59 )
60 .with_instructions(
61 "mati is a persistent engineering knowledge store for the current \
62 codebase. Use mem_get for direct record lookup, mem_query for \
63 search and graph traversal, mem_bootstrap for session context, \
64 and mem_set for writing knowledge records.",
65 )
66 }
67}
68
69pub async fn serve(repo_root: &Path) -> Result<()> {
91 let startup_t0 = std::time::Instant::now();
92
93 let mati_root: PathBuf = dirs::home_dir()
96 .map(|h| h.join(".mati").join(crate::store::derive_slug(repo_root)))
97 .ok_or_else(|| anyhow::anyhow!("cannot resolve home directory for mati_root"))?;
98
99 super::metadata::record_lifecycle_event(&mati_root, "startup", "phase=ensure_daemon");
100
101 if !super::daemon_lifecycle::ensure_daemon(&mati_root).await {
105 super::metadata::record_lifecycle_event(
106 &mati_root,
107 "serve_failed",
108 "daemon unreachable after auto-spawn",
109 );
110 anyhow::bail!(
111 "mati serve: daemon unreachable. \
112 Run `mati daemon start` manually and check the lifecycle.log."
113 );
114 }
115
116 super::metadata::record_lifecycle_event(
117 &mati_root,
118 "serve_start",
119 &format!("pid={} owner=proxy", std::process::id()),
120 );
121
122 super::metrics::init();
125
126 super::metadata::record_lifecycle_event(
127 &mati_root,
128 "startup",
129 &format!(
130 "phase=ready elapsed_ms={}",
131 startup_t0.elapsed().as_millis()
132 ),
133 );
134
135 let transport = rmcp::transport::io::stdio();
137 let service = MatiServer::with_socket_root(mati_root.clone())
138 .serve(transport)
139 .await
140 .map_err(|e| anyhow::anyhow!("MCP proxy initialization failed: {e}"))
141 .inspect_err(|e| {
142 super::metadata::record_lifecycle_event(
143 &mati_root,
144 "serve_failed",
145 &format!("proxy init: {e:#}"),
146 )
147 })?;
148
149 let shutdown_reason: &'static str = match service.waiting().await {
150 Ok(_) => "client_disconnect",
151 Err(e) => {
152 super::metadata::record_lifecycle_event(
153 &mati_root,
154 "serve_failed",
155 &format!("proxy waiting: {e}"),
156 );
157 "mcp_waiting_error"
158 }
159 };
160 super::metadata::record_lifecycle_event(
161 &mati_root,
162 "serve_shutdown",
163 &format!("reason={shutdown_reason}"),
164 );
165 Ok(())
166}
167
168pub(crate) async fn proxy_daemon_result(
169 root: &Path,
170 cmd: &str,
171 args: serde_json::Value,
172) -> ProxyDaemonResult {
173 let result = proxy_daemon_result_no_spawn(root, cmd, &args).await;
192
193 if matches!(
204 &result,
205 ProxyDaemonResult::NotRunning | ProxyDaemonResult::StaleSocket
206 ) && super::daemon_lifecycle::ensure_daemon(root).await
207 {
208 match proxy_daemon_result_once(root, cmd, &args).await {
209 AttemptOutcome::Final(r) | AttemptOutcome::Retryable(r) => return r,
210 }
211 }
212
213 result
214}
215
216pub(crate) async fn proxy_daemon_result_no_spawn(
220 root: &Path,
221 cmd: &str,
222 args: &serde_json::Value,
223) -> ProxyDaemonResult {
224 match proxy_daemon_result_once(root, cmd, args).await {
225 AttemptOutcome::Final(result) => result,
226 AttemptOutcome::Retryable(_) => {
227 tokio::time::sleep(Duration::from_millis(100)).await;
230 match proxy_daemon_result_once(root, cmd, args).await {
231 AttemptOutcome::Final(result) | AttemptOutcome::Retryable(result) => result,
232 }
233 }
234 }
235}
236
237enum AttemptOutcome {
244 Final(ProxyDaemonResult),
245 Retryable(ProxyDaemonResult),
246}
247
248async fn proxy_daemon_result_once(
249 root: &Path,
250 cmd: &str,
251 args: &serde_json::Value,
252) -> AttemptOutcome {
253 let v2_cmd = super::protocol::v1_to_v2_command(cmd, args);
257 proxy_daemon_send_v2(root, v2_cmd).await
258}
259
260pub(crate) async fn proxy_daemon_v2(
270 root: &Path,
271 cmd: super::protocol::Command,
272) -> ProxyDaemonResult {
273 let v2_cmd = match serde_json::to_value(&cmd) {
275 Ok(v) => v,
276 Err(_) => return ProxyDaemonResult::Unresponsive,
277 };
278
279 let result = match proxy_daemon_send_v2(root, v2_cmd.clone()).await {
280 AttemptOutcome::Final(result) => result,
281 AttemptOutcome::Retryable(_) => {
282 tokio::time::sleep(Duration::from_millis(100)).await;
283 match proxy_daemon_send_v2(root, v2_cmd.clone()).await {
284 AttemptOutcome::Final(result) | AttemptOutcome::Retryable(result) => result,
285 }
286 }
287 };
288
289 if matches!(
293 &result,
294 ProxyDaemonResult::NotRunning | ProxyDaemonResult::StaleSocket
295 ) && super::daemon_lifecycle::ensure_daemon(root).await
296 {
297 match proxy_daemon_send_v2(root, v2_cmd).await {
298 AttemptOutcome::Final(r) | AttemptOutcome::Retryable(r) => return r,
299 }
300 }
301
302 result
303}
304
305async fn proxy_daemon_send_v2(root: &Path, v2_cmd: serde_json::Value) -> AttemptOutcome {
309 let sock_path = root.join("mati.sock");
310
311 if sock_path.as_os_str().len() > UNIX_SOCK_PATH_MAX {
312 tracing::warn!(
313 path = %sock_path.display(),
314 "mcp proxy: socket path exceeds Unix limit"
315 );
316 return AttemptOutcome::Final(ProxyDaemonResult::NotRunning);
318 }
319
320 if !sock_path.exists() {
321 return AttemptOutcome::Retryable(ProxyDaemonResult::NotRunning);
323 }
324
325 let stream = match UnixStream::connect(&sock_path).await {
326 Ok(s) => s,
327 Err(e) => {
328 let is_refused = e.kind() == std::io::ErrorKind::ConnectionRefused;
329 if is_refused {
330 use super::metadata::{self as meta, StaleCheckResult};
333 match meta::check_and_cleanup_stale(root) {
334 StaleCheckResult::StaleRemoved | StaleCheckResult::Clean => {
335 return AttemptOutcome::Retryable(ProxyDaemonResult::StaleSocket);
336 }
337 StaleCheckResult::OrphanSocket => {
338 let _ = std::fs::remove_file(&sock_path);
340 return AttemptOutcome::Retryable(ProxyDaemonResult::StaleSocket);
341 }
342 StaleCheckResult::LiveDaemon { .. } => {
343 return AttemptOutcome::Retryable(ProxyDaemonResult::Unresponsive);
345 }
346 }
347 }
348 return AttemptOutcome::Retryable(ProxyDaemonResult::NotRunning);
349 }
350 };
351
352 let daemon_session = super::metadata::read_metadata(root)
355 .map(|m| m.session)
356 .unwrap_or_else(uuid::Uuid::nil);
357 let request = serde_json::json!({
358 "v": super::protocol::PROTOCOL_VERSION,
359 "id": uuid::Uuid::new_v4(),
360 "session": daemon_session,
361 "cmd": v2_cmd,
362 });
363
364 let (reader, mut writer) = stream.into_split();
365 let mut bytes = match serde_json::to_vec(&request) {
366 Ok(b) => b,
367 Err(_) => return AttemptOutcome::Final(ProxyDaemonResult::Unresponsive),
368 };
369 bytes.push(b'\n');
370
371 if writer.write_all(&bytes).await.is_err() {
372 return AttemptOutcome::Retryable(ProxyDaemonResult::Unresponsive);
373 }
374 if writer.shutdown().await.is_err() {
375 return AttemptOutcome::Retryable(ProxyDaemonResult::Unresponsive);
376 }
377
378 let mut buf_reader = BufReader::new(reader);
379 let mut line = String::new();
380 match tokio::time::timeout(Duration::from_secs(2), buf_reader.read_line(&mut line)).await {
381 Ok(Ok(n)) if n > 0 => {}
382 _ => return AttemptOutcome::Retryable(ProxyDaemonResult::Unresponsive),
383 }
384
385 let resp: serde_json::Value = match serde_json::from_str(line.trim()) {
387 Ok(v) => v,
388 Err(_) => return AttemptOutcome::Final(ProxyDaemonResult::Unresponsive),
389 };
390
391 match resp.get("status").and_then(|s| s.as_str()) {
392 Some("ok") => {
393 let data = resp.get("data").cloned().unwrap_or(serde_json::Value::Null);
394 AttemptOutcome::Final(ProxyDaemonResult::Ok(
395 serde_json::json!({"ok": true, "v": 2, "data": data}),
396 ))
397 }
398 Some("err") => {
399 let code = resp
400 .get("code")
401 .and_then(|c| c.as_str())
402 .unwrap_or("internal");
403 let message = resp
404 .get("message")
405 .and_then(|m| m.as_str())
406 .unwrap_or("unknown error");
407 let envelope = serde_json::json!({
408 "ok": false, "v": 2, "error": message, "code": code
409 });
410 if code == "session_mismatch" {
415 tracing::debug!(
416 "mcp proxy: session mismatch — daemon may have restarted, will retry"
417 );
418 AttemptOutcome::Retryable(ProxyDaemonResult::Ok(envelope))
419 } else {
420 AttemptOutcome::Final(ProxyDaemonResult::Ok(envelope))
421 }
422 }
423 _ => AttemptOutcome::Retryable(ProxyDaemonResult::Unresponsive),
424 }
425}
426
427pub const UNIX_SOCK_PATH_MAX: usize = 104;
437
438const READ_TIMEOUT: Duration = Duration::from_secs(3);
440
441pub const MAX_CONCURRENT_CONNECTIONS: usize = 64;
448
449pub const AUTO_DRAIN_TIMEOUT: Duration = Duration::from_secs(10);
456
457#[derive(Default)]
468pub struct Shutdown {
469 flag: std::sync::atomic::AtomicBool,
470 notify: tokio::sync::Notify,
471}
472
473impl Shutdown {
474 pub fn new() -> Self {
475 Self::default()
476 }
477
478 pub fn signal(&self) {
480 self.flag.store(true, std::sync::atomic::Ordering::SeqCst);
481 self.notify.notify_waiters();
482 }
483
484 pub fn is_set(&self) -> bool {
485 self.flag.load(std::sync::atomic::Ordering::SeqCst)
493 }
494
495 pub async fn wait(&self) {
498 let notified = self.notify.notified();
499 tokio::pin!(notified);
500 notified.as_mut().enable();
503 if self.is_set() {
504 return;
505 }
506 notified.await;
507 }
508}
509
510const PROTOCOL_VERSION: u32 = 1;
512
513#[derive(Debug, Deserialize)]
514pub(crate) struct SocketRequest {
515 pub cmd: String,
516 #[allow(dead_code)] #[serde(default, rename = "v")]
518 pub version: Option<u32>,
519 #[serde(default)]
520 pub args: serde_json::Value,
521}
522
523#[derive(Debug, Serialize)]
524pub(crate) struct SocketResponse {
525 pub(crate) ok: bool,
526 #[serde(rename = "v")]
527 version: u32,
528 #[serde(skip_serializing_if = "Option::is_none")]
529 pub(crate) data: Option<serde_json::Value>,
530 #[serde(skip_serializing_if = "Option::is_none")]
531 pub(crate) error: Option<String>,
532}
533
534impl SocketResponse {
535 pub(crate) fn ok(data: serde_json::Value) -> Self {
536 Self {
537 ok: true,
538 version: PROTOCOL_VERSION,
539 data: Some(data),
540 error: None,
541 }
542 }
543 pub(crate) fn err(msg: impl Into<String>) -> Self {
544 Self {
545 ok: false,
546 version: PROTOCOL_VERSION,
547 data: None,
548 error: Some(msg.into()),
549 }
550 }
551}
552
553pub async fn socket_handle_connection(
554 graph: Arc<tokio::sync::RwLock<Graph>>,
555 repo_root: &Path,
556 stream: UnixStream,
557 peer: super::metadata::PeerContext,
558 daemon_session: uuid::Uuid,
559) -> Result<()> {
560 use super::protocol::MAX_FRAME_SIZE;
561 use tokio::io::AsyncReadExt;
562
563 let (reader, mut writer) = stream.into_split();
564 let mut buf = String::new();
565
566 let limited = reader.take(MAX_FRAME_SIZE as u64 + 1);
571 let mut buf_reader = BufReader::new(limited);
572 match tokio::time::timeout(READ_TIMEOUT, buf_reader.read_line(&mut buf)).await {
573 Ok(Ok(0)) => return Ok(()),
574 Ok(Ok(_)) => {}
575 Ok(Err(e)) => anyhow::bail!("read error: {e}"),
576 Err(_) => anyhow::bail!("read timeout"),
577 }
578
579 if buf.len() > MAX_FRAME_SIZE {
580 let resp = super::protocol::Response::err(
581 uuid::Uuid::nil(),
582 super::protocol::ErrorCode::FrameTooLarge,
583 format!("request exceeds {MAX_FRAME_SIZE} byte limit"),
584 );
585 let json = serde_json::to_string(&resp)?;
586 writer.write_all(json.as_bytes()).await?;
587 writer.write_all(b"\n").await?;
588 writer.flush().await?;
589 return Ok(());
590 }
591
592 let trimmed = buf.trim();
593
594 let v2_req = match serde_json::from_str::<super::protocol::Request>(trimmed) {
599 Ok(r) => r,
600 Err(e) => {
601 let resp = super::protocol::Response::err(
604 uuid::Uuid::nil(),
605 super::protocol::ErrorCode::MalformedRequest,
606 format!("invalid v2 request: {e}"),
607 );
608 let json = serde_json::to_string(&resp)?;
609 writer.write_all(json.as_bytes()).await?;
610 writer.write_all(b"\n").await?;
611 writer.flush().await?;
612 return Ok(());
613 }
614 };
615
616 let ctx = super::dispatch_v2::RequestContext {
617 peer,
618 daemon_session,
619 repo_root: repo_root.to_path_buf(),
620 };
621 let resp = super::dispatch_v2::dispatch_v2(&graph, &ctx, v2_req).await;
622 let json = serde_json::to_string(&resp)?;
623 writer.write_all(json.as_bytes()).await?;
624 writer.write_all(b"\n").await?;
625 writer.flush().await?;
626 Ok(())
627}
628
629fn build_v1_dispatch_ctx(repo_root: &Path) -> super::dispatch_v2::RequestContext {
636 super::dispatch_v2::RequestContext {
637 peer: super::metadata::PeerContext {
638 uid: super::metadata::current_euid(),
639 pid: Some(std::process::id()),
640 },
641 daemon_session: uuid::Uuid::nil(),
642 repo_root: repo_root.to_path_buf(),
643 }
644}
645
646pub(crate) async fn socket_dispatch(
647 graph: &Arc<tokio::sync::RwLock<Graph>>,
648 repo_root: &Path,
649 req: &SocketRequest,
650) -> SocketResponse {
651 use crate::store::session as sess;
652
653 match req.cmd.as_str() {
654 "ping" => SocketResponse::ok(serde_json::Value::String("pong".into())),
655
656 "metrics" => match super::metrics::snapshot() {
661 Some(snap) => match serde_json::to_value(&snap) {
662 Ok(v) => SocketResponse::ok(v),
663 Err(e) => SocketResponse::err(format!("metrics serialize: {e}")),
664 },
665 None => SocketResponse::ok(serde_json::Value::Null),
666 },
667
668 "mem_get" => {
677 let params = match serde_json::from_value::<MemGetParams>(req.args.clone()) {
678 Ok(p) => p,
679 Err(e) => return SocketResponse::err(format!("invalid mem_get args: {e}")),
680 };
681 let input = super::protocol::MemGetInput { key: params.key };
682 let ctx = build_v1_dispatch_ctx(repo_root);
683 let g = graph.read().await;
684 match super::handlers::handle_mem_get(
685 g.store(),
686 graph,
687 &ctx,
688 uuid::Uuid::new_v4(),
689 &input,
690 )
691 .await
692 {
693 Ok(v) => SocketResponse::ok(serde_json::Value::String(
694 serde_json::to_string_pretty(&v).unwrap_or_else(|_| "{}".into()),
695 )),
696 Err((_code, msg)) => SocketResponse::err(msg),
697 }
698 }
699
700 "mem_query" => {
701 let params = match serde_json::from_value::<MemQueryParams>(req.args.clone()) {
702 Ok(p) => p,
703 Err(e) => return SocketResponse::err(format!("invalid mem_query args: {e}")),
704 };
705 let mode = match params.mode.as_str() {
706 "text" => super::protocol::QueryMode::Text,
707 "tag" => super::protocol::QueryMode::Tag,
708 "graph" => super::protocol::QueryMode::Graph,
709 "semantic" => super::protocol::QueryMode::Semantic,
710 other => {
711 return SocketResponse::err(format!(
712 "unknown mode: {other}. Valid modes: text, tag, graph, semantic"
713 ));
714 }
715 };
716 let input = super::protocol::MemQueryInput {
717 query: params.query,
718 mode,
719 limit: params.limit as u32,
720 };
721 let g = graph.read().await;
722 match super::handlers::handle_mem_query(g.store(), &g, &input).await {
723 Ok(v) => SocketResponse::ok(serde_json::Value::String(
724 serde_json::to_string_pretty(&v).unwrap_or_else(|_| "{}".into()),
725 )),
726 Err((_code, msg)) => SocketResponse::err(msg),
727 }
728 }
729
730 "mem_bootstrap" => {
731 let params = match serde_json::from_value::<MemBootstrapParams>(req.args.clone()) {
732 Ok(p) => p,
733 Err(e) => return SocketResponse::err(format!("invalid mem_bootstrap args: {e}")),
734 };
735 let input = super::protocol::MemBootstrapInput {
736 context_files: params.context_files,
737 };
738 let ctx = build_v1_dispatch_ctx(repo_root);
739 let g = graph.read().await;
740 match super::handlers::handle_mem_bootstrap(
741 g.store(),
742 &g,
743 graph,
744 &ctx,
745 uuid::Uuid::new_v4(),
746 &input,
747 )
748 .await
749 {
750 Ok(s) => SocketResponse::ok(serde_json::Value::String(s)),
751 Err((_code, msg)) => SocketResponse::err(msg),
752 }
753 }
754
755 "mem_set" => {
756 let params = match serde_json::from_value::<MemSetParams>(req.args.clone()) {
757 Ok(p) => p,
758 Err(e) => return SocketResponse::err(format!("invalid mem_set args: {e}")),
759 };
760 let ctx = build_v1_dispatch_ctx(repo_root);
761 let response =
762 super::handlers::handle_mem_set(graph, &ctx, uuid::Uuid::new_v4(), ¶ms).await;
763 SocketResponse::ok(serde_json::Value::String(response))
764 }
765
766 "get" => {
770 let key = match req.args.get("key").and_then(|v| v.as_str()) {
771 Some(k) => k,
772 None => return SocketResponse::err("missing args.key"),
773 };
774 let g = graph.read().await;
775 let store = g.store();
776 match store.get(key).await {
777 Ok(Some(record)) => {
778 let confirmed = record
779 .payload_as::<crate::store::GotchaRecord>()
780 .map(|g| g.confirmed)
781 .unwrap_or(false);
782 match serde_json::to_value(&record) {
783 Ok(mut val) => {
784 if let Some(obj) = val.as_object_mut() {
785 obj.insert(
786 "confirmed".to_string(),
787 serde_json::Value::Bool(confirmed),
788 );
789 }
790 SocketResponse::ok(val)
791 }
792 Err(e) => SocketResponse::err(format!("serialize: {e}")),
793 }
794 }
795 Ok(None) => SocketResponse::ok(serde_json::Value::Null),
796 Err(e) => SocketResponse::err(format!("store: {e}")),
797 }
798 }
799
800 "hook_evaluate" => {
804 let file_key = match req.args.get("file_key").and_then(|v| v.as_str()) {
805 Some(k) => k,
806 None => return SocketResponse::err("missing args.file_key"),
807 };
808 let include_recent = req
809 .args
810 .get("include_recent")
811 .and_then(|v| v.as_bool())
812 .unwrap_or(false);
813
814 let g = graph.read().await;
815 let store = g.store();
816
817 let (file_record, store_error) = match store.get(file_key).await {
819 Ok(Some(r)) => (serde_json::to_value(&r).ok(), false),
820 Ok(None) => (None, false),
821 Err(e) => {
822 tracing::warn!("hook_evaluate: store.get({file_key}) failed: {e}");
823 (None, true)
824 }
825 };
826
827 let mut gotcha_records = serde_json::Map::new();
844 let mut gotcha_error = false;
845 let mut linked_keys: std::collections::BTreeSet<String> =
846 std::collections::BTreeSet::new();
847
848 if let Some(ref fr) = file_record {
849 if let Some(keys) = fr
850 .pointer("/payload/gotcha_keys")
851 .and_then(|v| v.as_array())
852 {
853 for gk in keys {
854 if let Some(key_str) = gk.as_str() {
855 linked_keys.insert(key_str.to_string());
856 }
857 }
858 }
859 }
860
861 for nkey in g.neighbors(file_key, &crate::graph::EdgeKind::HasGotcha) {
864 linked_keys.insert(nkey);
865 }
866
867 if linked_keys.is_empty() && file_record.is_some() {
873 let rel_path = file_key.strip_prefix("file:").unwrap_or(file_key);
874 if let Ok(all_gotchas) = store.scan_prefix("gotcha:").await {
875 for r in all_gotchas {
876 if !matches!(r.lifecycle, crate::store::RecordLifecycle::Active) {
877 continue;
878 }
879 if let Some(g) = r.payload_as::<crate::store::GotchaRecord>() {
880 if g.affected_files.iter().any(|af| af == rel_path) {
881 linked_keys.insert(r.key.clone());
882 }
883 }
884 }
885 }
886 }
887
888 for key_str in &linked_keys {
889 match store.get(key_str).await {
890 Ok(Some(grec)) => {
891 if !matches!(grec.lifecycle, crate::store::RecordLifecycle::Active) {
893 continue;
894 }
895 let confirmed = grec
897 .payload_as::<crate::store::GotchaRecord>()
898 .map(|g| g.confirmed)
899 .unwrap_or(false);
900 if let Ok(mut val) = serde_json::to_value(&grec) {
901 if let Some(obj) = val.as_object_mut() {
902 obj.insert(
903 "confirmed".to_string(),
904 serde_json::Value::Bool(confirmed),
905 );
906 }
907 gotcha_records.insert(key_str.clone(), val);
908 }
909 }
910 Ok(None) => {}
911 Err(e) => {
912 tracing::warn!("hook_evaluate: store.get({key_str}) failed: {e}");
913 gotcha_error = true;
914 }
915 }
916 }
917
918 let file_record = if let Some(mut fr) = file_record {
923 if !gotcha_records.is_empty() {
924 if let Some(payload) = fr.pointer_mut("/payload") {
925 if let Some(obj) = payload.as_object_mut() {
926 let keys: Vec<serde_json::Value> = gotcha_records
927 .keys()
928 .map(|k| serde_json::Value::String(k.clone()))
929 .collect();
930 obj.insert("gotcha_keys".to_string(), serde_json::Value::Array(keys));
931 }
932 }
933 }
934 Some(fr)
935 } else {
936 None
937 };
938
939 let consulted = sess::check_consulted(store, file_key)
941 .await
942 .unwrap_or(false);
943 let consulted_recent = if include_recent {
944 sess::check_consulted_recent(store, file_key, 900)
945 .await
946 .unwrap_or(false)
947 } else {
948 false
949 };
950
951 SocketResponse::ok(serde_json::json!({
952 "file_key": file_key,
953 "file_record": file_record,
954 "gotcha_records": gotcha_records,
955 "consulted": consulted,
956 "consulted_recent": consulted_recent,
957 "store_error": store_error,
958 "gotcha_error": gotcha_error,
959 }))
960 }
961
962 "log_hit" => {
963 let key = match req.args.get("key").and_then(|v| v.as_str()) {
964 Some(k) => k,
965 None => return SocketResponse::err("missing args.key"),
966 };
967 let g = graph.read().await;
968 if let Err(e) = sess::log_hit(g.store(), key).await {
969 tracing::warn!("daemon socket log_hit: {e}");
970 }
971 SocketResponse::ok(serde_json::Value::Null)
972 }
973
974 "log_miss" => {
975 let key = match req.args.get("key").and_then(|v| v.as_str()) {
976 Some(k) => k,
977 None => return SocketResponse::err("missing args.key"),
978 };
979 let g = graph.read().await;
980 if let Err(e) = sess::log_miss(g.store(), key).await {
981 tracing::warn!("daemon socket log_miss: {e}");
982 }
983 SocketResponse::ok(serde_json::Value::Null)
984 }
985
986 "log_compliance_miss" => {
987 let key = match req.args.get("key").and_then(|v| v.as_str()) {
988 Some(k) => k,
989 None => return SocketResponse::err("missing args.key"),
990 };
991 let g = graph.read().await;
992 let store = g.store();
993 if let Err(e) = sess::log_compliance_miss(store, key).await {
994 tracing::warn!("daemon socket log_compliance_miss: {e}");
995 }
996 let _ = crate::store::enforcement::record_event(
998 store,
999 crate::store::enforcement::EnforcementEventType::Deny,
1000 crate::store::enforcement::SubjectKind::File,
1001 key.to_string(),
1002 "claude".to_string(),
1003 None,
1004 "gotcha_above_threshold".to_string(),
1005 None,
1006 )
1007 .await;
1008 SocketResponse::ok(serde_json::Value::Null)
1009 }
1010
1011 "log_compliance_hit" => {
1012 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1013 Some(k) => k,
1014 None => return SocketResponse::err("missing args.key"),
1015 };
1016 let g = graph.read().await;
1017 let store = g.store();
1018 if let Err(e) = sess::log_compliance_hit(store, key).await {
1019 tracing::warn!("daemon socket log_compliance_hit: {e}");
1020 }
1021 let _ = crate::store::enforcement::record_event(
1023 store,
1024 crate::store::enforcement::EnforcementEventType::AllowAfterReceipt,
1025 crate::store::enforcement::SubjectKind::File,
1026 key.to_string(),
1027 "claude".to_string(),
1028 None,
1029 "receipt_valid".to_string(),
1030 None,
1031 )
1032 .await;
1033 SocketResponse::ok(serde_json::Value::Null)
1034 }
1035
1036 "log_codex_shell_miss" => {
1037 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1038 Some(k) => k,
1039 None => return SocketResponse::err("missing args.key"),
1040 };
1041 let g = graph.read().await;
1042 if let Err(e) = sess::log_codex_shell_miss(g.store(), key).await {
1043 tracing::warn!("daemon socket log_codex_shell_miss: {e}");
1044 }
1045 SocketResponse::ok(serde_json::Value::Null)
1046 }
1047
1048 "log_bootstrap" => {
1049 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1050 Some(k) => k,
1051 None => return SocketResponse::err("missing args.key"),
1052 };
1053 let g = graph.read().await;
1054 if let Err(e) = sess::log_bootstrap(g.store(), key).await {
1055 tracing::warn!("daemon socket log_bootstrap: {e}");
1056 }
1057 SocketResponse::ok(serde_json::Value::Null)
1058 }
1059
1060 "log_prompt_nudge" => {
1061 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1062 Some(k) => k,
1063 None => return SocketResponse::err("missing args.key"),
1064 };
1065 let g = graph.read().await;
1066 if let Err(e) = sess::log_prompt_nudge(g.store(), key).await {
1067 tracing::warn!("daemon socket log_prompt_nudge: {e}");
1068 }
1069 SocketResponse::ok(serde_json::Value::Null)
1070 }
1071
1072 "session_check_consulted" => {
1073 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1074 Some(k) => k,
1075 None => return SocketResponse::err("missing args.key"),
1076 };
1077 let g = graph.read().await;
1078 match sess::check_consulted(g.store(), key).await {
1079 Ok(found) => SocketResponse::ok(serde_json::Value::Bool(found)),
1080 Err(e) => SocketResponse::err(format!("store: {e}")),
1081 }
1082 }
1083
1084 "session_check_consulted_recent" => {
1085 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1086 Some(k) => k,
1087 None => return SocketResponse::err("missing args.key"),
1088 };
1089 let ttl_secs = req
1090 .args
1091 .get("ttl_secs")
1092 .and_then(|v| v.as_u64())
1093 .unwrap_or(900);
1094 let g = graph.read().await;
1095 match sess::check_consulted_recent(g.store(), key, ttl_secs).await {
1096 Ok(found) => SocketResponse::ok(serde_json::Value::Bool(found)),
1097 Err(e) => SocketResponse::err(format!("store: {e}")),
1098 }
1099 }
1100
1101 "session_flush" => {
1102 let g = graph.read().await;
1103 if let Err(e) = sess::session_flush(g.store()).await {
1104 tracing::warn!("daemon socket session_flush: {e}");
1105 }
1106 SocketResponse::ok(serde_json::Value::Null)
1107 }
1108
1109 "session_harvest" => {
1110 let g = graph.read().await;
1113 if let Err(e) = sess::session_harvest_no_staleness(g.store()).await {
1114 tracing::warn!("daemon socket session_harvest: {e}");
1115 }
1116 SocketResponse::ok(serde_json::Value::Null)
1117 }
1118
1119 "reparse" => {
1120 let path = match req.args.get("path").and_then(|v| v.as_str()) {
1121 Some(p) => p,
1122 None => return SocketResponse::err("missing args.path"),
1123 };
1124 let g = graph.read().await;
1125 if let Err(e) = crate::analysis::reparse::reparse_impl(g.store(), repo_root, path).await
1126 {
1127 tracing::warn!("daemon socket reparse: {e}");
1128 }
1129 SocketResponse::ok(serde_json::Value::Null)
1130 }
1131
1132 "edit_hook" => {
1133 let path = match req.args.get("path").and_then(|v| v.as_str()) {
1134 Some(p) => p,
1135 None => return SocketResponse::err("missing args.path"),
1136 };
1137 let file_key = format!("file:{path}");
1138 let g = graph.read().await;
1139 let store = g.store();
1140 if let Err(e) = sess::log_hit(store, &file_key).await {
1141 tracing::warn!("daemon socket edit_hook: log_hit failed: {e}");
1142 }
1143 if let Err(e) = crate::analysis::reparse::reparse_impl(store, repo_root, path).await {
1144 tracing::warn!("daemon socket edit_hook: reparse failed (non-fatal): {e}");
1145 }
1146
1147 {
1150 use crate::analysis::blast_radius::BlastRadius;
1151 use crate::graph::edges::EdgeKind;
1152
1153 let mut keys_to_update = vec![file_key.clone()];
1154 keys_to_update.extend(g.neighbors_incoming(&file_key, &EdgeKind::Imports));
1157 keys_to_update.extend(g.neighbors(&file_key, &EdgeKind::Imports));
1159
1160 for key in keys_to_update {
1161 let br = BlastRadius::compute(&key, &g);
1162 if let Ok(Some(mut rec)) = store.get(&key).await {
1163 if let Some(mut fr) = rec.payload_as::<crate::store::record::FileRecord>() {
1164 fr.blast_radius = Some(br);
1165 rec.payload = serde_json::to_value(&fr).ok();
1166 let _ = store.put(&key, &rec).await;
1167 }
1168 }
1169 }
1170 }
1171
1172 {
1176 let mut affected_keys = vec![file_key.clone()];
1177 let d1 = g.neighbors_incoming(&file_key, &EdgeKind::Imports);
1178 for d1k in &d1 {
1179 affected_keys.push(d1k.clone());
1180 affected_keys.extend(g.neighbors_incoming(d1k, &EdgeKind::Imports));
1181 }
1182 let mut neighborhood_recs = Vec::new();
1184 for key in &affected_keys {
1185 if let Ok(Some(rec)) = store.get(key).await {
1186 neighborhood_recs.push(rec);
1187 }
1188 }
1189 if let Ok(Some(rec)) = store.get(&file_key).await {
1191 if !neighborhood_recs.iter().any(|r| r.key == file_key) {
1192 neighborhood_recs.push(rec);
1193 }
1194 }
1195 let propagation =
1196 crate::analysis::propagation::compute_propagation(&neighborhood_recs, &g);
1197 for (key, prop) in &propagation {
1198 if let Ok(Some(mut rec)) = store.get(key).await {
1199 if let Some(mut fr) = rec.payload_as::<crate::store::record::FileRecord>() {
1200 fr.propagated_staleness = Some(prop.clone());
1201 rec.payload = serde_json::to_value(&fr).ok();
1202 let _ = store.put(key, &rec).await;
1203 }
1204 }
1205 }
1206 }
1207
1208 SocketResponse::ok(serde_json::Value::Null)
1209 }
1210
1211 "doc_capture" => {
1212 let path = match req.args.get("path").and_then(|v| v.as_str()) {
1213 Some(p) => p,
1214 None => return SocketResponse::err("missing args.path"),
1215 };
1216 let content = req
1217 .args
1218 .get("content")
1219 .and_then(|v| v.as_str())
1220 .unwrap_or("");
1221 let g = graph.read().await;
1222 if let Err(e) = sess::doc_capture(g.store(), path, content).await {
1223 tracing::warn!("daemon socket doc_capture: {e}");
1224 }
1225 SocketResponse::ok(serde_json::Value::Null)
1226 }
1227
1228 "scan_prefix" => {
1229 let prefix = match req.args.get("prefix").and_then(|v| v.as_str()) {
1230 Some(p) => p,
1231 None => return SocketResponse::err("missing args.prefix"),
1232 };
1233 let g = graph.read().await;
1234 match g.store().scan_prefix(prefix).await {
1235 Ok(records) => match serde_json::to_value(&records) {
1236 Ok(val) => SocketResponse::ok(val),
1237 Err(e) => SocketResponse::err(format!("serialize: {e}")),
1238 },
1239 Err(e) => SocketResponse::err(format!("store: {e}")),
1240 }
1241 }
1242
1243 "scan_enforcement_events" => {
1244 let since_seq = req
1245 .args
1246 .get("since_seq")
1247 .and_then(|v| v.as_u64())
1248 .unwrap_or(0);
1249 let until_seq = req
1250 .args
1251 .get("until_seq")
1252 .and_then(|v| v.as_u64())
1253 .unwrap_or(u64::MAX);
1254 let g = graph.read().await;
1255 match crate::store::enforcement::scan_enforcement_events(
1256 g.store(),
1257 since_seq,
1258 until_seq,
1259 )
1260 .await
1261 {
1262 Ok(events) => match serde_json::to_value(&events) {
1263 Ok(val) => SocketResponse::ok(val),
1264 Err(e) => SocketResponse::err(format!("serialize: {e}")),
1265 },
1266 Err(e) => SocketResponse::err(format!("store: {e}")),
1267 }
1268 }
1269
1270 "put" => {
1271 use crate::store::Record;
1272 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1273 Some(k) => k,
1274 None => return SocketResponse::err("missing args.key"),
1275 };
1276 let record: Record = match req
1277 .args
1278 .get("record")
1279 .and_then(|v| serde_json::from_value(v.clone()).ok())
1280 {
1281 Some(r) => r,
1282 None => return SocketResponse::err("put: invalid record"),
1283 };
1284 let g = graph.read().await;
1285 match g.store().put(key, &record).await {
1286 Ok(()) => SocketResponse::ok(serde_json::Value::Null),
1287 Err(e) => SocketResponse::err(format!("store put: {e}")),
1288 }
1289 }
1290
1291 "delete" => {
1292 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1293 Some(k) => k,
1294 None => return SocketResponse::err("missing args.key"),
1295 };
1296 let g = graph.read().await;
1297 match g.store().delete(key).await {
1298 Ok(()) => SocketResponse::ok(serde_json::Value::Null),
1299 Err(e) => SocketResponse::err(format!("delete: {e}")),
1300 }
1301 }
1302
1303 "history" => {
1304 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1305 Some(k) => k,
1306 None => return SocketResponse::err("missing args.key"),
1307 };
1308 let limit = req.args.get("limit").and_then(|v| v.as_u64()).unwrap_or(50) as usize;
1309 let g = graph.read().await;
1310 match g.store().history(key, limit) {
1311 Ok(entries) => match serde_json::to_value(&entries) {
1312 Ok(val) => SocketResponse::ok(val),
1313 Err(e) => SocketResponse::err(format!("serialize: {e}")),
1314 },
1315 Err(e) => SocketResponse::err(format!("history: {e}")),
1316 }
1317 }
1318
1319 "history_since" => {
1320 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1321 Some(k) => k,
1322 None => return SocketResponse::err("missing args.key"),
1323 };
1324 let since_ts = req
1325 .args
1326 .get("since_ts")
1327 .and_then(|v| v.as_u64())
1328 .unwrap_or(0);
1329 let limit = req.args.get("limit").and_then(|v| v.as_u64()).unwrap_or(50) as usize;
1330 let g = graph.read().await;
1331 match g.store().history_since(key, since_ts, limit) {
1332 Ok(entries) => match serde_json::to_value(&entries) {
1333 Ok(val) => SocketResponse::ok(val),
1334 Err(e) => SocketResponse::err(format!("serialize: {e}")),
1335 },
1336 Err(e) => SocketResponse::err(format!("history_since: {e}")),
1337 }
1338 }
1339
1340 "gotcha_write" => {
1341 use crate::store::gotcha_ops::apply_gotcha_write;
1342 use crate::store::Record;
1343
1344 let record: Record = match req
1345 .args
1346 .get("record")
1347 .and_then(|v| serde_json::from_value(v.clone()).ok())
1348 {
1349 Some(r) => r,
1350 None => return SocketResponse::err("missing or invalid args.record"),
1351 };
1352 let new_files: Vec<String> = req
1353 .args
1354 .get("new_files")
1355 .and_then(|v| serde_json::from_value(v.clone()).ok())
1356 .unwrap_or_default();
1357 let old_files: Vec<String> = req
1358 .args
1359 .get("old_files")
1360 .and_then(|v| serde_json::from_value(v.clone()).ok())
1361 .unwrap_or_default();
1362 let is_new = req
1363 .args
1364 .get("is_new")
1365 .and_then(|v| v.as_bool())
1366 .unwrap_or(false);
1367
1368 {
1369 let g = graph.read().await;
1370 match apply_gotcha_write(g.store(), &record, &old_files, &new_files, is_new).await {
1371 Ok(()) => {}
1372 Err(e) => return SocketResponse::err(format!("{e}")),
1373 }
1374 }
1375
1376 let record_key = record.key.clone();
1382 let old_set: std::collections::HashSet<&str> =
1383 old_files.iter().map(String::as_str).collect();
1384 let new_set: std::collections::HashSet<&str> =
1385 new_files.iter().map(String::as_str).collect();
1386 {
1387 let mut g = graph.write().await;
1388 for file_path in new_set.difference(&old_set) {
1389 let file_key = format!("file:{file_path}");
1390 let _ = g
1391 .add_edge(&file_key, EdgeKind::HasGotcha, &record_key)
1392 .await;
1393 }
1394 for file_path in old_set.difference(&new_set) {
1395 let file_key = format!("file:{file_path}");
1396 let _ = g
1397 .remove_edge(&file_key, &EdgeKind::HasGotcha, &record_key)
1398 .await;
1399 }
1400 }
1401
1402 SocketResponse::ok(serde_json::Value::String("written".into()))
1403 }
1404
1405 "gotcha_tombstone" => {
1406 use crate::store::gotcha_ops::apply_gotcha_tombstone;
1407
1408 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1409 Some(k) => k,
1410 None => return SocketResponse::err("missing args.key"),
1411 };
1412 if !key.starts_with("gotcha:") {
1413 return SocketResponse::err("delete action only applies to gotcha: keys");
1414 }
1415 let mut affected_files: Vec<String> = req
1418 .args
1419 .get("affected_files")
1420 .and_then(|v| serde_json::from_value(v.clone()).ok())
1421 .unwrap_or_default();
1422
1423 let g = graph.read().await;
1424 if affected_files.is_empty() {
1425 if let Ok(Some(record)) = g.store().get(key).await {
1426 if let Some(gotcha) = record.payload_as::<crate::store::GotchaRecord>() {
1427 affected_files = gotcha.affected_files;
1428 }
1429 }
1430 }
1431 match apply_gotcha_tombstone(g.store(), key, &affected_files).await {
1432 Ok(()) => SocketResponse::ok(serde_json::Value::String("tombstoned".into())),
1433 Err(e) => SocketResponse::err(format!("{e}")),
1434 }
1435 }
1436
1437 "gotcha_confirm" => {
1438 let key = match req.args.get("key").and_then(|v| v.as_str()) {
1439 Some(k) => k,
1440 None => return SocketResponse::err("missing args.key"),
1441 };
1442
1443 let g = graph.read().await;
1445 let store = g.store();
1446 let mut record = match store.get(key).await {
1447 Ok(Some(r)) => r,
1448 Ok(None) => return SocketResponse::err(format!("record not found: {key}")),
1449 Err(e) => return SocketResponse::err(format!("store get: {e}")),
1450 };
1451
1452 if record.category != crate::store::record::Category::Gotcha {
1453 return SocketResponse::err(format!("{key} is not a gotcha record"));
1454 }
1455
1456 if !matches!(
1457 record.lifecycle,
1458 crate::store::record::RecordLifecycle::Active
1459 ) {
1460 return SocketResponse::err(format!(
1461 "{key} is tombstoned — cannot confirm a deleted record"
1462 ));
1463 }
1464
1465 if let Some(ref mut payload) = record.payload {
1467 if let Some(obj) = payload.as_object_mut() {
1468 if let Some(sev) = obj
1469 .get("severity")
1470 .and_then(|v| v.as_str())
1471 .map(|s| s.to_lowercase())
1472 {
1473 obj.insert("severity".to_string(), serde_json::Value::String(sev));
1474 }
1475 obj.insert("confirmed".to_string(), serde_json::Value::Bool(true));
1476 }
1477 }
1478
1479 record.source = crate::store::record::RecordSource::DeveloperManual;
1480 record.confidence.value = crate::store::record::ConfidenceScore::base_for_source(
1481 &crate::store::record::RecordSource::DeveloperManual,
1482 );
1483 record.confidence.confirmation_count += 1;
1484 record.quality = crate::health::quality::analyze(&record);
1485
1486 let now = std::time::SystemTime::now()
1487 .duration_since(std::time::UNIX_EPOCH)
1488 .unwrap_or_default()
1489 .as_secs();
1490 record.updated_at = now;
1491 record.version.logical_clock += 1;
1492 record.version.wall_clock = now;
1493
1494 let affected_files: Vec<String> = record
1496 .payload_as::<crate::store::record::GotchaRecord>()
1497 .map(|g| g.affected_files)
1498 .unwrap_or_default();
1499
1500 if let Err(e) = store.put(key, &record).await {
1501 return SocketResponse::err(format!("store put: {e}"));
1502 }
1503
1504 for file_path in &affected_files {
1506 let file_key = format!("file:{file_path}");
1507 if let Ok(Some(mut file_record)) = store.get(&file_key).await {
1508 let needs_link = file_record
1509 .payload
1510 .as_ref()
1511 .and_then(|p| p.get("gotcha_keys"))
1512 .and_then(|v| v.as_array())
1513 .map(|arr| !arr.iter().any(|v| v.as_str() == Some(key)))
1514 .unwrap_or(true);
1515 if needs_link {
1516 if let Some(ref mut payload) = file_record.payload {
1517 if let Some(obj) = payload.as_object_mut() {
1518 let arr = obj.entry("gotcha_keys").or_insert(serde_json::json!([]));
1519 if let Some(arr) = arr.as_array_mut() {
1520 arr.push(serde_json::Value::String(key.to_string()));
1521 }
1522 }
1523 }
1524 let _ = store.put(&file_key, &file_record).await;
1525 }
1526 }
1527 }
1528
1529 crate::store::gotcha_ops::propagate_confirmation_to_files(store, &affected_files).await;
1531
1532 let _ = crate::store::enforcement::record_event(
1534 store,
1535 crate::store::enforcement::EnforcementEventType::ControlChanged {
1536 change_kind: crate::store::enforcement::ControlChangeKind::Confirmed,
1537 },
1538 crate::store::enforcement::SubjectKind::Control,
1539 key.to_string(),
1540 "developer".to_string(),
1541 None,
1542 "control_confirmed".to_string(),
1543 None,
1544 )
1545 .await;
1546
1547 SocketResponse::ok(serde_json::json!({"confirmed": true, "key": key}))
1548 }
1549
1550 other => SocketResponse::err(format!("unknown command: {other}")),
1551 }
1552}
1553
1554pub const IDLE_SHUTDOWN_SECS: u64 = 30 * 60; pub const IDLE_CHECK_INTERVAL_SECS: u64 = 5 * 60; #[cfg(test)]
1567mod shutdown_tests {
1568 use super::*;
1569 use std::sync::Arc;
1570 use std::time::Duration;
1571
1572 #[tokio::test]
1573 async fn shutdown_signal_before_wait_returns_immediately() {
1574 let s = Shutdown::new();
1577 s.signal();
1578 tokio::time::timeout(Duration::from_millis(100), s.wait())
1580 .await
1581 .expect("wait must return immediately when already signaled");
1582 assert!(s.is_set());
1583 }
1584
1585 #[tokio::test]
1586 async fn shutdown_wait_then_signal_wakes_waiter() {
1587 let s = Arc::new(Shutdown::new());
1588 let s_clone = Arc::clone(&s);
1589 let waiter = tokio::spawn(async move { s_clone.wait().await });
1590
1591 tokio::time::sleep(Duration::from_millis(20)).await;
1593 assert!(!s.is_set());
1594
1595 s.signal();
1596
1597 tokio::time::timeout(Duration::from_millis(200), waiter)
1598 .await
1599 .expect("waiter must wake within timeout")
1600 .expect("waiter task should not panic");
1601 assert!(s.is_set());
1602 }
1603
1604 #[tokio::test]
1605 async fn shutdown_multiple_concurrent_waiters_all_wake() {
1606 let s = Arc::new(Shutdown::new());
1608 let mut handles = Vec::new();
1609 for _ in 0..16 {
1610 let s = Arc::clone(&s);
1611 handles.push(tokio::spawn(async move { s.wait().await }));
1612 }
1613 tokio::time::sleep(Duration::from_millis(20)).await;
1615
1616 s.signal();
1617
1618 for h in handles {
1619 tokio::time::timeout(Duration::from_millis(200), h)
1620 .await
1621 .expect("each waiter must wake within timeout")
1622 .expect("waiter task should not panic");
1623 }
1624 }
1625
1626 #[tokio::test]
1627 async fn shutdown_signal_is_idempotent() {
1628 let s = Shutdown::new();
1630 s.signal();
1631 s.signal();
1632 s.signal();
1633 tokio::time::timeout(Duration::from_millis(100), s.wait())
1634 .await
1635 .expect("wait must still return on idempotent re-signal");
1636 }
1637
1638 #[tokio::test]
1646 async fn joinset_abort_all_makes_drain_finite() {
1647 let mut set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1648 set.spawn(async {
1650 tokio::time::sleep(Duration::from_secs(60)).await;
1651 });
1652
1653 let primary = tokio::time::timeout(Duration::from_millis(100), async {
1655 while set.join_next().await.is_some() {}
1656 })
1657 .await;
1658 assert!(
1659 primary.is_err(),
1660 "primary drain should time out while task is still sleeping"
1661 );
1662
1663 set.abort_all();
1665 let secondary = tokio::time::timeout(Duration::from_millis(500), async {
1666 while set.join_next().await.is_some() {}
1667 })
1668 .await;
1669 assert!(
1670 secondary.is_ok(),
1671 "drain after abort_all must complete quickly"
1672 );
1673 assert!(set.is_empty(), "JoinSet should be empty after drain");
1674 }
1675
1676 #[tokio::test]
1682 async fn joinset_panics_are_observable_via_try_join_next() {
1683 let mut set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1684 set.spawn(async {
1685 panic!("simulated handler panic");
1686 });
1687
1688 let deadline = std::time::Instant::now() + Duration::from_millis(500);
1692 loop {
1693 if let Some(res) = set.try_join_next() {
1694 let err = res.expect_err("panicked task should yield Err");
1695 assert!(
1696 err.is_panic(),
1697 "JoinError must report is_panic for panicking task; got: {err:?}"
1698 );
1699 return;
1700 }
1701 if std::time::Instant::now() >= deadline {
1702 panic!("try_join_next never reported the panic within 500ms");
1703 }
1704 tokio::time::sleep(Duration::from_millis(10)).await;
1705 }
1706 }
1707
1708 #[tokio::test]
1714 async fn shutdown_no_lost_signal_under_race() {
1715 for trial in 0..50 {
1716 let s = Arc::new(Shutdown::new());
1717 let s_waiter = Arc::clone(&s);
1718 let s_signaler = Arc::clone(&s);
1719
1720 let waiter = tokio::spawn(async move { s_waiter.wait().await });
1721
1722 tokio::task::yield_now().await;
1724
1725 s_signaler.signal();
1727
1728 tokio::time::timeout(Duration::from_millis(500), waiter)
1729 .await
1730 .unwrap_or_else(|_| panic!("trial {trial}: waiter stranded by lost signal"))
1731 .expect("waiter task should not panic");
1732 }
1733 }
1734}
1735
1736#[cfg(test)]
1737mod tests {
1738 use super::*;
1739 use crate::store::record::{
1740 Category, ConfidenceScore, FileRecord, GotchaRecord, Priority, QualityScore, Record,
1741 RecordLifecycle, RecordSource, RecordVersion, StalenessScore,
1742 };
1743 use crate::store::Store;
1744
1745 fn make_gotcha_record(key: &str, files: &[&str]) -> Record {
1746 let gotcha = GotchaRecord {
1747 rule: "test rule".into(),
1748 reason: "test reason".into(),
1749 severity: Priority::High,
1750 affected_files: files.iter().map(|s| s.to_string()).collect(),
1751 ref_url: None,
1752 discovered_session: 1_000_000,
1753 confirmed: true,
1754 };
1755 Record {
1756 key: key.to_string(),
1757 value: "test rule because test reason".into(),
1758 payload: serde_json::to_value(&gotcha).ok(),
1759 category: Category::Gotcha,
1760 priority: Priority::High,
1761 tags: vec![],
1762 created_at: 1_000_000,
1763 updated_at: 1_000_000,
1764 ref_url: None,
1765 staleness: StalenessScore::fresh(),
1766 lifecycle: RecordLifecycle::Active,
1767 version: RecordVersion {
1768 device_id: uuid::Uuid::new_v4(),
1769 logical_clock: 1,
1770 wall_clock: 1_000_000,
1771 },
1772 quality: QualityScore::layer0_default(),
1773 access_count: 0,
1774 last_accessed: 0,
1775 source: RecordSource::DeveloperManual,
1776 confidence: ConfidenceScore::for_new_record(&RecordSource::DeveloperManual),
1777 gap_analysis_score: 0.0,
1778 }
1779 }
1780
1781 fn make_file_record(path: &str) -> Record {
1782 let file = FileRecord {
1783 path: path.to_string(),
1784 purpose: String::new(),
1785 entry_points: vec![],
1786 imports: vec![],
1787 gotcha_keys: vec![],
1788 decision_keys: vec![],
1789 todos: vec![],
1790 unsafe_count: 0,
1791 unwrap_count: 0,
1792 change_frequency: 0,
1793 last_author: None,
1794 is_hotspot: false,
1795 token_cost_estimate: 0,
1796 last_modified_session: 0,
1797 content_hash: None,
1798 line_count: 0,
1799 blast_radius: None,
1800 propagated_staleness: None,
1801 };
1802 Record {
1803 key: format!("file:{path}"),
1804 value: String::new(),
1805 payload: serde_json::to_value(&file).ok(),
1806 category: Category::File,
1807 priority: Priority::Normal,
1808 tags: vec![],
1809 created_at: 1_000_000,
1810 updated_at: 1_000_000,
1811 ref_url: None,
1812 staleness: StalenessScore::fresh(),
1813 lifecycle: RecordLifecycle::Active,
1814 version: RecordVersion {
1815 device_id: uuid::Uuid::new_v4(),
1816 logical_clock: 1,
1817 wall_clock: 1_000_000,
1818 },
1819 quality: QualityScore::layer0_default(),
1820 access_count: 0,
1821 last_accessed: 0,
1822 source: RecordSource::StaticAnalysis,
1823 confidence: ConfidenceScore::for_new_record(&RecordSource::StaticAnalysis),
1824 gap_analysis_score: 0.0,
1825 }
1826 }
1827
1828 fn file_gotcha_keys(record: &Record) -> Vec<String> {
1829 record
1830 .payload
1831 .as_ref()
1832 .and_then(|p| p.get("gotcha_keys"))
1833 .and_then(|v| v.as_array())
1834 .map(|arr| {
1835 arr.iter()
1836 .filter_map(|v| v.as_str().map(String::from))
1837 .collect()
1838 })
1839 .unwrap_or_default()
1840 }
1841
1842 async fn make_test_graph(store: Store) -> Arc<tokio::sync::RwLock<Graph>> {
1847 let graph = Graph::load(store).await.expect("failed to load test graph");
1848 Arc::new(tokio::sync::RwLock::new(graph))
1849 }
1850
1851 async fn dispatch_with_graph(
1852 graph: &Arc<tokio::sync::RwLock<Graph>>,
1853 cmd: &str,
1854 args: serde_json::Value,
1855 ) -> SocketResponse {
1856 let req = SocketRequest {
1857 cmd: cmd.to_string(),
1858 version: Some(PROTOCOL_VERSION),
1859 args,
1860 };
1861 socket_dispatch(graph, Path::new("/tmp/mati-test"), &req).await
1862 }
1863
1864 #[tokio::test]
1867 async fn socket_gotcha_write_adds_keys_to_file_records() {
1868 let dir = tempfile::TempDir::new().unwrap();
1869 let store = Store::open(dir.path()).await.unwrap();
1870 store
1871 .put("file:src/a.rs", &make_file_record("src/a.rs"))
1872 .await
1873 .unwrap();
1874 store
1875 .put("file:src/b.rs", &make_file_record("src/b.rs"))
1876 .await
1877 .unwrap();
1878 let graph = make_test_graph(store).await;
1879
1880 let record = make_gotcha_record("gotcha:socket-test", &["src/a.rs", "src/b.rs"]);
1881 let resp = dispatch_with_graph(&graph, "gotcha_write", serde_json::json!({
1882 "record": record, "new_files": ["src/a.rs", "src/b.rs"], "old_files": [], "is_new": true,
1883 })).await;
1884 assert!(resp.ok, "gotcha_write failed: {:?}", resp.error);
1885
1886 let g = graph.read().await;
1887 let a = g.store().get("file:src/a.rs").await.unwrap().unwrap();
1888 let b = g.store().get("file:src/b.rs").await.unwrap().unwrap();
1889 assert!(file_gotcha_keys(&a).contains(&"gotcha:socket-test".into()));
1890 assert!(file_gotcha_keys(&b).contains(&"gotcha:socket-test".into()));
1891 }
1892
1893 #[tokio::test]
1894 async fn socket_gotcha_write_edit_removes_key_from_old_file() {
1895 let dir = tempfile::TempDir::new().unwrap();
1896 let store = Store::open(dir.path()).await.unwrap();
1897 store
1898 .put("file:src/a.rs", &make_file_record("src/a.rs"))
1899 .await
1900 .unwrap();
1901 store
1902 .put("file:src/b.rs", &make_file_record("src/b.rs"))
1903 .await
1904 .unwrap();
1905 let graph = make_test_graph(store).await;
1906
1907 let record = make_gotcha_record("gotcha:edit-socket", &["src/a.rs"]);
1908 let resp = dispatch_with_graph(
1909 &graph,
1910 "gotcha_write",
1911 serde_json::json!({
1912 "record": record, "new_files": ["src/a.rs"], "old_files": [], "is_new": true,
1913 }),
1914 )
1915 .await;
1916 assert!(resp.ok);
1917
1918 let record2 = make_gotcha_record("gotcha:edit-socket", &["src/b.rs"]);
1919 let resp2 = dispatch_with_graph(&graph, "gotcha_write", serde_json::json!({
1920 "record": record2, "new_files": ["src/b.rs"], "old_files": ["src/a.rs"], "is_new": false,
1921 })).await;
1922 assert!(resp2.ok);
1923
1924 let g = graph.read().await;
1925 let a = g.store().get("file:src/a.rs").await.unwrap().unwrap();
1926 let b = g.store().get("file:src/b.rs").await.unwrap().unwrap();
1927 assert!(!file_gotcha_keys(&a).contains(&"gotcha:edit-socket".into()));
1928 assert!(file_gotcha_keys(&b).contains(&"gotcha:edit-socket".into()));
1929 }
1930
1931 #[tokio::test]
1932 async fn socket_gotcha_tombstone_removes_keys_from_file_records() {
1933 let dir = tempfile::TempDir::new().unwrap();
1934 let store = Store::open(dir.path()).await.unwrap();
1935 store
1936 .put("file:src/a.rs", &make_file_record("src/a.rs"))
1937 .await
1938 .unwrap();
1939 store
1940 .put("file:src/b.rs", &make_file_record("src/b.rs"))
1941 .await
1942 .unwrap();
1943 let graph = make_test_graph(store).await;
1944
1945 let record = make_gotcha_record("gotcha:tomb-socket", &["src/a.rs", "src/b.rs"]);
1946 let resp = dispatch_with_graph(&graph, "gotcha_write", serde_json::json!({
1947 "record": record, "new_files": ["src/a.rs", "src/b.rs"], "old_files": [], "is_new": true,
1948 })).await;
1949 assert!(resp.ok);
1950
1951 let resp2 = dispatch_with_graph(
1952 &graph,
1953 "gotcha_tombstone",
1954 serde_json::json!({
1955 "key": "gotcha:tomb-socket", "affected_files": ["src/a.rs", "src/b.rs"],
1956 }),
1957 )
1958 .await;
1959 assert!(resp2.ok, "gotcha_tombstone failed: {:?}", resp2.error);
1960
1961 let g = graph.read().await;
1962 let rec = g.store().get("gotcha:tomb-socket").await.unwrap().unwrap();
1963 assert!(matches!(rec.lifecycle, RecordLifecycle::Tombstoned { .. }));
1964 let a = g.store().get("file:src/a.rs").await.unwrap().unwrap();
1965 let b = g.store().get("file:src/b.rs").await.unwrap().unwrap();
1966 assert!(file_gotcha_keys(&a).is_empty());
1967 assert!(file_gotcha_keys(&b).is_empty());
1968 }
1969
1970 #[tokio::test]
1971 async fn socket_gotcha_write_rejects_duplicate_key() {
1972 let dir = tempfile::TempDir::new().unwrap();
1973 let store = Store::open(dir.path()).await.unwrap();
1974 let record1 = make_gotcha_record("gotcha:dup-socket", &["src/a.rs"]);
1975 store.put("gotcha:dup-socket", &record1).await.unwrap();
1976 let graph = make_test_graph(store).await;
1977
1978 let record2 = make_gotcha_record("gotcha:dup-socket", &["src/b.rs"]);
1979 let resp = dispatch_with_graph(
1980 &graph,
1981 "gotcha_write",
1982 serde_json::json!({
1983 "record": record2, "new_files": ["src/b.rs"], "old_files": [], "is_new": true,
1984 }),
1985 )
1986 .await;
1987 assert!(!resp.ok, "duplicate key should be rejected");
1988 assert!(resp
1989 .error
1990 .as_deref()
1991 .unwrap_or("")
1992 .contains("already exists"));
1993
1994 let g = graph.read().await;
1995 let original = g.store().get("gotcha:dup-socket").await.unwrap().unwrap();
1996 let payload = original.payload_as::<GotchaRecord>().unwrap();
1997 assert_eq!(payload.affected_files, vec!["src/a.rs"]);
1998 }
1999
2000 #[tokio::test]
2003 async fn oversized_request_returns_frame_too_large_with_response() {
2004 use super::super::protocol::MAX_FRAME_SIZE;
2005 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
2006
2007 let dir = tempfile::TempDir::new().unwrap();
2008 let store = Store::open(dir.path()).await.unwrap();
2009 let graph = make_test_graph(store).await;
2010
2011 let (client, server) = UnixStream::pair().unwrap();
2012 let peer = super::super::metadata::PeerContext {
2013 uid: 501,
2014 pid: None,
2015 };
2016
2017 let oversized = "x".repeat(MAX_FRAME_SIZE + 100);
2019 let payload = format!("{oversized}\n");
2020
2021 let (client_read, client_write) = client.into_split();
2023
2024 let write_handle = tokio::spawn(async move {
2025 let mut w = client_write;
2026 w.write_all(payload.as_bytes()).await.unwrap();
2027 w.shutdown().await.unwrap();
2028 });
2029
2030 let handle_result =
2031 socket_handle_connection(graph, dir.path(), server, peer, uuid::Uuid::nil()).await;
2032 assert!(handle_result.is_ok());
2033
2034 write_handle.await.unwrap();
2035
2036 let mut reader = tokio::io::BufReader::new(client_read);
2038 let mut line = String::new();
2039 reader.read_line(&mut line).await.unwrap();
2040 let resp: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
2041
2042 assert_eq!(resp["status"], "err");
2043 assert_eq!(resp["code"], "frame_too_large");
2044 assert!(
2045 resp["message"]
2046 .as_str()
2047 .unwrap()
2048 .contains(&MAX_FRAME_SIZE.to_string()),
2049 "error message should mention the size limit"
2050 );
2051 }
2052
2053 #[tokio::test]
2054 async fn normal_sized_request_is_not_rejected_by_size_check() {
2055 use super::super::protocol::MAX_FRAME_SIZE;
2056 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
2057
2058 let dir = tempfile::TempDir::new().unwrap();
2059 let store = Store::open(dir.path()).await.unwrap();
2060 let graph = make_test_graph(store).await;
2061
2062 let (client, server) = UnixStream::pair().unwrap();
2063 let peer = super::super::metadata::PeerContext {
2064 uid: 501,
2065 pid: None,
2066 };
2067
2068 let request = serde_json::json!({
2070 "v": 2,
2071 "id": uuid::Uuid::new_v4(),
2072 "session": uuid::Uuid::nil(),
2073 "cmd": { "type": "ping" }
2074 });
2075 let payload = format!("{}\n", serde_json::to_string(&request).unwrap());
2076 assert!(
2077 payload.len() < MAX_FRAME_SIZE,
2078 "test payload should be small"
2079 );
2080
2081 let (client_read, client_write) = client.into_split();
2082
2083 let write_handle = tokio::spawn(async move {
2084 let mut w = client_write;
2085 w.write_all(payload.as_bytes()).await.unwrap();
2086 w.shutdown().await.unwrap();
2087 });
2088
2089 let handle_result =
2090 socket_handle_connection(graph, dir.path(), server, peer, uuid::Uuid::nil()).await;
2091 assert!(handle_result.is_ok());
2092
2093 write_handle.await.unwrap();
2094
2095 let mut reader = tokio::io::BufReader::new(client_read);
2097 let mut line = String::new();
2098 reader.read_line(&mut line).await.unwrap();
2099 let resp: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
2100
2101 assert_eq!(resp["status"], "ok", "ping should succeed, got: {resp}");
2102 }
2103
2104 async fn spawn_canned_responder(
2124 sock_path: std::path::PathBuf,
2125 responses: Vec<serde_json::Value>,
2126 ) -> tokio::task::JoinHandle<()> {
2127 let listener = tokio::net::UnixListener::bind(&sock_path).expect("bind responder socket");
2130 tokio::spawn(async move {
2131 for resp in responses {
2132 let (stream, _) = match listener.accept().await {
2133 Ok(s) => s,
2134 Err(_) => return,
2135 };
2136 let (reader, mut writer) = stream.into_split();
2137 let mut buf_reader = tokio::io::BufReader::new(reader);
2139 let mut line = String::new();
2140 let _ = tokio::io::AsyncBufReadExt::read_line(&mut buf_reader, &mut line).await;
2141 let mut bytes = serde_json::to_vec(&resp).unwrap();
2142 bytes.push(b'\n');
2143 let _ = tokio::io::AsyncWriteExt::write_all(&mut writer, &bytes).await;
2144 let _ = tokio::io::AsyncWriteExt::shutdown(&mut writer).await;
2145 }
2146 })
2147 }
2148
2149 #[tokio::test]
2150 async fn mcp_call_after_daemon_restart_does_not_kill_transport() {
2151 let dir = tempfile::TempDir::new().unwrap();
2157 let root = dir.path().to_path_buf();
2158 let sock_path = root.join("mati.sock");
2159
2160 let session_before = uuid::Uuid::new_v4();
2166 let session_after = uuid::Uuid::new_v4();
2167
2168 let meta_before = super::super::metadata::DaemonMetadata {
2169 pid: std::process::id(),
2170 session: session_before,
2171 owner: super::super::metadata::DaemonOwner::Daemon,
2172 };
2173 super::super::metadata::publish_metadata(&root, &meta_before).unwrap();
2174
2175 let responder_handle = spawn_canned_responder(
2179 sock_path.clone(),
2180 vec![
2181 serde_json::json!({
2182 "v": 2,
2183 "id": uuid::Uuid::new_v4(),
2184 "status": "err",
2185 "code": "session_mismatch",
2186 "message": "session mismatch: re-read daemon metadata and retry",
2187 }),
2188 serde_json::json!({
2189 "v": 2,
2190 "id": uuid::Uuid::new_v4(),
2191 "status": "ok",
2192 "data": "pong",
2193 }),
2194 ],
2195 )
2196 .await;
2197
2198 let root_for_rotate = root.clone();
2201 let rotate_handle = tokio::spawn(async move {
2202 tokio::time::sleep(Duration::from_millis(20)).await;
2205 let meta_after = super::super::metadata::DaemonMetadata {
2206 pid: std::process::id(),
2207 session: session_after,
2208 owner: super::super::metadata::DaemonOwner::Daemon,
2209 };
2210 super::super::metadata::publish_metadata(&root_for_rotate, &meta_after).unwrap();
2211 });
2212
2213 let result = tokio::time::timeout(
2220 Duration::from_secs(5),
2221 super::proxy_daemon_result(&root, "ping", serde_json::json!({})),
2222 )
2223 .await
2224 .expect("proxy_daemon_result should resolve within 5s — retry path appears wedged");
2225
2226 rotate_handle.await.unwrap();
2227 responder_handle.abort();
2231
2232 match result {
2235 super::ProxyDaemonResult::Ok(v) => {
2236 let ok = v.get("ok") == Some(&serde_json::Value::Bool(true));
2237 let code = v.get("code").and_then(|c| c.as_str()).unwrap_or("");
2238 assert!(
2239 ok,
2240 "second attempt should succeed after metadata rotation, \
2241 but caller saw the first attempt's session_mismatch envelope: \
2242 ok={ok} code={code:?} v={v}"
2243 );
2244 }
2245 other => panic!(
2246 "expected Ok(true) after auto-reconnect, got {other:?}; \
2247 the daemon-restart retry path is not engaging"
2248 ),
2249 }
2250 }
2251
2252 #[tokio::test]
2253 async fn mcp_call_session_mismatch_no_retry_target_returns_envelope() {
2254 let dir = tempfile::TempDir::new().unwrap();
2262 let root = dir.path().to_path_buf();
2263 let sock_path = root.join("mati.sock");
2264
2265 let session = uuid::Uuid::new_v4();
2266 let meta = super::super::metadata::DaemonMetadata {
2267 pid: std::process::id(),
2268 session,
2269 owner: super::super::metadata::DaemonOwner::Daemon,
2270 };
2271 super::super::metadata::publish_metadata(&root, &meta).unwrap();
2272
2273 let responder_handle = spawn_canned_responder(
2277 sock_path.clone(),
2278 vec![
2279 serde_json::json!({
2280 "v": 2,
2281 "id": uuid::Uuid::new_v4(),
2282 "status": "err",
2283 "code": "session_mismatch",
2284 "message": "session mismatch (1)",
2285 }),
2286 serde_json::json!({
2287 "v": 2,
2288 "id": uuid::Uuid::new_v4(),
2289 "status": "err",
2290 "code": "session_mismatch",
2291 "message": "session mismatch (2)",
2292 }),
2293 ],
2294 )
2295 .await;
2296
2297 let result = tokio::time::timeout(
2298 Duration::from_secs(5),
2299 super::proxy_daemon_result(&root, "ping", serde_json::json!({})),
2300 )
2301 .await
2302 .expect("proxy_daemon_result must resolve within 5s");
2303 responder_handle.abort();
2304
2305 match result {
2311 super::ProxyDaemonResult::Ok(v) => {
2312 assert_eq!(v.get("ok"), Some(&serde_json::Value::Bool(false)));
2313 assert_eq!(
2314 v.get("code").and_then(|c| c.as_str()),
2315 Some("session_mismatch")
2316 );
2317 }
2318 other => panic!("expected structured Ok envelope, got {other:?}"),
2319 }
2320 }
2321
2322 #[tokio::test]
2338 async fn proxy_daemon_result_handles_mem_get_translation_no_panic() {
2339 let dir = tempfile::TempDir::new().unwrap();
2340 let result = super::proxy_daemon_result(
2344 dir.path(),
2345 "mem_get",
2346 serde_json::json!({ "key": "file:src/main.rs" }),
2347 )
2348 .await;
2349 assert!(
2350 matches!(result, super::ProxyDaemonResult::NotRunning),
2351 "mem_get without daemon must return NotRunning, got {result:?}"
2352 );
2353 }
2354
2355 #[tokio::test]
2356 async fn proxy_daemon_result_handles_mem_bootstrap_translation_no_panic() {
2357 let dir = tempfile::TempDir::new().unwrap();
2358 let result = super::proxy_daemon_result(
2359 dir.path(),
2360 "mem_bootstrap",
2361 serde_json::json!({ "context_files": ["src/lib.rs"] }),
2362 )
2363 .await;
2364 assert!(
2365 matches!(result, super::ProxyDaemonResult::NotRunning),
2366 "mem_bootstrap without daemon must return NotRunning, got {result:?}"
2367 );
2368 }
2369
2370 #[tokio::test]
2371 async fn proxy_daemon_v2_typed_path_handles_mem_set_mutations_no_panic() {
2372 let dir = tempfile::TempDir::new().unwrap();
2380 let cmd = super::super::protocol::Command::GotchaConfirm(
2381 super::super::protocol::GotchaConfirmInput {
2382 key: "gotcha:test".into(),
2383 },
2384 );
2385 let result = super::proxy_daemon_v2(dir.path(), cmd).await;
2386 assert!(
2387 matches!(result, super::ProxyDaemonResult::NotRunning),
2388 "typed proxy_daemon_v2 must return NotRunning when daemon is absent, got {result:?}"
2389 );
2390 }
2391}