Skip to main content

mii_memory/
explorer.rs

1use std::collections::HashMap;
2use std::io::{BufRead, BufReader, Write};
3use std::net::{TcpListener, TcpStream};
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6use std::sync::Arc;
7use std::thread;
8use std::time::Duration;
9
10use anyhow::{Context, Result, bail};
11use serde::Serialize;
12
13use crate::model::MemoryMode;
14use crate::store::{BrowseOptions, MemoryStore, StoreSignature};
15
16const INDEX_HTML: &str = include_str!("explorer/index.html");
17const POLL_INTERVAL: Duration = Duration::from_millis(750);
18const MAX_REQUEST_BYTES: usize = 16 * 1024;
19const READ_TIMEOUT: Duration = Duration::from_secs(15);
20
21pub fn serve(database_path: PathBuf, host: &str, port: u16) -> Result<()> {
22    let bind = format!("{host}:{port}");
23    let listener =
24        TcpListener::bind(&bind).with_context(|| format!("failed to bind explorer to {bind}"))?;
25    let address = listener.local_addr()?;
26    eprintln!("mii-memory explorer listening on http://{address}");
27    serve_with_listener(listener, database_path)
28}
29
30pub fn serve_with_listener(listener: TcpListener, database_path: PathBuf) -> Result<()> {
31    let database_path = Arc::new(database_path);
32    for stream in listener.incoming() {
33        let stream = match stream {
34            Ok(stream) => stream,
35            Err(error) => {
36                eprintln!("explorer accept failed: {error}");
37                continue;
38            }
39        };
40
41        let database_path = Arc::clone(&database_path);
42        thread::spawn(move || {
43            if let Err(error) = handle_connection(stream, &database_path) {
44                eprintln!("explorer connection error: {error}");
45            }
46        });
47    }
48
49    Ok(())
50}
51
52#[derive(Debug)]
53struct Request {
54    method: String,
55    path: String,
56    query: HashMap<String, Vec<String>>,
57}
58
59fn handle_connection(mut stream: TcpStream, database_path: &Path) -> Result<()> {
60    stream.set_read_timeout(Some(READ_TIMEOUT))?;
61    stream.set_write_timeout(Some(READ_TIMEOUT))?;
62
63    let request = match read_request(&mut stream) {
64        Ok(request) => request,
65        Err(error) => {
66            write_response(
67                &mut stream,
68                400,
69                "text/plain; charset=utf-8",
70                error.to_string().as_bytes(),
71            )?;
72            return Ok(());
73        }
74    };
75
76    match (request.method.as_str(), request.path.as_str()) {
77        ("GET", "/") | ("GET", "/index.html") => write_response(
78            &mut stream,
79            200,
80            "text/html; charset=utf-8",
81            INDEX_HTML.as_bytes(),
82        ),
83        ("GET", "/api/memories") => serve_memories(&mut stream, database_path, &request.query),
84        ("GET", "/api/tags") => serve_tags(&mut stream, database_path, &request.query),
85        ("GET", "/api/events") => serve_events(stream, database_path),
86        ("GET", _) => write_response(&mut stream, 404, "text/plain; charset=utf-8", b"not found"),
87        _ => write_response(
88            &mut stream,
89            405,
90            "text/plain; charset=utf-8",
91            b"method not allowed",
92        ),
93    }
94}
95
96fn read_request(stream: &mut TcpStream) -> Result<Request> {
97    let mut reader = BufReader::new(stream);
98    let mut request_line = String::new();
99    let read = reader.read_line(&mut request_line)?;
100    if read == 0 {
101        bail!("empty request");
102    }
103
104    let mut parts = request_line.split_whitespace();
105    let method = parts.next().context("missing HTTP method")?.to_string();
106    let target = parts.next().context("missing HTTP target")?.to_string();
107
108    let mut total = read;
109    loop {
110        let mut line = String::new();
111        let read = reader.read_line(&mut line)?;
112        total += read;
113        if total > MAX_REQUEST_BYTES {
114            bail!("request headers too large");
115        }
116        if read == 0 || line == "\r\n" || line == "\n" {
117            break;
118        }
119    }
120
121    let (path, query) = parse_target(&target);
122    Ok(Request {
123        method,
124        path,
125        query,
126    })
127}
128
129fn parse_target(target: &str) -> (String, HashMap<String, Vec<String>>) {
130    let mut parts = target.splitn(2, '?');
131    let path = parts.next().unwrap_or("/").to_string();
132    let mut query: HashMap<String, Vec<String>> = HashMap::new();
133
134    if let Some(raw_query) = parts.next() {
135        for pair in raw_query.split('&').filter(|pair| !pair.is_empty()) {
136            let mut kv = pair.splitn(2, '=');
137            let key = decode_form(kv.next().unwrap_or(""));
138            let value = decode_form(kv.next().unwrap_or(""));
139            query.entry(key).or_default().push(value);
140        }
141    }
142
143    (path, query)
144}
145
146fn decode_form(value: &str) -> String {
147    let bytes = value.as_bytes();
148    let mut output = Vec::with_capacity(bytes.len());
149    let mut index = 0;
150    while index < bytes.len() {
151        match bytes[index] {
152            b'+' => {
153                output.push(b' ');
154                index += 1;
155            }
156            b'%' if index + 2 < bytes.len() => {
157                let hex = std::str::from_utf8(&bytes[index + 1..index + 3]).unwrap_or("");
158                match u8::from_str_radix(hex, 16) {
159                    Ok(decoded) => output.push(decoded),
160                    Err(_) => output.push(bytes[index]),
161                }
162                index += 3;
163            }
164            other => {
165                output.push(other);
166                index += 1;
167            }
168        }
169    }
170
171    String::from_utf8_lossy(&output).into_owned()
172}
173
174#[derive(Debug, Serialize)]
175struct MemoriesResponse {
176    memories: Vec<crate::store::MemoryEntry>,
177    signature: StoreSignature,
178}
179
180#[derive(Debug, Serialize)]
181struct TagsResponse {
182    tags: Vec<crate::store::TagSummary>,
183    signature: StoreSignature,
184}
185
186fn serve_memories(
187    stream: &mut TcpStream,
188    database_path: &Path,
189    query: &HashMap<String, Vec<String>>,
190) -> Result<()> {
191    let store = MemoryStore::open(database_path)?;
192    let text = query.get("text").and_then(|values| values.first().cloned());
193    let mode = query
194        .get("mode")
195        .and_then(|values| values.first())
196        .map(|value| MemoryMode::from_str(value))
197        .transpose()
198        .ok()
199        .flatten();
200    let tags = query.get("tag").cloned().unwrap_or_default();
201    let limit = query
202        .get("limit")
203        .and_then(|values| values.first())
204        .and_then(|value| value.parse::<usize>().ok())
205        .unwrap_or(100);
206    let offset = query
207        .get("offset")
208        .and_then(|values| values.first())
209        .and_then(|value| value.parse::<usize>().ok())
210        .unwrap_or(0);
211
212    let memories = store.browse(BrowseOptions {
213        text,
214        tags,
215        mode,
216        limit,
217        offset,
218    })?;
219    let signature = store.signature()?;
220    let response = MemoriesResponse {
221        memories,
222        signature,
223    };
224    let body = serde_json::to_vec(&response)?;
225    write_response(stream, 200, "application/json", &body)
226}
227
228fn serve_tags(
229    stream: &mut TcpStream,
230    database_path: &Path,
231    query: &HashMap<String, Vec<String>>,
232) -> Result<()> {
233    let store = MemoryStore::open(database_path)?;
234    let filter = query
235        .get("filter")
236        .and_then(|values| values.first())
237        .cloned();
238    let tags = store.list_tags(filter.as_deref())?;
239    let signature = store.signature()?;
240    let response = TagsResponse { tags, signature };
241    let body = serde_json::to_vec(&response)?;
242    write_response(stream, 200, "application/json", &body)
243}
244
245fn serve_events(mut stream: TcpStream, database_path: &Path) -> Result<()> {
246    let headers = "HTTP/1.1 200 OK\r\n\
247        Content-Type: text/event-stream\r\n\
248        Cache-Control: no-cache\r\n\
249        Connection: keep-alive\r\n\
250        X-Accel-Buffering: no\r\n\r\n";
251    stream.write_all(headers.as_bytes())?;
252    stream.write_all(b"event: ready\ndata: {}\n\n")?;
253    stream.flush()?;
254
255    let store = MemoryStore::open(database_path)?;
256    let mut signature = store.signature().ok();
257
258    loop {
259        thread::sleep(POLL_INTERVAL);
260        let current = match store.signature() {
261            Ok(value) => Some(value),
262            Err(_) => continue,
263        };
264        if current != signature {
265            signature = current.clone();
266            let payload = serde_json::to_string(&current)?;
267            if stream
268                .write_all(format!("event: update\ndata: {payload}\n\n").as_bytes())
269                .is_err()
270            {
271                break;
272            }
273            if stream.flush().is_err() {
274                break;
275            }
276        } else if stream.write_all(b": keep-alive\n\n").is_err() || stream.flush().is_err() {
277            break;
278        }
279    }
280
281    Ok(())
282}
283
284fn write_response(
285    stream: &mut TcpStream,
286    status: u16,
287    content_type: &str,
288    body: &[u8],
289) -> Result<()> {
290    let reason = match status {
291        200 => "OK",
292        400 => "Bad Request",
293        404 => "Not Found",
294        405 => "Method Not Allowed",
295        _ => "OK",
296    };
297    let header = format!(
298        "HTTP/1.1 {status} {reason}\r\n\
299         Content-Type: {content_type}\r\n\
300         Content-Length: {len}\r\n\
301         Cache-Control: no-store\r\n\
302         Connection: close\r\n\r\n",
303        len = body.len()
304    );
305    stream.write_all(header.as_bytes())?;
306    stream.write_all(body)?;
307    stream.flush()?;
308    Ok(())
309}
310
311#[cfg(all(test, has_embedded_embeddings))]
312mod tests {
313    use super::*;
314    use std::io::{Read, Write};
315    use std::net::TcpStream;
316    use std::time::Instant;
317
318    use crate::store::SetMemory;
319    use tempfile::tempdir;
320
321    fn read_http_response(stream: &mut TcpStream) -> (u16, String, Vec<u8>) {
322        let mut buffer = Vec::new();
323        let mut chunk = [0_u8; 4096];
324        let started = Instant::now();
325        loop {
326            match stream.read(&mut chunk) {
327                Ok(0) => break,
328                Ok(read) => buffer.extend_from_slice(&chunk[..read]),
329                Err(_) => break,
330            }
331            if started.elapsed() > Duration::from_secs(5) {
332                break;
333            }
334        }
335        let text = String::from_utf8_lossy(&buffer).into_owned();
336        let mut split = text.splitn(2, "\r\n\r\n");
337        let head = split.next().unwrap_or("").to_string();
338        let body = split.next().unwrap_or("").as_bytes().to_vec();
339        let status = head
340            .lines()
341            .next()
342            .and_then(|line| line.split_whitespace().nth(1))
343            .and_then(|code| code.parse::<u16>().ok())
344            .unwrap_or(0);
345        (status, head, body)
346    }
347
348    fn spawn_explorer(database_path: PathBuf) -> std::net::SocketAddr {
349        let listener = TcpListener::bind("127.0.0.1:0").expect("bind explorer");
350        let address = listener.local_addr().expect("local address");
351        thread::spawn(move || {
352            let _ = serve_with_listener(listener, database_path);
353        });
354        address
355    }
356
357    fn http_get(address: std::net::SocketAddr, path: &str) -> (u16, String, Vec<u8>) {
358        let mut stream = TcpStream::connect(address).expect("connect");
359        stream
360            .set_read_timeout(Some(Duration::from_secs(5)))
361            .expect("set timeout");
362        let request =
363            format!("GET {path} HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n");
364        stream.write_all(request.as_bytes()).expect("write");
365        read_http_response(&mut stream)
366    }
367
368    #[test]
369    fn explorer_serves_index_and_api() -> Result<()> {
370        let directory = tempdir()?;
371        let database_path = directory.path().join("explorer.db");
372        {
373            let mut store = MemoryStore::open(&database_path)?;
374            store.set(SetMemory {
375                content: "Explorer ready".to_string(),
376                mode: MemoryMode::Global,
377                mode_ref: None,
378                tags: vec!["explorer".to_string()],
379                expiration_condition: None,
380                expiration_value: None,
381                metadata: Some("{\"note\":\"hi\"}".to_string()),
382            })?;
383        }
384
385        let address = spawn_explorer(database_path.clone());
386
387        let (status, _, body) = http_get(address, "/");
388        assert_eq!(status, 200);
389        assert!(String::from_utf8_lossy(&body).contains("mii-memory explorer"));
390
391        let (status, _, body) = http_get(address, "/api/memories");
392        assert_eq!(status, 200);
393        let value: serde_json::Value = serde_json::from_slice(&body)?;
394        assert_eq!(value["memories"][0]["content"], "Explorer ready");
395        assert_eq!(value["memories"][0]["tags"][0], "explorer");
396
397        let (status, _, body) = http_get(address, "/api/tags");
398        assert_eq!(status, 200);
399        let value: serde_json::Value = serde_json::from_slice(&body)?;
400        assert_eq!(value["tags"][0]["tag"], "explorer");
401
402        let (status, _, body) = http_get(address, "/api/memories?tag=missing");
403        assert_eq!(status, 200);
404        let value: serde_json::Value = serde_json::from_slice(&body)?;
405        assert_eq!(value["memories"].as_array().unwrap().len(), 0);
406
407        let (status, _, _) = http_get(address, "/nope");
408        assert_eq!(status, 404);
409
410        Ok(())
411    }
412
413    #[test]
414    fn explorer_events_emit_updates_when_memories_change() -> Result<()> {
415        let directory = tempdir()?;
416        let database_path = directory.path().join("events.db");
417        {
418            let _ = MemoryStore::open(&database_path)?;
419        }
420
421        let address = spawn_explorer(database_path.clone());
422
423        let mut stream = TcpStream::connect(address).expect("connect events");
424        stream.set_read_timeout(Some(Duration::from_secs(5)))?;
425        stream.write_all(
426            b"GET /api/events HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: keep-alive\r\n\r\n",
427        )?;
428
429        // Read until we see the ready event.
430        let mut buffer = Vec::new();
431        let mut chunk = [0_u8; 1024];
432        let deadline = Instant::now() + Duration::from_secs(5);
433        while !String::from_utf8_lossy(&buffer).contains("event: ready") {
434            if Instant::now() > deadline {
435                panic!("did not receive ready event");
436            }
437            match stream.read(&mut chunk) {
438                Ok(0) => break,
439                Ok(read) => buffer.extend_from_slice(&chunk[..read]),
440                Err(_) => break,
441            }
442        }
443
444        // Mutate the DB from a different process-like connection.
445        {
446            let mut store = MemoryStore::open(&database_path)?;
447            store.set(SetMemory {
448                content: "live update".to_string(),
449                mode: MemoryMode::Global,
450                mode_ref: None,
451                tags: vec!["live".to_string()],
452                expiration_condition: None,
453                expiration_value: None,
454                metadata: None,
455            })?;
456        }
457
458        let deadline = Instant::now() + Duration::from_secs(5);
459        while !String::from_utf8_lossy(&buffer).contains("event: update") {
460            if Instant::now() > deadline {
461                panic!(
462                    "did not receive update event; got: {}",
463                    String::from_utf8_lossy(&buffer)
464                );
465            }
466            match stream.read(&mut chunk) {
467                Ok(0) => break,
468                Ok(read) => buffer.extend_from_slice(&chunk[..read]),
469                Err(_) => break,
470            }
471        }
472
473        Ok(())
474    }
475}