Skip to main content

noether_engine/index/
mod.rs

1pub mod cache;
2pub mod embedding;
3pub mod search;
4pub mod text;
5
6use embedding::{EmbeddingError, EmbeddingProvider};
7use noether_core::stage::{Stage, StageId, StageLifecycle};
8use noether_store::StageStore;
9use search::SubIndex;
10use std::collections::BTreeMap;
11use std::collections::HashMap;
12
13/// Configuration for search result fusion weights.
14pub struct IndexConfig {
15    /// Weight for type signature similarity (default: 0.3).
16    pub signature_weight: f32,
17    /// Weight for description similarity (default: 0.5).
18    pub semantic_weight: f32,
19    /// Weight for example similarity (default: 0.2).
20    pub example_weight: f32,
21}
22
23impl Default for IndexConfig {
24    fn default() -> Self {
25        Self {
26            signature_weight: 0.3,
27            semantic_weight: 0.5,
28            example_weight: 0.2,
29        }
30    }
31}
32
33/// A search result with fused scores from all three indexes.
34#[derive(Debug, Clone)]
35pub struct SearchResult {
36    pub stage_id: StageId,
37    pub score: f32,
38    pub signature_score: f32,
39    pub semantic_score: f32,
40    pub example_score: f32,
41}
42
43/// Three-index semantic search over the stage store.
44pub struct SemanticIndex {
45    provider: Box<dyn EmbeddingProvider>,
46    signature_index: SubIndex,
47    semantic_index: SubIndex,
48    example_index: SubIndex,
49    config: IndexConfig,
50    /// Exact-match tag → stage IDs lookup for fast `search_filtered` pre-filtering.
51    tag_map: HashMap<String, Vec<StageId>>,
52}
53
54impl SemanticIndex {
55    /// Build the index from an owned list of stages (useful in async contexts
56    /// where holding a `&dyn StageStore` across `.await` is not possible).
57    pub fn from_stages(
58        stages: Vec<Stage>,
59        provider: Box<dyn EmbeddingProvider>,
60        config: IndexConfig,
61    ) -> Result<Self, EmbeddingError> {
62        let mut index = Self {
63            provider,
64            signature_index: SubIndex::new(),
65            semantic_index: SubIndex::new(),
66            example_index: SubIndex::new(),
67            config,
68            tag_map: HashMap::new(),
69        };
70        for stage in &stages {
71            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
72                continue;
73            }
74            index.add_stage(stage)?;
75        }
76        Ok(index)
77    }
78
79    /// Build the index from all non-tombstoned stages in a store.
80    pub fn build(
81        store: &dyn StageStore,
82        provider: Box<dyn EmbeddingProvider>,
83        config: IndexConfig,
84    ) -> Result<Self, EmbeddingError> {
85        let mut index = Self {
86            provider,
87            signature_index: SubIndex::new(),
88            semantic_index: SubIndex::new(),
89            example_index: SubIndex::new(),
90            config,
91            tag_map: HashMap::new(),
92        };
93        for stage in store.list(None) {
94            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
95                continue;
96            }
97            index.add_stage(stage)?;
98        }
99        Ok(index)
100    }
101
102    /// Build the index in a single pass: collect every signature/description/
103    /// example text upfront, dispatch all cache misses through
104    /// `inner.embed_batch` in chunks of `chunk_size`, then assemble the three
105    /// sub-indexes. Used by noether-cloud's registry on cold start so that
106    /// 486 stages × 3 texts = 1458 individual API calls collapse into ~46
107    /// batch calls of 32 texts each — well within typical rate limits.
108    pub fn from_stages_batched(
109        stages: Vec<Stage>,
110        cached_provider: cache::CachedEmbeddingProvider,
111        config: IndexConfig,
112        chunk_size: usize,
113    ) -> Result<Self, EmbeddingError> {
114        Self::from_stages_batched_paced(
115            stages,
116            cached_provider,
117            config,
118            chunk_size,
119            std::time::Duration::ZERO,
120        )
121    }
122
123    /// Like `from_stages_batched`, but waits `inter_batch_delay` between
124    /// successive batch calls and commits cache entries to disk after each
125    /// batch. Use this with rate-limited remote providers (e.g. Mistral
126    /// free tier ≈ 1 req/s → pass ~1100 ms).
127    pub fn from_stages_batched_paced(
128        stages: Vec<Stage>,
129        mut cached_provider: cache::CachedEmbeddingProvider,
130        config: IndexConfig,
131        chunk_size: usize,
132        inter_batch_delay: std::time::Duration,
133    ) -> Result<Self, EmbeddingError> {
134        // Filter active stages once and pre-compute all three texts per stage.
135        let active: Vec<&Stage> = stages
136            .iter()
137            .filter(|s| !matches!(s.lifecycle, StageLifecycle::Tombstone))
138            .collect();
139
140        let mut all_texts: Vec<String> = Vec::with_capacity(active.len() * 3);
141        for s in &active {
142            all_texts.push(text::signature_text(s));
143            all_texts.push(text::description_text(s));
144            all_texts.push(text::examples_text(s));
145        }
146        let text_refs: Vec<&str> = all_texts.iter().map(|s| s.as_str()).collect();
147        let embeddings =
148            cached_provider.embed_batch_cached_paced(&text_refs, chunk_size, inter_batch_delay)?;
149        cached_provider.flush();
150
151        // Distribute back into the three sub-indexes in stride 3.
152        let mut signature_index = SubIndex::new();
153        let mut semantic_index = SubIndex::new();
154        let mut example_index = SubIndex::new();
155        let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
156
157        for (i, s) in active.iter().enumerate() {
158            signature_index.add(s.id.clone(), embeddings[i * 3].clone());
159            semantic_index.add(s.id.clone(), embeddings[i * 3 + 1].clone());
160            example_index.add(s.id.clone(), embeddings[i * 3 + 2].clone());
161            for tag in &s.tags {
162                tag_map.entry(tag.clone()).or_default().push(s.id.clone());
163            }
164        }
165
166        Ok(Self {
167            provider: Box::new(cached_provider),
168            signature_index,
169            semantic_index,
170            example_index,
171            config,
172            tag_map,
173        })
174    }
175
176    /// Build using a CachedEmbeddingProvider for persistent embedding cache.
177    pub fn build_cached(
178        store: &dyn StageStore,
179        mut cached_provider: cache::CachedEmbeddingProvider,
180        config: IndexConfig,
181    ) -> Result<Self, EmbeddingError> {
182        let mut signature_index = SubIndex::new();
183        let mut semantic_index = SubIndex::new();
184        let mut example_index = SubIndex::new();
185        let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
186
187        for stage in store.list(None) {
188            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
189                continue;
190            }
191            let sig_emb = cached_provider.embed_cached(&text::signature_text(stage))?;
192            let desc_emb = cached_provider.embed_cached(&text::description_text(stage))?;
193            let ex_emb = cached_provider.embed_cached(&text::examples_text(stage))?;
194
195            signature_index.add(stage.id.clone(), sig_emb);
196            semantic_index.add(stage.id.clone(), desc_emb);
197            example_index.add(stage.id.clone(), ex_emb);
198
199            for tag in &stage.tags {
200                tag_map
201                    .entry(tag.clone())
202                    .or_default()
203                    .push(stage.id.clone());
204            }
205        }
206
207        cached_provider.flush();
208
209        // Wrap the inner provider for future queries
210        let provider: Box<dyn EmbeddingProvider> = Box::new(cached_provider);
211
212        Ok(Self {
213            provider,
214            signature_index,
215            semantic_index,
216            example_index,
217            config,
218            tag_map,
219        })
220    }
221
222    /// Add a single stage to all three indexes.
223    pub fn add_stage(&mut self, stage: &Stage) -> Result<(), EmbeddingError> {
224        let sig_text = text::signature_text(stage);
225        let desc_text = text::description_text(stage);
226        let ex_text = text::examples_text(stage);
227
228        let sig_emb = self.provider.embed(&sig_text)?;
229        let desc_emb = self.provider.embed(&desc_text)?;
230        let ex_emb = self.provider.embed(&ex_text)?;
231
232        self.signature_index.add(stage.id.clone(), sig_emb);
233        self.semantic_index.add(stage.id.clone(), desc_emb);
234        self.example_index.add(stage.id.clone(), ex_emb);
235
236        for tag in &stage.tags {
237            self.tag_map
238                .entry(tag.clone())
239                .or_default()
240                .push(stage.id.clone());
241        }
242
243        Ok(())
244    }
245
246    /// Remove a stage from all three indexes.
247    pub fn remove_stage(&mut self, stage_id: &StageId) {
248        self.signature_index.remove(stage_id);
249        self.semantic_index.remove(stage_id);
250        self.example_index.remove(stage_id);
251
252        for ids in self.tag_map.values_mut() {
253            ids.retain(|id| id != stage_id);
254        }
255        self.tag_map.retain(|_, ids| !ids.is_empty());
256    }
257
258    /// Number of stages indexed.
259    pub fn len(&self) -> usize {
260        self.signature_index.len()
261    }
262
263    pub fn is_empty(&self) -> bool {
264        self.len() == 0
265    }
266
267    /// Search across all three indexes and return ranked results.
268    pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, EmbeddingError> {
269        self.search_filtered(query, top_k, None)
270    }
271
272    /// Like `search`, but restricts candidates to stages carrying `tag` (exact match).
273    /// Passing `tag: None` is equivalent to `search`.
274    pub fn search_filtered(
275        &self,
276        query: &str,
277        top_k: usize,
278        tag: Option<&str>,
279    ) -> Result<Vec<SearchResult>, EmbeddingError> {
280        let query_emb = self.provider.embed(query)?;
281        let fetch_k = top_k * 2;
282
283        let sig_results = self.signature_index.search(&query_emb, fetch_k);
284        let sem_results = self.semantic_index.search(&query_emb, fetch_k);
285        let ex_results = self.example_index.search(&query_emb, fetch_k);
286
287        // Optional tag allow-list for filtering
288        let allowed: Option<std::collections::BTreeSet<&str>> = tag.map(|t| {
289            self.tag_map
290                .get(t)
291                .map(|ids| ids.iter().map(|id| id.0.as_str()).collect())
292                .unwrap_or_default()
293        });
294
295        // Collect scores per stage_id
296        let mut scores: BTreeMap<String, (f32, f32, f32)> = BTreeMap::new();
297        for r in &sig_results {
298            scores.entry(r.stage_id.0.clone()).or_default().0 = r.score;
299        }
300        for r in &sem_results {
301            scores.entry(r.stage_id.0.clone()).or_default().1 = r.score;
302        }
303        for r in &ex_results {
304            scores.entry(r.stage_id.0.clone()).or_default().2 = r.score;
305        }
306
307        // Fuse scores
308        let mut results: Vec<SearchResult> = scores
309            .into_iter()
310            .filter(|(id, _)| {
311                allowed
312                    .as_ref()
313                    .map(|a| a.contains(id.as_str()))
314                    .unwrap_or(true)
315            })
316            .map(|(id, (sig, sem, ex))| {
317                let fused = self.config.signature_weight * sig.max(0.0)
318                    + self.config.semantic_weight * sem.max(0.0)
319                    + self.config.example_weight * ex.max(0.0);
320                SearchResult {
321                    stage_id: StageId(id),
322                    score: fused,
323                    signature_score: sig,
324                    semantic_score: sem,
325                    example_score: ex,
326                }
327            })
328            .collect();
329
330        results.sort_by(|a, b| {
331            b.score
332                .partial_cmp(&a.score)
333                .unwrap_or(std::cmp::Ordering::Equal)
334        });
335        results.truncate(top_k);
336        Ok(results)
337    }
338
339    /// Return all stage IDs that carry `tag` (exact match).
340    pub fn search_by_tag(&self, tag: &str) -> Vec<StageId> {
341        self.tag_map.get(tag).cloned().unwrap_or_default()
342    }
343
344    /// Return the set of all known tags across indexed stages.
345    pub fn all_tags(&self) -> Vec<String> {
346        let mut tags: Vec<String> = self.tag_map.keys().cloned().collect();
347        tags.sort();
348        tags
349    }
350
351    /// Check whether a candidate description is a near-duplicate of an existing stage.
352    ///
353    /// Returns `Some((stage_id, similarity))` if any existing stage's semantic embedding
354    /// exceeds `threshold` (default 0.92). Returns `None` if the description is novel enough.
355    pub fn check_duplicate_before_insert(
356        &self,
357        description: &str,
358        threshold: f32,
359    ) -> Result<Option<(StageId, f32)>, EmbeddingError> {
360        let emb = self.provider.embed(description)?;
361        let results = self.semantic_index.search(&emb, 1);
362        if let Some(top) = results.first() {
363            if top.score >= threshold {
364                return Ok(Some((top.stage_id.clone(), top.score)));
365            }
366        }
367        Ok(None)
368    }
369
370    /// Scan all active stages for near-duplicate pairs.
371    ///
372    /// Returns pairs `(id_a, id_b, similarity)` where semantic similarity >= `threshold`.
373    /// Each pair appears only once (id_a < id_b lexicographically).
374    pub fn find_near_duplicates(&self, threshold: f32) -> Vec<(StageId, StageId, f32)> {
375        use search::cosine_similarity;
376
377        let entries = self.semantic_index.entries().to_vec();
378        let mut pairs: Vec<(StageId, StageId, f32)> = Vec::new();
379
380        for i in 0..entries.len() {
381            for j in (i + 1)..entries.len() {
382                let sim = cosine_similarity(&entries[i].embedding, &entries[j].embedding);
383                if sim >= threshold {
384                    let (a, b) = if entries[i].stage_id.0 < entries[j].stage_id.0 {
385                        (entries[i].stage_id.clone(), entries[j].stage_id.clone())
386                    } else {
387                        (entries[j].stage_id.clone(), entries[i].stage_id.clone())
388                    };
389                    pairs.push((a, b, sim));
390                }
391            }
392        }
393
394        // Sort by similarity descending
395        pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
396        pairs
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use embedding::MockEmbeddingProvider;
404    use noether_core::effects::EffectSet;
405    use noether_core::stage::{CostEstimate, StageSignature};
406    use noether_core::types::NType;
407    use noether_store::MemoryStore;
408    use std::collections::BTreeSet;
409
410    fn make_stage(id: &str, desc: &str, input: NType, output: NType) -> Stage {
411        Stage {
412            id: StageId(id.into()),
413            signature_id: None,
414            signature: StageSignature {
415                input,
416                output,
417                effects: EffectSet::pure(),
418                implementation_hash: format!("impl_{id}"),
419            },
420            capabilities: BTreeSet::new(),
421            cost: CostEstimate {
422                time_ms_p50: None,
423                tokens_est: None,
424                memory_mb: None,
425            },
426            description: desc.into(),
427            examples: vec![],
428            lifecycle: StageLifecycle::Active,
429            ed25519_signature: None,
430            signer_public_key: None,
431            implementation_code: None,
432            implementation_language: None,
433            ui_style: None,
434            tags: vec![],
435            aliases: vec![],
436            name: None,
437            properties: Vec::new(),
438        }
439    }
440
441    fn test_store() -> MemoryStore {
442        let mut store = MemoryStore::new();
443        store
444            .put(make_stage(
445                "s1",
446                "convert text to number",
447                NType::Text,
448                NType::Number,
449            ))
450            .unwrap();
451        store
452            .put(make_stage(
453                "s2",
454                "make http request",
455                NType::Text,
456                NType::Text,
457            ))
458            .unwrap();
459        store
460            .put(make_stage(
461                "s3",
462                "sort a list of items",
463                NType::List(Box::new(NType::Any)),
464                NType::List(Box::new(NType::Any)),
465            ))
466            .unwrap();
467        store
468    }
469
470    #[test]
471    fn build_indexes_all_stages() {
472        let store = test_store();
473        let index = SemanticIndex::build(
474            &store,
475            Box::new(MockEmbeddingProvider::new(32)),
476            IndexConfig::default(),
477        )
478        .unwrap();
479        assert_eq!(index.len(), 3);
480    }
481
482    #[test]
483    fn add_stage_increments_count() {
484        let store = test_store();
485        let mut index = SemanticIndex::build(
486            &store,
487            Box::new(MockEmbeddingProvider::new(32)),
488            IndexConfig::default(),
489        )
490        .unwrap();
491        assert_eq!(index.len(), 3);
492        index
493            .add_stage(&make_stage("s4", "new stage", NType::Bool, NType::Text))
494            .unwrap();
495        assert_eq!(index.len(), 4);
496    }
497
498    #[test]
499    fn remove_stage_decrements_count() {
500        let store = test_store();
501        let mut index = SemanticIndex::build(
502            &store,
503            Box::new(MockEmbeddingProvider::new(32)),
504            IndexConfig::default(),
505        )
506        .unwrap();
507        index.remove_stage(&StageId("s1".into()));
508        assert_eq!(index.len(), 2);
509    }
510
511    #[test]
512    fn search_returns_results() {
513        let store = test_store();
514        let index = SemanticIndex::build(
515            &store,
516            Box::new(MockEmbeddingProvider::new(32)),
517            IndexConfig::default(),
518        )
519        .unwrap();
520        let results = index.search("convert text", 10).unwrap();
521        assert!(!results.is_empty());
522    }
523
524    #[test]
525    fn search_respects_top_k() {
526        let store = test_store();
527        let index = SemanticIndex::build(
528            &store,
529            Box::new(MockEmbeddingProvider::new(32)),
530            IndexConfig::default(),
531        )
532        .unwrap();
533        let results = index.search("anything", 2).unwrap();
534        assert!(results.len() <= 2);
535    }
536
537    #[test]
538    fn search_self_is_top_result() {
539        let store = test_store();
540        let index = SemanticIndex::build(
541            &store,
542            Box::new(MockEmbeddingProvider::new(128)),
543            IndexConfig::default(),
544        )
545        .unwrap();
546        // Searching with exact description should return that stage highly ranked
547        let results = index.search("convert text to number", 3).unwrap();
548        assert!(!results.is_empty());
549        // With mock embeddings, the exact description match should have the highest
550        // semantic score (identical hash → identical embedding → cosine sim = 1.0)
551        let top = &results[0];
552        assert!(
553            top.semantic_score > 0.9,
554            "Expected high semantic score for exact match, got {}",
555            top.semantic_score
556        );
557    }
558
559    #[test]
560    fn tombstoned_stages_not_indexed() {
561        let mut store = MemoryStore::new();
562        let mut s = make_stage("s1", "active stage", NType::Text, NType::Text);
563        store.put(s.clone()).unwrap();
564        s.id = StageId("s2".into());
565        s.description = "tombstoned stage".into();
566        s.lifecycle = StageLifecycle::Tombstone;
567        store.put(s).unwrap();
568
569        let index = SemanticIndex::build(
570            &store,
571            Box::new(MockEmbeddingProvider::new(32)),
572            IndexConfig::default(),
573        )
574        .unwrap();
575        assert_eq!(index.len(), 1);
576    }
577
578    #[test]
579    fn search_by_tag_returns_matching_stages() {
580        let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
581        s1.tags = vec!["network".into(), "io".into()];
582        let mut s2 = make_stage("s2", "text length", NType::Text, NType::Number);
583        s2.tags = vec!["text".into(), "pure".into()];
584
585        let stages = vec![s1, s2];
586        let index = SemanticIndex::from_stages(
587            stages,
588            Box::new(MockEmbeddingProvider::new(32)),
589            IndexConfig::default(),
590        )
591        .unwrap();
592
593        let network_ids = index.search_by_tag("network");
594        assert_eq!(network_ids.len(), 1);
595        assert_eq!(network_ids[0], StageId("s1".into()));
596
597        let pure_ids = index.search_by_tag("pure");
598        assert_eq!(pure_ids.len(), 1);
599        assert_eq!(pure_ids[0], StageId("s2".into()));
600
601        let missing = index.search_by_tag("nonexistent");
602        assert!(missing.is_empty());
603    }
604
605    #[test]
606    fn all_tags_returns_sorted_set() {
607        let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
608        s1.tags = vec!["zebra".into(), "apple".into()];
609        let index = SemanticIndex::from_stages(
610            vec![s1],
611            Box::new(MockEmbeddingProvider::new(32)),
612            IndexConfig::default(),
613        )
614        .unwrap();
615        let tags = index.all_tags();
616        assert_eq!(tags, vec!["apple", "zebra"]);
617    }
618
619    #[test]
620    fn search_filtered_restricts_to_tag() {
621        let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
622        s1.tags = vec!["network".into()];
623        let s2 = make_stage("s2", "sort list", NType::Text, NType::Text);
624
625        let stages = vec![s1, s2];
626        let index = SemanticIndex::from_stages(
627            stages,
628            Box::new(MockEmbeddingProvider::new(32)),
629            IndexConfig::default(),
630        )
631        .unwrap();
632
633        let filtered = index
634            .search_filtered("anything", 10, Some("network"))
635            .unwrap();
636        assert!(filtered.iter().all(|r| r.stage_id == StageId("s1".into())));
637
638        let all = index.search_filtered("anything", 10, None).unwrap();
639        assert_eq!(all.len(), 2);
640    }
641
642    #[test]
643    fn remove_stage_cleans_tag_map() {
644        let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
645        s1.tags = vec!["mytag".into()];
646        let mut index = SemanticIndex::from_stages(
647            vec![s1],
648            Box::new(MockEmbeddingProvider::new(32)),
649            IndexConfig::default(),
650        )
651        .unwrap();
652        assert_eq!(index.search_by_tag("mytag").len(), 1);
653        index.remove_stage(&StageId("s1".into()));
654        assert!(index.search_by_tag("mytag").is_empty());
655        assert!(index.all_tags().is_empty());
656    }
657}