1#![allow(dead_code)]
19
20mod auth;
21mod embed;
22mod store;
23
24use embed::DIM;
25use serde::Deserialize;
26use store::VectorStore;
27use worker::*;
28
29fn allowed_origins(env: &Env) -> Vec<String> {
36 env.var("ALLOWED_ORIGINS")
37 .map(|v| v.to_string())
38 .unwrap_or_else(|_| "https://example.com".to_string())
39 .split(',')
40 .map(|s| s.trim().to_string())
41 .collect()
42}
43
44fn cors_origin(req: &Request, env: &Env) -> String {
45 let origins = allowed_origins(env);
46 let origin = req
47 .headers()
48 .get("Origin")
49 .ok()
50 .flatten()
51 .unwrap_or_default();
52 if origins.iter().any(|o| o == &origin) {
53 origin
54 } else {
55 origins
56 .into_iter()
57 .next()
58 .unwrap_or_else(|| "https://example.com".to_string())
59 }
60}
61
62fn cors_headers(req: &Request, env: &Env) -> Headers {
63 let headers = Headers::new();
64 headers
65 .set("Access-Control-Allow-Origin", &cors_origin(req, env))
66 .ok();
67 headers
68 .set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
69 .ok();
70 headers
71 .set(
72 "Access-Control-Allow-Headers",
73 "Content-Type, Authorization",
74 )
75 .ok();
76 headers.set("Access-Control-Max-Age", "86400").ok();
77 headers.set("Vary", "Origin").ok();
78 headers
79}
80
81fn json_response(
82 req: &Request,
83 env: &Env,
84 body: &serde_json::Value,
85 status: u16,
86) -> Result<Response> {
87 let json_str = serde_json::to_string(body).map_err(|e| Error::RustError(e.to_string()))?;
88 let headers = cors_headers(req, env);
89 headers.set("Content-Type", "application/json").ok();
90 Ok(Response::ok(json_str)?
91 .with_status(status)
92 .with_headers(headers))
93}
94
95#[derive(Deserialize)]
100struct SearchRequest {
101 embedding: Vec<f32>,
102 #[serde(default = "default_k")]
103 k: usize,
104 #[serde(default, rename = "minScore")]
105 min_score: f32,
106}
107
108fn default_k() -> usize {
109 10
110}
111
112#[derive(Deserialize)]
113struct IngestEntry {
114 id: String,
115 embedding: Vec<f32>,
116}
117
118#[derive(Deserialize)]
119struct IngestRequest {
120 entries: Vec<IngestEntry>,
121}
122
123#[derive(Deserialize)]
124struct EmbedRequest {
125 text: Option<String>,
126 texts: Option<Vec<String>>,
127}
128
129async fn load_store(env: &Env) -> Result<VectorStore> {
135 let store_key = env
136 .var("RVF_STORE_KEY")
137 .map(|v| v.to_string())
138 .unwrap_or_else(|_| "nostr-bbs.rvf".to_string());
139
140 let bucket = env.bucket("VECTORS")?;
141 let obj = bucket.get(&store_key).execute().await?;
142
143 if let Some(obj) = obj {
144 let body = obj
149 .body()
150 .ok_or_else(|| worker::Error::RustError("R2 object missing body".into()))?;
151 let bytes = body.bytes().await?;
152 if let Some(store) = VectorStore::from_rvf_bytes(&bytes) {
153 return Ok(store);
154 }
155 }
156
157 Ok(VectorStore::new())
158}
159
160async fn persist_store(
162 store: &VectorStore,
163 id_to_label: &std::collections::HashMap<String, u64>,
164 next_label: u64,
165 env: &Env,
166) -> Result<()> {
167 let store_key = env
168 .var("RVF_STORE_KEY")
169 .map(|v| v.to_string())
170 .unwrap_or_else(|_| "nostr-bbs.rvf".to_string());
171
172 let rvf_bytes = store.to_rvf_bytes();
174 let bucket = env.bucket("VECTORS")?;
175 bucket.put(&store_key, rvf_bytes).execute().await?;
176
177 let kv = env.kv("SEARCH_CONFIG")?;
179 let pairs: Vec<(&str, u64)> = id_to_label.iter().map(|(k, v)| (k.as_str(), *v)).collect();
180 let mapping = serde_json::json!({
181 "pairs": pairs,
182 "next": next_label,
183 });
184 kv.put(
185 &format!("{store_key}:mapping"),
186 serde_json::to_string(&mapping).map_err(|e| Error::RustError(e.to_string()))?,
187 )?
188 .execute()
189 .await?;
190
191 Ok(())
192}
193
194async fn load_mapping(
196 env: &Env,
197) -> Result<(
198 std::collections::HashMap<String, u64>,
199 std::collections::HashMap<u64, String>,
200 u64,
201)> {
202 let store_key = env
203 .var("RVF_STORE_KEY")
204 .map(|v| v.to_string())
205 .unwrap_or_else(|_| "nostr-bbs.rvf".to_string());
206
207 let kv = env.kv("SEARCH_CONFIG")?;
208 let mapping_key = format!("{store_key}:mapping");
209
210 if let Some(json_str) = kv.get(&mapping_key).text().await? {
211 if let Ok(val) = serde_json::from_str::<serde_json::Value>(&json_str) {
212 let next = val["next"].as_u64().unwrap_or(1);
213 let mut id_to_label = std::collections::HashMap::new();
214 let mut label_to_id = std::collections::HashMap::new();
215
216 if let Some(pairs) = val["pairs"].as_array() {
217 for pair in pairs {
218 if let (Some(id), Some(label)) = (pair[0].as_str(), pair[1].as_u64()) {
219 id_to_label.insert(id.to_string(), label);
220 label_to_id.insert(label, id.to_string());
221 }
222 }
223 }
224
225 return Ok((id_to_label, label_to_id, next));
226 }
227 }
228
229 Ok((
230 std::collections::HashMap::new(),
231 std::collections::HashMap::new(),
232 1,
233 ))
234}
235
236async fn handle_search(req: &Request, env: &Env) -> Result<Response> {
241 let mut req_clone = req.clone()?;
242 let body: SearchRequest = req_clone.json().await?;
243
244 if body.embedding.len() != DIM {
245 return json_response(
246 req,
247 env,
248 &serde_json::json!({ "error": format!("Expected {DIM}-dim embedding") }),
249 400,
250 );
251 }
252
253 let store = load_store(env).await?;
254 let k = body.k.clamp(1, 100);
255
256 if store.count() == 0 {
257 return json_response(
258 req,
259 env,
260 &serde_json::json!({ "results": [], "totalVectors": 0 }),
261 200,
262 );
263 }
264
265 let (_, label_to_id, _) = load_mapping(env).await?;
266 let results = store.search(&body.embedding, k, body.min_score);
267
268 let results_json: Vec<serde_json::Value> = results
269 .iter()
270 .map(|(label, score)| {
271 let id = label_to_id
272 .get(label)
273 .cloned()
274 .unwrap_or_else(|| label.to_string());
275 serde_json::json!({
276 "id": id,
277 "distance": 1.0 - score,
278 "score": score,
279 })
280 })
281 .collect();
282
283 json_response(
284 req,
285 env,
286 &serde_json::json!({
287 "results": results_json,
288 "totalVectors": store.count(),
289 "engine": "rvf-rust",
290 "dimensions": DIM,
291 }),
292 200,
293 )
294}
295
296async fn handle_embed(req: &Request, env: &Env) -> Result<Response> {
297 let mut req_clone = req.clone()?;
298 let body: EmbedRequest = req_clone.json().await?;
299
300 let texts: Vec<String> = match (body.texts, body.text) {
301 (Some(texts), _) => texts,
302 (None, Some(text)) => vec![text],
303 (None, None) => {
304 return json_response(
305 req,
306 env,
307 &serde_json::json!({ "error": "Missing text or texts field" }),
308 400,
309 );
310 }
311 };
312
313 if texts.is_empty() {
314 return json_response(
315 req,
316 env,
317 &serde_json::json!({ "error": "Missing text or texts field" }),
318 400,
319 );
320 }
321 if texts.len() > 100 {
322 return json_response(
323 req,
324 env,
325 &serde_json::json!({ "error": "Maximum 100 texts per request" }),
326 400,
327 );
328 }
329
330 let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| embed::generate_embedding(t)).collect();
331
332 json_response(
333 req,
334 env,
335 &serde_json::json!({
336 "embeddings": embeddings,
337 "dimensions": DIM,
338 "model": "hash-fallback-v1",
339 "note": "Hash-based fallback embedding. Replace with ONNX WASM model for semantic quality.",
340 }),
341 200,
342 )
343}
344
345async fn handle_ingest(req: &Request, env: &Env) -> Result<Response> {
346 let url = req.url()?;
348 let request_url = format!("{}{}", url.origin().ascii_serialization(), url.path());
349 let auth_header = req.headers().get("Authorization")?;
350 let mut req_clone = req.clone()?;
351 let raw_body = req_clone.bytes().await?;
352
353 if let Err((err_body, status)) = auth::require_nip98_admin(
354 auth_header.as_deref(),
355 &request_url,
356 "POST",
357 Some(&raw_body),
358 env,
359 )
360 .await
361 {
362 return json_response(req, env, &err_body, status);
363 }
364
365 let body: IngestRequest =
366 serde_json::from_slice(&raw_body).map_err(|e| Error::RustError(e.to_string()))?;
367
368 if body.entries.is_empty() {
369 return json_response(
370 req,
371 env,
372 &serde_json::json!({ "error": "Missing entries array" }),
373 400,
374 );
375 }
376
377 let mut store = load_store(env).await?;
378 let (mut id_to_label, _, mut next_label) = load_mapping(env).await?;
379
380 let mut accepted = 0u32;
381 let mut rejected = 0u32;
382
383 for entry in &body.entries {
384 if entry.id.is_empty() || entry.embedding.len() != DIM {
385 rejected += 1;
386 continue;
387 }
388
389 let label = *id_to_label.entry(entry.id.clone()).or_insert_with(|| {
390 let l = next_label;
391 next_label += 1;
392 l
393 });
394
395 store.insert(label, &entry.embedding);
396 accepted += 1;
397 }
398
399 persist_store(&store, &id_to_label, next_label, env).await?;
401
402 json_response(
403 req,
404 env,
405 &serde_json::json!({
406 "accepted": accepted,
407 "rejected": rejected,
408 "totalVectors": store.count(),
409 "engine": "rvf-rust",
410 }),
411 200,
412 )
413}
414
415async fn handle_status(req: &Request, env: &Env) -> Result<Response> {
416 let store = load_store(env).await?;
417
418 json_response(
419 req,
420 env,
421 &serde_json::json!({
422 "status": "healthy",
423 "totalVectors": store.count(),
424 "dimensions": DIM,
425 "metric": "cosine",
426 "model": "all-MiniLM-L6-v2",
427 "engine": "rvf-rust",
428 "runtime": "workers-rs",
429 "format": "rvf-v1",
430 }),
431 200,
432 )
433}
434
435#[event(fetch)]
440async fn fetch(req: Request, env: Env, _ctx: Context) -> Result<Response> {
441 nostr_bbs_rate_limit::ensure_replay_schema(&env, "REPLAY_DB").await;
442
443 if req.method() == Method::Options {
445 return Ok(Response::empty()?
446 .with_status(204)
447 .with_headers(cors_headers(&req, &env)));
448 }
449
450 let ip = nostr_bbs_rate_limit::client_ip(&req);
452 if !nostr_bbs_rate_limit::check_rate_limit(&env, "SEARCH_CONFIG", &ip, 100, 60).await {
453 return json_response(
454 &req,
455 &env,
456 &serde_json::json!({ "error": "Too many requests" }),
457 429,
458 );
459 }
460
461 let url = req.url()?;
462 let path = url.path();
463
464 let result = route(&req, &env, path).await;
465 match result {
466 Ok(resp) => Ok(resp),
467 Err(e) => {
468 console_error!("Search worker error: {e}");
469 let msg = e.to_string();
470 if msg.contains("JSON") || msg.contains("json") || msg.contains("Syntax") {
471 json_response(
472 &req,
473 &env,
474 &serde_json::json!({ "error": "Invalid JSON body" }),
475 400,
476 )
477 } else {
478 json_response(
479 &req,
480 &env,
481 &serde_json::json!({ "error": "Internal error" }),
482 500,
483 )
484 }
485 }
486 }
487}
488
489async fn route(req: &Request, env: &Env, path: &str) -> Result<Response> {
490 let method = req.method();
491
492 if (path == "/health" || path == "/status" || path == "/") && method == Method::Get {
494 return handle_status(req, env).await;
495 }
496
497 if path == "/search" && method == Method::Post {
499 return handle_search(req, env).await;
500 }
501
502 if path == "/embed" && method == Method::Post {
504 return handle_embed(req, env).await;
505 }
506
507 if path == "/ingest" && method == Method::Post {
509 return handle_ingest(req, env).await;
510 }
511
512 json_response(req, env, &serde_json::json!({ "error": "Not found" }), 404)
513}
514
515#[event(scheduled)]
520async fn scheduled(_event: ScheduledEvent, env: Env, _ctx: ScheduleContext) {
521 let _ = load_store(&env).await;
523}