1pub mod cache;
2pub mod embedding;
3pub mod search;
4pub mod text;
5
6use embedding::{EmbeddingError, EmbeddingProvider};
7use noether_core::stage::{Stage, StageId, StageLifecycle};
8use noether_store::StageStore;
9use search::SubIndex;
10use std::collections::BTreeMap;
11use std::collections::HashMap;
12
13pub struct IndexConfig {
15 pub signature_weight: f32,
17 pub semantic_weight: f32,
19 pub example_weight: f32,
21}
22
23impl Default for IndexConfig {
24 fn default() -> Self {
25 Self {
26 signature_weight: 0.3,
27 semantic_weight: 0.5,
28 example_weight: 0.2,
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct SearchResult {
36 pub stage_id: StageId,
37 pub score: f32,
38 pub signature_score: f32,
39 pub semantic_score: f32,
40 pub example_score: f32,
41}
42
43pub struct SemanticIndex {
45 provider: Box<dyn EmbeddingProvider>,
46 signature_index: SubIndex,
47 semantic_index: SubIndex,
48 example_index: SubIndex,
49 config: IndexConfig,
50 tag_map: HashMap<String, Vec<StageId>>,
52}
53
54impl SemanticIndex {
55 pub fn from_stages(
58 stages: Vec<Stage>,
59 provider: Box<dyn EmbeddingProvider>,
60 config: IndexConfig,
61 ) -> Result<Self, EmbeddingError> {
62 let mut index = Self {
63 provider,
64 signature_index: SubIndex::new(),
65 semantic_index: SubIndex::new(),
66 example_index: SubIndex::new(),
67 config,
68 tag_map: HashMap::new(),
69 };
70 for stage in &stages {
71 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
72 continue;
73 }
74 index.add_stage(stage)?;
75 }
76 Ok(index)
77 }
78
79 pub fn build(
81 store: &dyn StageStore,
82 provider: Box<dyn EmbeddingProvider>,
83 config: IndexConfig,
84 ) -> Result<Self, EmbeddingError> {
85 let mut index = Self {
86 provider,
87 signature_index: SubIndex::new(),
88 semantic_index: SubIndex::new(),
89 example_index: SubIndex::new(),
90 config,
91 tag_map: HashMap::new(),
92 };
93 for stage in store.list(None) {
94 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
95 continue;
96 }
97 index.add_stage(stage)?;
98 }
99 Ok(index)
100 }
101
102 pub fn from_stages_batched(
109 stages: Vec<Stage>,
110 cached_provider: cache::CachedEmbeddingProvider,
111 config: IndexConfig,
112 chunk_size: usize,
113 ) -> Result<Self, EmbeddingError> {
114 Self::from_stages_batched_paced(
115 stages,
116 cached_provider,
117 config,
118 chunk_size,
119 std::time::Duration::ZERO,
120 )
121 }
122
123 pub fn from_stages_batched_paced(
128 stages: Vec<Stage>,
129 mut cached_provider: cache::CachedEmbeddingProvider,
130 config: IndexConfig,
131 chunk_size: usize,
132 inter_batch_delay: std::time::Duration,
133 ) -> Result<Self, EmbeddingError> {
134 let active: Vec<&Stage> = stages
136 .iter()
137 .filter(|s| !matches!(s.lifecycle, StageLifecycle::Tombstone))
138 .collect();
139
140 let mut all_texts: Vec<String> = Vec::with_capacity(active.len() * 3);
141 for s in &active {
142 all_texts.push(text::signature_text(s));
143 all_texts.push(text::description_text(s));
144 all_texts.push(text::examples_text(s));
145 }
146 let text_refs: Vec<&str> = all_texts.iter().map(|s| s.as_str()).collect();
147 let embeddings =
148 cached_provider.embed_batch_cached_paced(&text_refs, chunk_size, inter_batch_delay)?;
149 cached_provider.flush();
150
151 let mut signature_index = SubIndex::new();
153 let mut semantic_index = SubIndex::new();
154 let mut example_index = SubIndex::new();
155 let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
156
157 for (i, s) in active.iter().enumerate() {
158 signature_index.add(s.id.clone(), embeddings[i * 3].clone());
159 semantic_index.add(s.id.clone(), embeddings[i * 3 + 1].clone());
160 example_index.add(s.id.clone(), embeddings[i * 3 + 2].clone());
161 for tag in &s.tags {
162 tag_map.entry(tag.clone()).or_default().push(s.id.clone());
163 }
164 }
165
166 Ok(Self {
167 provider: Box::new(cached_provider),
168 signature_index,
169 semantic_index,
170 example_index,
171 config,
172 tag_map,
173 })
174 }
175
176 pub fn build_cached(
178 store: &dyn StageStore,
179 mut cached_provider: cache::CachedEmbeddingProvider,
180 config: IndexConfig,
181 ) -> Result<Self, EmbeddingError> {
182 let mut signature_index = SubIndex::new();
183 let mut semantic_index = SubIndex::new();
184 let mut example_index = SubIndex::new();
185 let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
186
187 for stage in store.list(None) {
188 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
189 continue;
190 }
191 let sig_emb = cached_provider.embed_cached(&text::signature_text(stage))?;
192 let desc_emb = cached_provider.embed_cached(&text::description_text(stage))?;
193 let ex_emb = cached_provider.embed_cached(&text::examples_text(stage))?;
194
195 signature_index.add(stage.id.clone(), sig_emb);
196 semantic_index.add(stage.id.clone(), desc_emb);
197 example_index.add(stage.id.clone(), ex_emb);
198
199 for tag in &stage.tags {
200 tag_map
201 .entry(tag.clone())
202 .or_default()
203 .push(stage.id.clone());
204 }
205 }
206
207 cached_provider.flush();
208
209 let provider: Box<dyn EmbeddingProvider> = Box::new(cached_provider);
211
212 Ok(Self {
213 provider,
214 signature_index,
215 semantic_index,
216 example_index,
217 config,
218 tag_map,
219 })
220 }
221
222 pub fn add_stage(&mut self, stage: &Stage) -> Result<(), EmbeddingError> {
224 let sig_text = text::signature_text(stage);
225 let desc_text = text::description_text(stage);
226 let ex_text = text::examples_text(stage);
227
228 let sig_emb = self.provider.embed(&sig_text)?;
229 let desc_emb = self.provider.embed(&desc_text)?;
230 let ex_emb = self.provider.embed(&ex_text)?;
231
232 self.signature_index.add(stage.id.clone(), sig_emb);
233 self.semantic_index.add(stage.id.clone(), desc_emb);
234 self.example_index.add(stage.id.clone(), ex_emb);
235
236 for tag in &stage.tags {
237 self.tag_map
238 .entry(tag.clone())
239 .or_default()
240 .push(stage.id.clone());
241 }
242
243 Ok(())
244 }
245
246 pub fn remove_stage(&mut self, stage_id: &StageId) {
248 self.signature_index.remove(stage_id);
249 self.semantic_index.remove(stage_id);
250 self.example_index.remove(stage_id);
251
252 for ids in self.tag_map.values_mut() {
253 ids.retain(|id| id != stage_id);
254 }
255 self.tag_map.retain(|_, ids| !ids.is_empty());
256 }
257
258 pub fn len(&self) -> usize {
260 self.signature_index.len()
261 }
262
263 pub fn is_empty(&self) -> bool {
264 self.len() == 0
265 }
266
267 pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, EmbeddingError> {
269 self.search_filtered(query, top_k, None)
270 }
271
272 pub fn search_filtered(
275 &self,
276 query: &str,
277 top_k: usize,
278 tag: Option<&str>,
279 ) -> Result<Vec<SearchResult>, EmbeddingError> {
280 let query_emb = self.provider.embed(query)?;
281 let fetch_k = top_k * 2;
282
283 let sig_results = self.signature_index.search(&query_emb, fetch_k);
284 let sem_results = self.semantic_index.search(&query_emb, fetch_k);
285 let ex_results = self.example_index.search(&query_emb, fetch_k);
286
287 let allowed: Option<std::collections::BTreeSet<&str>> = tag.map(|t| {
289 self.tag_map
290 .get(t)
291 .map(|ids| ids.iter().map(|id| id.0.as_str()).collect())
292 .unwrap_or_default()
293 });
294
295 let mut scores: BTreeMap<String, (f32, f32, f32)> = BTreeMap::new();
297 for r in &sig_results {
298 scores.entry(r.stage_id.0.clone()).or_default().0 = r.score;
299 }
300 for r in &sem_results {
301 scores.entry(r.stage_id.0.clone()).or_default().1 = r.score;
302 }
303 for r in &ex_results {
304 scores.entry(r.stage_id.0.clone()).or_default().2 = r.score;
305 }
306
307 let mut results: Vec<SearchResult> = scores
309 .into_iter()
310 .filter(|(id, _)| {
311 allowed
312 .as_ref()
313 .map(|a| a.contains(id.as_str()))
314 .unwrap_or(true)
315 })
316 .map(|(id, (sig, sem, ex))| {
317 let fused = self.config.signature_weight * sig.max(0.0)
318 + self.config.semantic_weight * sem.max(0.0)
319 + self.config.example_weight * ex.max(0.0);
320 SearchResult {
321 stage_id: StageId(id),
322 score: fused,
323 signature_score: sig,
324 semantic_score: sem,
325 example_score: ex,
326 }
327 })
328 .collect();
329
330 results.sort_by(|a, b| {
331 b.score
332 .partial_cmp(&a.score)
333 .unwrap_or(std::cmp::Ordering::Equal)
334 });
335 results.truncate(top_k);
336 Ok(results)
337 }
338
339 pub fn search_by_tag(&self, tag: &str) -> Vec<StageId> {
341 self.tag_map.get(tag).cloned().unwrap_or_default()
342 }
343
344 pub fn all_tags(&self) -> Vec<String> {
346 let mut tags: Vec<String> = self.tag_map.keys().cloned().collect();
347 tags.sort();
348 tags
349 }
350
351 pub fn check_duplicate_before_insert(
356 &self,
357 description: &str,
358 threshold: f32,
359 ) -> Result<Option<(StageId, f32)>, EmbeddingError> {
360 let emb = self.provider.embed(description)?;
361 let results = self.semantic_index.search(&emb, 1);
362 if let Some(top) = results.first() {
363 if top.score >= threshold {
364 return Ok(Some((top.stage_id.clone(), top.score)));
365 }
366 }
367 Ok(None)
368 }
369
370 pub fn find_near_duplicates(&self, threshold: f32) -> Vec<(StageId, StageId, f32)> {
375 use search::cosine_similarity;
376
377 let entries = self.semantic_index.entries().to_vec();
378 let mut pairs: Vec<(StageId, StageId, f32)> = Vec::new();
379
380 for i in 0..entries.len() {
381 for j in (i + 1)..entries.len() {
382 let sim = cosine_similarity(&entries[i].embedding, &entries[j].embedding);
383 if sim >= threshold {
384 let (a, b) = if entries[i].stage_id.0 < entries[j].stage_id.0 {
385 (entries[i].stage_id.clone(), entries[j].stage_id.clone())
386 } else {
387 (entries[j].stage_id.clone(), entries[i].stage_id.clone())
388 };
389 pairs.push((a, b, sim));
390 }
391 }
392 }
393
394 pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
396 pairs
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use embedding::MockEmbeddingProvider;
404 use noether_core::effects::EffectSet;
405 use noether_core::stage::{CostEstimate, StageSignature};
406 use noether_core::types::NType;
407 use noether_store::MemoryStore;
408 use std::collections::BTreeSet;
409
410 fn make_stage(id: &str, desc: &str, input: NType, output: NType) -> Stage {
411 Stage {
412 id: StageId(id.into()),
413 canonical_id: None,
414 signature: StageSignature {
415 input,
416 output,
417 effects: EffectSet::pure(),
418 implementation_hash: format!("impl_{id}"),
419 },
420 capabilities: BTreeSet::new(),
421 cost: CostEstimate {
422 time_ms_p50: None,
423 tokens_est: None,
424 memory_mb: None,
425 },
426 description: desc.into(),
427 examples: vec![],
428 lifecycle: StageLifecycle::Active,
429 ed25519_signature: None,
430 signer_public_key: None,
431 implementation_code: None,
432 implementation_language: None,
433 ui_style: None,
434 tags: vec![],
435 aliases: vec![],
436 }
437 }
438
439 fn test_store() -> MemoryStore {
440 let mut store = MemoryStore::new();
441 store
442 .put(make_stage(
443 "s1",
444 "convert text to number",
445 NType::Text,
446 NType::Number,
447 ))
448 .unwrap();
449 store
450 .put(make_stage(
451 "s2",
452 "make http request",
453 NType::Text,
454 NType::Text,
455 ))
456 .unwrap();
457 store
458 .put(make_stage(
459 "s3",
460 "sort a list of items",
461 NType::List(Box::new(NType::Any)),
462 NType::List(Box::new(NType::Any)),
463 ))
464 .unwrap();
465 store
466 }
467
468 #[test]
469 fn build_indexes_all_stages() {
470 let store = test_store();
471 let index = SemanticIndex::build(
472 &store,
473 Box::new(MockEmbeddingProvider::new(32)),
474 IndexConfig::default(),
475 )
476 .unwrap();
477 assert_eq!(index.len(), 3);
478 }
479
480 #[test]
481 fn add_stage_increments_count() {
482 let store = test_store();
483 let mut index = SemanticIndex::build(
484 &store,
485 Box::new(MockEmbeddingProvider::new(32)),
486 IndexConfig::default(),
487 )
488 .unwrap();
489 assert_eq!(index.len(), 3);
490 index
491 .add_stage(&make_stage("s4", "new stage", NType::Bool, NType::Text))
492 .unwrap();
493 assert_eq!(index.len(), 4);
494 }
495
496 #[test]
497 fn remove_stage_decrements_count() {
498 let store = test_store();
499 let mut index = SemanticIndex::build(
500 &store,
501 Box::new(MockEmbeddingProvider::new(32)),
502 IndexConfig::default(),
503 )
504 .unwrap();
505 index.remove_stage(&StageId("s1".into()));
506 assert_eq!(index.len(), 2);
507 }
508
509 #[test]
510 fn search_returns_results() {
511 let store = test_store();
512 let index = SemanticIndex::build(
513 &store,
514 Box::new(MockEmbeddingProvider::new(32)),
515 IndexConfig::default(),
516 )
517 .unwrap();
518 let results = index.search("convert text", 10).unwrap();
519 assert!(!results.is_empty());
520 }
521
522 #[test]
523 fn search_respects_top_k() {
524 let store = test_store();
525 let index = SemanticIndex::build(
526 &store,
527 Box::new(MockEmbeddingProvider::new(32)),
528 IndexConfig::default(),
529 )
530 .unwrap();
531 let results = index.search("anything", 2).unwrap();
532 assert!(results.len() <= 2);
533 }
534
535 #[test]
536 fn search_self_is_top_result() {
537 let store = test_store();
538 let index = SemanticIndex::build(
539 &store,
540 Box::new(MockEmbeddingProvider::new(128)),
541 IndexConfig::default(),
542 )
543 .unwrap();
544 let results = index.search("convert text to number", 3).unwrap();
546 assert!(!results.is_empty());
547 let top = &results[0];
550 assert!(
551 top.semantic_score > 0.9,
552 "Expected high semantic score for exact match, got {}",
553 top.semantic_score
554 );
555 }
556
557 #[test]
558 fn tombstoned_stages_not_indexed() {
559 let mut store = MemoryStore::new();
560 let mut s = make_stage("s1", "active stage", NType::Text, NType::Text);
561 store.put(s.clone()).unwrap();
562 s.id = StageId("s2".into());
563 s.description = "tombstoned stage".into();
564 s.lifecycle = StageLifecycle::Tombstone;
565 store.put(s).unwrap();
566
567 let index = SemanticIndex::build(
568 &store,
569 Box::new(MockEmbeddingProvider::new(32)),
570 IndexConfig::default(),
571 )
572 .unwrap();
573 assert_eq!(index.len(), 1);
574 }
575
576 #[test]
577 fn search_by_tag_returns_matching_stages() {
578 let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
579 s1.tags = vec!["network".into(), "io".into()];
580 let mut s2 = make_stage("s2", "text length", NType::Text, NType::Number);
581 s2.tags = vec!["text".into(), "pure".into()];
582
583 let stages = vec![s1, s2];
584 let index = SemanticIndex::from_stages(
585 stages,
586 Box::new(MockEmbeddingProvider::new(32)),
587 IndexConfig::default(),
588 )
589 .unwrap();
590
591 let network_ids = index.search_by_tag("network");
592 assert_eq!(network_ids.len(), 1);
593 assert_eq!(network_ids[0], StageId("s1".into()));
594
595 let pure_ids = index.search_by_tag("pure");
596 assert_eq!(pure_ids.len(), 1);
597 assert_eq!(pure_ids[0], StageId("s2".into()));
598
599 let missing = index.search_by_tag("nonexistent");
600 assert!(missing.is_empty());
601 }
602
603 #[test]
604 fn all_tags_returns_sorted_set() {
605 let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
606 s1.tags = vec!["zebra".into(), "apple".into()];
607 let index = SemanticIndex::from_stages(
608 vec![s1],
609 Box::new(MockEmbeddingProvider::new(32)),
610 IndexConfig::default(),
611 )
612 .unwrap();
613 let tags = index.all_tags();
614 assert_eq!(tags, vec!["apple", "zebra"]);
615 }
616
617 #[test]
618 fn search_filtered_restricts_to_tag() {
619 let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
620 s1.tags = vec!["network".into()];
621 let s2 = make_stage("s2", "sort list", NType::Text, NType::Text);
622
623 let stages = vec![s1, s2];
624 let index = SemanticIndex::from_stages(
625 stages,
626 Box::new(MockEmbeddingProvider::new(32)),
627 IndexConfig::default(),
628 )
629 .unwrap();
630
631 let filtered = index
632 .search_filtered("anything", 10, Some("network"))
633 .unwrap();
634 assert!(filtered.iter().all(|r| r.stage_id == StageId("s1".into())));
635
636 let all = index.search_filtered("anything", 10, None).unwrap();
637 assert_eq!(all.len(), 2);
638 }
639
640 #[test]
641 fn remove_stage_cleans_tag_map() {
642 let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
643 s1.tags = vec!["mytag".into()];
644 let mut index = SemanticIndex::from_stages(
645 vec![s1],
646 Box::new(MockEmbeddingProvider::new(32)),
647 IndexConfig::default(),
648 )
649 .unwrap();
650 assert_eq!(index.search_by_tag("mytag").len(), 1);
651 index.remove_stage(&StageId("s1".into()));
652 assert!(index.search_by_tag("mytag").is_empty());
653 assert!(index.all_tags().is_empty());
654 }
655}