use std::sync::{Arc, Weak};
use solo_core::InvalidateEvent;
use solo_storage::TenantHandle;
use tokio::sync::broadcast;
use crate::mcp_session::{McpEventKind, SessionState};
pub const MCP_NOTIFICATION_MESSAGE_METHOD: &str = "notifications/message";
pub const MCP_NOTIFICATION_MESSAGE_LEVEL: &str = "info";
pub const MCP_NOTIFICATION_MESSAGE_LOGGER: &str = "solo";
pub const MCP_NOTIFICATION_DATA_MEMORIES_UPDATED: &str = "memories_updated";
pub const MCP_NOTIFICATION_DATA_DOCUMENTS_UPDATED: &str = "documents_updated";
pub const MCP_NOTIFICATION_DATA_CONSOLIDATION_UPDATED: &str = "consolidation_updated";
pub const MCP_NOTIFICATION_DATA_GRAPH_UPDATED: &str = "graph_updated";
pub const MCP_NOTIFICATION_DATA_TENANT_UPDATED: &str = "tenant_updated";
pub const MCP_NOTIFICATION_DATA_MEMORY_UPDATED: &str = "memory_updated";
pub fn map_invalidate_to_message(event: &InvalidateEvent) -> serde_json::Value {
let data_kind = match event.kind.as_str() {
"episode" => MCP_NOTIFICATION_DATA_MEMORIES_UPDATED,
"document" | "chunk" => MCP_NOTIFICATION_DATA_DOCUMENTS_UPDATED,
"cluster" => MCP_NOTIFICATION_DATA_CONSOLIDATION_UPDATED,
"triple" => MCP_NOTIFICATION_DATA_GRAPH_UPDATED,
"tenant" => MCP_NOTIFICATION_DATA_TENANT_UPDATED,
_ => MCP_NOTIFICATION_DATA_MEMORY_UPDATED,
};
serde_json::json!({
"jsonrpc": "2.0",
"method": MCP_NOTIFICATION_MESSAGE_METHOD,
"params": {
"level": MCP_NOTIFICATION_MESSAGE_LEVEL,
"logger": MCP_NOTIFICATION_MESSAGE_LOGGER,
"data": data_kind,
"details": {
"reason": event.reason,
"tenant_id": event.tenant_id,
"ts_ms": event.ts_ms,
"kind": event.kind,
}
}
})
}
pub fn spawn_invalidate_bridge(
tenant: Arc<TenantHandle>,
session: Arc<SessionState>,
) -> tokio::task::JoinHandle<()> {
let rx = tenant.invalidate_sender().subscribe();
let weak = Arc::downgrade(&session);
drop(session);
let session_id_for_log = match weak.upgrade() {
Some(s) => format!("{:p}", Arc::as_ptr(&s)),
None => "<dropped>".to_string(),
};
let task_session_id = session_id_for_log.clone();
tokio::spawn(async move {
run_invalidate_bridge(rx, weak, task_session_id).await;
})
}
async fn run_invalidate_bridge(
mut rx: broadcast::Receiver<InvalidateEvent>,
weak: Weak<SessionState>,
session_id_for_log: String,
) {
loop {
match rx.recv().await {
Ok(event) => {
let Some(state) = weak.upgrade() else {
tracing::debug!(
session = %session_id_for_log,
"mcp invalidate bridge: session dropped; exiting"
);
return;
};
let payload = map_invalidate_to_message(&event);
state.publish_event(McpEventKind::Message, payload);
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(
lagged = n,
session = %session_id_for_log,
"mcp invalidate bridge subscriber lagged; client will \
resync on the next real invalidate"
);
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!(
session = %session_id_for_log,
"mcp invalidate bridge: tenant invalidate channel closed; exiting"
);
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use solo_core::TenantId;
fn fake_event(kind: &str) -> InvalidateEvent {
InvalidateEvent {
reason: "memory.remember".to_string(),
tenant_id: "default".to_string(),
ts_ms: 1_715_625_600_000,
kind: kind.to_string(),
}
}
#[test]
fn map_invalidate_to_message_routes_each_kind_to_data_discriminator() {
let cases = [
("episode", MCP_NOTIFICATION_DATA_MEMORIES_UPDATED),
("document", MCP_NOTIFICATION_DATA_DOCUMENTS_UPDATED),
("chunk", MCP_NOTIFICATION_DATA_DOCUMENTS_UPDATED),
("cluster", MCP_NOTIFICATION_DATA_CONSOLIDATION_UPDATED),
("triple", MCP_NOTIFICATION_DATA_GRAPH_UPDATED),
("tenant", MCP_NOTIFICATION_DATA_TENANT_UPDATED),
(
"hypothetical_new_kind",
MCP_NOTIFICATION_DATA_MEMORY_UPDATED,
),
];
for (kind, expected_data) in cases {
let msg = map_invalidate_to_message(&fake_event(kind));
assert_eq!(
msg["params"]["data"].as_str(),
Some(expected_data),
"kind={kind} must map to data={expected_data}",
);
}
}
#[test]
fn map_invalidate_to_message_uses_jsonrpc_notifications_message_method() {
let msg = map_invalidate_to_message(&fake_event("episode"));
assert_eq!(msg["jsonrpc"].as_str(), Some("2.0"));
assert_eq!(
msg["method"].as_str(),
Some(MCP_NOTIFICATION_MESSAGE_METHOD),
);
let params = &msg["params"];
assert_eq!(
params["level"].as_str(),
Some(MCP_NOTIFICATION_MESSAGE_LEVEL)
);
assert_eq!(
params["logger"].as_str(),
Some(MCP_NOTIFICATION_MESSAGE_LOGGER),
);
assert!(params["data"].is_string(), "data must be present + string");
}
#[test]
fn mcp_message_payload_includes_level_and_data() {
let msg = map_invalidate_to_message(&fake_event("episode"));
let params = &msg["params"];
assert_eq!(
params["level"].as_str(),
Some("info"),
"level must be 'info' for invalidate-bridged messages",
);
assert_eq!(
params["data"].as_str(),
Some(MCP_NOTIFICATION_DATA_MEMORIES_UPDATED),
);
assert_eq!(
params["details"]["reason"].as_str(),
Some("memory.remember"),
);
assert_eq!(params["details"]["kind"].as_str(), Some("episode"),);
assert!(params["details"]["ts_ms"].is_number());
}
#[tokio::test]
async fn run_invalidate_bridge_publishes_mcp_message_to_session() {
let (tx, _) = broadcast::channel::<InvalidateEvent>(16);
let state = Arc::new(SessionState::new(TenantId::default_tenant(), None));
let rx = tx.subscribe();
let weak = Arc::downgrade(&state);
let mut session_rx = state.subscribe_events();
let bridge_handle = tokio::spawn(async move {
run_invalidate_bridge(rx, weak, "test-session".to_string()).await;
});
tx.send(fake_event("document"))
.expect("at least one subscriber");
let received = tokio::time::timeout(std::time::Duration::from_secs(2), session_rx.recv())
.await
.expect("event must arrive within 2s")
.expect("session receiver must receive published event");
assert_eq!(received.event, McpEventKind::Message);
assert_eq!(
received.data["method"].as_str(),
Some(MCP_NOTIFICATION_MESSAGE_METHOD),
);
assert_eq!(
received.data["params"]["data"].as_str(),
Some(MCP_NOTIFICATION_DATA_DOCUMENTS_UPDATED),
);
drop(state);
tx.send(fake_event("episode")).ok();
tokio::time::timeout(std::time::Duration::from_secs(2), bridge_handle)
.await
.expect("bridge must exit when session drops")
.expect("bridge task panic");
}
#[tokio::test]
async fn session_task_exits_when_session_dropped() {
let (tx, _) = broadcast::channel::<InvalidateEvent>(16);
let state = Arc::new(SessionState::new(TenantId::default_tenant(), None));
let rx = tx.subscribe();
let weak = Arc::downgrade(&state);
let bridge = tokio::spawn(async move {
run_invalidate_bridge(rx, weak, "test-session-drop".to_string()).await;
});
drop(state);
tx.send(fake_event("episode")).ok();
tokio::time::timeout(std::time::Duration::from_secs(2), bridge)
.await
.expect("bridge must exit within 2s of session drop")
.expect("bridge task panic");
}
#[tokio::test]
async fn bridge_exits_when_tenant_invalidate_channel_closes() {
let (tx, _) = broadcast::channel::<InvalidateEvent>(16);
let state = Arc::new(SessionState::new(TenantId::default_tenant(), None));
let rx = tx.subscribe();
let weak = Arc::downgrade(&state);
let bridge = tokio::spawn(async move {
run_invalidate_bridge(rx, weak, "test-tenant-drop".to_string()).await;
});
drop(tx);
tokio::time::timeout(std::time::Duration::from_secs(2), bridge)
.await
.expect("bridge must exit within 2s of channel close")
.expect("bridge task panic");
let _id = state.publish_event(McpEventKind::Message, serde_json::json!({"alive": true}));
}
#[tokio::test]
async fn bridge_logs_and_continues_on_lagged_subscriber() {
let (tx, _) = broadcast::channel::<InvalidateEvent>(2);
let state = Arc::new(SessionState::new(TenantId::default_tenant(), None));
let rx = tx.subscribe();
let weak = Arc::downgrade(&state);
let mut session_rx = state.subscribe_events();
let bridge = tokio::spawn(async move {
run_invalidate_bridge(rx, weak, "test-lag".to_string()).await;
});
for _ in 0..5 {
tx.send(fake_event("episode")).ok();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
tx.send(fake_event("triple")).expect("send must succeed");
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), session_rx.recv())
.await
.expect("bridge must still be alive after lag")
.expect("session must receive at least one event");
drop(state);
drop(tx);
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), bridge).await;
}
}