Skip to main content

nexus_memory_web/
lib.rs

1//! Nexus Web Dashboard - Axum-based web interface for Nexus Memory System
2//!
3//! This crate provides:
4//! - REST API endpoints for memory CRUD operations
5//! - WebSocket real-time updates
6//! - Static file serving for the dashboard UI
7//! - CORS and security middleware
8//!
9//! ## Example
10//!
11//! ```rust,ignore
12//! use nexus_memory_web::WebDashboard;
13//! use std::sync::Arc;
14//! use tokio::sync::RwLock;
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//!     let dashboard = WebDashboard::new(manager).await?;
19//!     let addr = SocketAddr::from(([0, 0, 0, 0], 8768));
20//!     dashboard.serve(addr).await?;
21//!     Ok(())
22//! }
23//! ```
24
25pub mod api;
26pub mod error;
27pub mod models;
28pub mod state;
29pub mod websocket;
30
31use axum::{
32    routing::{get, post},
33    Router,
34};
35use http::HeaderValue;
36use std::net::SocketAddr;
37use std::sync::Arc;
38use tokio::sync::RwLock;
39use tower_http::cors::{AllowOrigin, CorsLayer};
40use tower_http::services::ServeDir;
41use tower_http::trace::TraceLayer;
42use tracing::info;
43use url::Url;
44
45pub use error::{Result, WebError};
46pub use models::*;
47pub use state::AppState;
48
49use api::{
50    agent_boost, agent_consolidate, agent_ingest, agent_query, agent_status, cognition_overview,
51    create_memory, create_namespace, dashboard, delete_memory, get_agent_stats, get_memory,
52    get_namespace, get_stats, health_check, job_summary, list_digests, list_jobs, list_memories,
53    list_namespaces, query_introspection, reflection_state, runtime_health, search_memories,
54    update_memory,
55};
56use websocket::websocket_handler;
57
58/// Web Dashboard for Nexus Memory System
59pub struct WebDashboard {
60    router: Router,
61    state: Arc<RwLock<AppState>>,
62}
63
64impl WebDashboard {
65    /// Create a new web dashboard instance
66    pub async fn new(
67        storage: nexus_storage::StorageManager,
68        orchestrator: nexus_orchestrator::Orchestrator,
69    ) -> Result<Self> {
70        let state = Arc::new(RwLock::new(AppState::new(storage, orchestrator).await?));
71        let router = Self::build_router(state.clone());
72
73        Ok(Self { router, state })
74    }
75
76    /// Build the Axum router with all routes
77    fn build_router(state: Arc<RwLock<AppState>>) -> Router {
78        // CORS Layer — restrict to exact local-only origins.
79        // Parses the Origin header as a URL and compares host exactly
80        // to prevent prefix-spoofing attacks (e.g. http://localhost.evil.com).
81        let cors = CorsLayer::new()
82            .allow_origin(AllowOrigin::predicate(
83                |origin: &HeaderValue, _request: &http::request::Parts| {
84                    let origin_str = origin.to_str().unwrap_or("");
85                    match Url::parse(origin_str) {
86                        Ok(url) => {
87                            let host = url.host_str().unwrap_or("");
88                            let scheme = url.scheme();
89                            // Only allow exact localhost / 127.0.0.1 origins
90                            (scheme == "http" || scheme == "https")
91                                && (host == "127.0.0.1" || host == "localhost")
92                        }
93                        Err(_) => false, // Malformed origins are rejected
94                    }
95                },
96            ))
97            .allow_methods([
98                axum::http::Method::GET,
99                axum::http::Method::POST,
100                axum::http::Method::PUT,
101                axum::http::Method::DELETE,
102            ])
103            .allow_headers([
104                axum::http::header::CONTENT_TYPE,
105                axum::http::header::ACCEPT,
106                axum::http::header::ORIGIN,
107            ]);
108
109        // API routes
110        let api_routes = Router::new()
111            // Memory endpoints
112            .route("/memories", get(list_memories).post(create_memory))
113            .route(
114                "/memories/{id}",
115                get(get_memory).put(update_memory).delete(delete_memory),
116            )
117            .route("/memories/search", post(search_memories))
118            // Namespace endpoints
119            .route("/namespaces", get(list_namespaces).post(create_namespace))
120            .route("/namespaces/{id}", get(get_namespace))
121            // Stats endpoints
122            .route("/stats", get(get_stats))
123            .route("/stats/{agent}", get(get_agent_stats))
124            // Health check
125            .route("/health", get(health_check))
126            // Agent endpoints
127            .route("/agent/ingest", post(agent_ingest))
128            .route("/agent/query", post(agent_query))
129            .route("/agent/consolidate", post(agent_consolidate))
130            .route("/agent/boost", post(agent_boost))
131            .route("/agent/status", get(agent_status))
132            // Cognition observability endpoints
133            .route("/cognition/jobs", get(list_jobs))
134            .route("/cognition/jobs/summary", get(job_summary))
135            .route("/cognition/digests", get(list_digests))
136            .route("/cognition/overview", get(cognition_overview))
137            .route("/cognition/reflection", get(reflection_state))
138            .route("/cognition/runtime", get(runtime_health))
139            .route("/cognition/query-introspection", get(query_introspection))
140            .route("/cognition/dashboard", get(dashboard));
141
142        // WebSocket route
143        let ws_route = Router::new().route("/ws", get(websocket_handler));
144
145        // Combine all routes
146        Router::new()
147            .nest("/api", api_routes)
148            .merge(ws_route)
149            // Serve static files from the static directory
150            .fallback_service(ServeDir::new("src/static").append_index_html_on_directories(true))
151            .layer(cors)
152            .layer(TraceLayer::new_for_http())
153            .with_state(state)
154    }
155
156    /// Serve the web dashboard on the specified address
157    pub async fn serve(self, addr: SocketAddr) -> Result<()> {
158        info!("Starting Nexus Web Dashboard on {}", addr);
159
160        let listener = tokio::net::TcpListener::bind(addr)
161            .await
162            .map_err(|e| WebError::ServerStart(e.to_string()))?;
163
164        info!("Web Dashboard listening on http://{}", addr);
165
166        axum::serve(listener, self.router)
167            .await
168            .map_err(|e| WebError::ServerStart(e.to_string()))?;
169
170        Ok(())
171    }
172
173    /// Get a clone of the state
174    pub fn state(&self) -> Arc<RwLock<AppState>> {
175        self.state.clone()
176    }
177}
178
179/// Create a new web dashboard with the given storage and orchestrator
180pub async fn create_dashboard(
181    storage: nexus_storage::StorageManager,
182    orchestrator: nexus_orchestrator::Orchestrator,
183) -> Result<WebDashboard> {
184    WebDashboard::new(storage, orchestrator).await
185}
186
187/// Run the web dashboard on the default port (8768)
188pub async fn run_default(
189    storage: nexus_storage::StorageManager,
190    orchestrator: nexus_orchestrator::Orchestrator,
191) -> Result<()> {
192    let dashboard = WebDashboard::new(storage, orchestrator).await?;
193    let addr = SocketAddr::from(([0, 0, 0, 0], 8768));
194    dashboard.serve(addr).await
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use axum::body::to_bytes;
201    use axum::body::Body;
202    use axum::http::{Request, StatusCode};
203    use nexus_orchestrator::Orchestrator;
204    use serde_json::{json, Value};
205    use tower::ServiceExt;
206
207    fn body_to_json(body: axum::body::Bytes) -> Value {
208        serde_json::from_slice(&body).expect("valid JSON")
209    }
210
211    #[test]
212    fn test_web_error_display() {
213        let err = WebError::ServerStart("test error".to_string());
214        assert!(err.to_string().contains("test error"));
215    }
216
217    #[tokio::test]
218    async fn test_production_router_exposes_cognition_runtime_route() {
219        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
220            .await
221            .expect("connect to in-memory db");
222        nexus_storage::migrations::run_migrations(&pool)
223            .await
224            .expect("run migrations");
225
226        let mut storage = nexus_storage::StorageManager::new(pool.clone());
227        storage.initialize().await.expect("initialize storage");
228
229        let dashboard = WebDashboard::new(storage, Orchestrator::default())
230            .await
231            .expect("create dashboard");
232
233        let resp = dashboard
234            .router
235            .oneshot(
236                Request::builder()
237                    .uri("/api/cognition/runtime")
238                    .body(Body::empty())
239                    .unwrap(),
240            )
241            .await
242            .unwrap();
243
244        assert_eq!(resp.status(), StatusCode::OK);
245    }
246
247    #[tokio::test]
248    async fn test_production_router_exposes_cognition_dashboard_route() {
249        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
250            .await
251            .expect("connect to in-memory db");
252        nexus_storage::migrations::run_migrations(&pool)
253            .await
254            .expect("run migrations");
255
256        let mut storage = nexus_storage::StorageManager::new(pool.clone());
257        storage.initialize().await.expect("initialize storage");
258
259        let dashboard = WebDashboard::new(storage, Orchestrator::default())
260            .await
261            .expect("create dashboard");
262
263        // Dashboard requires a namespace, so this will be 400 (missing namespace).
264        let resp = dashboard
265            .router
266            .oneshot(
267                Request::builder()
268                    .uri("/api/cognition/dashboard")
269                    .body(Body::empty())
270                    .unwrap(),
271            )
272            .await
273            .unwrap();
274
275        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
276    }
277
278    #[tokio::test]
279    async fn test_update_memory_persists_native_sql_values() {
280        let pool = sqlx::SqlitePool::connect("sqlite::memory:")
281            .await
282            .expect("connect to in-memory db");
283        nexus_storage::migrations::run_migrations(&pool)
284            .await
285            .expect("run migrations");
286
287        let mut storage = nexus_storage::StorageManager::new(pool.clone());
288        storage.initialize().await.expect("initialize storage");
289
290        let dashboard = WebDashboard::new(storage, Orchestrator::default())
291            .await
292            .expect("create dashboard");
293
294        let memory_id = {
295            let state = dashboard.state.read().await;
296            let namespace = state
297                .namespace_repo
298                .get_or_create("web-update-test", "test-agent")
299                .await
300                .expect("create namespace");
301            state
302                .memory_repo
303                .store(nexus_storage::StoreMemoryParams {
304                    namespace_id: namespace.id,
305                    content: "original content",
306                    category: &nexus_core::MemoryCategory::General,
307                    memory_lane_type: None,
308                    labels: &["initial".to_string()],
309                    metadata: &json!({"before": true}),
310                    embedding: None,
311                    embedding_model: None,
312                })
313                .await
314                .expect("store memory")
315                .id
316        };
317
318        let resp = dashboard
319            .router
320            .clone()
321            .oneshot(
322                Request::builder()
323                    .method("PUT")
324                    .uri(format!("/api/memories/{memory_id}"))
325                    .header("content-type", "application/json")
326                    .body(Body::from(
327                        serde_json::to_vec(&json!({
328                            "content": "updated content",
329                            "category": "facts",
330                            "memory_lane_type": "decision",
331                            "labels": ["updated", "native-bindings"],
332                            "metadata": {"source": "test"},
333                            "is_active": true,
334                            "is_archived": false
335                        }))
336                        .expect("serialize request"),
337                    ))
338                    .unwrap(),
339            )
340            .await
341            .unwrap();
342
343        let status = resp.status();
344        let body = to_bytes(resp.into_body(), 1_000_000).await.unwrap();
345        assert_eq!(
346            status,
347            StatusCode::OK,
348            "unexpected response body: {}",
349            String::from_utf8_lossy(&body)
350        );
351        let json = body_to_json(body);
352        assert_eq!(json["content"], "updated content");
353        assert_eq!(json["category"], "facts");
354        assert_eq!(json["memory_lane_type"], "decision");
355        assert_eq!(json["metadata"]["source"], "test");
356
357        let row: (String, String, String, i64, i64) = sqlx::query_as(
358            "SELECT category, memory_lane_type, metadata, is_active, is_archived FROM memories WHERE id = ?",
359        )
360        .bind(memory_id)
361        .fetch_one(&pool)
362        .await
363        .expect("fetch updated row");
364
365        assert_eq!(row.0, "facts");
366        assert_eq!(row.1, "decision");
367        assert_eq!(row.2, r#"{"source":"test"}"#);
368        assert_eq!(row.3, 1);
369        assert_eq!(row.4, 0);
370    }
371}