Skip to main content

noether_engine/index/
mod.rs

1pub mod cache;
2pub mod embedding;
3pub mod search;
4pub mod text;
5
6use embedding::{EmbeddingError, EmbeddingProvider};
7use noether_core::stage::{Stage, StageId, StageLifecycle};
8use noether_store::StageStore;
9use search::SubIndex;
10use std::collections::BTreeMap;
11use std::collections::HashMap;
12
13/// Configuration for search result fusion weights.
14pub struct IndexConfig {
15    /// Weight for type signature similarity (default: 0.3).
16    pub signature_weight: f32,
17    /// Weight for description similarity (default: 0.5).
18    pub semantic_weight: f32,
19    /// Weight for example similarity (default: 0.2).
20    pub example_weight: f32,
21}
22
23impl Default for IndexConfig {
24    fn default() -> Self {
25        Self {
26            signature_weight: 0.3,
27            semantic_weight: 0.5,
28            example_weight: 0.2,
29        }
30    }
31}
32
33/// A search result with fused scores from all three indexes.
34#[derive(Debug, Clone)]
35pub struct SearchResult {
36    pub stage_id: StageId,
37    pub score: f32,
38    pub signature_score: f32,
39    pub semantic_score: f32,
40    pub example_score: f32,
41}
42
43/// Three-index semantic search over the stage store.
44pub struct SemanticIndex {
45    provider: Box<dyn EmbeddingProvider>,
46    signature_index: SubIndex,
47    semantic_index: SubIndex,
48    example_index: SubIndex,
49    config: IndexConfig,
50    /// Exact-match tag → stage IDs lookup for fast `search_filtered` pre-filtering.
51    tag_map: HashMap<String, Vec<StageId>>,
52}
53
54impl SemanticIndex {
55    /// Build the index from an owned list of stages (useful in async contexts
56    /// where holding a `&dyn StageStore` across `.await` is not possible).
57    pub fn from_stages(
58        stages: Vec<Stage>,
59        provider: Box<dyn EmbeddingProvider>,
60        config: IndexConfig,
61    ) -> Result<Self, EmbeddingError> {
62        let mut index = Self {
63            provider,
64            signature_index: SubIndex::new(),
65            semantic_index: SubIndex::new(),
66            example_index: SubIndex::new(),
67            config,
68            tag_map: HashMap::new(),
69        };
70        for stage in &stages {
71            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
72                continue;
73            }
74            index.add_stage(stage)?;
75        }
76        Ok(index)
77    }
78
79    /// Build the index from all non-tombstoned stages in a store.
80    pub fn build(
81        store: &dyn StageStore,
82        provider: Box<dyn EmbeddingProvider>,
83        config: IndexConfig,
84    ) -> Result<Self, EmbeddingError> {
85        let mut index = Self {
86            provider,
87            signature_index: SubIndex::new(),
88            semantic_index: SubIndex::new(),
89            example_index: SubIndex::new(),
90            config,
91            tag_map: HashMap::new(),
92        };
93        for stage in store.list(None) {
94            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
95                continue;
96            }
97            index.add_stage(stage)?;
98        }
99        Ok(index)
100    }
101
102    /// Build the index in a single pass: collect every signature/description/
103    /// example text upfront, dispatch all cache misses through
104    /// `inner.embed_batch` in chunks of `chunk_size`, then assemble the three
105    /// sub-indexes. Used by noether-cloud's registry on cold start so that
106    /// 486 stages × 3 texts = 1458 individual API calls collapse into ~46
107    /// batch calls of 32 texts each — well within typical rate limits.
108    pub fn from_stages_batched(
109        stages: Vec<Stage>,
110        cached_provider: cache::CachedEmbeddingProvider,
111        config: IndexConfig,
112        chunk_size: usize,
113    ) -> Result<Self, EmbeddingError> {
114        Self::from_stages_batched_paced(
115            stages,
116            cached_provider,
117            config,
118            chunk_size,
119            std::time::Duration::ZERO,
120        )
121    }
122
123    /// Like `from_stages_batched`, but waits `inter_batch_delay` between
124    /// successive batch calls and commits cache entries to disk after each
125    /// batch. Use this with rate-limited remote providers (e.g. Mistral
126    /// free tier ≈ 1 req/s → pass ~1100 ms).
127    pub fn from_stages_batched_paced(
128        stages: Vec<Stage>,
129        mut cached_provider: cache::CachedEmbeddingProvider,
130        config: IndexConfig,
131        chunk_size: usize,
132        inter_batch_delay: std::time::Duration,
133    ) -> Result<Self, EmbeddingError> {
134        // Filter active stages once and pre-compute all three texts per stage.
135        let active: Vec<&Stage> = stages
136            .iter()
137            .filter(|s| !matches!(s.lifecycle, StageLifecycle::Tombstone))
138            .collect();
139
140        let mut all_texts: Vec<String> = Vec::with_capacity(active.len() * 3);
141        for s in &active {
142            all_texts.push(text::signature_text(s));
143            all_texts.push(text::description_text(s));
144            all_texts.push(text::examples_text(s));
145        }
146        let text_refs: Vec<&str> = all_texts.iter().map(|s| s.as_str()).collect();
147        let embeddings =
148            cached_provider.embed_batch_cached_paced(&text_refs, chunk_size, inter_batch_delay)?;
149        cached_provider.flush();
150
151        // Distribute back into the three sub-indexes in stride 3.
152        let mut signature_index = SubIndex::new();
153        let mut semantic_index = SubIndex::new();
154        let mut example_index = SubIndex::new();
155        let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
156
157        for (i, s) in active.iter().enumerate() {
158            signature_index.add(s.id.clone(), embeddings[i * 3].clone());
159            semantic_index.add(s.id.clone(), embeddings[i * 3 + 1].clone());
160            example_index.add(s.id.clone(), embeddings[i * 3 + 2].clone());
161            for tag in &s.tags {
162                tag_map.entry(tag.clone()).or_default().push(s.id.clone());
163            }
164        }
165
166        Ok(Self {
167            provider: Box::new(cached_provider),
168            signature_index,
169            semantic_index,
170            example_index,
171            config,
172            tag_map,
173        })
174    }
175
176    /// Build using a CachedEmbeddingProvider for persistent embedding cache.
177    pub fn build_cached(
178        store: &dyn StageStore,
179        mut cached_provider: cache::CachedEmbeddingProvider,
180        config: IndexConfig,
181    ) -> Result<Self, EmbeddingError> {
182        let mut signature_index = SubIndex::new();
183        let mut semantic_index = SubIndex::new();
184        let mut example_index = SubIndex::new();
185        let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
186
187        for stage in store.list(None) {
188            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
189                continue;
190            }
191            let sig_emb = cached_provider.embed_cached(&text::signature_text(stage))?;
192            let desc_emb = cached_provider.embed_cached(&text::description_text(stage))?;
193            let ex_emb = cached_provider.embed_cached(&text::examples_text(stage))?;
194
195            signature_index.add(stage.id.clone(), sig_emb);
196            semantic_index.add(stage.id.clone(), desc_emb);
197            example_index.add(stage.id.clone(), ex_emb);
198
199            for tag in &stage.tags {
200                tag_map
201                    .entry(tag.clone())
202                    .or_default()
203                    .push(stage.id.clone());
204            }
205        }
206
207        cached_provider.flush();
208
209        // Wrap the inner provider for future queries
210        let provider: Box<dyn EmbeddingProvider> = Box::new(cached_provider);
211
212        Ok(Self {
213            provider,
214            signature_index,
215            semantic_index,
216            example_index,
217            config,
218            tag_map,
219        })
220    }
221
222    /// Add a single stage to all three indexes.
223    pub fn add_stage(&mut self, stage: &Stage) -> Result<(), EmbeddingError> {
224        let sig_text = text::signature_text(stage);
225        let desc_text = text::description_text(stage);
226        let ex_text = text::examples_text(stage);
227
228        let sig_emb = self.provider.embed(&sig_text)?;
229        let desc_emb = self.provider.embed(&desc_text)?;
230        let ex_emb = self.provider.embed(&ex_text)?;
231
232        self.signature_index.add(stage.id.clone(), sig_emb);
233        self.semantic_index.add(stage.id.clone(), desc_emb);
234        self.example_index.add(stage.id.clone(), ex_emb);
235
236        for tag in &stage.tags {
237            self.tag_map
238                .entry(tag.clone())
239                .or_default()
240                .push(stage.id.clone());
241        }
242
243        Ok(())
244    }
245
246    /// Remove a stage from all three indexes.
247    pub fn remove_stage(&mut self, stage_id: &StageId) {
248        self.signature_index.remove(stage_id);
249        self.semantic_index.remove(stage_id);
250        self.example_index.remove(stage_id);
251
252        for ids in self.tag_map.values_mut() {
253            ids.retain(|id| id != stage_id);
254        }
255        self.tag_map.retain(|_, ids| !ids.is_empty());
256    }
257
258    /// Number of stages indexed.
259    pub fn len(&self) -> usize {
260        self.signature_index.len()
261    }
262
263    pub fn is_empty(&self) -> bool {
264        self.len() == 0
265    }
266
267    /// Search across all three indexes and return ranked results.
268    pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, EmbeddingError> {
269        self.search_filtered(query, top_k, None)
270    }
271
272    /// Like `search`, but restricts candidates to stages carrying `tag` (exact match).
273    /// Passing `tag: None` is equivalent to `search`.
274    pub fn search_filtered(
275        &self,
276        query: &str,
277        top_k: usize,
278        tag: Option<&str>,
279    ) -> Result<Vec<SearchResult>, EmbeddingError> {
280        let query_emb = self.provider.embed(query)?;
281        let fetch_k = top_k * 2;
282
283        let sig_results = self.signature_index.search(&query_emb, fetch_k);
284        let sem_results = self.semantic_index.search(&query_emb, fetch_k);
285        let ex_results = self.example_index.search(&query_emb, fetch_k);
286
287        // Optional tag allow-list for filtering
288        let allowed: Option<std::collections::BTreeSet<&str>> = tag.map(|t| {
289            self.tag_map
290                .get(t)
291                .map(|ids| ids.iter().map(|id| id.0.as_str()).collect())
292                .unwrap_or_default()
293        });
294
295        // Collect scores per stage_id
296        let mut scores: BTreeMap<String, (f32, f32, f32)> = BTreeMap::new();
297        for r in &sig_results {
298            scores.entry(r.stage_id.0.clone()).or_default().0 = r.score;
299        }
300        for r in &sem_results {
301            scores.entry(r.stage_id.0.clone()).or_default().1 = r.score;
302        }
303        for r in &ex_results {
304            scores.entry(r.stage_id.0.clone()).or_default().2 = r.score;
305        }
306
307        // Fuse scores
308        let mut results: Vec<SearchResult> = scores
309            .into_iter()
310            .filter(|(id, _)| {
311                allowed
312                    .as_ref()
313                    .map(|a| a.contains(id.as_str()))
314                    .unwrap_or(true)
315            })
316            .map(|(id, (sig, sem, ex))| {
317                let fused = self.config.signature_weight * sig.max(0.0)
318                    + self.config.semantic_weight * sem.max(0.0)
319                    + self.config.example_weight * ex.max(0.0);
320                SearchResult {
321                    stage_id: StageId(id),
322                    score: fused,
323                    signature_score: sig,
324                    semantic_score: sem,
325                    example_score: ex,
326                }
327            })
328            .collect();
329
330        results.sort_by(|a, b| {
331            b.score
332                .partial_cmp(&a.score)
333                .unwrap_or(std::cmp::Ordering::Equal)
334        });
335        results.truncate(top_k);
336        Ok(results)
337    }
338
339    /// Return all stage IDs that carry `tag` (exact match).
340    pub fn search_by_tag(&self, tag: &str) -> Vec<StageId> {
341        self.tag_map.get(tag).cloned().unwrap_or_default()
342    }
343
344    /// Return the set of all known tags across indexed stages.
345    pub fn all_tags(&self) -> Vec<String> {
346        let mut tags: Vec<String> = self.tag_map.keys().cloned().collect();
347        tags.sort();
348        tags
349    }
350
351    /// Check whether a candidate description is a near-duplicate of an existing stage.
352    ///
353    /// Returns `Some((stage_id, similarity))` if any existing stage's semantic embedding
354    /// exceeds `threshold` (default 0.92). Returns `None` if the description is novel enough.
355    pub fn check_duplicate_before_insert(
356        &self,
357        description: &str,
358        threshold: f32,
359    ) -> Result<Option<(StageId, f32)>, EmbeddingError> {
360        let emb = self.provider.embed(description)?;
361        let results = self.semantic_index.search(&emb, 1);
362        if let Some(top) = results.first() {
363            if top.score >= threshold {
364                return Ok(Some((top.stage_id.clone(), top.score)));
365            }
366        }
367        Ok(None)
368    }
369
370    /// Scan all active stages for near-duplicate pairs.
371    ///
372    /// Returns pairs `(id_a, id_b, similarity)` where semantic similarity >= `threshold`.
373    /// Each pair appears only once (id_a < id_b lexicographically).
374    pub fn find_near_duplicates(&self, threshold: f32) -> Vec<(StageId, StageId, f32)> {
375        use search::cosine_similarity;
376
377        let entries = self.semantic_index.entries().to_vec();
378        let mut pairs: Vec<(StageId, StageId, f32)> = Vec::new();
379
380        for i in 0..entries.len() {
381            for j in (i + 1)..entries.len() {
382                let sim = cosine_similarity(&entries[i].embedding, &entries[j].embedding);
383                if sim >= threshold {
384                    let (a, b) = if entries[i].stage_id.0 < entries[j].stage_id.0 {
385                        (entries[i].stage_id.clone(), entries[j].stage_id.clone())
386                    } else {
387                        (entries[j].stage_id.clone(), entries[i].stage_id.clone())
388                    };
389                    pairs.push((a, b, sim));
390                }
391            }
392        }
393
394        // Sort by similarity descending
395        pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
396        pairs
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use embedding::MockEmbeddingProvider;
404    use noether_core::effects::EffectSet;
405    use noether_core::stage::{CostEstimate, StageSignature};
406    use noether_core::types::NType;
407    use noether_store::MemoryStore;
408    use std::collections::BTreeSet;
409
410    fn make_stage(id: &str, desc: &str, input: NType, output: NType) -> Stage {
411        Stage {
412            id: StageId(id.into()),
413            canonical_id: None,
414            signature: StageSignature {
415                input,
416                output,
417                effects: EffectSet::pure(),
418                implementation_hash: format!("impl_{id}"),
419            },
420            capabilities: BTreeSet::new(),
421            cost: CostEstimate {
422                time_ms_p50: None,
423                tokens_est: None,
424                memory_mb: None,
425            },
426            description: desc.into(),
427            examples: vec![],
428            lifecycle: StageLifecycle::Active,
429            ed25519_signature: None,
430            signer_public_key: None,
431            implementation_code: None,
432            implementation_language: None,
433            ui_style: None,
434            tags: vec![],
435            aliases: vec![],
436        }
437    }
438
439    fn test_store() -> MemoryStore {
440        let mut store = MemoryStore::new();
441        store
442            .put(make_stage(
443                "s1",
444                "convert text to number",
445                NType::Text,
446                NType::Number,
447            ))
448            .unwrap();
449        store
450            .put(make_stage(
451                "s2",
452                "make http request",
453                NType::Text,
454                NType::Text,
455            ))
456            .unwrap();
457        store
458            .put(make_stage(
459                "s3",
460                "sort a list of items",
461                NType::List(Box::new(NType::Any)),
462                NType::List(Box::new(NType::Any)),
463            ))
464            .unwrap();
465        store
466    }
467
468    #[test]
469    fn build_indexes_all_stages() {
470        let store = test_store();
471        let index = SemanticIndex::build(
472            &store,
473            Box::new(MockEmbeddingProvider::new(32)),
474            IndexConfig::default(),
475        )
476        .unwrap();
477        assert_eq!(index.len(), 3);
478    }
479
480    #[test]
481    fn add_stage_increments_count() {
482        let store = test_store();
483        let mut index = SemanticIndex::build(
484            &store,
485            Box::new(MockEmbeddingProvider::new(32)),
486            IndexConfig::default(),
487        )
488        .unwrap();
489        assert_eq!(index.len(), 3);
490        index
491            .add_stage(&make_stage("s4", "new stage", NType::Bool, NType::Text))
492            .unwrap();
493        assert_eq!(index.len(), 4);
494    }
495
496    #[test]
497    fn remove_stage_decrements_count() {
498        let store = test_store();
499        let mut index = SemanticIndex::build(
500            &store,
501            Box::new(MockEmbeddingProvider::new(32)),
502            IndexConfig::default(),
503        )
504        .unwrap();
505        index.remove_stage(&StageId("s1".into()));
506        assert_eq!(index.len(), 2);
507    }
508
509    #[test]
510    fn search_returns_results() {
511        let store = test_store();
512        let index = SemanticIndex::build(
513            &store,
514            Box::new(MockEmbeddingProvider::new(32)),
515            IndexConfig::default(),
516        )
517        .unwrap();
518        let results = index.search("convert text", 10).unwrap();
519        assert!(!results.is_empty());
520    }
521
522    #[test]
523    fn search_respects_top_k() {
524        let store = test_store();
525        let index = SemanticIndex::build(
526            &store,
527            Box::new(MockEmbeddingProvider::new(32)),
528            IndexConfig::default(),
529        )
530        .unwrap();
531        let results = index.search("anything", 2).unwrap();
532        assert!(results.len() <= 2);
533    }
534
535    #[test]
536    fn search_self_is_top_result() {
537        let store = test_store();
538        let index = SemanticIndex::build(
539            &store,
540            Box::new(MockEmbeddingProvider::new(128)),
541            IndexConfig::default(),
542        )
543        .unwrap();
544        // Searching with exact description should return that stage highly ranked
545        let results = index.search("convert text to number", 3).unwrap();
546        assert!(!results.is_empty());
547        // With mock embeddings, the exact description match should have the highest
548        // semantic score (identical hash → identical embedding → cosine sim = 1.0)
549        let top = &results[0];
550        assert!(
551            top.semantic_score > 0.9,
552            "Expected high semantic score for exact match, got {}",
553            top.semantic_score
554        );
555    }
556
557    #[test]
558    fn tombstoned_stages_not_indexed() {
559        let mut store = MemoryStore::new();
560        let mut s = make_stage("s1", "active stage", NType::Text, NType::Text);
561        store.put(s.clone()).unwrap();
562        s.id = StageId("s2".into());
563        s.description = "tombstoned stage".into();
564        s.lifecycle = StageLifecycle::Tombstone;
565        store.put(s).unwrap();
566
567        let index = SemanticIndex::build(
568            &store,
569            Box::new(MockEmbeddingProvider::new(32)),
570            IndexConfig::default(),
571        )
572        .unwrap();
573        assert_eq!(index.len(), 1);
574    }
575
576    #[test]
577    fn search_by_tag_returns_matching_stages() {
578        let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
579        s1.tags = vec!["network".into(), "io".into()];
580        let mut s2 = make_stage("s2", "text length", NType::Text, NType::Number);
581        s2.tags = vec!["text".into(), "pure".into()];
582
583        let stages = vec![s1, s2];
584        let index = SemanticIndex::from_stages(
585            stages,
586            Box::new(MockEmbeddingProvider::new(32)),
587            IndexConfig::default(),
588        )
589        .unwrap();
590
591        let network_ids = index.search_by_tag("network");
592        assert_eq!(network_ids.len(), 1);
593        assert_eq!(network_ids[0], StageId("s1".into()));
594
595        let pure_ids = index.search_by_tag("pure");
596        assert_eq!(pure_ids.len(), 1);
597        assert_eq!(pure_ids[0], StageId("s2".into()));
598
599        let missing = index.search_by_tag("nonexistent");
600        assert!(missing.is_empty());
601    }
602
603    #[test]
604    fn all_tags_returns_sorted_set() {
605        let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
606        s1.tags = vec!["zebra".into(), "apple".into()];
607        let index = SemanticIndex::from_stages(
608            vec![s1],
609            Box::new(MockEmbeddingProvider::new(32)),
610            IndexConfig::default(),
611        )
612        .unwrap();
613        let tags = index.all_tags();
614        assert_eq!(tags, vec!["apple", "zebra"]);
615    }
616
617    #[test]
618    fn search_filtered_restricts_to_tag() {
619        let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
620        s1.tags = vec!["network".into()];
621        let s2 = make_stage("s2", "sort list", NType::Text, NType::Text);
622
623        let stages = vec![s1, s2];
624        let index = SemanticIndex::from_stages(
625            stages,
626            Box::new(MockEmbeddingProvider::new(32)),
627            IndexConfig::default(),
628        )
629        .unwrap();
630
631        let filtered = index
632            .search_filtered("anything", 10, Some("network"))
633            .unwrap();
634        assert!(filtered.iter().all(|r| r.stage_id == StageId("s1".into())));
635
636        let all = index.search_filtered("anything", 10, None).unwrap();
637        assert_eq!(all.len(), 2);
638    }
639
640    #[test]
641    fn remove_stage_cleans_tag_map() {
642        let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
643        s1.tags = vec!["mytag".into()];
644        let mut index = SemanticIndex::from_stages(
645            vec![s1],
646            Box::new(MockEmbeddingProvider::new(32)),
647            IndexConfig::default(),
648        )
649        .unwrap();
650        assert_eq!(index.search_by_tag("mytag").len(), 1);
651        index.remove_stage(&StageId("s1".into()));
652        assert!(index.search_by_tag("mytag").is_empty());
653        assert!(index.all_tags().is_empty());
654    }
655}