Skip to main content

nexus_memory_web/
state.rs

1//! Application state for the web dashboard
2
3use crate::error::Result;
4use crate::WebError;
5use nexus_agent::AgentSupervisor;
6use nexus_core::traits::EmbeddingService;
7use nexus_orchestrator::{Event, EventType, Orchestrator};
8use nexus_storage::{MemoryRepository, NamespaceRepository, StorageManager};
9use sqlx::SqlitePool;
10use std::sync::Arc;
11use tokio::sync::{broadcast, RwLock};
12use tracing::{error, info};
13
14/// Shared application state
15pub struct AppState {
16    /// Storage manager for database operations
17    pub storage: StorageManager,
18    /// Orchestrator for event handling
19    pub orchestrator: Orchestrator,
20    /// Memory repository
21    pub memory_repo: MemoryRepository,
22    /// Namespace repository
23    pub namespace_repo: NamespaceRepository,
24    /// WebSocket broadcaster
25    pub ws_sender: broadcast::Sender<crate::models::WebSocketMessage>,
26    /// Server start time for uptime calculation
27    pub start_time: std::time::Instant,
28    /// Optional agent supervisor (set when --agent flag is used)
29    pub agent_supervisor: Option<AgentSupervisor>,
30}
31
32impl AppState {
33    /// Create a new application state
34    pub async fn new(storage: StorageManager, orchestrator: Orchestrator) -> Result<Self> {
35        let pool = storage.pool().clone();
36        let memory_repo = MemoryRepository::new(pool.clone());
37        let namespace_repo = NamespaceRepository::new(pool.clone());
38
39        // Create WebSocket broadcast channel
40        let (ws_sender, _) = broadcast::channel(1000);
41
42        // Initialize agent supervisor if enabled
43        let agent_supervisor = match Self::create_agent_supervisor(&pool, &namespace_repo).await {
44            Ok(Some(supervisor)) => {
45                info!("Agent supervisor initialized");
46                Some(supervisor)
47            }
48            Ok(None) => None,
49            Err(e) => {
50                error!("Failed to initialize agent supervisor: {}", e);
51                None
52            }
53        };
54
55        let state = Self {
56            storage,
57            orchestrator,
58            memory_repo,
59            namespace_repo,
60            ws_sender,
61            start_time: std::time::Instant::now(),
62            agent_supervisor,
63        };
64
65        // Start event forwarding from orchestrator to WebSocket
66        state.start_event_forwarding().await?;
67
68        Ok(state)
69    }
70
71    /// Start forwarding events from orchestrator to WebSocket clients
72    async fn start_event_forwarding(&self) -> Result<()> {
73        let mut rx = self.orchestrator.subscribe_events();
74        let ws_sender = self.ws_sender.clone();
75
76        tokio::spawn(async move {
77            loop {
78                match rx.recv().await {
79                    Ok(event) => {
80                        if let Some(msg) = Self::convert_event_to_ws_message(&event) {
81                            let _ = ws_sender.send(msg);
82                        }
83                    }
84                    Err(e) => {
85                        error!("Event receive error: {}", e);
86                        break;
87                    }
88                }
89            }
90        });
91
92        Ok(())
93    }
94
95    /// Convert orchestrator event to WebSocket message
96    fn convert_event_to_ws_message(event: &Event) -> Option<crate::models::WebSocketMessage> {
97        use crate::models::{WebSocketMessage, WebSocketMessageType};
98
99        match event.event_type {
100            EventType::MemoryStored => {
101                let memory_id = event.get::<i64>("memory_id").unwrap_or(0);
102                let agent_type = event.get::<String>("agent_type").unwrap_or_default();
103                // Note: We can't construct full MemoryResponse here without DB lookup
104                // The client may need to fetch the full memory
105                let data = serde_json::json!({
106                    "memory_id": memory_id,
107                    "agent_type": agent_type,
108                });
109                Some(WebSocketMessage::new(
110                    WebSocketMessageType::MemoryStored,
111                    data,
112                ))
113            }
114            EventType::MemoryUpdated => {
115                let memory_id = event.get::<i64>("memory_id").unwrap_or(0);
116                Some(WebSocketMessage::memory_updated(memory_id))
117            }
118            EventType::MemoryDeleted => {
119                let memory_id = event.get::<i64>("memory_id").unwrap_or(0);
120                Some(WebSocketMessage::memory_deleted(memory_id))
121            }
122            EventType::SessionStarted => {
123                let session_id = event.get::<String>("session_id").unwrap_or_default();
124                let data = serde_json::json!({
125                    "session_id": session_id,
126                });
127                Some(WebSocketMessage::new(
128                    WebSocketMessageType::SessionStarted,
129                    data,
130                ))
131            }
132            EventType::SessionEnded => {
133                let session_id = event.get::<String>("session_id").unwrap_or_default();
134                let data = serde_json::json!({
135                    "session_id": session_id,
136                });
137                Some(WebSocketMessage::new(
138                    WebSocketMessageType::SessionEnded,
139                    data,
140                ))
141            }
142            _ => None,
143        }
144    }
145
146    /// Get the database pool
147    pub fn pool(&self) -> &SqlitePool {
148        self.storage.pool()
149    }
150
151    /// Create an agent supervisor if agent mode is enabled in the config.
152    async fn create_agent_supervisor(
153        pool: &SqlitePool,
154        namespace_repo: &NamespaceRepository,
155    ) -> Result<Option<AgentSupervisor>> {
156        let config = nexus_core::Config::from_env().map_err(|e| WebError::Config(e.to_string()))?;
157
158        if !config.agent.enabled {
159            return Ok(None);
160        }
161
162        let llm = nexus_llm::create_client_auto_with_fallback()
163            .map_err(|e| WebError::Config(format!("Failed to create LLM client: {}", e)))?;
164
165        let query_embedder: Option<Arc<dyn EmbeddingService>> =
166            match nexus_embeddings::create_service(&config).await {
167                Ok(service) => service,
168                Err(error) => {
169                    error!("Failed to initialize query embedding service: {}", error);
170                    None
171                }
172            };
173
174        // Ensure the agent namespace exists
175        let namespace = namespace_repo
176            .get_or_create(&config.agent.namespace, "nexus-agent")
177            .await
178            .map_err(|e| WebError::Storage(e.to_string()))?;
179
180        let mut supervisor = AgentSupervisor::new(config.agent, llm, pool.clone(), namespace.id);
181        if let Some(embedder) = query_embedder {
182            supervisor = supervisor.with_query_embedder(embedder);
183        }
184        supervisor
185            .start()
186            .await
187            .map_err(|e| WebError::Config(format!("Failed to start agent supervisor: {}", e)))?;
188
189        Ok(Some(supervisor))
190    }
191
192    /// Get uptime in seconds
193    pub fn uptime_seconds(&self) -> u64 {
194        self.start_time.elapsed().as_secs()
195    }
196
197    /// Subscribe to WebSocket messages
198    pub fn subscribe_ws(&self) -> broadcast::Receiver<crate::models::WebSocketMessage> {
199        self.ws_sender.subscribe()
200    }
201
202    /// Broadcast a message to all WebSocket clients
203    pub fn broadcast_ws(&self, msg: crate::models::WebSocketMessage) -> Result<()> {
204        let _ = self.ws_sender.send(msg);
205        Ok(())
206    }
207}
208
209/// Shared state type alias
210pub type SharedState = Arc<RwLock<AppState>>;