1use std::net::SocketAddr;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use anyhow::{Context, Result};
12use axum::{
13 extract::{
14 ws::{Message, WebSocket, WebSocketUpgrade},
15 Path, State,
16 },
17 http::StatusCode,
18 response::{IntoResponse, Response},
19 routing::{get, post},
20 Json, Router,
21};
22use futures_util::{SinkExt, StreamExt};
23use rust_embed::RustEmbed;
24use serde::{Deserialize, Serialize};
25use tokio::sync::{broadcast, Mutex};
26
27use memvid_core::Memvid;
28
29#[derive(RustEmbed)]
31#[folder = "web/dist"]
32struct WebAssets;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct RetrievalConfig {
37 pub mode: String,
38 pub k: usize,
39 pub adaptive: bool,
40 #[serde(rename = "adaptiveStrategy")]
41 pub adaptive_strategy: Option<String>,
42 #[serde(rename = "minRelevancy")]
43 pub min_relevancy: Option<f32>,
44 #[serde(rename = "maxK")]
45 pub max_k: Option<usize>,
46}
47
48impl Default for RetrievalConfig {
49 fn default() -> Self {
50 Self {
51 mode: "hybrid".to_string(),
52 k: 10,
53 adaptive: false,
54 adaptive_strategy: None,
55 min_relevancy: Some(0.5),
56 max_k: Some(20),
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct DocumentHit {
64 #[serde(rename = "frameId")]
65 pub frame_id: u64,
66 pub title: String,
67 pub snippet: String,
68 pub score: f64,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct SearchResults {
74 pub hits: Vec<DocumentHit>,
75 #[serde(rename = "totalHits")]
76 pub total_hits: usize,
77 #[serde(rename = "filteredCount")]
78 pub filtered_count: usize,
79 #[serde(rename = "elapsedMs")]
80 pub elapsed_ms: u64,
81 pub engine: String,
82 #[serde(rename = "cliffIndex")]
83 pub cliff_index: Option<usize>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct QueryRecord {
89 pub id: String,
90 pub timestamp: i64,
91 pub text: String,
92 pub config: RetrievalConfig,
93 pub results: SearchResults,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct SessionData {
99 pub id: String,
100 pub name: Option<String>,
101 #[serde(rename = "createdAt")]
102 pub created_at: i64,
103 #[serde(rename = "endedAt")]
104 pub ended_at: Option<i64>,
105 pub queries: Vec<QueryRecord>,
106 #[serde(rename = "originalConfig")]
107 pub original_config: RetrievalConfig,
108 #[serde(rename = "mv2Path")]
109 pub mv2_path: String,
110}
111
112#[derive(Debug, Deserialize)]
114pub struct ReplayRequest {
115 #[serde(rename = "queryId")]
116 pub query_id: String,
117 pub config: RetrievalConfig,
118}
119
120#[derive(Debug, Deserialize)]
122#[serde(tag = "type")]
123pub enum WsClientMessage {
124 #[serde(rename = "select_query")]
125 SelectQuery { query: String },
126 #[serde(rename = "config_change")]
127 ConfigChange { config: RetrievalConfig },
128 #[serde(rename = "start_optimize")]
129 StartOptimize,
130}
131
132#[derive(Debug, Clone, Serialize)]
134#[serde(tag = "type")]
135pub enum WsServerMessage {
136 #[serde(rename = "results")]
137 Results { data: SearchResults },
138 #[serde(rename = "optimize_progress")]
139 OptimizeProgress { data: OptimizeProgress },
140 #[serde(rename = "optimize_complete")]
141 OptimizeComplete { data: OptimizeResult },
142 #[serde(rename = "error")]
143 Error { message: String },
144}
145
146#[derive(Debug, Clone, Serialize)]
148pub struct OptimizeProgress {
149 pub progress: f32,
150 #[serde(rename = "configsTested")]
151 pub configs_tested: usize,
152 #[serde(rename = "totalConfigs")]
153 pub total_configs: usize,
154}
155
156#[derive(Debug, Clone, Serialize)]
158pub struct OptimizeResult {
159 #[serde(rename = "recommendedConfig")]
160 pub recommended_config: RetrievalConfig,
161 pub score: f32,
162 pub coverage: f32,
163 #[serde(rename = "tokenReduction")]
164 pub token_reduction: f32,
165 pub explanation: String,
166}
167
168pub struct AppState {
170 pub session_id: String,
171 pub mv2_path: PathBuf,
172 pub memvid: Mutex<Memvid>,
173 pub broadcast_tx: broadcast::Sender<WsServerMessage>,
174 pub current_query: Mutex<String>,
175 pub current_config: Mutex<RetrievalConfig>,
176}
177
178pub async fn start_web_server(
180 session_id: String,
181 mv2_path: PathBuf,
182 port: u16,
183 open_browser: bool,
184) -> Result<()> {
185 let memvid = Memvid::open_read_only(&mv2_path)
187 .context("Failed to open memory file")?;
188
189 let (broadcast_tx, _) = broadcast::channel::<WsServerMessage>(100);
191
192 let state = Arc::new(AppState {
193 session_id: session_id.clone(),
194 mv2_path: mv2_path.clone(),
195 memvid: Mutex::new(memvid),
196 broadcast_tx,
197 current_query: Mutex::new(String::new()),
198 current_config: Mutex::new(RetrievalConfig::default()),
199 });
200
201 let app = Router::new()
203 .route("/api/session/:id", get(get_session))
205 .route("/api/session/:id/replay", post(replay_query))
206 .route("/api/session/:id/timeline", get(get_timeline))
207 .route("/ws/session/:id", get(ws_handler))
209 .fallback(static_handler)
211 .with_state(state);
212
213 let addr = SocketAddr::from(([127, 0, 0, 1], port));
214 let url = format!("http://localhost:{}", port);
215
216 println!();
217 println!(" ╭──────────────────────────────────────────────────────╮");
218 println!(" │ │");
219 println!(" │ 🕰️ Memvid Time Machine │");
220 println!(" │ │");
221 println!(" │ Session: {:40} │", &session_id[..std::cmp::min(40, session_id.len())]);
222 println!(" │ URL: {:<40} │", url);
223 println!(" │ │");
224 println!(" │ Press Ctrl+C to stop the server │");
225 println!(" │ │");
226 println!(" ╰──────────────────────────────────────────────────────╯");
227 println!();
228
229 if open_browser {
231 let url_with_session = format!("{}?session={}", url, session_id);
232 if let Err(e) = open::that(&url_with_session) {
233 eprintln!("Failed to open browser: {}", e);
234 println!("Please open {} manually", url_with_session);
235 }
236 }
237
238 let listener = tokio::net::TcpListener::bind(addr).await?;
240 axum::serve(listener, app).await?;
241
242 Ok(())
243}
244
245async fn static_handler(uri: axum::http::Uri) -> impl IntoResponse {
247 let path = uri.path().trim_start_matches('/');
248
249 let path = if path.is_empty() || !path.contains('.') {
251 "index.html"
252 } else {
253 path
254 };
255
256 match WebAssets::get(path) {
257 Some(content) => {
258 let mime = mime_guess::from_path(path)
259 .first_or_octet_stream()
260 .to_string();
261
262 Response::builder()
263 .status(StatusCode::OK)
264 .header("Content-Type", mime)
265 .body(axum::body::Body::from(content.data.into_owned()))
266 .unwrap()
267 }
268 None => {
269 match WebAssets::get("index.html") {
271 Some(content) => Response::builder()
272 .status(StatusCode::OK)
273 .header("Content-Type", "text/html")
274 .body(axum::body::Body::from(content.data.into_owned()))
275 .unwrap(),
276 None => Response::builder()
277 .status(StatusCode::NOT_FOUND)
278 .body(axum::body::Body::from("Not Found"))
279 .unwrap(),
280 }
281 }
282 }
283}
284
285async fn get_session(
287 Path(session_id): Path<String>,
288 State(state): State<Arc<AppState>>,
289) -> Result<Json<SessionData>, (StatusCode, String)> {
290 tracing::info!("get_session called with session_id: {}", session_id);
291 let mut memvid = state.memvid.lock().await;
292
293 tracing::info!("Loading replay sessions...");
295 memvid.load_replay_sessions().map_err(|e| {
296 tracing::error!("Failed to load replay sessions: {}", e);
297 (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to load replay sessions: {}", e))
298 })?;
299
300 let uuid = session_id.parse::<uuid::Uuid>().map_err(|e| {
301 tracing::error!("Invalid session ID '{}': {}", session_id, e);
302 (StatusCode::BAD_REQUEST, format!("Invalid session ID: {}", e))
303 })?;
304 tracing::info!("Looking for session with UUID: {}", uuid);
305
306 let session = memvid.get_session(uuid).ok_or_else(|| {
307 tracing::error!("Session {} not found", uuid);
308 (StatusCode::NOT_FOUND, format!("Session {} not found", uuid))
309 })?;
310 tracing::info!("Found session '{}' with {} actions", session.name.as_deref().unwrap_or("unnamed"), session.actions.len());
311
312 let queries: Vec<QueryRecord> = session
314 .actions
315 .iter()
316 .filter_map(|action| {
317 match &action.action_type {
318 memvid_core::replay::ActionType::Find { query, mode, result_count } => {
320 let result_frames = &action.affected_frames;
321 Some(QueryRecord {
322 id: format!("{}", action.sequence),
323 timestamp: action.timestamp_secs,
324 text: query.clone(),
325 config: RetrievalConfig {
326 mode: mode.clone(),
327 k: *result_count,
328 adaptive: false,
329 adaptive_strategy: None,
330 min_relevancy: Some(0.5),
331 max_k: None,
332 },
333 results: SearchResults {
334 hits: result_frames
335 .iter()
336 .enumerate()
337 .map(|(i, &frame_id)| DocumentHit {
338 frame_id,
339 title: format!("Document {}", frame_id),
340 snippet: "...".to_string(),
341 score: 1.0 - (i as f64 * 0.1),
342 })
343 .collect(),
344 total_hits: result_frames.len() * 2,
345 filtered_count: *result_count,
346 elapsed_ms: 10,
347 engine: mode.clone(),
348 cliff_index: Some(result_frames.len().min(5)),
349 },
350 })
351 }
352 memvid_core::replay::ActionType::Ask { query, provider: _, model: _ } => {
354 let result_frames = &action.affected_frames;
355 Some(QueryRecord {
356 id: format!("{}", action.sequence),
357 timestamp: action.timestamp_secs,
358 text: query.clone(),
359 config: RetrievalConfig {
360 mode: "sem".to_string(), k: 5, adaptive: false,
363 adaptive_strategy: None,
364 min_relevancy: Some(0.5),
365 max_k: None,
366 },
367 results: SearchResults {
368 hits: result_frames
369 .iter()
370 .enumerate()
371 .map(|(i, &frame_id)| DocumentHit {
372 frame_id,
373 title: format!("Document {}", frame_id),
374 snippet: action.output_preview.clone(),
375 score: 1.0 - (i as f64 * 0.1),
376 })
377 .collect(),
378 total_hits: result_frames.len(),
379 filtered_count: result_frames.len(),
380 elapsed_ms: action.duration_ms,
381 engine: "semantic".to_string(),
382 cliff_index: None,
383 },
384 })
385 }
386 _ => None,
387 }
388 })
389 .collect();
390
391 let data = SessionData {
392 id: session.session_id.to_string(),
393 name: session.name.clone(),
394 created_at: session.created_secs,
395 ended_at: session.ended_secs,
396 queries,
397 original_config: RetrievalConfig {
398 mode: "sem".to_string(),
399 k: 5,
400 adaptive: false,
401 adaptive_strategy: None,
402 min_relevancy: Some(0.5),
403 max_k: None,
404 },
405 mv2_path: state.mv2_path.display().to_string(),
406 };
407
408 Ok(Json(data))
409}
410
411async fn get_timeline(
413 Path(session_id): Path<String>,
414 State(state): State<Arc<AppState>>,
415) -> Result<Json<Vec<QueryRecord>>, (StatusCode, String)> {
416 let session_data = get_session(Path(session_id), State(state)).await?;
417 Ok(Json(session_data.queries.clone()))
418}
419
420async fn replay_query(
422 Path(_session_id): Path<String>,
423 State(state): State<Arc<AppState>>,
424 Json(request): Json<ReplayRequest>,
425) -> Result<Json<SearchResults>, (StatusCode, String)> {
426 let mut memvid = state.memvid.lock().await;
427
428 let search_request = memvid_core::types::SearchRequest {
430 query: request.query_id.clone(),
431 top_k: request.config.k,
432 snippet_chars: 240,
433 uri: None,
434 scope: None,
435 cursor: None,
436 temporal: None,
437 as_of_frame: None,
438 as_of_ts: None,
439 };
440
441 let response = memvid
443 .search(search_request)
444 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
445
446 let hit_count = response.hits.len();
448 let results = SearchResults {
449 hits: response
450 .hits
451 .into_iter()
452 .map(|hit| DocumentHit {
453 frame_id: hit.frame_id,
454 title: hit.title.unwrap_or_else(|| format!("Frame {}", hit.frame_id)),
455 snippet: hit.text,
456 score: hit.score.unwrap_or(0.0) as f64,
457 })
458 .collect(),
459 total_hits: response.total_hits,
460 filtered_count: hit_count,
461 elapsed_ms: response.elapsed_ms as u64,
462 engine: request.config.mode.clone(),
463 cliff_index: None,
464 };
465
466 Ok(Json(results))
467}
468
469async fn ws_handler(
471 ws: WebSocketUpgrade,
472 Path(_session_id): Path<String>,
473 State(state): State<Arc<AppState>>,
474) -> impl IntoResponse {
475 ws.on_upgrade(move |socket| handle_websocket(socket, state))
476}
477
478async fn handle_websocket(socket: WebSocket, state: Arc<AppState>) {
480 let (mut sender, mut receiver) = socket.split();
481
482 let mut rx = state.broadcast_tx.subscribe();
484
485 let send_task = tokio::spawn(async move {
487 while let Ok(msg) = rx.recv().await {
488 if let Ok(json) = serde_json::to_string(&msg) {
489 if sender.send(Message::Text(json)).await.is_err() {
490 break;
491 }
492 }
493 }
494 });
495
496 while let Some(Ok(msg)) = receiver.next().await {
498 if let Message::Text(text) = msg {
499 if let Ok(client_msg) = serde_json::from_str::<WsClientMessage>(&text) {
500 match client_msg {
501 WsClientMessage::SelectQuery { query } => {
502 {
504 let mut current = state.current_query.lock().await;
505 *current = query;
506 }
507 let config = state.current_config.lock().await.clone();
509 let results = execute_search_with_config(&state, &config).await;
510 match results {
511 Ok(results) => {
512 let _ = state.broadcast_tx.send(WsServerMessage::Results { data: results });
513 }
514 Err(e) => {
515 tracing::debug!("Initial search after query select failed: {}", e);
517 }
518 }
519 }
520 WsClientMessage::ConfigChange { config } => {
521 {
523 let mut current = state.current_config.lock().await;
524 *current = config.clone();
525 }
526 let results = execute_search_with_config(&state, &config).await;
528
529 match results {
530 Ok(results) => {
531 let _ = state.broadcast_tx.send(WsServerMessage::Results { data: results });
532 }
533 Err(e) => {
534 let _ = state.broadcast_tx.send(WsServerMessage::Error { message: e });
535 }
536 }
537 }
538 WsClientMessage::StartOptimize => {
539 let state_clone = state.clone();
541 tokio::spawn(async move {
542 run_optimization(state_clone).await;
543 });
544 }
545 }
546 }
547 }
548 }
549
550 send_task.abort();
551}
552
553async fn execute_search_with_config(
555 state: &Arc<AppState>,
556 config: &RetrievalConfig,
557) -> Result<SearchResults, String> {
558 let mut memvid = state.memvid.lock().await;
559 let query = state.current_query.lock().await.clone();
560
561 if query.is_empty() {
562 return Err("No query selected. Please select a query from the timeline first.".to_string());
563 }
564
565 let search_request = memvid_core::types::SearchRequest {
566 query,
567 top_k: config.k,
568 snippet_chars: 240,
569 uri: None,
570 scope: None,
571 cursor: None,
572 temporal: None,
573 as_of_frame: None,
574 as_of_ts: None,
575 };
576
577 let response = memvid
578 .search(search_request)
579 .map_err(|e| e.to_string())?;
580
581 let hit_count = response.hits.len();
582 Ok(SearchResults {
583 hits: response
584 .hits
585 .into_iter()
586 .map(|hit| DocumentHit {
587 frame_id: hit.frame_id,
588 title: hit.title.unwrap_or_else(|| format!("Frame {}", hit.frame_id)),
589 snippet: hit.text,
590 score: hit.score.unwrap_or(0.0) as f64,
591 })
592 .collect(),
593 total_hits: response.total_hits,
594 filtered_count: hit_count,
595 elapsed_ms: response.elapsed_ms as u64,
596 engine: config.mode.clone(),
597 cliff_index: None,
598 })
599}
600
601async fn run_optimization(state: Arc<AppState>) {
603 let total_configs = 150;
604
605 for i in 0..=total_configs {
606 let _ = state.broadcast_tx.send(WsServerMessage::OptimizeProgress {
608 data: OptimizeProgress {
609 progress: (i as f32 / total_configs as f32) * 100.0,
610 configs_tested: i,
611 total_configs,
612 },
613 });
614
615 tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
617 }
618
619 let _ = state.broadcast_tx.send(WsServerMessage::OptimizeComplete {
621 data: OptimizeResult {
622 recommended_config: RetrievalConfig {
623 mode: "sem".to_string(),
624 k: 12,
625 adaptive: true,
626 adaptive_strategy: Some("combined".to_string()),
627 min_relevancy: Some(0.5),
628 max_k: Some(100),
629 },
630 score: 0.942,
631 coverage: 0.942,
632 token_reduction: 0.47,
633 explanation: "This configuration retrieves 94.2% of relevant documents while filtering 47% of noise.".to_string(),
634 },
635 });
636}