1pub mod cache;
2pub mod embedding;
3pub mod search;
4pub mod text;
5
6use embedding::{EmbeddingError, EmbeddingProvider};
7use noether_core::stage::{Stage, StageId, StageLifecycle};
8use noether_store::StageStore;
9use search::SubIndex;
10use std::collections::BTreeMap;
11use std::collections::HashMap;
12
13pub struct IndexConfig {
15 pub signature_weight: f32,
17 pub semantic_weight: f32,
19 pub example_weight: f32,
21}
22
23impl Default for IndexConfig {
24 fn default() -> Self {
25 Self {
26 signature_weight: 0.3,
27 semantic_weight: 0.5,
28 example_weight: 0.2,
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct SearchResult {
36 pub stage_id: StageId,
37 pub score: f32,
38 pub signature_score: f32,
39 pub semantic_score: f32,
40 pub example_score: f32,
41}
42
43pub struct SemanticIndex {
45 provider: Box<dyn EmbeddingProvider>,
46 signature_index: SubIndex,
47 semantic_index: SubIndex,
48 example_index: SubIndex,
49 config: IndexConfig,
50 tag_map: HashMap<String, Vec<StageId>>,
52}
53
54impl SemanticIndex {
55 pub fn from_stages(
58 stages: Vec<Stage>,
59 provider: Box<dyn EmbeddingProvider>,
60 config: IndexConfig,
61 ) -> Result<Self, EmbeddingError> {
62 let mut index = Self {
63 provider,
64 signature_index: SubIndex::new(),
65 semantic_index: SubIndex::new(),
66 example_index: SubIndex::new(),
67 config,
68 tag_map: HashMap::new(),
69 };
70 for stage in &stages {
71 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
72 continue;
73 }
74 index.add_stage(stage)?;
75 }
76 Ok(index)
77 }
78
79 pub fn build(
81 store: &dyn StageStore,
82 provider: Box<dyn EmbeddingProvider>,
83 config: IndexConfig,
84 ) -> Result<Self, EmbeddingError> {
85 let mut index = Self {
86 provider,
87 signature_index: SubIndex::new(),
88 semantic_index: SubIndex::new(),
89 example_index: SubIndex::new(),
90 config,
91 tag_map: HashMap::new(),
92 };
93 for stage in store.list(None) {
94 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
95 continue;
96 }
97 index.add_stage(stage)?;
98 }
99 Ok(index)
100 }
101
102 pub fn from_stages_batched(
109 stages: Vec<Stage>,
110 cached_provider: cache::CachedEmbeddingProvider,
111 config: IndexConfig,
112 chunk_size: usize,
113 ) -> Result<Self, EmbeddingError> {
114 Self::from_stages_batched_paced(
115 stages,
116 cached_provider,
117 config,
118 chunk_size,
119 std::time::Duration::ZERO,
120 )
121 }
122
123 pub fn from_stages_batched_paced(
128 stages: Vec<Stage>,
129 mut cached_provider: cache::CachedEmbeddingProvider,
130 config: IndexConfig,
131 chunk_size: usize,
132 inter_batch_delay: std::time::Duration,
133 ) -> Result<Self, EmbeddingError> {
134 let active: Vec<&Stage> = stages
136 .iter()
137 .filter(|s| !matches!(s.lifecycle, StageLifecycle::Tombstone))
138 .collect();
139
140 let mut all_texts: Vec<String> = Vec::with_capacity(active.len() * 3);
141 for s in &active {
142 all_texts.push(text::signature_text(s));
143 all_texts.push(text::description_text(s));
144 all_texts.push(text::examples_text(s));
145 }
146 let text_refs: Vec<&str> = all_texts.iter().map(|s| s.as_str()).collect();
147 let embeddings =
148 cached_provider.embed_batch_cached_paced(&text_refs, chunk_size, inter_batch_delay)?;
149 cached_provider.flush();
150
151 let mut signature_index = SubIndex::new();
153 let mut semantic_index = SubIndex::new();
154 let mut example_index = SubIndex::new();
155 let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
156
157 for (i, s) in active.iter().enumerate() {
158 signature_index.add(s.id.clone(), embeddings[i * 3].clone());
159 semantic_index.add(s.id.clone(), embeddings[i * 3 + 1].clone());
160 example_index.add(s.id.clone(), embeddings[i * 3 + 2].clone());
161 for tag in &s.tags {
162 tag_map.entry(tag.clone()).or_default().push(s.id.clone());
163 }
164 }
165
166 Ok(Self {
167 provider: Box::new(cached_provider),
168 signature_index,
169 semantic_index,
170 example_index,
171 config,
172 tag_map,
173 })
174 }
175
176 pub fn build_cached(
178 store: &dyn StageStore,
179 mut cached_provider: cache::CachedEmbeddingProvider,
180 config: IndexConfig,
181 ) -> Result<Self, EmbeddingError> {
182 let mut signature_index = SubIndex::new();
183 let mut semantic_index = SubIndex::new();
184 let mut example_index = SubIndex::new();
185 let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
186
187 for stage in store.list(None) {
188 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
189 continue;
190 }
191 let sig_emb = cached_provider.embed_cached(&text::signature_text(stage))?;
192 let desc_emb = cached_provider.embed_cached(&text::description_text(stage))?;
193 let ex_emb = cached_provider.embed_cached(&text::examples_text(stage))?;
194
195 signature_index.add(stage.id.clone(), sig_emb);
196 semantic_index.add(stage.id.clone(), desc_emb);
197 example_index.add(stage.id.clone(), ex_emb);
198
199 for tag in &stage.tags {
200 tag_map
201 .entry(tag.clone())
202 .or_default()
203 .push(stage.id.clone());
204 }
205 }
206
207 cached_provider.flush();
208
209 let provider: Box<dyn EmbeddingProvider> = Box::new(cached_provider);
211
212 Ok(Self {
213 provider,
214 signature_index,
215 semantic_index,
216 example_index,
217 config,
218 tag_map,
219 })
220 }
221
222 pub fn add_stage(&mut self, stage: &Stage) -> Result<(), EmbeddingError> {
224 let sig_text = text::signature_text(stage);
225 let desc_text = text::description_text(stage);
226 let ex_text = text::examples_text(stage);
227
228 let sig_emb = self.provider.embed(&sig_text)?;
229 let desc_emb = self.provider.embed(&desc_text)?;
230 let ex_emb = self.provider.embed(&ex_text)?;
231
232 self.signature_index.add(stage.id.clone(), sig_emb);
233 self.semantic_index.add(stage.id.clone(), desc_emb);
234 self.example_index.add(stage.id.clone(), ex_emb);
235
236 for tag in &stage.tags {
237 self.tag_map
238 .entry(tag.clone())
239 .or_default()
240 .push(stage.id.clone());
241 }
242
243 Ok(())
244 }
245
246 pub fn remove_stage(&mut self, stage_id: &StageId) {
248 self.signature_index.remove(stage_id);
249 self.semantic_index.remove(stage_id);
250 self.example_index.remove(stage_id);
251
252 for ids in self.tag_map.values_mut() {
253 ids.retain(|id| id != stage_id);
254 }
255 self.tag_map.retain(|_, ids| !ids.is_empty());
256 }
257
258 pub fn len(&self) -> usize {
260 self.signature_index.len()
261 }
262
263 pub fn is_empty(&self) -> bool {
264 self.len() == 0
265 }
266
267 pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, EmbeddingError> {
269 self.search_filtered(query, top_k, None)
270 }
271
272 pub fn search_filtered(
275 &self,
276 query: &str,
277 top_k: usize,
278 tag: Option<&str>,
279 ) -> Result<Vec<SearchResult>, EmbeddingError> {
280 let query_emb = self.provider.embed(query)?;
281 let fetch_k = top_k * 2;
282
283 let sig_results = self.signature_index.search(&query_emb, fetch_k);
284 let sem_results = self.semantic_index.search(&query_emb, fetch_k);
285 let ex_results = self.example_index.search(&query_emb, fetch_k);
286
287 let allowed: Option<std::collections::BTreeSet<&str>> = tag.map(|t| {
289 self.tag_map
290 .get(t)
291 .map(|ids| ids.iter().map(|id| id.0.as_str()).collect())
292 .unwrap_or_default()
293 });
294
295 let mut scores: BTreeMap<String, (f32, f32, f32)> = BTreeMap::new();
297 for r in &sig_results {
298 scores.entry(r.stage_id.0.clone()).or_default().0 = r.score;
299 }
300 for r in &sem_results {
301 scores.entry(r.stage_id.0.clone()).or_default().1 = r.score;
302 }
303 for r in &ex_results {
304 scores.entry(r.stage_id.0.clone()).or_default().2 = r.score;
305 }
306
307 let mut results: Vec<SearchResult> = scores
309 .into_iter()
310 .filter(|(id, _)| {
311 allowed
312 .as_ref()
313 .map(|a| a.contains(id.as_str()))
314 .unwrap_or(true)
315 })
316 .map(|(id, (sig, sem, ex))| {
317 let fused = self.config.signature_weight * sig.max(0.0)
318 + self.config.semantic_weight * sem.max(0.0)
319 + self.config.example_weight * ex.max(0.0);
320 SearchResult {
321 stage_id: StageId(id),
322 score: fused,
323 signature_score: sig,
324 semantic_score: sem,
325 example_score: ex,
326 }
327 })
328 .collect();
329
330 results.sort_by(|a, b| {
331 b.score
332 .partial_cmp(&a.score)
333 .unwrap_or(std::cmp::Ordering::Equal)
334 });
335 results.truncate(top_k);
336 Ok(results)
337 }
338
339 pub fn search_by_tag(&self, tag: &str) -> Vec<StageId> {
341 self.tag_map.get(tag).cloned().unwrap_or_default()
342 }
343
344 pub fn all_tags(&self) -> Vec<String> {
346 let mut tags: Vec<String> = self.tag_map.keys().cloned().collect();
347 tags.sort();
348 tags
349 }
350
351 pub fn check_duplicate_before_insert(
356 &self,
357 description: &str,
358 threshold: f32,
359 ) -> Result<Option<(StageId, f32)>, EmbeddingError> {
360 let emb = self.provider.embed(description)?;
361 let results = self.semantic_index.search(&emb, 1);
362 if let Some(top) = results.first() {
363 if top.score >= threshold {
364 return Ok(Some((top.stage_id.clone(), top.score)));
365 }
366 }
367 Ok(None)
368 }
369
370 pub fn find_near_duplicates(&self, threshold: f32) -> Vec<(StageId, StageId, f32)> {
375 use search::cosine_similarity;
376
377 let entries = self.semantic_index.entries().to_vec();
378 let mut pairs: Vec<(StageId, StageId, f32)> = Vec::new();
379
380 for i in 0..entries.len() {
381 for j in (i + 1)..entries.len() {
382 let sim = cosine_similarity(&entries[i].embedding, &entries[j].embedding);
383 if sim >= threshold {
384 let (a, b) = if entries[i].stage_id.0 < entries[j].stage_id.0 {
385 (entries[i].stage_id.clone(), entries[j].stage_id.clone())
386 } else {
387 (entries[j].stage_id.clone(), entries[i].stage_id.clone())
388 };
389 pairs.push((a, b, sim));
390 }
391 }
392 }
393
394 pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
396 pairs
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use embedding::MockEmbeddingProvider;
404 use noether_core::effects::EffectSet;
405 use noether_core::stage::{CostEstimate, StageSignature};
406 use noether_core::types::NType;
407 use noether_store::MemoryStore;
408 use std::collections::BTreeSet;
409
410 fn make_stage(id: &str, desc: &str, input: NType, output: NType) -> Stage {
411 Stage {
412 id: StageId(id.into()),
413 canonical_id: None,
414 signature: StageSignature {
415 input,
416 output,
417 effects: EffectSet::pure(),
418 implementation_hash: format!("impl_{id}"),
419 },
420 capabilities: BTreeSet::new(),
421 cost: CostEstimate {
422 time_ms_p50: None,
423 tokens_est: None,
424 memory_mb: None,
425 },
426 description: desc.into(),
427 examples: vec![],
428 lifecycle: StageLifecycle::Active,
429 ed25519_signature: None,
430 signer_public_key: None,
431 implementation_code: None,
432 implementation_language: None,
433 ui_style: None,
434 tags: vec![],
435 aliases: vec![],
436 name: None,
437 }
438 }
439
440 fn test_store() -> MemoryStore {
441 let mut store = MemoryStore::new();
442 store
443 .put(make_stage(
444 "s1",
445 "convert text to number",
446 NType::Text,
447 NType::Number,
448 ))
449 .unwrap();
450 store
451 .put(make_stage(
452 "s2",
453 "make http request",
454 NType::Text,
455 NType::Text,
456 ))
457 .unwrap();
458 store
459 .put(make_stage(
460 "s3",
461 "sort a list of items",
462 NType::List(Box::new(NType::Any)),
463 NType::List(Box::new(NType::Any)),
464 ))
465 .unwrap();
466 store
467 }
468
469 #[test]
470 fn build_indexes_all_stages() {
471 let store = test_store();
472 let index = SemanticIndex::build(
473 &store,
474 Box::new(MockEmbeddingProvider::new(32)),
475 IndexConfig::default(),
476 )
477 .unwrap();
478 assert_eq!(index.len(), 3);
479 }
480
481 #[test]
482 fn add_stage_increments_count() {
483 let store = test_store();
484 let mut index = SemanticIndex::build(
485 &store,
486 Box::new(MockEmbeddingProvider::new(32)),
487 IndexConfig::default(),
488 )
489 .unwrap();
490 assert_eq!(index.len(), 3);
491 index
492 .add_stage(&make_stage("s4", "new stage", NType::Bool, NType::Text))
493 .unwrap();
494 assert_eq!(index.len(), 4);
495 }
496
497 #[test]
498 fn remove_stage_decrements_count() {
499 let store = test_store();
500 let mut index = SemanticIndex::build(
501 &store,
502 Box::new(MockEmbeddingProvider::new(32)),
503 IndexConfig::default(),
504 )
505 .unwrap();
506 index.remove_stage(&StageId("s1".into()));
507 assert_eq!(index.len(), 2);
508 }
509
510 #[test]
511 fn search_returns_results() {
512 let store = test_store();
513 let index = SemanticIndex::build(
514 &store,
515 Box::new(MockEmbeddingProvider::new(32)),
516 IndexConfig::default(),
517 )
518 .unwrap();
519 let results = index.search("convert text", 10).unwrap();
520 assert!(!results.is_empty());
521 }
522
523 #[test]
524 fn search_respects_top_k() {
525 let store = test_store();
526 let index = SemanticIndex::build(
527 &store,
528 Box::new(MockEmbeddingProvider::new(32)),
529 IndexConfig::default(),
530 )
531 .unwrap();
532 let results = index.search("anything", 2).unwrap();
533 assert!(results.len() <= 2);
534 }
535
536 #[test]
537 fn search_self_is_top_result() {
538 let store = test_store();
539 let index = SemanticIndex::build(
540 &store,
541 Box::new(MockEmbeddingProvider::new(128)),
542 IndexConfig::default(),
543 )
544 .unwrap();
545 let results = index.search("convert text to number", 3).unwrap();
547 assert!(!results.is_empty());
548 let top = &results[0];
551 assert!(
552 top.semantic_score > 0.9,
553 "Expected high semantic score for exact match, got {}",
554 top.semantic_score
555 );
556 }
557
558 #[test]
559 fn tombstoned_stages_not_indexed() {
560 let mut store = MemoryStore::new();
561 let mut s = make_stage("s1", "active stage", NType::Text, NType::Text);
562 store.put(s.clone()).unwrap();
563 s.id = StageId("s2".into());
564 s.description = "tombstoned stage".into();
565 s.lifecycle = StageLifecycle::Tombstone;
566 store.put(s).unwrap();
567
568 let index = SemanticIndex::build(
569 &store,
570 Box::new(MockEmbeddingProvider::new(32)),
571 IndexConfig::default(),
572 )
573 .unwrap();
574 assert_eq!(index.len(), 1);
575 }
576
577 #[test]
578 fn search_by_tag_returns_matching_stages() {
579 let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
580 s1.tags = vec!["network".into(), "io".into()];
581 let mut s2 = make_stage("s2", "text length", NType::Text, NType::Number);
582 s2.tags = vec!["text".into(), "pure".into()];
583
584 let stages = vec![s1, s2];
585 let index = SemanticIndex::from_stages(
586 stages,
587 Box::new(MockEmbeddingProvider::new(32)),
588 IndexConfig::default(),
589 )
590 .unwrap();
591
592 let network_ids = index.search_by_tag("network");
593 assert_eq!(network_ids.len(), 1);
594 assert_eq!(network_ids[0], StageId("s1".into()));
595
596 let pure_ids = index.search_by_tag("pure");
597 assert_eq!(pure_ids.len(), 1);
598 assert_eq!(pure_ids[0], StageId("s2".into()));
599
600 let missing = index.search_by_tag("nonexistent");
601 assert!(missing.is_empty());
602 }
603
604 #[test]
605 fn all_tags_returns_sorted_set() {
606 let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
607 s1.tags = vec!["zebra".into(), "apple".into()];
608 let index = SemanticIndex::from_stages(
609 vec![s1],
610 Box::new(MockEmbeddingProvider::new(32)),
611 IndexConfig::default(),
612 )
613 .unwrap();
614 let tags = index.all_tags();
615 assert_eq!(tags, vec!["apple", "zebra"]);
616 }
617
618 #[test]
619 fn search_filtered_restricts_to_tag() {
620 let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
621 s1.tags = vec!["network".into()];
622 let s2 = make_stage("s2", "sort list", NType::Text, NType::Text);
623
624 let stages = vec![s1, s2];
625 let index = SemanticIndex::from_stages(
626 stages,
627 Box::new(MockEmbeddingProvider::new(32)),
628 IndexConfig::default(),
629 )
630 .unwrap();
631
632 let filtered = index
633 .search_filtered("anything", 10, Some("network"))
634 .unwrap();
635 assert!(filtered.iter().all(|r| r.stage_id == StageId("s1".into())));
636
637 let all = index.search_filtered("anything", 10, None).unwrap();
638 assert_eq!(all.len(), 2);
639 }
640
641 #[test]
642 fn remove_stage_cleans_tag_map() {
643 let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
644 s1.tags = vec!["mytag".into()];
645 let mut index = SemanticIndex::from_stages(
646 vec![s1],
647 Box::new(MockEmbeddingProvider::new(32)),
648 IndexConfig::default(),
649 )
650 .unwrap();
651 assert_eq!(index.search_by_tag("mytag").len(), 1);
652 index.remove_stage(&StageId("s1".into()));
653 assert!(index.search_by_tag("mytag").is_empty());
654 assert!(index.all_tags().is_empty());
655 }
656}