use anyhow::Result;
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use tokio::sync::{broadcast, OnceCell, RwLock};
use trusty_common::bm25_client::Bm25Client;
use trusty_common::mcp::initialize_response;
use trusty_common::memory_core::embed::FastEmbedder;
use trusty_common::memory_core::{store::ChatSessionStore, PalaceRegistry};
use trusty_common::ChatProvider;
#[cfg(feature = "axum-server")]
use tracing::info;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DaemonReadiness {
Warming = 0,
Ready = 1,
}
impl DaemonReadiness {
pub fn from_u8(v: u8) -> Self {
if v == 0 {
Self::Warming
} else {
Self::Ready
}
}
}
pub mod activity;
pub mod attribution;
pub mod bm25_supervisor;
pub mod bootstrap;
pub mod dream_scheduler;
pub mod fd_metrics;
#[cfg(feature = "axum-server")]
pub mod chat;
pub mod commands;
pub mod console_metrics;
pub mod discovery;
pub mod foreground;
pub mod hook_emit;
pub mod kg_extract;
pub mod mcp_service;
pub mod messaging;
pub mod openrpc;
pub mod palace_id_derive;
pub mod project_root;
pub mod prompt_facts;
pub mod prompt_log;
pub mod service;
pub mod startup_scan;
pub mod tools;
pub mod transport;
#[cfg(feature = "axum-server")]
pub mod web;
pub use activity::{ActivityEntry, ActivityFilter, ActivityLog, ActivitySource};
pub use attribution::{CreatorInfo, CreatorSource};
pub const HOOK_PROMPT_EXCERPT_CHARS: usize = 80;
pub fn hook_prompt_excerpt(prompt: &str) -> String {
let normalised: String = prompt.split_whitespace().collect::<Vec<_>>().join(" ");
if normalised.chars().count() <= HOOK_PROMPT_EXCERPT_CHARS {
normalised
} else {
let kept: String = normalised
.chars()
.take(HOOK_PROMPT_EXCERPT_CHARS.saturating_sub(1))
.collect();
format!("{kept}…")
}
}
pub use mcp_service::MemoryMcpService;
pub use tools::MemoryMcpServer;
pub fn resolve_palace_registry_dir(data_dir: PathBuf) -> PathBuf {
trusty_common::palace_alias::palace_registry_dir_from(data_dir)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum HookType {
UserPromptSubmit,
SessionStart,
}
impl HookType {
pub fn as_str(&self) -> &'static str {
match self {
Self::UserPromptSubmit => "UserPromptSubmit",
Self::SessionStart => "SessionStart",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum InjectionKind {
PromptContext,
InboxCheck,
}
impl InjectionKind {
pub fn as_str(&self) -> &'static str {
match self {
Self::PromptContext => "prompt-context",
Self::InboxCheck => "inbox-check",
}
}
}
#[derive(Clone, Debug, serde::Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DaemonEvent {
PalaceCreated {
id: String,
name: String,
source: ActivitySource,
},
DrawerAdded {
palace_id: String,
#[serde(default)]
palace_name: String,
drawer_count: usize,
timestamp: chrono::DateTime<chrono::Utc>,
#[serde(default)]
content_preview: String,
source: ActivitySource,
},
DrawerDeleted {
palace_id: String,
drawer_count: usize,
source: ActivitySource,
},
DreamCompleted {
palace_id: Option<String>,
merged: usize,
pruned: usize,
compacted: usize,
closets_updated: usize,
duration_ms: u64,
source: ActivitySource,
},
StatusChanged {
total_drawers: usize,
total_vectors: usize,
total_kg_triples: usize,
},
HookFired {
#[serde(default)]
palace_id: Option<String>,
#[serde(default)]
palace_name: Option<String>,
hook_type: HookType,
injection_kind: InjectionKind,
injection_length: u64,
#[serde(default)]
trigger_prompt_excerpt: String,
timestamp: chrono::DateTime<chrono::Utc>,
duration_ms: u64,
source: ActivitySource,
},
}
fn open_activity_log_with_fallback(data_root: &Path) -> Arc<ActivityLog> {
match ActivityLog::open(data_root) {
Ok(log) => Arc::new(log),
Err(primary_err) => {
tracing::warn!(
"could not open activity log at {}: {primary_err:#}; falling back to per-process tempdir",
data_root.display()
);
let fallback =
std::env::temp_dir().join(format!("trusty-memory-activity-{}", std::process::id()));
match ActivityLog::open(&fallback) {
Ok(log) => Arc::new(log),
Err(fallback_err) => {
tracing::warn!(
"activity log tempdir fallback at {} also failed: {fallback_err:#}; \
activity feed disabled for this process (no-op log)",
fallback.display()
);
Arc::new(ActivityLog::discard())
}
}
}
}
}
impl DaemonEvent {
pub fn type_str(&self) -> &'static str {
match self {
Self::PalaceCreated { .. } => "palace_created",
Self::DrawerAdded { .. } => "drawer_added",
Self::DrawerDeleted { .. } => "drawer_deleted",
Self::DreamCompleted { .. } => "dream_completed",
Self::StatusChanged { .. } => "status_changed",
Self::HookFired { .. } => "hook_fired",
}
}
pub fn palace_id(&self) -> Option<&str> {
match self {
Self::PalaceCreated { id, .. } => Some(id),
Self::DrawerAdded { palace_id, .. } | Self::DrawerDeleted { palace_id, .. } => {
Some(palace_id)
}
Self::DreamCompleted { palace_id, .. } => palace_id.as_deref(),
Self::HookFired { palace_id, .. } => palace_id.as_deref(),
Self::StatusChanged { .. } => None,
}
}
pub fn source(&self) -> Option<ActivitySource> {
match self {
Self::PalaceCreated { source, .. }
| Self::DrawerAdded { source, .. }
| Self::DrawerDeleted { source, .. }
| Self::DreamCompleted { source, .. }
| Self::HookFired { source, .. } => Some(*source),
Self::StatusChanged { .. } => None,
}
}
}
#[derive(Clone)]
pub struct AppState {
pub version: String,
pub registry: Arc<PalaceRegistry>,
pub data_root: PathBuf,
pub embedder: Arc<OnceCell<Arc<FastEmbedder>>>,
pub default_palace: Option<String>,
pub chat_provider: Arc<OnceCell<Option<Arc<dyn ChatProvider>>>>,
pub session_stores: Arc<dashmap::DashMap<String, Arc<ChatSessionStore>>>,
pub events: Arc<broadcast::Sender<DaemonEvent>>,
pub started_at: std::time::Instant,
pub log_buffer: trusty_common::log_buffer::LogBuffer,
pub error_store: Option<trusty_common::error_capture::ErrorStore>,
pub disk_bytes: Arc<std::sync::atomic::AtomicU64>,
pub sys_metrics: Arc<tokio::sync::Mutex<trusty_common::sys_metrics::SysMetrics>>,
pub bound_addr: Arc<OnceLock<SocketAddr>>,
pub prompt_context_cache: Arc<RwLock<prompt_facts::PromptFactsCache>>,
pub activity_log: Arc<ActivityLog>,
pub bm25_client: Option<Arc<Bm25Client>>,
pub bm25_supervisor: Option<Arc<bm25_supervisor::Bm25Supervisor>>,
pub palace_write_locks: Arc<dashmap::DashMap<String, Arc<tokio::sync::Mutex<()>>>>,
pub pending_activity_writes: Arc<AtomicUsize>,
pub palace_names: Arc<dashmap::DashMap<String, String>>,
pub pin_project_map: Arc<dashmap::DashMap<String, PathBuf>>,
pub bm25_index_tx: tokio::sync::mpsc::Sender<tools::Bm25IndexRequest>,
pub update_available: Arc<std::sync::Mutex<Option<String>>>,
pub daemon_readiness: Arc<AtomicU8>,
}
impl AppState {
pub fn new(data_root: PathBuf) -> Self {
let (events_tx, _) = broadcast::channel::<DaemonEvent>(128);
let activity_log = open_activity_log_with_fallback(&data_root);
let (bm25_index_tx, bm25_index_rx) =
tokio::sync::mpsc::channel::<tools::Bm25IndexRequest>(tools::BM25_INDEX_QUEUE_CAPACITY);
tools::spawn_bm25_index_worker(bm25_index_rx, None, None);
Self {
version: env!("CARGO_PKG_VERSION").to_string(),
registry: Arc::new(PalaceRegistry::new()),
data_root,
embedder: Arc::new(OnceCell::new()),
default_palace: None,
chat_provider: Arc::new(OnceCell::new()),
session_stores: Arc::new(dashmap::DashMap::new()),
events: Arc::new(events_tx),
started_at: std::time::Instant::now(),
log_buffer: trusty_common::log_buffer::LogBuffer::new(
trusty_common::log_buffer::DEFAULT_LOG_CAPACITY,
),
error_store: None,
disk_bytes: Arc::new(std::sync::atomic::AtomicU64::new(0)),
sys_metrics: Arc::new(tokio::sync::Mutex::new(
trusty_common::sys_metrics::SysMetrics::new(),
)),
bound_addr: Arc::new(OnceLock::new()),
prompt_context_cache: Arc::new(RwLock::new(prompt_facts::PromptFactsCache::default())),
activity_log,
bm25_client: None,
bm25_supervisor: None,
palace_write_locks: Arc::new(dashmap::DashMap::new()),
pending_activity_writes: Arc::new(AtomicUsize::new(0)),
palace_names: Arc::new(dashmap::DashMap::new()),
pin_project_map: Arc::new(dashmap::DashMap::new()),
bm25_index_tx,
update_available: Arc::new(std::sync::Mutex::new(None)),
daemon_readiness: Arc::new(AtomicU8::new(DaemonReadiness::Warming as u8)),
}
}
pub fn palace_write_lock(&self, palace_id: &str) -> Arc<tokio::sync::Mutex<()>> {
if let Some(existing) = self.palace_write_locks.get(palace_id) {
return existing.clone();
}
self.palace_write_locks
.entry(palace_id.to_string())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}
pub fn pinned_project_path(&self, palace_id: &str) -> Option<PathBuf> {
self.pin_project_map.get(palace_id).map(|e| e.clone())
}
#[must_use]
pub fn with_bm25_client_from_env(mut self) -> Self {
if std::env::var("TRUSTY_BM25_DAEMON").as_deref() == Ok("1") {
let default_palace = self.default_palace.as_deref().unwrap_or("default");
self.bm25_client = Some(Arc::new(Bm25Client::for_palace(default_palace)));
self.bm25_supervisor = Some(Arc::new(bm25_supervisor::Bm25Supervisor::new()));
let (tx, rx) = tokio::sync::mpsc::channel::<tools::Bm25IndexRequest>(
tools::BM25_INDEX_QUEUE_CAPACITY,
);
tools::spawn_bm25_index_worker(
rx,
self.bm25_client.clone(),
self.bm25_supervisor.clone(),
);
self.bm25_index_tx = tx;
tracing::info!(
palace = default_palace,
"BM25 daemon client + spawn supervisor enabled (TRUSTY_BM25_DAEMON=1)"
);
}
self
}
pub async fn load_palaces_from_disk(&self) -> Result<usize> {
let registry_dir = self.data_root.clone();
let registry = self.registry.clone();
let palace_names = self.palace_names.clone();
let count = tokio::task::spawn_blocking(move || -> Result<usize> {
let palaces = PalaceRegistry::list_palaces(®istry_dir)?;
let total = palaces.len();
let mut loaded = 0usize;
let mut skipped = 0usize;
for palace in palaces {
match trusty_common::memory_core::PalaceHandle::open(&palace) {
Ok(handle) => {
tracing::debug!(
palace = %palace.id,
data_dir = %palace.data_dir.display(),
"loaded palace from disk"
);
palace_names.insert(palace.id.0.clone(), palace.name.clone());
registry.register_arc(handle);
loaded += 1;
}
Err(e) => {
tracing::warn!(
palace = %palace.id,
data_dir = %palace.data_dir.display(),
"skipping palace during startup hydration: {e:#}; \
will retry lazily on first access"
);
skipped += 1;
}
}
}
tracing::info!(
"palace hydration summary: loaded {loaded}/{total} ({skipped} skipped due to errors)"
);
Ok(loaded)
})
.await
.map_err(|e| anyhow::anyhow!("join load_palaces_from_disk: {e}"))??;
Ok(count)
}
#[must_use]
pub fn with_log_buffer(mut self, buffer: trusty_common::log_buffer::LogBuffer) -> Self {
self.log_buffer = buffer;
self
}
#[must_use]
pub fn with_writer_intent(mut self) -> Self {
debug_assert!(self.registry.is_empty() && Arc::strong_count(&self.registry) == 1);
self.registry = Arc::new(PalaceRegistry::new().with_writer_intent());
self
}
#[must_use]
pub fn with_error_store(mut self, store: trusty_common::error_capture::ErrorStore) -> Self {
self.error_store = Some(store);
self
}
pub fn emit(&self, event: DaemonEvent) {
if let Some(source) = event.source() {
let event_type = event.type_str();
let palace_id = event.palace_id().map(|s| s.to_string());
let log = Arc::clone(&self.activity_log);
let event_for_log = event.clone();
let pending = Arc::clone(&self.pending_activity_writes);
let id = log.alloc_id();
pending.fetch_add(1, Ordering::SeqCst);
tokio::task::spawn_blocking(move || {
let result = log.append_with_id(id, source, palace_id, event_type, &event_for_log);
if let Err(e) = result {
tracing::warn!("activity_log.append failed for {event_type}: {e:#}");
}
pending.fetch_sub(1, Ordering::SeqCst);
});
}
let _ = self.events.send(event);
}
pub async fn flush_activity_writes(&self) {
while self.pending_activity_writes.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
}
pub fn session_store(&self, palace_id: &str) -> Result<Arc<ChatSessionStore>> {
if let Some(entry) = self.session_stores.get(palace_id) {
return Ok(entry.clone());
}
let dir = self.data_root.join(palace_id);
std::fs::create_dir_all(&dir)
.map_err(|e| anyhow::anyhow!("create palace dir {}: {e}", dir.display()))?;
let store = Arc::new(ChatSessionStore::open(&dir.join("chat_sessions.db"))?);
self.session_stores
.insert(palace_id.to_string(), store.clone());
Ok(store)
}
pub fn with_default_palace(mut self, name: Option<String>) -> Self {
self.default_palace = name;
self
}
pub async fn chat_provider(&self) -> Option<Arc<dyn ChatProvider>> {
self.chat_provider
.get_or_init(|| async {
let cfg = crate::service::load_user_config().unwrap_or_default();
if cfg.local_model.enabled {
if let Some(mut p) =
trusty_common::auto_detect_local_provider(&cfg.local_model.base_url).await
{
p.model = cfg.local_model.model.clone();
return Some(Arc::new(p) as Arc<dyn ChatProvider>);
}
}
if !cfg.openrouter_api_key.is_empty() {
return Some(Arc::new(trusty_common::OpenRouterProvider::new(
cfg.openrouter_api_key,
cfg.openrouter_model,
)) as Arc<dyn ChatProvider>);
}
None
})
.await
.clone()
}
pub fn spawn_alias_discovery(&self, palace: String, project_root: PathBuf) {
let state = self.clone();
tokio::spawn(async move {
let args = serde_json::json!({
"palace": palace,
"project_root": project_root.to_string_lossy(),
});
match tools::dispatch_tool(&state, "discover_aliases", args).await {
Ok(result) => tracing::info!(
new = ?result.get("new"),
already_known = ?result.get("already_known"),
"alias discovery complete"
),
Err(e) => tracing::warn!("alias discovery failed: {e:#}"),
}
});
}
pub fn readiness(&self) -> DaemonReadiness {
DaemonReadiness::from_u8(self.daemon_readiness.load(Ordering::Acquire))
}
pub fn set_ready(&self) {
self.daemon_readiness
.store(DaemonReadiness::Ready as u8, Ordering::Release);
}
pub async fn embedder(&self) -> Result<Arc<FastEmbedder>> {
use trusty_common::memory_core::timeouts;
let cell = self.embedder.clone();
let timeout = timeouts::embedder_init_timeout();
let embedder = tokio::time::timeout(
timeout,
cell.get_or_try_init(|| async {
let e = FastEmbedder::new().await?;
Ok::<Arc<FastEmbedder>, anyhow::Error>(Arc::new(e))
}),
)
.await
.map_err(|_| {
anyhow::anyhow!(
"AppState::embedder() timed out after {:?}; \
the CoreML/CUDA model is taking unusually long to compile — \
increase TRUSTY_EMBEDDER_INIT_TIMEOUT_SECS if needed",
timeout
)
})??
.clone();
Ok(embedder)
}
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppState")
.field("version", &self.version)
.field("data_root", &self.data_root)
.field("registry_len", &self.registry.len())
.finish()
}
}
pub async fn handle_message(state: &AppState, msg: Value) -> Value {
let id = msg.get("id").cloned().unwrap_or(Value::Null);
let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
match method {
"initialize" => {
let extra = state
.default_palace
.as_ref()
.map(|dp| json!({ "default_palace": dp }));
let result = initialize_response("trusty-memory", &state.version, extra);
json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
})
}
"notifications/initialized" | "notifications/cancelled" => Value::Null,
"tools/list" => json!({
"jsonrpc": "2.0",
"id": id,
"result": tools::tool_definitions_with(state.default_palace.is_some())
}),
"rpc.discover" => json!({
"jsonrpc": "2.0",
"id": id,
"result": openrpc::build_discover_response(
&state.version,
state.default_palace.is_some(),
),
}),
"tools/call" => {
let params = msg.get("params").cloned().unwrap_or_default();
let tool_name = params
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args = params.get("arguments").cloned().unwrap_or_default();
match tools::dispatch_tool(state, &tool_name, args).await {
Ok(content) => {
let text = match &content {
Value::String(s) => s.clone(),
other => other.to_string(),
};
json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"content": [{"type": "text", "text": text}]
}
})
}
Err(e) => json!({
"jsonrpc": "2.0",
"id": id,
"error": {"code": -32603, "message": format!("{e:#}")}
}),
}
}
"ping" => json!({"jsonrpc": "2.0", "id": id, "result": {}}),
_ => json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -32601,
"message": format!("Method not found: {method}")
}
}),
}
}
pub const DEFAULT_HTTP_PORT: u16 = 7070;
const DYNAMIC_PORT_RANGE: u16 = 10;
pub fn http_addr_path() -> Option<PathBuf> {
trusty_common::resolve_data_dir("trusty-memory")
.ok()
.map(|d| d.join("http_addr"))
}
pub async fn bind_dynamic_port() -> Result<tokio::net::TcpListener> {
let preferred: SocketAddr = SocketAddr::from(([127, 0, 0, 1], DEFAULT_HTTP_PORT));
if let Ok(listener) =
trusty_common::bind_with_auto_port(preferred, DYNAMIC_PORT_RANGE - 1).await
{
return Ok(listener);
}
tracing::warn!(
"all ports {DEFAULT_HTTP_PORT}..{} in use; requesting OS-assigned port",
DEFAULT_HTTP_PORT + DYNAMIC_PORT_RANGE - 1
);
let any: SocketAddr = SocketAddr::from(([127, 0, 0, 1], 0));
trusty_common::bind_with_auto_port(any, 0).await
}
#[cfg(feature = "axum-server")]
fn write_http_addr_file(path: &Path, addr: &SocketAddr) -> std::io::Result<()> {
use std::io::Write;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = path.with_extension("addr.tmp");
{
let mut f = std::fs::File::create(&tmp)?;
writeln!(f, "{addr}")?;
f.sync_all()?;
}
std::fs::rename(&tmp, path)?;
Ok(())
}
#[inline]
pub fn is_data_dir_override_active() -> bool {
matches!(
std::env::var(trusty_common::DATA_DIR_OVERRIDE_ENV),
Ok(v) if !v.trim().is_empty()
)
}
#[cfg(feature = "axum-server")]
fn dotfile_http_addr_path() -> Option<PathBuf> {
if is_data_dir_override_active() {
return None;
}
dirs::home_dir().map(|h| h.join(".trusty-memory").join("http_addr"))
}
#[cfg(feature = "axum-server")]
pub async fn run_http_on(state: AppState, listener: tokio::net::TcpListener) -> Result<()> {
use axum::routing::get;
spawn_disk_size_ticker(state.clone());
spawn_status_event_ticker(state.clone());
let local = listener.local_addr().ok();
let (written_path, written_dotfile_path) = if let Some(a) = local {
let _ = state.bound_addr.set(a);
info!("HTTP server listening on http://{a}");
eprintln!("HTTP server listening on http://{a}");
let primary = match http_addr_path() {
Some(p) => match write_http_addr_file(&p, &a) {
Ok(()) => {
info!("wrote daemon address to {}", p.display());
Some(p)
}
Err(e) => {
tracing::warn!("could not write {}: {e}", p.display());
None
}
},
None => {
tracing::warn!("no $HOME — skipping http_addr discovery file");
None
}
};
let dotfile = match dotfile_http_addr_path() {
Some(p) => match write_http_addr_file(&p, &a) {
Ok(()) => {
info!("wrote daemon address to dotfile {}", p.display());
Some(p)
}
Err(e) => {
tracing::warn!("could not write dotfile {}: {e}", p.display());
None
}
},
None => None,
};
(primary, dotfile)
} else {
(None, None)
};
let bm25_supervisor = state.bm25_supervisor.clone();
let app = web::router()
.route("/sse", get(sse_handler))
.with_state(state);
let serve_result = axum::serve(listener, app)
.with_graceful_shutdown(trusty_common::shutdown_signal())
.await;
if let Some(p) = written_path.as_ref() {
let _ = std::fs::remove_file(p);
}
if let Some(p) = written_dotfile_path.as_ref() {
let _ = std::fs::remove_file(p);
}
if let Some(supervisor) = bm25_supervisor {
supervisor.shutdown().await;
}
serve_result?;
Ok(())
}
#[cfg(feature = "axum-server")]
pub async fn run_http(state: AppState, addr: std::net::SocketAddr) -> Result<()> {
let listener = tokio::net::TcpListener::bind(addr).await?;
run_http_on(state, listener).await
}
#[cfg(feature = "axum-server")]
pub async fn run_http_dynamic(state: AppState) -> Result<()> {
let listener = bind_dynamic_port().await?;
run_http_on(state, listener).await
}
#[cfg(feature = "axum-server")]
fn spawn_disk_size_ticker(state: AppState) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(10));
loop {
interval.tick().await;
let dir = state.data_root.clone();
let bytes = tokio::task::spawn_blocking(move || {
trusty_common::sys_metrics::dir_size_bytes(&dir)
})
.await
.unwrap_or(0);
state
.disk_bytes
.store(bytes, std::sync::atomic::Ordering::Relaxed);
}
});
}
#[allow(dead_code)]
const STATUS_EVENT_TICK_SECS: u64 = 30;
#[allow(dead_code)]
fn spawn_status_event_ticker(state: AppState) {
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_secs(STATUS_EVENT_TICK_SECS));
loop {
interval.tick().await;
let event = service::MemoryService::new(state.clone()).aggregate_status_event();
state.emit(event);
}
});
}
#[cfg(feature = "axum-server")]
pub(crate) async fn sse_handler(
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl axum::response::IntoResponse {
use futures::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
let rx = state.events.subscribe();
let initial = futures::stream::once(async {
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(
"data: {\"type\":\"connected\"}\n\n",
))
});
let events = BroadcastStream::new(rx).map(|res| {
let frame = match res {
Ok(event) => match serde_json::to_string(&event) {
Ok(json) => format!("data: {json}\n\n"),
Err(e) => format!("data: {{\"type\":\"error\",\"message\":\"{e}\"}}\n\n"),
},
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")
}
};
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(frame))
});
let stream = initial.chain(events);
axum::response::Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(axum::body::Body::from_stream(stream))
.expect("valid SSE response") }
#[cfg(test)]
mod lib_tests;