1use std::env;
5use std::fs;
6use std::io::{self, Read};
7use std::path::{Path, PathBuf};
8use std::time::Duration;
9
10use serde_json::Value;
11use zotron_rpc::UreqProviderHttpTransport;
12use zotron_types::{
13 bm25_score_chunks, build_embedding_provider_request, cosine_similarity,
14 diversity_filter, execute_embedding_provider_request, gap_cutoff, max_k_truncate,
15 parse_embedding_provider_response, read_machine_artifact_sidecar, rrf_merge,
16 score_floor_filter, token_budget_filter, ArtifactStorePlatform, EmbeddingChunkInput,
17 EmbeddingRequestInput, EmbeddingVector, MachineArtifactKind, StructureChunk,
18};
19
20use crate::output::{format_json, normalize_list_envelope};
21use crate::rpc::RpcCaller;
22use crate::{
23 collection_items, embedding_provider_spec, find_collection_in_tree, local_path_from_zotero_path,
24 paginate_rpc, resolve_collection, RagCommand, RagSearchOptions,
25};
26
27pub(crate) fn run_rag_command(command: RagCommand, client: &mut impl RpcCaller) -> Result<String, String> {
28 match command {
29 RagCommand::Providers => format_json(
30 &serde_json::json!({
31 "providers": [
32 embedding_provider_spec("volcengine")?,
33 embedding_provider_spec("alibaba")?,
34 embedding_provider_spec("custom")?,
35 ],
36 })),
37 RagCommand::Embed {
38 provider,
39 input,
40 endpoint,
41 model,
42 input_type,
43 api_key_env,
44 } => {
45 let value = run_embedding_provider_json_command(
46 provider,
47 input,
48 endpoint,
49 model,
50 input_type,
51 api_key_env,
52 )?;
53 format_json(&value)
54 }
55 RagCommand::Status { collection, .. } => {
56 let value = rag_status_value(client, &collection)?;
57 format_json(&value)
58 }
59 RagCommand::Search {
60 query,
61 collection,
62 keys,
63 zotero,
64 top_spans_per_item,
65 include_fulltext_spans,
66 top_k,
67 output,
68 ..
69 } => run_rag_search_command(
70 client,
71 RagSearchOptions {
72 query,
73 collection,
74 keys,
75 zotero,
76 top_spans_per_item,
77 include_fulltext_spans,
78 top_k,
79 output,
80 },
81 ),
82 }
83}
84
85pub(crate) fn run_embedding_provider_json_command(
86 provider: String,
87 input: String,
88 endpoint: Option<String>,
89 model: Option<String>,
90 input_type: Option<String>,
91 api_key_env: Option<String>,
92) -> Result<Value, String> {
93 let mut input: EmbeddingRequestInput = read_json_input(&input)?;
94 if endpoint.is_some() {
95 input.url = endpoint;
96 }
97 if model.is_some() {
98 input.model = model;
99 }
100 if input_type.is_some() {
101 input.input_type = input_type;
102 }
103 let mut transport = provider_http_transport(api_key_env.as_deref())?;
104 let vectors = execute_embedding_provider_request(&provider, &input, &mut transport)?;
105
106 Ok(serde_json::json!({
107 "provider": provider,
108 "vectors": vectors,
109 }))
110}
111
112pub(crate) fn provider_http_transport(api_key_env: Option<&str>) -> Result<UreqProviderHttpTransport, String> {
113 provider_http_transport_with_auth(api_key_env, "bearer")
114}
115
116pub(crate) fn provider_http_transport_with_auth(
117 api_key_env: Option<&str>,
118 auth_scheme: &str,
119) -> Result<UreqProviderHttpTransport, String> {
120 let Some(env_name) = api_key_env else {
121 return Ok(UreqProviderHttpTransport::new());
122 };
123 let token = env::var(env_name)
124 .map_err(|_| format!("missing provider credential env var {env_name}"))?;
125 if token.trim().is_empty() {
126 return Err(format!("provider credential env var {env_name} is empty"));
127 }
128 let token = token.trim();
129 match auth_scheme {
130 "token" if token.starts_with("token ") => {
131 Ok(UreqProviderHttpTransport::with_api_key(token.to_string()))
132 }
133 "token" => Ok(UreqProviderHttpTransport::with_api_key(format!(
134 "token {token}"
135 ))),
136 "bearer" if token.starts_with("Bearer ") => {
137 Ok(UreqProviderHttpTransport::with_api_key(token.to_string()))
138 }
139 "bearer" => Ok(UreqProviderHttpTransport::with_bearer_token(token)),
140 "none" => Ok(UreqProviderHttpTransport::new()),
141 other => Err(format!("unsupported provider auth scheme {other}")),
142 }
143}
144
145pub(crate) fn read_json_input<T: serde::de::DeserializeOwned>(path: &str) -> Result<T, String> {
146 let payload = if path == "-" {
147 let mut input = String::new();
148 io::stdin()
149 .read_to_string(&mut input)
150 .map_err(|err| format!("read stdin: {err}"))?;
151 input
152 } else {
153 fs::read_to_string(path).map_err(|err| format!("read {path}: {err}"))?
154 };
155 serde_json::from_str::<T>(&payload)
156 .map_err(|err| format!("INVALID_JSON: Could not parse JSON: {err}"))
157}
158
159pub(crate) fn fetch_embedding_settings(
160 client: &mut impl RpcCaller,
161) -> Result<(String, String, String, String), String> {
162 let settings = client.call("settings.getAll", None)?;
163 let raw = client.call("settings.getRaw", Some(serde_json::json!({"key": "embedding.apiKey"})))?;
164 let api_key = raw
165 .get("embedding.apiKey")
166 .and_then(Value::as_str)
167 .unwrap_or("")
168 .to_string();
169 Ok(parse_embedding_settings(&settings, api_key))
170}
171
172pub(crate) fn parse_embedding_settings(
176 settings: &Value,
177 api_key: String,
178) -> (String, String, String, String) {
179 let provider = settings
180 .get("embedding.provider")
181 .and_then(Value::as_str)
182 .unwrap_or("ollama")
183 .to_string();
184 let model = settings
185 .get("embedding.model")
186 .and_then(Value::as_str)
187 .unwrap_or("")
188 .to_string();
189 let api_url = settings
190 .get("embedding.apiUrl")
191 .and_then(Value::as_str)
192 .unwrap_or("")
193 .to_string();
194 (provider, model, api_url, api_key)
195}
196
197#[derive(Debug)]
198pub struct RerankSettings {
199 pub provider: String,
200 pub model: String,
201 pub api_url: String,
202 pub api_key: String,
203 pub candidate_count: usize,
204}
205
206pub fn fetch_rerank_settings(
207 client: &mut impl RpcCaller,
208) -> Result<RerankSettings, String> {
209 let settings = client.call("settings.getAll", None)?;
210 let raw = client.call(
211 "settings.getRaw",
212 Some(serde_json::json!({"key": "rerank.apiKey"})),
213 )?;
214 let api_key = raw
215 .get("rerank.apiKey")
216 .and_then(Value::as_str)
217 .unwrap_or("")
218 .to_string();
219 Ok(parse_rerank_settings(&settings, api_key))
220}
221
222pub(crate) fn parse_rerank_settings(settings: &Value, api_key: String) -> RerankSettings {
226 let provider = settings
227 .get("rerank.provider")
228 .and_then(Value::as_str)
229 .unwrap_or("")
230 .to_string();
231 let model = settings
232 .get("rerank.model")
233 .and_then(Value::as_str)
234 .unwrap_or("")
235 .to_string();
236 let api_url = settings
237 .get("rerank.apiUrl")
238 .and_then(Value::as_str)
239 .unwrap_or("")
240 .to_string();
241 let candidate_count = settings
242 .get("rerank.candidateCount")
243 .and_then(Value::as_str)
244 .and_then(|s| s.parse().ok())
245 .unwrap_or(30);
246
247 let specs = zotron_types::builtin_rerank_provider_specs();
248 let spec = specs.iter().find(|s| s.id == provider);
249
250 let api_url = if api_url.is_empty() {
251 spec.map(|s| s.default_url.to_string()).unwrap_or_default()
252 } else {
253 api_url
254 };
255
256 let model = if model.is_empty() {
257 spec.map(|s| s.default_model.to_string()).unwrap_or_default()
258 } else {
259 model
260 };
261
262 RerankSettings {
263 provider,
264 model,
265 api_url,
266 api_key,
267 candidate_count,
268 }
269}
270
271pub(crate) struct RagCutoffSettings {
272 min_k: usize,
273 max_k: usize,
274 token_budget: usize,
275 mmr_lambda: f64,
276 score_floor: f64,
277 gap_threshold: f64,
278}
279
280pub(crate) fn parse_rag_cutoff_settings(settings: &Value) -> RagCutoffSettings {
284 let get = |key: &str, default: &str| -> String {
285 settings
286 .get(key)
287 .and_then(|v| v.as_str())
288 .unwrap_or(default)
289 .to_string()
290 };
291 let legacy_top_k: Option<usize> = settings
292 .get("rag.topK")
293 .and_then(|v| v.as_str())
294 .and_then(|s| s.parse().ok());
295 let max_k = get("rag.maxK", "")
296 .parse()
297 .ok()
298 .or(legacy_top_k)
299 .unwrap_or(20);
300 if legacy_top_k.is_some()
301 && settings
302 .get("rag.maxK")
303 .and_then(|v| v.as_str())
304 .unwrap_or("")
305 .is_empty()
306 {
307 eprintln!("warning: rag.topK is deprecated, use rag.maxK instead");
308 }
309 RagCutoffSettings {
310 min_k: get("rag.minK", "3").parse().unwrap_or(3),
311 max_k,
312 token_budget: get("rag.tokenBudget", "6000").parse().unwrap_or(6000),
313 mmr_lambda: get("rag.mmrLambda", "0.7").parse().unwrap_or(0.7),
314 score_floor: get("rerank.scoreFloor", "0.1").parse().unwrap_or(0.1),
315 gap_threshold: get("rerank.gapThreshold", "0.15").parse().unwrap_or(0.15),
316 }
317}
318
319pub(crate) fn rerank_chunks(
320 query: &str,
321 chunks: &[StructureChunk],
322 ranked: &[(usize, f64)],
323 settings: &RerankSettings,
324) -> Result<Vec<(usize, f64)>, String> {
325 let specs = zotron_types::builtin_rerank_provider_specs();
326 let spec = specs
327 .iter()
328 .find(|s| s.id == settings.provider)
329 .ok_or_else(|| format!("unknown rerank provider: {}", settings.provider))?;
330
331 let candidate_count = settings.candidate_count.min(ranked.len());
332 let candidates: Vec<(usize, f64)> = ranked.iter().take(candidate_count).copied().collect();
333 let documents: Vec<&str> = candidates
334 .iter()
335 .map(|(idx, _)| chunks[*idx].text.as_str())
336 .collect();
337
338 let request_body = zotron_types::build_rerank_provider_request(
339 &settings.model,
340 query,
341 &documents,
342 candidate_count,
343 );
344
345 let body_str = serde_json::to_string(&request_body)
346 .map_err(|e| format!("rerank request serialize error: {e}"))?;
347
348 let agent = ureq::AgentBuilder::new()
349 .timeout(Duration::from_secs(10))
350 .build();
351
352 let send = |agent: &ureq::Agent| -> Result<ureq::Response, (bool, String)> {
353 agent
354 .post(&settings.api_url)
355 .set("Content-Type", "application/json")
356 .set("Authorization", &format!("Bearer {}", settings.api_key))
357 .send_string(&body_str)
358 .map_err(|e| {
359 let transient = matches!(&e, ureq::Error::Status(code, _) if *code == 429 || *code >= 500);
360 (transient, e.to_string())
361 })
362 };
363
364 let response = match send(&agent) {
365 Ok(r) => r,
366 Err((true, _)) => {
367 std::thread::sleep(Duration::from_secs(1));
368 send(&agent).map_err(|(_, msg)| format!("rerank API retry failed: {msg}"))?
369 }
370 Err((_, msg)) => return Err(format!("rerank API failed: {msg}")),
371 };
372
373 let payload: serde_json::Value = response
374 .into_json()
375 .map_err(|e| format!("rerank response parse error: {e}"))?;
376
377 let reranked = zotron_types::parse_rerank_provider_response(spec, &payload)?;
378
379 Ok(map_reranked_to_candidates(reranked, &candidates))
380}
381
382pub(crate) fn map_reranked_to_candidates(
389 reranked: Vec<zotron_types::RerankResult>,
390 candidates: &[(usize, f64)],
391) -> Vec<(usize, f64)> {
392 reranked
393 .into_iter()
394 .filter_map(|r| candidates.get(r.index).map(|c| (c.0, r.score)))
395 .collect()
396}
397
398pub(crate) fn parse_retrieval_mode(settings: &Value) -> String {
402 settings
403 .get("rag.retrievalMode")
404 .and_then(Value::as_str)
405 .map(String::from)
406 .unwrap_or_else(|| "hybrid".to_string())
407}
408
409pub(crate) fn resolve_sidecar_paths(
410 client: &mut impl RpcCaller,
411 collection: Option<&str>,
412 keys: &[String],
413) -> Result<Vec<(String, String, PathBuf)>, String> {
414 let items = if !keys.is_empty() {
415 let mut items = Vec::new();
416 for key in keys {
417 let item = client.call("items.get", Some(serde_json::json!({"key": key})))?;
418 items.push(item);
419 }
420 items
421 } else if let Some(col) = collection {
422 let col_key = resolve_collection(client, col)?;
423 let response = client.call(
424 "collections.getItems",
425 Some(serde_json::json!({"key": col_key})),
426 )?;
427 collection_items(&response)
428 } else {
429 return Err("INVALID_ARGS: --collection or --key required".into());
430 };
431
432 let mut results = Vec::new();
433 for item in &items {
434 let item_key = item.get("key").and_then(Value::as_str).unwrap_or_default();
435 let attachments = client.call(
436 "attachments.list",
437 Some(serde_json::json!({"parentKey": item_key})),
438 )?;
439 let att_list = attachments
440 .get("items")
441 .and_then(Value::as_array)
442 .or_else(|| attachments.as_array())
443 .cloned()
444 .unwrap_or_default();
445 for att in &att_list {
446 let content_type = att
447 .get("contentType")
448 .and_then(Value::as_str)
449 .unwrap_or("");
450 if content_type != "application/pdf" {
451 continue;
452 }
453 let att_key = att.get("key").and_then(Value::as_str).unwrap_or_default();
454 let path = att.get("path").and_then(Value::as_str).unwrap_or_default();
455 if path.is_empty() {
456 continue;
457 }
458 let local_path = local_path_from_zotero_path(path);
459 let pdf_path = PathBuf::from(&local_path);
460 if let Some(parent) = pdf_path.parent() {
461 let sidecar_root = parent.join(".zotron");
462 if sidecar_root.exists() {
463 results.push((item_key.to_string(), att_key.to_string(), sidecar_root));
464 }
465 }
466 }
467 }
468 Ok(results)
469}
470
471pub(crate) fn is_chunk_schema_header(line: &str) -> bool {
476 serde_json::from_str::<serde_json::Value>(line)
477 .ok()
478 .and_then(|v| v.get("schema_version").map(serde_json::Value::is_number))
479 .unwrap_or(false)
480}
481
482pub(crate) fn load_sidecar_chunks(sidecar_root: &Path) -> Vec<StructureChunk> {
483 let chunks_path =
484 sidecar_root.join(MachineArtifactKind::Chunks.sidecar_relative_path());
485 let Ok(content) = fs::read_to_string(&chunks_path) else {
486 return Vec::new();
487 };
488 content
489 .lines()
490 .filter(|line| !line.trim().is_empty())
491 .filter(|line| !is_chunk_schema_header(line))
492 .filter_map(|line| serde_json::from_str::<StructureChunk>(line).ok())
493 .collect()
494}
495
496pub(crate) fn embedding_vector_filename(provider: &str, model: &str) -> String {
497 let p = provider.trim().to_lowercase().replace('/', "-");
498 let m = model.trim().to_lowercase().replace('/', "-");
499 if p.is_empty() && m.is_empty() {
500 return "vectors.jsonl".to_string();
501 }
502 format!("{p}--{m}.jsonl")
503}
504
505pub(crate) fn load_sidecar_vectors(sidecar_root: &Path, provider: &str, model: &str) -> Vec<EmbeddingVector> {
506 let embeddings_dir = sidecar_root.join("embeddings");
507 let target = embedding_vector_filename(provider, model);
508 let target_path = embeddings_dir.join(&target);
509 if let Ok(content) = fs::read_to_string(&target_path) {
510 let vecs: Vec<EmbeddingVector> = content
511 .lines()
512 .filter(|line| !line.trim().is_empty())
513 .filter_map(|line| serde_json::from_str(line).ok())
514 .collect();
515 if !vecs.is_empty() {
516 return vecs;
517 }
518 }
519 for legacy in &["vectors.v1.jsonl", "vectors.jsonl"] {
521 let path = embeddings_dir.join(legacy);
522 if let Ok(content) = fs::read_to_string(&path) {
523 let vecs: Vec<EmbeddingVector> = content
524 .lines()
525 .filter(|line| !line.trim().is_empty())
526 .filter_map(|line| serde_json::from_str::<EmbeddingVector>(line).ok())
527 .filter(|v| v.source_provider == provider || provider.is_empty())
528 .collect();
529 if !vecs.is_empty() {
530 return vecs;
531 }
532 }
533 }
534 Vec::new()
535}
536
537pub(crate) fn embed_query_text(
538 query: &str,
539 provider: &str,
540 model: &str,
541 api_url: &str,
542 api_key: &str,
543) -> Result<Vec<f64>, String> {
544 let input = EmbeddingRequestInput {
545 item_key: "query".to_string(),
546 chunks: vec![EmbeddingChunkInput {
547 chunk_key: "q0".to_string(),
548 text: query.to_string(),
549 }],
550 model: if model.is_empty() {
551 None
552 } else {
553 Some(model.to_string())
554 },
555 url: if api_url.is_empty() {
556 None
557 } else {
558 Some(api_url.to_string())
559 },
560 input_type: Some("query".to_string()),
561 };
562 let request = build_embedding_provider_request(provider, &input)?;
563 let url = request
564 .url
565 .as_deref()
566 .ok_or("no embedding URL configured")?;
567 let mut http = ureq::post(url).set("Content-Type", "application/json");
568 if let Some(auth) = request.auth_header {
569 if !api_key.is_empty() {
570 http = http.set(auth, &format!("Bearer {api_key}"));
571 }
572 }
573 let resp = http
574 .send_json(&request.body)
575 .map_err(|e| format!("embedding request failed: {e}"))?;
576 let payload: Value = resp
577 .into_json()
578 .map_err(|e| format!("embedding response parse: {e}"))?;
579 let vectors =
580 parse_embedding_provider_response(provider, &payload, "query", &input.chunks)?;
581 vectors
582 .into_iter()
583 .next()
584 .map(|v| v.vector)
585 .ok_or_else(|| "no embedding vector returned".to_string())
586}
587
588fn camelize_xpi_hit(hit: &Value) -> Value {
591 let Some(obj) = hit.as_object() else {
592 return hit.clone();
593 };
594 let mut out = serde_json::Map::with_capacity(obj.len());
595 for (key, value) in obj {
596 let mapped = match key.as_str() {
597 "item_key" => "itemKey",
598 "chunk_key" => "chunkKey",
599 "attachment_key" => "attachmentKey",
600 "page_range" => "pageRange",
601 "section_path" => "sectionPath",
602 "score_kind" => "scoreKind",
603 "block_key" => "blockKey",
604 "block_keys" => "blockKeys",
605 "page_idx" => "pageIdx",
606 "evidence_refs" => "evidenceRefs",
607 other => other,
608 };
609 out.insert(mapped.to_string(), value.clone());
610 }
611 Value::Object(out)
612}
613
614pub(crate) fn run_rag_search_xpi_fallback(
615 client: &mut impl RpcCaller,
616 options: &RagSearchOptions,
617) -> Result<String, String> {
618 let mut params = serde_json::json!({
619 "query": options.query,
620 "limit": options.top_k,
621 "top_spans_per_item": options.top_spans_per_item,
622 "include_fulltext_spans": options.include_fulltext_spans,
623 });
624 if let Some(map) = params.as_object_mut() {
625 if let Some(col) = &options.collection {
626 map.insert("collection".into(), Value::String(col.clone()));
627 }
628 if !options.keys.is_empty() {
629 map.insert(
630 "keys".into(),
631 Value::Array(options.keys.iter().map(|k| Value::String(k.clone())).collect()),
632 );
633 }
634 }
635 let payload = client.call("rag.searchHits", Some(params))?;
636 let hits = payload
637 .get("hits")
638 .and_then(Value::as_array)
639 .cloned()
640 .unwrap_or_default()
641 .into_iter()
642 .map(|hit| camelize_xpi_hit(&hit))
643 .collect::<Vec<_>>();
644 if options.output == "jsonl" {
645 let mut out = String::new();
646 for hit in &hits {
647 out.push_str(&serde_json::to_string(hit).map_err(|e| e.to_string())?);
648 out.push('\n');
649 }
650 Ok(out)
651 } else {
652 let total = hits.len() as u64;
653 format_json(
654 &normalize_list_envelope(
655 serde_json::json!({"items": hits, "total": total}),
656 "items",
657 Some(options.top_k),
658 0,
659 ))
660 }
661}
662
663fn score_dense(
670 query: &str,
671 emb_provider: &str,
672 emb_model: &str,
673 emb_url: &str,
674 emb_key: &str,
675 all_chunks: &[StructureChunk],
676 all_vectors: &[EmbeddingVector],
677) -> Vec<(usize, f64)> {
678 match embed_query_text(query, emb_provider, emb_model, emb_url, emb_key) {
679 Ok(query_vec) => {
680 let vec_map: std::collections::HashMap<&str, &[f64]> = all_vectors
681 .iter()
682 .map(|v| (v.chunk_key.as_str(), v.vector.as_slice()))
683 .collect();
684 let mut scores: Vec<(usize, f64)> = all_chunks
685 .iter()
686 .enumerate()
687 .filter_map(|(i, chunk)| {
688 vec_map
689 .get(chunk.chunk_key.as_str())
690 .map(|stored| (i, cosine_similarity(&query_vec, stored)))
691 })
692 .filter(|(_, s)| *s > 0.0)
693 .collect();
694 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
695 scores
696 }
697 Err(e) => {
698 eprintln!("warning: dense retrieval unavailable (query embedding failed): {e}");
700 Vec::new()
701 }
702 }
703}
704
705fn build_diversity_vector_map<'a>(
709 all_chunks: &[StructureChunk],
710 all_vectors: &'a [EmbeddingVector],
711) -> std::collections::HashMap<usize, &'a [f64]> {
712 let chunk_key_index: std::collections::HashMap<&str, usize> = all_chunks
713 .iter()
714 .enumerate()
715 .map(|(i, c)| (c.chunk_key.as_str(), i))
716 .collect();
717 all_vectors
718 .iter()
719 .filter_map(|v| {
720 let &idx = chunk_key_index.get(v.chunk_key.as_str())?;
721 Some((idx, v.vector.as_slice()))
722 })
723 .collect()
724}
725
726fn apply_cutoff_pipeline(
730 mut pipeline_ranked: Vec<(usize, f64)>,
731 rrf_ranked: &[(usize, f64)],
732 full_reranked: &Option<Vec<(usize, f64)>>,
733 all_chunks: &[StructureChunk],
734 all_vectors: &[EmbeddingVector],
735 rag_cutoff: &RagCutoffSettings,
736) -> Vec<(usize, f64)> {
737 if full_reranked.is_some() {
739 pipeline_ranked = score_floor_filter(&pipeline_ranked, rag_cutoff.score_floor);
740 pipeline_ranked = gap_cutoff(&pipeline_ranked, rag_cutoff.gap_threshold);
741 }
742
743 let mmr_input: Vec<(usize, f64)> = if full_reranked.is_some() {
754 pipeline_ranked.clone()
755 } else {
756 let normalized_rel = zotron_types::min_max_normalize(
757 &pipeline_ranked.iter().map(|(_, s)| *s as f32).collect::<Vec<_>>(),
758 );
759 pipeline_ranked
760 .iter()
761 .zip(normalized_rel.iter())
762 .map(|((idx, _), norm)| (*idx, *norm as f64))
763 .collect()
764 };
765 let vector_map = build_diversity_vector_map(all_chunks, all_vectors);
766 let diversity_kept = diversity_filter(&mmr_input, &vector_map, rag_cutoff.mmr_lambda, 0.05);
769 let original_score: std::collections::HashMap<usize, f64> =
770 pipeline_ranked.iter().map(|(idx, s)| (*idx, *s)).collect();
771 pipeline_ranked = diversity_kept
772 .into_iter()
773 .map(|(idx, _norm)| (idx, *original_score.get(&idx).unwrap_or(&0.0)))
774 .collect();
775
776 let char_lens: Vec<usize> = all_chunks.iter().map(|c| c.text.chars().count()).collect();
779 pipeline_ranked = token_budget_filter(&pipeline_ranked, &char_lens, rag_cutoff.token_budget);
780
781 if pipeline_ranked.len() < rag_cutoff.min_k {
783 let source = full_reranked.as_deref().unwrap_or(rrf_ranked);
784 for &(idx, score) in source {
785 if pipeline_ranked.len() >= rag_cutoff.min_k {
786 break;
787 }
788 if !pipeline_ranked.iter().any(|(i, _)| *i == idx) {
789 pipeline_ranked.push((idx, score));
790 }
791 }
792 }
793 max_k_truncate(pipeline_ranked, rag_cutoff.max_k)
794}
795
796fn enrich_hits(
799 client: &mut impl RpcCaller,
800 selected: &[(usize, f64)],
801 all_chunks: &[StructureChunk],
802 score_kind: &str,
803 include_fulltext_spans: bool,
804) -> Vec<Value> {
805 let mut meta_cache: std::collections::HashMap<String, Value> =
806 std::collections::HashMap::new();
807 let mut hits: Vec<Value> = Vec::new();
808 for (idx, score) in selected {
809 let chunk = &all_chunks[*idx];
810 let meta = if let Some(cached) = meta_cache.get(&chunk.item_key) {
811 cached.clone()
812 } else {
813 let fetched = client
814 .call("items.get", Some(serde_json::json!({"key": chunk.item_key})))
815 .unwrap_or(Value::Null);
816 meta_cache.insert(chunk.item_key.clone(), fetched.clone());
817 fetched
818 };
819 let title = meta
820 .get("title")
821 .and_then(Value::as_str)
822 .unwrap_or("")
823 .to_string();
824 let authors = meta
825 .get("creators")
826 .and_then(Value::as_array)
827 .map(|creators| {
828 creators
829 .iter()
830 .filter_map(|c| {
831 let last = c.get("lastName").and_then(Value::as_str).unwrap_or("");
832 let first = c.get("firstName").and_then(Value::as_str).unwrap_or("");
833 if last.is_empty() && first.is_empty() {
834 None
835 } else {
836 Some(format!("{last}{first}"))
837 }
838 })
839 .collect::<Vec<_>>()
840 .join(", ")
841 })
842 .unwrap_or_default();
843 let year = meta.get("date").and_then(Value::as_str).unwrap_or("");
844 let mut hit = serde_json::json!({
845 "itemKey": chunk.item_key,
846 "chunkKey": chunk.chunk_key,
847 "title": title,
848 "authors": authors,
849 "year": year,
850 "text": chunk.text,
851 "pageRange": chunk.page_range,
852 "sectionPath": chunk.section_path,
853 "score": score,
854 "scoreKind": score_kind,
855 });
856 if include_fulltext_spans {
857 hit.as_object_mut().unwrap().insert(
858 "attachmentKey".to_string(),
859 Value::String(chunk.attachment_key.clone()),
860 );
861 }
862 hits.push(hit);
863 }
864 hits
865}
866
867fn format_hits(hits: &[Value], actual_mode: &str, options: &RagSearchOptions) -> Result<String, String> {
872 if options.output == "jsonl" {
873 let mut out = String::new();
874 for hit in hits {
875 out.push_str(&serde_json::to_string(hit).map_err(|e| e.to_string())?);
876 out.push('\n');
877 }
878 Ok(out)
879 } else {
880 let total = hits.len() as u64;
881 format_json(&normalize_list_envelope(
882 serde_json::json!({"items": hits, "total": total, "mode": actual_mode}),
883 "items",
884 Some(options.top_k),
885 0,
886 ))
887 }
888}
889
890pub(crate) fn run_rag_search_command(
891 client: &mut impl RpcCaller,
892 options: RagSearchOptions,
893) -> Result<String, String> {
894 if options.zotero {
896 if options.collection.is_none() && options.keys.is_empty() {
897 return Err(
898 "INVALID_ARGS: --collection or --key is required".to_string(),
899 );
900 }
901 return run_rag_search_xpi_fallback(client, &options);
902 }
903
904 if options.collection.is_none() && options.keys.is_empty() {
906 return Err("INVALID_ARGS: --collection or --key required".to_string());
907 }
908
909 let sidecars = resolve_sidecar_paths(
911 client,
912 options.collection.as_deref(),
913 &options.keys,
914 );
915
916 let sidecars = match sidecars {
919 Ok(ref s) if !s.is_empty() => s,
920 Err(ref e) if e.contains("COLLECTION_NOT_FOUND") => return Err(e.clone()),
921 _ => return run_rag_search_xpi_fallback(client, &options),
922 };
923
924 let settings_blob = client.call("settings.getAll", None)?;
930 let emb_raw = client.call(
931 "settings.getRaw",
932 Some(serde_json::json!({"key": "embedding.apiKey"})),
933 )?;
934 let emb_key = emb_raw
935 .get("embedding.apiKey")
936 .and_then(Value::as_str)
937 .unwrap_or("")
938 .to_string();
939 let (emb_provider, emb_model, emb_url, emb_key) =
940 parse_embedding_settings(&settings_blob, emb_key);
941
942 let mut all_chunks: Vec<StructureChunk> = Vec::new();
943 let mut all_vectors: Vec<EmbeddingVector> = Vec::new();
944 for (_item_key, _att_key, sidecar_root) in sidecars {
945 all_chunks.extend(load_sidecar_chunks(sidecar_root));
946 all_vectors.extend(load_sidecar_vectors(sidecar_root, &emb_provider, &emb_model));
947 }
948
949 if all_chunks.is_empty() {
950 return run_rag_search_xpi_fallback(client, &options);
951 }
952
953 let requested_mode = parse_retrieval_mode(&settings_blob);
955
956 let mut bm25_ranked = if requested_mode != "dense" {
958 bm25_score_chunks(&all_chunks, &options.query, 1.2, 0.75)
959 } else {
960 Vec::new()
961 };
962
963 let dense_ranked = if requested_mode != "lexical" && !all_vectors.is_empty() {
965 score_dense(
966 &options.query,
967 &emb_provider,
968 &emb_model,
969 &emb_url,
970 &emb_key,
971 &all_chunks,
972 &all_vectors,
973 )
974 } else {
975 if requested_mode == "dense" && all_vectors.is_empty() {
976 eprintln!("warning: dense retrieval requested but no embedding vectors found for this scope");
977 }
978 Vec::new()
979 };
980
981 let limit = options.top_k as usize;
984 let actual_mode: &str;
985 let rrf_ranked = if !bm25_ranked.is_empty() && !dense_ranked.is_empty() {
986 actual_mode = "hybrid";
987 rrf_merge(&bm25_ranked, &dense_ranked, 60.0, limit)
988 } else if !dense_ranked.is_empty() {
989 actual_mode = "dense";
990 dense_ranked.into_iter().take(limit).collect()
991 } else if !bm25_ranked.is_empty() {
992 actual_mode = "lexical";
993 bm25_ranked.clone().into_iter().take(limit).collect()
994 } else {
995 if requested_mode == "dense" {
1002 eprintln!("warning: falling back to lexical (BM25) retrieval");
1003 bm25_ranked = bm25_score_chunks(&all_chunks, &options.query, 1.2, 0.75);
1004 }
1005 actual_mode = "lexical";
1006 bm25_ranked.clone().into_iter().take(limit).collect()
1007 };
1008
1009 let rerank_api_key = client
1016 .call(
1017 "settings.getRaw",
1018 Some(serde_json::json!({"key": "rerank.apiKey"})),
1019 )
1020 .ok()
1021 .and_then(|raw| {
1022 raw.get("rerank.apiKey")
1023 .and_then(Value::as_str)
1024 .map(str::to_string)
1025 })
1026 .unwrap_or_default();
1027 let rerank_settings = parse_rerank_settings(&settings_blob, rerank_api_key);
1028 let rag_cutoff = parse_rag_cutoff_settings(&settings_blob);
1029
1030 let mut pipeline_ranked = rrf_ranked.clone();
1031 let mut full_reranked: Option<Vec<(usize, f64)>> = None;
1032
1033 if !rerank_settings.provider.is_empty() && !rerank_settings.api_key.is_empty() {
1035 match rerank_chunks(&options.query, &all_chunks, &pipeline_ranked, &rerank_settings) {
1036 Ok(reranked) => {
1037 full_reranked = Some(reranked.clone());
1038 pipeline_ranked = reranked;
1039 }
1040 Err(e) => {
1041 eprintln!("warning: reranker skipped: {e}");
1042 }
1043 }
1044 }
1045
1046 let score_kind: &str = if full_reranked.is_some() {
1049 "rerank"
1050 } else {
1051 match actual_mode {
1052 "hybrid" => "rrf",
1053 "dense" => "cosine",
1054 _ => "bm25",
1055 }
1056 };
1057
1058 let ranked = apply_cutoff_pipeline(
1060 pipeline_ranked,
1061 &rrf_ranked,
1062 &full_reranked,
1063 &all_chunks,
1064 &all_vectors,
1065 &rag_cutoff,
1066 );
1067
1068 let mut per_item_count: std::collections::HashMap<&str, u64> =
1070 std::collections::HashMap::new();
1071 let mut selected: Vec<(usize, f64)> = Vec::new();
1072 for (idx, score) in &ranked {
1073 let item_key = all_chunks[*idx].item_key.as_str();
1074 let count = per_item_count.entry(item_key).or_insert(0);
1075 if *count < options.top_spans_per_item {
1076 *count += 1;
1077 selected.push((*idx, *score));
1078 }
1079 }
1080
1081 let hits = enrich_hits(
1083 client,
1084 &selected,
1085 &all_chunks,
1086 score_kind,
1087 options.include_fulltext_spans,
1088 );
1089
1090 format_hits(&hits, actual_mode, &options)
1092}
1093
1094pub(crate) fn rag_status_value(client: &mut impl RpcCaller, collection: &str) -> Result<Value, String> {
1095 let raw_store_path = rag_store_path(collection);
1096 if raw_store_path.exists() {
1097 return rag_status_from_store(collection, &raw_store_path);
1098 }
1099
1100 let mut store_candidates = Vec::new();
1101 let collection_match = find_collection_in_tree(client, collection)?;
1102 if let Some(collection_node) = collection_match.as_ref() {
1103 if let Some(name) = collection_node.get("name").and_then(Value::as_str) {
1104 store_candidates.push(rag_store_path(name));
1105 }
1106 if let Some(key) = collection_node.get("key").and_then(Value::as_str) {
1107 store_candidates.push(rag_store_path(key));
1108 }
1109 }
1110 for store_path in unique_paths(store_candidates) {
1111 if store_path.exists() {
1112 return rag_status_from_store(collection, &store_path);
1113 }
1114 }
1115
1116 rag_status_from_zotero_sidecars(client, collection, collection_match)
1117}
1118
1119pub(crate) fn unique_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
1120 let mut unique = Vec::new();
1121 for path in paths {
1122 if !unique.iter().any(|seen| seen == &path) {
1123 unique.push(path);
1124 }
1125 }
1126 unique
1127}
1128
1129pub(crate) fn rag_status_from_store(collection: &str, store_path: &Path) -> Result<Value, String> {
1130 let raw = fs::read_to_string(store_path)
1131 .map_err(|err| format!("read RAG store {}: {err}", store_path.display()))?;
1132 let store: Value = serde_json::from_str(&raw)
1133 .map_err(|err| format!("parse RAG store {}: {err}", store_path.display()))?;
1134 let chunks = store
1135 .get("chunks")
1136 .and_then(Value::as_array)
1137 .cloned()
1138 .unwrap_or_default();
1139 let mut item_keys = Vec::<Value>::new();
1140 for chunk in &chunks {
1141 let Some(item_key) = chunk.get("item_key") else {
1142 continue;
1143 };
1144 if !item_keys.iter().any(|seen| seen == item_key) {
1145 item_keys.push(item_key.clone());
1146 }
1147 }
1148 Ok(serde_json::json!({
1149 "status": "indexed",
1150 "collection": store.get("collection").and_then(Value::as_str).unwrap_or(collection),
1151 "collectionKey": store.get("collection_key").cloned().unwrap_or(Value::Null),
1152 "model": store.get("model").cloned().unwrap_or(Value::String("unknown".to_string())),
1153 "totalChunks": chunks.len(),
1154 "totalItems": item_keys.len(),
1155 "storePath": store_path.to_string_lossy(),
1156 }))
1157}
1158
1159pub(crate) fn rag_status_from_zotero_sidecars(
1160 client: &mut impl RpcCaller,
1161 collection: &str,
1162 collection_match: Option<Value>,
1163) -> Result<Value, String> {
1164 let collection_key = collection_match
1165 .as_ref()
1166 .and_then(|node| node.get("key").cloned())
1167 .ok_or_else(|| format!("COLLECTION_NOT_FOUND: Collection not found: {collection:?}"))?;
1168 let raw = paginate_rpc(
1169 client,
1170 "collections.getItems",
1171 serde_json::json!({"key": collection_key}),
1172 500,
1173 )?;
1174 let items = raw
1175 .get("items")
1176 .and_then(Value::as_array)
1177 .or_else(|| raw.as_array())
1178 .ok_or_else(|| "collections.getItems returned non-array/non-items result".to_string())?
1179 .clone();
1180
1181 let (emb_provider, emb_model) = if items.is_empty() {
1186 (String::new(), String::new())
1187 } else {
1188 fetch_embedding_settings(client)
1189 .map(|(p, m, _, _)| (p, m))
1190 .unwrap_or_default()
1191 };
1192
1193 let mut indexed_items = 0usize;
1194 let mut total_chunks = 0usize;
1195 let mut total_vectors = 0usize;
1196 for item in &items {
1197 let item_key = item.get("key").cloned().unwrap_or(Value::Null);
1198 let (chunk_count, vector_count) =
1200 sidecar_counts_for_item(client, &item_key, &emb_provider, &emb_model)?;
1201 if chunk_count > 0 {
1202 indexed_items += 1;
1203 total_chunks += chunk_count;
1204 total_vectors += vector_count;
1205 }
1206 }
1207
1208 if indexed_items == 0 {
1209 return Ok(serde_json::json!({
1210 "status": "not indexed",
1211 "collection": collection,
1212 "totalItems": items.len(),
1213 "indexedItems": 0,
1214 }));
1215 }
1216
1217 Ok(serde_json::json!({
1218 "status": "indexed",
1219 "collection": collection,
1220 "totalChunks": total_chunks,
1221 "totalItems": indexed_items,
1222 "collectionItems": items.len(),
1223 "totalVectors": total_vectors,
1227 "embeddingsAvailable": total_vectors > 0,
1228 "embeddingProvider": emb_provider,
1229 "embeddingModel": emb_model,
1230 "source": "zotero-sidecar",
1231 }))
1232}
1233
1234pub(crate) fn sidecar_counts_for_item(
1237 client: &mut impl RpcCaller,
1238 item_key: &Value,
1239 emb_provider: &str,
1240 emb_model: &str,
1241) -> Result<(usize, usize), String> {
1242 let attachments = client.call(
1243 "attachments.list",
1244 Some(serde_json::json!({"parentKey": item_key.clone()})),
1245 )?;
1246 let Some(attachments) = attachments.as_array() else {
1247 return Ok((0, 0));
1248 };
1249
1250 let mut chunk_count = 0usize;
1251 let mut vector_count = 0usize;
1252 for attachment in attachments {
1253 let Some(path) = attachment.get("path").and_then(Value::as_str) else {
1254 continue;
1255 };
1256 let local = local_path_from_zotero_path(path);
1257 let Some(dir) = Path::new(&local).parent() else {
1258 continue;
1259 };
1260 if let Ok(bytes) = read_machine_artifact_sidecar(dir, MachineArtifactKind::Chunks) {
1261 let text = String::from_utf8_lossy(&bytes);
1262 chunk_count += text
1264 .lines()
1265 .filter(|line| !line.trim().is_empty())
1266 .filter(|line| !is_chunk_schema_header(line))
1267 .count();
1268 }
1269 let sidecar_root = dir.join(".zotron");
1270 vector_count += load_sidecar_vectors(&sidecar_root, emb_provider, emb_model).len();
1271 }
1272 Ok((chunk_count, vector_count))
1273}
1274
1275pub(crate) fn rag_store_path(collection: &str) -> PathBuf {
1276 rag_store_root().join(format!("{collection}.json"))
1277}
1278
1279pub(crate) fn rag_store_root() -> PathBuf {
1280 let xdg_data_home = env::var_os("XDG_DATA_HOME")
1281 .filter(|path| !path.is_empty())
1282 .map(PathBuf::from);
1283 let appdata = env::var_os("APPDATA")
1284 .filter(|path| !path.is_empty())
1285 .map(PathBuf::from);
1286 let userprofile = env::var_os("USERPROFILE")
1287 .filter(|path| !path.is_empty())
1288 .map(PathBuf::from);
1289 let home = env::var_os("HOME")
1290 .filter(|path| !path.is_empty())
1291 .map(PathBuf::from);
1292
1293 rag_store_root_for_platform(
1294 ArtifactStorePlatform::current(),
1295 xdg_data_home.as_deref(),
1296 appdata.as_deref(),
1297 userprofile.as_deref(),
1298 home.as_deref(),
1299 )
1300}
1301
1302pub(crate) fn rag_store_root_for_platform(
1303 platform: ArtifactStorePlatform,
1304 xdg_data_home: Option<&Path>,
1305 appdata: Option<&Path>,
1306 userprofile: Option<&Path>,
1307 home: Option<&Path>,
1308) -> PathBuf {
1309 match platform {
1310 ArtifactStorePlatform::Windows => {
1311 if let Some(path) = appdata {
1312 return path.join("Zotron").join("rag");
1313 }
1314 if let Some(path) = userprofile {
1315 return path
1316 .join("AppData")
1317 .join("Roaming")
1318 .join("Zotron")
1319 .join("rag");
1320 }
1321 if let Some(path) = home {
1322 return path
1323 .join("AppData")
1324 .join("Roaming")
1325 .join("Zotron")
1326 .join("rag");
1327 }
1328 PathBuf::from(".zotron").join("rag")
1329 }
1330 ArtifactStorePlatform::Macos => {
1331 if let Some(path) = home {
1332 return path
1333 .join("Library")
1334 .join("Application Support")
1335 .join("Zotron")
1336 .join("rag");
1337 }
1338 if let Some(path) = xdg_data_home {
1339 return path.join("zotron").join("rag");
1340 }
1341 PathBuf::from(".zotron").join("rag")
1342 }
1343 ArtifactStorePlatform::Linux | ArtifactStorePlatform::Other => xdg_data_home
1344 .map(|path| path.join("zotron").join("rag"))
1345 .or_else(|| {
1346 home.map(|path| path.join(".local").join("share").join("zotron").join("rag"))
1347 })
1348 .unwrap_or_else(|| PathBuf::from(".zotron").join("rag")),
1349 }
1350}
1351
1352#[cfg(test)]
1353mod rerank_bounds_tests {
1354 use super::map_reranked_to_candidates;
1355 use zotron_types::RerankResult;
1356
1357 #[test]
1358 fn drops_out_of_range_indices_without_panicking() {
1359 let candidates = vec![(10_usize, 0.1_f64), (20, 0.2), (30, 0.3)];
1361 let reranked = vec![
1362 RerankResult { index: 1, score: 0.9 }, RerankResult { index: 5, score: 0.8 }, RerankResult { index: 0, score: 0.7 }, ];
1366
1367 let mapped = map_reranked_to_candidates(reranked, &candidates);
1368
1369 assert_eq!(mapped.len(), 2);
1371 assert_eq!(mapped[0], (20, 0.9));
1373 assert_eq!(mapped[1], (10, 0.7));
1374 }
1375
1376 #[test]
1377 fn all_out_of_range_yields_empty() {
1378 let candidates = vec![(10_usize, 0.1_f64)];
1379 let reranked = vec![
1380 RerankResult { index: 1, score: 0.9 },
1381 RerankResult { index: 99, score: 0.8 },
1382 ];
1383 let mapped = map_reranked_to_candidates(reranked, &candidates);
1384 assert!(mapped.is_empty());
1385 }
1386}
1387
1388#[cfg(test)]
1389mod sidecar_header_tests {
1390 use super::is_chunk_schema_header;
1391
1392 #[test]
1393 fn detects_schema_version_header_line() {
1394 assert!(is_chunk_schema_header("{\"schema_version\":2}"));
1395 assert!(is_chunk_schema_header("{\"schema_version\": 1}"));
1396 }
1397
1398 #[test]
1399 fn does_not_flag_a_chunk_whose_text_is_the_token() {
1400 let chunk_line = "{\"chunk_key\":\"ATT1:c0\",\"item_key\":\"ITEM1\",\"attachment_key\":\"ATT1\",\"block_keys\":[],\"section_path\":[],\"text\":\"schema_version\",\"page_range\":[0,0],\"evidence_refs\":[]}";
1404 assert!(!is_chunk_schema_header(chunk_line));
1405 }
1406
1407 #[test]
1408 fn does_not_flag_a_plain_chunk_line() {
1409 assert!(!is_chunk_schema_header("{\"chunk_key\":\"x\",\"text\":\"hello world\"}"));
1410 }
1411}