Skip to main content

nostr_bbs_search_worker/
lib.rs

1//! nostr-bbs Search Worker (Rust)
2//!
3//! Cloudflare Workers-based vector search with:
4//! - In-memory cosine k-NN over 384-dim embeddings
5//! - RVF binary format persistence to R2
6//! - id↔label mapping in KV
7//! - NIP-98 authenticated ingest
8//! - Hash-based fallback embedding generation
9//!
10//! ## Architecture
11//!
12//! - `lib.rs`   -- HTTP router, CORS, entry point
13//! - `store.rs` -- In-memory vector store, RVF serialization
14//! - `embed.rs` -- Hash-based embedding generator
15//! - `auth.rs`  -- NIP-98 admin verification
16
17// Worker entry points are invoked via wasm-bindgen and appear unused in native builds.
18#![allow(dead_code)]
19
20mod auth;
21mod embed;
22mod store;
23
24use embed::DIM;
25use serde::Deserialize;
26use store::VectorStore;
27use worker::*;
28
29// ---------------------------------------------------------------------------
30// CORS
31// ---------------------------------------------------------------------------
32
33/// Build allowed origins list from `ALLOWED_ORIGINS` env var (comma-separated)
34/// or fall back to the production domain.
35fn allowed_origins(env: &Env) -> Vec<String> {
36    env.var("ALLOWED_ORIGINS")
37        .map(|v| v.to_string())
38        .unwrap_or_else(|_| "https://example.com".to_string())
39        .split(',')
40        .map(|s| s.trim().to_string())
41        .collect()
42}
43
44fn cors_origin(req: &Request, env: &Env) -> String {
45    let origins = allowed_origins(env);
46    let origin = req
47        .headers()
48        .get("Origin")
49        .ok()
50        .flatten()
51        .unwrap_or_default();
52    if origins.iter().any(|o| o == &origin) {
53        origin
54    } else {
55        origins
56            .into_iter()
57            .next()
58            .unwrap_or_else(|| "https://example.com".to_string())
59    }
60}
61
62fn cors_headers(req: &Request, env: &Env) -> Headers {
63    let headers = Headers::new();
64    headers
65        .set("Access-Control-Allow-Origin", &cors_origin(req, env))
66        .ok();
67    headers
68        .set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
69        .ok();
70    headers
71        .set(
72            "Access-Control-Allow-Headers",
73            "Content-Type, Authorization",
74        )
75        .ok();
76    headers.set("Access-Control-Max-Age", "86400").ok();
77    headers.set("Vary", "Origin").ok();
78    headers
79}
80
81fn json_response(
82    req: &Request,
83    env: &Env,
84    body: &serde_json::Value,
85    status: u16,
86) -> Result<Response> {
87    let json_str = serde_json::to_string(body).map_err(|e| Error::RustError(e.to_string()))?;
88    let headers = cors_headers(req, env);
89    headers.set("Content-Type", "application/json").ok();
90    Ok(Response::ok(json_str)?
91        .with_status(status)
92        .with_headers(headers))
93}
94
95// ---------------------------------------------------------------------------
96// Request/Response types
97// ---------------------------------------------------------------------------
98
99#[derive(Deserialize)]
100struct SearchRequest {
101    embedding: Vec<f32>,
102    #[serde(default = "default_k")]
103    k: usize,
104    #[serde(default, rename = "minScore")]
105    min_score: f32,
106}
107
108fn default_k() -> usize {
109    10
110}
111
112#[derive(Deserialize)]
113struct IngestEntry {
114    id: String,
115    embedding: Vec<f32>,
116}
117
118#[derive(Deserialize)]
119struct IngestRequest {
120    entries: Vec<IngestEntry>,
121}
122
123#[derive(Deserialize)]
124struct EmbedRequest {
125    text: Option<String>,
126    texts: Option<Vec<String>>,
127}
128
129// ---------------------------------------------------------------------------
130// Store lifecycle (R2 + KV)
131// ---------------------------------------------------------------------------
132
133/// Load the vector store from R2, or create empty if none exists.
134async fn load_store(env: &Env) -> Result<VectorStore> {
135    let store_key = env
136        .var("RVF_STORE_KEY")
137        .map(|v| v.to_string())
138        .unwrap_or_else(|_| "nostr-bbs.rvf".to_string());
139
140    let bucket = env.bucket("VECTORS")?;
141    let obj = bucket.get(&store_key).execute().await?;
142
143    if let Some(obj) = obj {
144        // Sprint v9 D5: never panic on a missing body. R2 can in principle
145        // return an object with no body (e.g. zero-length write race or
146        // bucket inconsistency); surface a typed worker::Error instead so
147        // the caller returns a 5xx rather than crashing the isolate.
148        let body = obj
149            .body()
150            .ok_or_else(|| worker::Error::RustError("R2 object missing body".into()))?;
151        let bytes = body.bytes().await?;
152        if let Some(store) = VectorStore::from_rvf_bytes(&bytes) {
153            return Ok(store);
154        }
155    }
156
157    Ok(VectorStore::new())
158}
159
160/// Persist the vector store to R2 as RVF binary + mapping to KV.
161async fn persist_store(
162    store: &VectorStore,
163    id_to_label: &std::collections::HashMap<String, u64>,
164    next_label: u64,
165    env: &Env,
166) -> Result<()> {
167    let store_key = env
168        .var("RVF_STORE_KEY")
169        .map(|v| v.to_string())
170        .unwrap_or_else(|_| "nostr-bbs.rvf".to_string());
171
172    // Persist RVF bytes to R2
173    let rvf_bytes = store.to_rvf_bytes();
174    let bucket = env.bucket("VECTORS")?;
175    bucket.put(&store_key, rvf_bytes).execute().await?;
176
177    // Persist id↔label mapping to KV
178    let kv = env.kv("SEARCH_CONFIG")?;
179    let pairs: Vec<(&str, u64)> = id_to_label.iter().map(|(k, v)| (k.as_str(), *v)).collect();
180    let mapping = serde_json::json!({
181        "pairs": pairs,
182        "next": next_label,
183    });
184    kv.put(
185        &format!("{store_key}:mapping"),
186        serde_json::to_string(&mapping).map_err(|e| Error::RustError(e.to_string()))?,
187    )?
188    .execute()
189    .await?;
190
191    Ok(())
192}
193
194/// Load id↔label mapping from KV.
195async fn load_mapping(
196    env: &Env,
197) -> Result<(
198    std::collections::HashMap<String, u64>,
199    std::collections::HashMap<u64, String>,
200    u64,
201)> {
202    let store_key = env
203        .var("RVF_STORE_KEY")
204        .map(|v| v.to_string())
205        .unwrap_or_else(|_| "nostr-bbs.rvf".to_string());
206
207    let kv = env.kv("SEARCH_CONFIG")?;
208    let mapping_key = format!("{store_key}:mapping");
209
210    if let Some(json_str) = kv.get(&mapping_key).text().await? {
211        if let Ok(val) = serde_json::from_str::<serde_json::Value>(&json_str) {
212            let next = val["next"].as_u64().unwrap_or(1);
213            let mut id_to_label = std::collections::HashMap::new();
214            let mut label_to_id = std::collections::HashMap::new();
215
216            if let Some(pairs) = val["pairs"].as_array() {
217                for pair in pairs {
218                    if let (Some(id), Some(label)) = (pair[0].as_str(), pair[1].as_u64()) {
219                        id_to_label.insert(id.to_string(), label);
220                        label_to_id.insert(label, id.to_string());
221                    }
222                }
223            }
224
225            return Ok((id_to_label, label_to_id, next));
226        }
227    }
228
229    Ok((
230        std::collections::HashMap::new(),
231        std::collections::HashMap::new(),
232        1,
233    ))
234}
235
236// ---------------------------------------------------------------------------
237// Handlers
238// ---------------------------------------------------------------------------
239
240async fn handle_search(req: &Request, env: &Env) -> Result<Response> {
241    let mut req_clone = req.clone()?;
242    let body: SearchRequest = req_clone.json().await?;
243
244    if body.embedding.len() != DIM {
245        return json_response(
246            req,
247            env,
248            &serde_json::json!({ "error": format!("Expected {DIM}-dim embedding") }),
249            400,
250        );
251    }
252
253    let store = load_store(env).await?;
254    let k = body.k.clamp(1, 100);
255
256    if store.count() == 0 {
257        return json_response(
258            req,
259            env,
260            &serde_json::json!({ "results": [], "totalVectors": 0 }),
261            200,
262        );
263    }
264
265    let (_, label_to_id, _) = load_mapping(env).await?;
266    let results = store.search(&body.embedding, k, body.min_score);
267
268    let results_json: Vec<serde_json::Value> = results
269        .iter()
270        .map(|(label, score)| {
271            let id = label_to_id
272                .get(label)
273                .cloned()
274                .unwrap_or_else(|| label.to_string());
275            serde_json::json!({
276                "id": id,
277                "distance": 1.0 - score,
278                "score": score,
279            })
280        })
281        .collect();
282
283    json_response(
284        req,
285        env,
286        &serde_json::json!({
287            "results": results_json,
288            "totalVectors": store.count(),
289            "engine": "rvf-rust",
290            "dimensions": DIM,
291        }),
292        200,
293    )
294}
295
296async fn handle_embed(req: &Request, env: &Env) -> Result<Response> {
297    let mut req_clone = req.clone()?;
298    let body: EmbedRequest = req_clone.json().await?;
299
300    let texts: Vec<String> = match (body.texts, body.text) {
301        (Some(texts), _) => texts,
302        (None, Some(text)) => vec![text],
303        (None, None) => {
304            return json_response(
305                req,
306                env,
307                &serde_json::json!({ "error": "Missing text or texts field" }),
308                400,
309            );
310        }
311    };
312
313    if texts.is_empty() {
314        return json_response(
315            req,
316            env,
317            &serde_json::json!({ "error": "Missing text or texts field" }),
318            400,
319        );
320    }
321    if texts.len() > 100 {
322        return json_response(
323            req,
324            env,
325            &serde_json::json!({ "error": "Maximum 100 texts per request" }),
326            400,
327        );
328    }
329
330    let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| embed::generate_embedding(t)).collect();
331
332    json_response(
333        req,
334        env,
335        &serde_json::json!({
336            "embeddings": embeddings,
337            "dimensions": DIM,
338            "model": "hash-fallback-v1",
339            "note": "Hash-based fallback embedding. Replace with ONNX WASM model for semantic quality.",
340        }),
341        200,
342    )
343}
344
345async fn handle_ingest(req: &Request, env: &Env) -> Result<Response> {
346    // NIP-98 admin auth
347    let url = req.url()?;
348    let request_url = format!("{}{}", url.origin().ascii_serialization(), url.path());
349    let auth_header = req.headers().get("Authorization")?;
350    let mut req_clone = req.clone()?;
351    let raw_body = req_clone.bytes().await?;
352
353    if let Err((err_body, status)) = auth::require_nip98_admin(
354        auth_header.as_deref(),
355        &request_url,
356        "POST",
357        Some(&raw_body),
358        env,
359    )
360    .await
361    {
362        return json_response(req, env, &err_body, status);
363    }
364
365    let body: IngestRequest =
366        serde_json::from_slice(&raw_body).map_err(|e| Error::RustError(e.to_string()))?;
367
368    if body.entries.is_empty() {
369        return json_response(
370            req,
371            env,
372            &serde_json::json!({ "error": "Missing entries array" }),
373            400,
374        );
375    }
376
377    let mut store = load_store(env).await?;
378    let (mut id_to_label, _, mut next_label) = load_mapping(env).await?;
379
380    let mut accepted = 0u32;
381    let mut rejected = 0u32;
382
383    for entry in &body.entries {
384        if entry.id.is_empty() || entry.embedding.len() != DIM {
385            rejected += 1;
386            continue;
387        }
388
389        let label = *id_to_label.entry(entry.id.clone()).or_insert_with(|| {
390            let l = next_label;
391            next_label += 1;
392            l
393        });
394
395        store.insert(label, &entry.embedding);
396        accepted += 1;
397    }
398
399    // Persist to R2 + KV
400    persist_store(&store, &id_to_label, next_label, env).await?;
401
402    json_response(
403        req,
404        env,
405        &serde_json::json!({
406            "accepted": accepted,
407            "rejected": rejected,
408            "totalVectors": store.count(),
409            "engine": "rvf-rust",
410        }),
411        200,
412    )
413}
414
415async fn handle_status(req: &Request, env: &Env) -> Result<Response> {
416    let store = load_store(env).await?;
417
418    json_response(
419        req,
420        env,
421        &serde_json::json!({
422            "status": "healthy",
423            "totalVectors": store.count(),
424            "dimensions": DIM,
425            "metric": "cosine",
426            "model": "all-MiniLM-L6-v2",
427            "engine": "rvf-rust",
428            "runtime": "workers-rs",
429            "format": "rvf-v1",
430        }),
431        200,
432    )
433}
434
435// ---------------------------------------------------------------------------
436// Entry point
437// ---------------------------------------------------------------------------
438
439#[event(fetch)]
440async fn fetch(req: Request, env: Env, _ctx: Context) -> Result<Response> {
441    nostr_bbs_rate_limit::ensure_replay_schema(&env, "REPLAY_DB").await;
442
443    // CORS preflight
444    if req.method() == Method::Options {
445        return Ok(Response::empty()?
446            .with_status(204)
447            .with_headers(cors_headers(&req, &env)));
448    }
449
450    // Rate limit: 100 requests per 60 seconds per IP
451    let ip = nostr_bbs_rate_limit::client_ip(&req);
452    if !nostr_bbs_rate_limit::check_rate_limit(&env, "SEARCH_CONFIG", &ip, 100, 60).await {
453        return json_response(
454            &req,
455            &env,
456            &serde_json::json!({ "error": "Too many requests" }),
457            429,
458        );
459    }
460
461    let url = req.url()?;
462    let path = url.path();
463
464    let result = route(&req, &env, path).await;
465    match result {
466        Ok(resp) => Ok(resp),
467        Err(e) => {
468            console_error!("Search worker error: {e}");
469            let msg = e.to_string();
470            if msg.contains("JSON") || msg.contains("json") || msg.contains("Syntax") {
471                json_response(
472                    &req,
473                    &env,
474                    &serde_json::json!({ "error": "Invalid JSON body" }),
475                    400,
476                )
477            } else {
478                json_response(
479                    &req,
480                    &env,
481                    &serde_json::json!({ "error": "Internal error" }),
482                    500,
483                )
484            }
485        }
486    }
487}
488
489async fn route(req: &Request, env: &Env, path: &str) -> Result<Response> {
490    let method = req.method();
491
492    // Health / status
493    if (path == "/health" || path == "/status" || path == "/") && method == Method::Get {
494        return handle_status(req, env).await;
495    }
496
497    // Search
498    if path == "/search" && method == Method::Post {
499        return handle_search(req, env).await;
500    }
501
502    // Embed
503    if path == "/embed" && method == Method::Post {
504        return handle_embed(req, env).await;
505    }
506
507    // Ingest (NIP-98 admin only)
508    if path == "/ingest" && method == Method::Post {
509        return handle_ingest(req, env).await;
510    }
511
512    json_response(req, env, &serde_json::json!({ "error": "Not found" }), 404)
513}
514
515// ---------------------------------------------------------------------------
516// Cron keep-warm
517// ---------------------------------------------------------------------------
518
519#[event(scheduled)]
520async fn scheduled(_event: ScheduledEvent, env: Env, _ctx: ScheduleContext) {
521    // Touch R2 to keep the connection warm
522    let _ = load_store(&env).await;
523}