Skip to main content

noether_engine/index/
mod.rs

1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4pub mod cache;
5pub mod embedding;
6pub mod search;
7pub mod text;
8
9use embedding::{EmbeddingError, EmbeddingProvider};
10use noether_core::stage::{Stage, StageId, StageLifecycle};
11use noether_store::StageStore;
12use search::SubIndex;
13use std::collections::BTreeMap;
14use std::collections::HashMap;
15
16/// Configuration for search result fusion weights.
17pub struct IndexConfig {
18    /// Weight for type signature similarity (default: 0.3).
19    pub signature_weight: f32,
20    /// Weight for description similarity (default: 0.5).
21    pub semantic_weight: f32,
22    /// Weight for example similarity (default: 0.2).
23    pub example_weight: f32,
24}
25
26impl Default for IndexConfig {
27    fn default() -> Self {
28        Self {
29            signature_weight: 0.3,
30            semantic_weight: 0.5,
31            example_weight: 0.2,
32        }
33    }
34}
35
36/// A search result with fused scores from all three indexes.
37#[derive(Debug, Clone)]
38pub struct SearchResult {
39    pub stage_id: StageId,
40    pub score: f32,
41    pub signature_score: f32,
42    pub semantic_score: f32,
43    pub example_score: f32,
44}
45
46/// Three-index semantic search over the stage store.
47pub struct SemanticIndex {
48    provider: Box<dyn EmbeddingProvider>,
49    signature_index: SubIndex,
50    semantic_index: SubIndex,
51    example_index: SubIndex,
52    config: IndexConfig,
53    /// Exact-match tag → stage IDs lookup for fast `search_filtered` pre-filtering.
54    tag_map: HashMap<String, Vec<StageId>>,
55}
56
57impl SemanticIndex {
58    /// Build the index from an owned list of stages (useful in async contexts
59    /// where holding a `&dyn StageStore` across `.await` is not possible).
60    pub fn from_stages(
61        stages: Vec<Stage>,
62        provider: Box<dyn EmbeddingProvider>,
63        config: IndexConfig,
64    ) -> Result<Self, EmbeddingError> {
65        let mut index = Self {
66            provider,
67            signature_index: SubIndex::new(),
68            semantic_index: SubIndex::new(),
69            example_index: SubIndex::new(),
70            config,
71            tag_map: HashMap::new(),
72        };
73        for stage in &stages {
74            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
75                continue;
76            }
77            index.add_stage(stage)?;
78        }
79        Ok(index)
80    }
81
82    /// Build the index from all non-tombstoned stages in a store.
83    pub fn build(
84        store: &dyn StageStore,
85        provider: Box<dyn EmbeddingProvider>,
86        config: IndexConfig,
87    ) -> Result<Self, EmbeddingError> {
88        let mut index = Self {
89            provider,
90            signature_index: SubIndex::new(),
91            semantic_index: SubIndex::new(),
92            example_index: SubIndex::new(),
93            config,
94            tag_map: HashMap::new(),
95        };
96        for stage in store.list(None) {
97            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
98                continue;
99            }
100            index.add_stage(stage)?;
101        }
102        Ok(index)
103    }
104
105    /// Build the index in a single pass: collect every signature/description/
106    /// example text upfront, dispatch all cache misses through
107    /// `inner.embed_batch` in chunks of `chunk_size`, then assemble the three
108    /// sub-indexes. Used by noether-cloud's registry on cold start so that
109    /// 486 stages × 3 texts = 1458 individual API calls collapse into ~46
110    /// batch calls of 32 texts each — well within typical rate limits.
111    pub fn from_stages_batched(
112        stages: Vec<Stage>,
113        cached_provider: cache::CachedEmbeddingProvider,
114        config: IndexConfig,
115        chunk_size: usize,
116    ) -> Result<Self, EmbeddingError> {
117        Self::from_stages_batched_paced(
118            stages,
119            cached_provider,
120            config,
121            chunk_size,
122            std::time::Duration::ZERO,
123        )
124    }
125
126    /// Like `from_stages_batched`, but waits `inter_batch_delay` between
127    /// successive batch calls and commits cache entries to disk after each
128    /// batch. Use this with rate-limited remote providers (e.g. Mistral
129    /// free tier ≈ 1 req/s → pass ~1100 ms).
130    pub fn from_stages_batched_paced(
131        stages: Vec<Stage>,
132        mut cached_provider: cache::CachedEmbeddingProvider,
133        config: IndexConfig,
134        chunk_size: usize,
135        inter_batch_delay: std::time::Duration,
136    ) -> Result<Self, EmbeddingError> {
137        // Filter active stages once and pre-compute all three texts per stage.
138        let active: Vec<&Stage> = stages
139            .iter()
140            .filter(|s| !matches!(s.lifecycle, StageLifecycle::Tombstone))
141            .collect();
142
143        let mut all_texts: Vec<String> = Vec::with_capacity(active.len() * 3);
144        for s in &active {
145            all_texts.push(text::signature_text(s));
146            all_texts.push(text::description_text(s));
147            all_texts.push(text::examples_text(s));
148        }
149        let text_refs: Vec<&str> = all_texts.iter().map(|s| s.as_str()).collect();
150        let embeddings =
151            cached_provider.embed_batch_cached_paced(&text_refs, chunk_size, inter_batch_delay)?;
152        cached_provider.flush();
153
154        // Distribute back into the three sub-indexes in stride 3.
155        let mut signature_index = SubIndex::new();
156        let mut semantic_index = SubIndex::new();
157        let mut example_index = SubIndex::new();
158        let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
159
160        for (i, s) in active.iter().enumerate() {
161            signature_index.add(s.id.clone(), embeddings[i * 3].clone());
162            semantic_index.add(s.id.clone(), embeddings[i * 3 + 1].clone());
163            example_index.add(s.id.clone(), embeddings[i * 3 + 2].clone());
164            for tag in &s.tags {
165                tag_map.entry(tag.clone()).or_default().push(s.id.clone());
166            }
167        }
168
169        Ok(Self {
170            provider: Box::new(cached_provider),
171            signature_index,
172            semantic_index,
173            example_index,
174            config,
175            tag_map,
176        })
177    }
178
179    /// Build using a CachedEmbeddingProvider for persistent embedding cache.
180    pub fn build_cached(
181        store: &dyn StageStore,
182        mut cached_provider: cache::CachedEmbeddingProvider,
183        config: IndexConfig,
184    ) -> Result<Self, EmbeddingError> {
185        let mut signature_index = SubIndex::new();
186        let mut semantic_index = SubIndex::new();
187        let mut example_index = SubIndex::new();
188        let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
189
190        for stage in store.list(None) {
191            if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
192                continue;
193            }
194            let sig_emb = cached_provider.embed_cached(&text::signature_text(stage))?;
195            let desc_emb = cached_provider.embed_cached(&text::description_text(stage))?;
196            let ex_emb = cached_provider.embed_cached(&text::examples_text(stage))?;
197
198            signature_index.add(stage.id.clone(), sig_emb);
199            semantic_index.add(stage.id.clone(), desc_emb);
200            example_index.add(stage.id.clone(), ex_emb);
201
202            for tag in &stage.tags {
203                tag_map
204                    .entry(tag.clone())
205                    .or_default()
206                    .push(stage.id.clone());
207            }
208        }
209
210        cached_provider.flush();
211
212        // Wrap the inner provider for future queries
213        let provider: Box<dyn EmbeddingProvider> = Box::new(cached_provider);
214
215        Ok(Self {
216            provider,
217            signature_index,
218            semantic_index,
219            example_index,
220            config,
221            tag_map,
222        })
223    }
224
225    /// Add a single stage to all three indexes.
226    pub fn add_stage(&mut self, stage: &Stage) -> Result<(), EmbeddingError> {
227        let sig_text = text::signature_text(stage);
228        let desc_text = text::description_text(stage);
229        let ex_text = text::examples_text(stage);
230
231        let sig_emb = self.provider.embed(&sig_text)?;
232        let desc_emb = self.provider.embed(&desc_text)?;
233        let ex_emb = self.provider.embed(&ex_text)?;
234
235        self.signature_index.add(stage.id.clone(), sig_emb);
236        self.semantic_index.add(stage.id.clone(), desc_emb);
237        self.example_index.add(stage.id.clone(), ex_emb);
238
239        for tag in &stage.tags {
240            self.tag_map
241                .entry(tag.clone())
242                .or_default()
243                .push(stage.id.clone());
244        }
245
246        Ok(())
247    }
248
249    /// Remove a stage from all three indexes.
250    pub fn remove_stage(&mut self, stage_id: &StageId) {
251        self.signature_index.remove(stage_id);
252        self.semantic_index.remove(stage_id);
253        self.example_index.remove(stage_id);
254
255        for ids in self.tag_map.values_mut() {
256            ids.retain(|id| id != stage_id);
257        }
258        self.tag_map.retain(|_, ids| !ids.is_empty());
259    }
260
261    /// Number of stages indexed.
262    pub fn len(&self) -> usize {
263        self.signature_index.len()
264    }
265
266    pub fn is_empty(&self) -> bool {
267        self.len() == 0
268    }
269
270    /// Search across all three indexes and return ranked results.
271    pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, EmbeddingError> {
272        self.search_filtered(query, top_k, None)
273    }
274
275    /// Like `search`, but restricts candidates to stages carrying `tag` (exact match).
276    /// Passing `tag: None` is equivalent to `search`.
277    pub fn search_filtered(
278        &self,
279        query: &str,
280        top_k: usize,
281        tag: Option<&str>,
282    ) -> Result<Vec<SearchResult>, EmbeddingError> {
283        let query_emb = self.provider.embed(query)?;
284        let fetch_k = top_k * 2;
285
286        let sig_results = self.signature_index.search(&query_emb, fetch_k);
287        let sem_results = self.semantic_index.search(&query_emb, fetch_k);
288        let ex_results = self.example_index.search(&query_emb, fetch_k);
289
290        // Optional tag allow-list for filtering
291        let allowed: Option<std::collections::BTreeSet<&str>> = tag.map(|t| {
292            self.tag_map
293                .get(t)
294                .map(|ids| ids.iter().map(|id| id.0.as_str()).collect())
295                .unwrap_or_default()
296        });
297
298        // Collect scores per stage_id
299        let mut scores: BTreeMap<String, (f32, f32, f32)> = BTreeMap::new();
300        for r in &sig_results {
301            scores.entry(r.stage_id.0.clone()).or_default().0 = r.score;
302        }
303        for r in &sem_results {
304            scores.entry(r.stage_id.0.clone()).or_default().1 = r.score;
305        }
306        for r in &ex_results {
307            scores.entry(r.stage_id.0.clone()).or_default().2 = r.score;
308        }
309
310        // Fuse scores
311        let mut results: Vec<SearchResult> = scores
312            .into_iter()
313            .filter(|(id, _)| {
314                allowed
315                    .as_ref()
316                    .map(|a| a.contains(id.as_str()))
317                    .unwrap_or(true)
318            })
319            .map(|(id, (sig, sem, ex))| {
320                let fused = self.config.signature_weight * sig.max(0.0)
321                    + self.config.semantic_weight * sem.max(0.0)
322                    + self.config.example_weight * ex.max(0.0);
323                SearchResult {
324                    stage_id: StageId(id),
325                    score: fused,
326                    signature_score: sig,
327                    semantic_score: sem,
328                    example_score: ex,
329                }
330            })
331            .collect();
332
333        results.sort_by(|a, b| {
334            b.score
335                .partial_cmp(&a.score)
336                .unwrap_or(std::cmp::Ordering::Equal)
337        });
338        results.truncate(top_k);
339        Ok(results)
340    }
341
342    /// Return all stage IDs that carry `tag` (exact match).
343    pub fn search_by_tag(&self, tag: &str) -> Vec<StageId> {
344        self.tag_map.get(tag).cloned().unwrap_or_default()
345    }
346
347    /// Return the set of all known tags across indexed stages.
348    pub fn all_tags(&self) -> Vec<String> {
349        let mut tags: Vec<String> = self.tag_map.keys().cloned().collect();
350        tags.sort();
351        tags
352    }
353
354    /// Check whether a candidate description is a near-duplicate of an existing stage.
355    ///
356    /// Returns `Some((stage_id, similarity))` if any existing stage's semantic embedding
357    /// exceeds `threshold` (default 0.92). Returns `None` if the description is novel enough.
358    pub fn check_duplicate_before_insert(
359        &self,
360        description: &str,
361        threshold: f32,
362    ) -> Result<Option<(StageId, f32)>, EmbeddingError> {
363        let emb = self.provider.embed(description)?;
364        let results = self.semantic_index.search(&emb, 1);
365        if let Some(top) = results.first() {
366            if top.score >= threshold {
367                return Ok(Some((top.stage_id.clone(), top.score)));
368            }
369        }
370        Ok(None)
371    }
372
373    /// Scan all active stages for near-duplicate pairs.
374    ///
375    /// Returns pairs `(id_a, id_b, similarity)` where semantic similarity >= `threshold`.
376    /// Each pair appears only once (id_a < id_b lexicographically).
377    pub fn find_near_duplicates(&self, threshold: f32) -> Vec<(StageId, StageId, f32)> {
378        use search::cosine_similarity;
379
380        let entries = self.semantic_index.entries().to_vec();
381        let mut pairs: Vec<(StageId, StageId, f32)> = Vec::new();
382
383        for i in 0..entries.len() {
384            for j in (i + 1)..entries.len() {
385                let sim = cosine_similarity(&entries[i].embedding, &entries[j].embedding);
386                if sim >= threshold {
387                    let (a, b) = if entries[i].stage_id.0 < entries[j].stage_id.0 {
388                        (entries[i].stage_id.clone(), entries[j].stage_id.clone())
389                    } else {
390                        (entries[j].stage_id.clone(), entries[i].stage_id.clone())
391                    };
392                    pairs.push((a, b, sim));
393                }
394            }
395        }
396
397        // Sort by similarity descending
398        pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
399        pairs
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use embedding::MockEmbeddingProvider;
407    use noether_core::effects::EffectSet;
408    use noether_core::stage::{CostEstimate, StageSignature};
409    use noether_core::types::NType;
410    use noether_store::MemoryStore;
411    use std::collections::BTreeSet;
412
413    fn make_stage(id: &str, desc: &str, input: NType, output: NType) -> Stage {
414        Stage {
415            id: StageId(id.into()),
416            signature_id: None,
417            signature: StageSignature {
418                input,
419                output,
420                effects: EffectSet::pure(),
421                implementation_hash: format!("impl_{id}"),
422            },
423            capabilities: BTreeSet::new(),
424            cost: CostEstimate {
425                time_ms_p50: None,
426                tokens_est: None,
427                memory_mb: None,
428            },
429            description: desc.into(),
430            examples: vec![],
431            lifecycle: StageLifecycle::Active,
432            ed25519_signature: None,
433            signer_public_key: None,
434            implementation_code: None,
435            implementation_language: None,
436            ui_style: None,
437            tags: vec![],
438            aliases: vec![],
439            name: None,
440            properties: Vec::new(),
441        }
442    }
443
444    fn test_store() -> MemoryStore {
445        let mut store = MemoryStore::new();
446        store
447            .put(make_stage(
448                "s1",
449                "convert text to number",
450                NType::Text,
451                NType::Number,
452            ))
453            .unwrap();
454        store
455            .put(make_stage(
456                "s2",
457                "make http request",
458                NType::Text,
459                NType::Text,
460            ))
461            .unwrap();
462        store
463            .put(make_stage(
464                "s3",
465                "sort a list of items",
466                NType::List(Box::new(NType::Any)),
467                NType::List(Box::new(NType::Any)),
468            ))
469            .unwrap();
470        store
471    }
472
473    #[test]
474    fn build_indexes_all_stages() {
475        let store = test_store();
476        let index = SemanticIndex::build(
477            &store,
478            Box::new(MockEmbeddingProvider::new(32)),
479            IndexConfig::default(),
480        )
481        .unwrap();
482        assert_eq!(index.len(), 3);
483    }
484
485    #[test]
486    fn add_stage_increments_count() {
487        let store = test_store();
488        let mut index = SemanticIndex::build(
489            &store,
490            Box::new(MockEmbeddingProvider::new(32)),
491            IndexConfig::default(),
492        )
493        .unwrap();
494        assert_eq!(index.len(), 3);
495        index
496            .add_stage(&make_stage("s4", "new stage", NType::Bool, NType::Text))
497            .unwrap();
498        assert_eq!(index.len(), 4);
499    }
500
501    #[test]
502    fn remove_stage_decrements_count() {
503        let store = test_store();
504        let mut index = SemanticIndex::build(
505            &store,
506            Box::new(MockEmbeddingProvider::new(32)),
507            IndexConfig::default(),
508        )
509        .unwrap();
510        index.remove_stage(&StageId("s1".into()));
511        assert_eq!(index.len(), 2);
512    }
513
514    #[test]
515    fn search_returns_results() {
516        let store = test_store();
517        let index = SemanticIndex::build(
518            &store,
519            Box::new(MockEmbeddingProvider::new(32)),
520            IndexConfig::default(),
521        )
522        .unwrap();
523        let results = index.search("convert text", 10).unwrap();
524        assert!(!results.is_empty());
525    }
526
527    #[test]
528    fn search_respects_top_k() {
529        let store = test_store();
530        let index = SemanticIndex::build(
531            &store,
532            Box::new(MockEmbeddingProvider::new(32)),
533            IndexConfig::default(),
534        )
535        .unwrap();
536        let results = index.search("anything", 2).unwrap();
537        assert!(results.len() <= 2);
538    }
539
540    #[test]
541    fn search_self_is_top_result() {
542        let store = test_store();
543        let index = SemanticIndex::build(
544            &store,
545            Box::new(MockEmbeddingProvider::new(128)),
546            IndexConfig::default(),
547        )
548        .unwrap();
549        // Searching with exact description should return that stage highly ranked
550        let results = index.search("convert text to number", 3).unwrap();
551        assert!(!results.is_empty());
552        // With mock embeddings, the exact description match should have the highest
553        // semantic score (identical hash → identical embedding → cosine sim = 1.0)
554        let top = &results[0];
555        assert!(
556            top.semantic_score > 0.9,
557            "Expected high semantic score for exact match, got {}",
558            top.semantic_score
559        );
560    }
561
562    #[test]
563    fn tombstoned_stages_not_indexed() {
564        let mut store = MemoryStore::new();
565        let mut s = make_stage("s1", "active stage", NType::Text, NType::Text);
566        store.put(s.clone()).unwrap();
567        s.id = StageId("s2".into());
568        s.description = "tombstoned stage".into();
569        s.lifecycle = StageLifecycle::Tombstone;
570        store.put(s).unwrap();
571
572        let index = SemanticIndex::build(
573            &store,
574            Box::new(MockEmbeddingProvider::new(32)),
575            IndexConfig::default(),
576        )
577        .unwrap();
578        assert_eq!(index.len(), 1);
579    }
580
581    #[test]
582    fn search_by_tag_returns_matching_stages() {
583        let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
584        s1.tags = vec!["network".into(), "io".into()];
585        let mut s2 = make_stage("s2", "text length", NType::Text, NType::Number);
586        s2.tags = vec!["text".into(), "pure".into()];
587
588        let stages = vec![s1, s2];
589        let index = SemanticIndex::from_stages(
590            stages,
591            Box::new(MockEmbeddingProvider::new(32)),
592            IndexConfig::default(),
593        )
594        .unwrap();
595
596        let network_ids = index.search_by_tag("network");
597        assert_eq!(network_ids.len(), 1);
598        assert_eq!(network_ids[0], StageId("s1".into()));
599
600        let pure_ids = index.search_by_tag("pure");
601        assert_eq!(pure_ids.len(), 1);
602        assert_eq!(pure_ids[0], StageId("s2".into()));
603
604        let missing = index.search_by_tag("nonexistent");
605        assert!(missing.is_empty());
606    }
607
608    #[test]
609    fn all_tags_returns_sorted_set() {
610        let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
611        s1.tags = vec!["zebra".into(), "apple".into()];
612        let index = SemanticIndex::from_stages(
613            vec![s1],
614            Box::new(MockEmbeddingProvider::new(32)),
615            IndexConfig::default(),
616        )
617        .unwrap();
618        let tags = index.all_tags();
619        assert_eq!(tags, vec!["apple", "zebra"]);
620    }
621
622    #[test]
623    fn search_filtered_restricts_to_tag() {
624        let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
625        s1.tags = vec!["network".into()];
626        let s2 = make_stage("s2", "sort list", NType::Text, NType::Text);
627
628        let stages = vec![s1, s2];
629        let index = SemanticIndex::from_stages(
630            stages,
631            Box::new(MockEmbeddingProvider::new(32)),
632            IndexConfig::default(),
633        )
634        .unwrap();
635
636        let filtered = index
637            .search_filtered("anything", 10, Some("network"))
638            .unwrap();
639        assert!(filtered.iter().all(|r| r.stage_id == StageId("s1".into())));
640
641        let all = index.search_filtered("anything", 10, None).unwrap();
642        assert_eq!(all.len(), 2);
643    }
644
645    #[test]
646    fn remove_stage_cleans_tag_map() {
647        let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
648        s1.tags = vec!["mytag".into()];
649        let mut index = SemanticIndex::from_stages(
650            vec![s1],
651            Box::new(MockEmbeddingProvider::new(32)),
652            IndexConfig::default(),
653        )
654        .unwrap();
655        assert_eq!(index.search_by_tag("mytag").len(), 1);
656        index.remove_stage(&StageId("s1".into()));
657        assert!(index.search_by_tag("mytag").is_empty());
658        assert!(index.all_tags().is_empty());
659    }
660}