Skip to main content

nexus_memory_web/
lib.rs

1//! Nexus Web Dashboard - Axum-based web interface for Nexus Memory System
2//!
3//! This crate provides:
4//! - REST API endpoints for memory CRUD operations
5//! - WebSocket real-time updates
6//! - Static file serving for the dashboard UI
7//! - CORS and security middleware
8//!
9//! ## Example
10//!
11//! ```rust,ignore
12//! use nexus_memory_web::WebDashboard;
13//! use std::sync::Arc;
14//! use tokio::sync::RwLock;
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//!     let dashboard = WebDashboard::new(manager).await?;
19//!     let addr = SocketAddr::from(([0, 0, 0, 0], 8768));
20//!     dashboard.serve(addr).await?;
21//!     Ok(())
22//! }
23//! ```
24
25pub mod api;
26pub mod error;
27pub mod models;
28pub mod state;
29pub mod websocket;
30
31use axum::{
32    routing::{get, post},
33    Router,
34};
35use http::HeaderValue;
36use std::net::SocketAddr;
37use std::sync::Arc;
38use tokio::sync::RwLock;
39use tower_http::cors::{AllowOrigin, CorsLayer};
40use tower_http::services::ServeDir;
41use tower_http::trace::TraceLayer;
42use tracing::info;
43use url::Url;
44
45pub use error::{Result, WebError};
46pub use models::*;
47pub use state::AppState;
48
49use api::{
50    agent_consolidate, agent_ingest, agent_query, agent_status, cognition_overview, create_memory,
51    create_namespace, dashboard, delete_memory, get_agent_stats, get_memory, get_namespace,
52    get_stats, health_check, job_summary, list_digests, list_jobs, list_memories, list_namespaces,
53    query_introspection, reflection_state, runtime_health, search_memories, update_memory,
54};
55use websocket::websocket_handler;
56
57/// Web Dashboard for Nexus Memory System
58pub struct WebDashboard {
59    router: Router,
60    state: Arc<RwLock<AppState>>,
61}
62
63impl WebDashboard {
64    /// Create a new web dashboard instance
65    pub async fn new(
66        storage: nexus_storage::StorageManager,
67        orchestrator: nexus_orchestrator::Orchestrator,
68    ) -> Result<Self> {
69        let state = Arc::new(RwLock::new(AppState::new(storage, orchestrator).await?));
70        let router = Self::build_router(state.clone());
71
72        Ok(Self { router, state })
73    }
74
75    /// Build the Axum router with all routes
76    fn build_router(state: Arc<RwLock<AppState>>) -> Router {
77        // CORS Layer — restrict to exact local-only origins.
78        // Parses the Origin header as a URL and compares host exactly
79        // to prevent prefix-spoofing attacks (e.g. http://localhost.evil.com).
80        let cors = CorsLayer::new()
81            .allow_origin(AllowOrigin::predicate(
82                |origin: &HeaderValue, _request: &http::request::Parts| {
83                    let origin_str = origin.to_str().unwrap_or("");
84                    match Url::parse(origin_str) {
85                        Ok(url) => {
86                            let host = url.host_str().unwrap_or("");
87                            let scheme = url.scheme();
88                            // Only allow exact localhost / 127.0.0.1 origins
89                            (scheme == "http" || scheme == "https")
90                                && (host == "127.0.0.1" || host == "localhost")
91                        }
92                        Err(_) => false, // Malformed origins are rejected
93                    }
94                },
95            ))
96            .allow_methods([
97                axum::http::Method::GET,
98                axum::http::Method::POST,
99                axum::http::Method::PUT,
100                axum::http::Method::DELETE,
101            ])
102            .allow_headers([
103                axum::http::header::CONTENT_TYPE,
104                axum::http::header::ACCEPT,
105                axum::http::header::ORIGIN,
106            ]);
107
108        // API routes
109        let api_routes = Router::new()
110            // Memory endpoints
111            .route("/memories", get(list_memories).post(create_memory))
112            .route(
113                "/memories/{id}",
114                get(get_memory).put(update_memory).delete(delete_memory),
115            )
116            .route("/memories/search", post(search_memories))
117            // Namespace endpoints
118            .route("/namespaces", get(list_namespaces).post(create_namespace))
119            .route("/namespaces/{id}", get(get_namespace))
120            // Stats endpoints
121            .route("/stats", get(get_stats))
122            .route("/stats/{agent}", get(get_agent_stats))
123            // Health check
124            .route("/health", get(health_check))
125            // Agent endpoints
126            .route("/agent/ingest", post(agent_ingest))
127            .route("/agent/query", post(agent_query))
128            .route("/agent/consolidate", post(agent_consolidate))
129            .route("/agent/status", get(agent_status))
130            // Cognition observability endpoints
131            .route("/cognition/jobs", get(list_jobs))
132            .route("/cognition/jobs/summary", get(job_summary))
133            .route("/cognition/digests", get(list_digests))
134            .route("/cognition/overview", get(cognition_overview))
135            .route("/cognition/reflection", get(reflection_state))
136            .route("/cognition/runtime", get(runtime_health))
137            .route("/cognition/query-introspection", get(query_introspection))
138            .route("/cognition/dashboard", get(dashboard));
139
140        // WebSocket route
141        let ws_route = Router::new().route("/ws", get(websocket_handler));
142
143        // Combine all routes
144        Router::new()
145            .nest("/api", api_routes)
146            .merge(ws_route)
147            // Serve static files from the static directory
148            .fallback_service(ServeDir::new("src/static").append_index_html_on_directories(true))
149            .layer(cors)
150            .layer(TraceLayer::new_for_http())
151            .with_state(state)
152    }
153
154    /// Serve the web dashboard on the specified address
155    pub async fn serve(self, addr: SocketAddr) -> Result<()> {
156        info!("Starting Nexus Web Dashboard on {}", addr);
157
158        let listener = tokio::net::TcpListener::bind(addr)
159            .await
160            .map_err(|e| WebError::ServerStart(e.to_string()))?;
161
162        info!("Web Dashboard listening on http://{}", addr);
163
164        axum::serve(listener, self.router)
165            .await
166            .map_err(|e| WebError::ServerStart(e.to_string()))?;
167
168        Ok(())
169    }
170
171    /// Get a clone of the state
172    pub fn state(&self) -> Arc<RwLock<AppState>> {
173        self.state.clone()
174    }
175}
176
177/// Create a new web dashboard with the given storage and orchestrator
178pub async fn create_dashboard(
179    storage: nexus_storage::StorageManager,
180    orchestrator: nexus_orchestrator::Orchestrator,
181) -> Result<WebDashboard> {
182    WebDashboard::new(storage, orchestrator).await
183}
184
185/// Run the web dashboard on the default port (8768)
186pub async fn run_default(
187    storage: nexus_storage::StorageManager,
188    orchestrator: nexus_orchestrator::Orchestrator,
189) -> Result<()> {
190    let dashboard = WebDashboard::new(storage, orchestrator).await?;
191    let addr = SocketAddr::from(([0, 0, 0, 0], 8768));
192    dashboard.serve(addr).await
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use axum::body::to_bytes;
199    use axum::body::Body;
200    use axum::http::{Request, StatusCode};
201    use nexus_orchestrator::Orchestrator;
202    use serde_json::{json, Value};
203    use tower::ServiceExt;
204
205    fn body_to_json(body: axum::body::Bytes) -> Value {
206        serde_json::from_slice(&body).expect("valid JSON")
207    }
208
209    #[test]
210    fn test_web_error_display() {
211        let err = WebError::ServerStart("test error".to_string());
212        assert!(err.to_string().contains("test error"));
213    }
214
215    #[tokio::test]
216    async fn test_production_router_exposes_cognition_runtime_route() {
217        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
218            .await
219            .expect("connect to in-memory db");
220        nexus_storage::migrations::run_migrations(&pool)
221            .await
222            .expect("run migrations");
223
224        let mut storage = nexus_storage::StorageManager::new(pool.clone());
225        storage.initialize().await.expect("initialize storage");
226
227        let dashboard = WebDashboard::new(storage, Orchestrator::default())
228            .await
229            .expect("create dashboard");
230
231        let resp = dashboard
232            .router
233            .oneshot(
234                Request::builder()
235                    .uri("/api/cognition/runtime")
236                    .body(Body::empty())
237                    .unwrap(),
238            )
239            .await
240            .unwrap();
241
242        assert_eq!(resp.status(), StatusCode::OK);
243    }
244
245    #[tokio::test]
246    async fn test_production_router_exposes_cognition_dashboard_route() {
247        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
248            .await
249            .expect("connect to in-memory db");
250        nexus_storage::migrations::run_migrations(&pool)
251            .await
252            .expect("run migrations");
253
254        let mut storage = nexus_storage::StorageManager::new(pool.clone());
255        storage.initialize().await.expect("initialize storage");
256
257        let dashboard = WebDashboard::new(storage, Orchestrator::default())
258            .await
259            .expect("create dashboard");
260
261        // Dashboard requires a namespace, so this will be 400 (missing namespace).
262        let resp = dashboard
263            .router
264            .oneshot(
265                Request::builder()
266                    .uri("/api/cognition/dashboard")
267                    .body(Body::empty())
268                    .unwrap(),
269            )
270            .await
271            .unwrap();
272
273        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
274    }
275
276    #[tokio::test]
277    async fn test_update_memory_persists_native_sql_values() {
278        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
279            .await
280            .expect("connect to in-memory db");
281        nexus_storage::migrations::run_migrations(&pool)
282            .await
283            .expect("run migrations");
284
285        let mut storage = nexus_storage::StorageManager::new(pool.clone());
286        storage.initialize().await.expect("initialize storage");
287
288        let dashboard = WebDashboard::new(storage, Orchestrator::default())
289            .await
290            .expect("create dashboard");
291
292        let memory_id = {
293            let state = dashboard.state.read().await;
294            let namespace = state
295                .namespace_repo
296                .get_or_create("web-update-test", "test-agent")
297                .await
298                .expect("create namespace");
299            state
300                .memory_repo
301                .store(nexus_storage::StoreMemoryParams {
302                    namespace_id: namespace.id,
303                    content: "original content",
304                    category: &nexus_core::MemoryCategory::General,
305                    memory_lane_type: None,
306                    labels: &["initial".to_string()],
307                    metadata: &json!({"before": true}),
308                    embedding: None,
309                    embedding_model: None,
310                })
311                .await
312                .expect("store memory")
313                .id
314        };
315
316        let resp = dashboard
317            .router
318            .clone()
319            .oneshot(
320                Request::builder()
321                    .method("PUT")
322                    .uri(format!("/api/memories/{memory_id}"))
323                    .header("content-type", "application/json")
324                    .body(Body::from(
325                        serde_json::to_vec(&json!({
326                            "content": "updated content",
327                            "category": "facts",
328                            "memory_lane_type": "decision",
329                            "labels": ["updated", "native-bindings"],
330                            "metadata": {"source": "test"},
331                            "is_active": true,
332                            "is_archived": false
333                        }))
334                        .expect("serialize request"),
335                    ))
336                    .unwrap(),
337            )
338            .await
339            .unwrap();
340
341        let status = resp.status();
342        let body = to_bytes(resp.into_body(), 1_000_000).await.unwrap();
343        assert_eq!(
344            status,
345            StatusCode::OK,
346            "unexpected response body: {}",
347            String::from_utf8_lossy(&body)
348        );
349        let json = body_to_json(body);
350        assert_eq!(json["content"], "updated content");
351        assert_eq!(json["category"], "facts");
352        assert_eq!(json["memory_lane_type"], "decision");
353        assert_eq!(json["metadata"]["source"], "test");
354
355        let row: (String, String, String, i64, i64) = sqlx::query_as(
356            "SELECT category, memory_lane_type, metadata, is_active, is_archived FROM memories WHERE id = ?",
357        )
358        .bind(memory_id)
359        .fetch_one(&pool)
360        .await
361        .expect("fetch updated row");
362
363        assert_eq!(row.0, "facts");
364        assert_eq!(row.1, "decision");
365        assert_eq!(row.2, r#"{"source":"test"}"#);
366        assert_eq!(row.3, 1);
367        assert_eq!(row.4, 0);
368    }
369}