1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4pub mod cache;
5pub mod embedding;
6pub mod search;
7pub mod text;
8
9use embedding::{EmbeddingError, EmbeddingProvider};
10use noether_core::stage::{Stage, StageId, StageLifecycle};
11use noether_store::StageStore;
12use search::SubIndex;
13use std::collections::BTreeMap;
14use std::collections::HashMap;
15
16pub struct IndexConfig {
18 pub signature_weight: f32,
20 pub semantic_weight: f32,
22 pub example_weight: f32,
24}
25
26impl Default for IndexConfig {
27 fn default() -> Self {
28 Self {
29 signature_weight: 0.3,
30 semantic_weight: 0.5,
31 example_weight: 0.2,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct SearchResult {
39 pub stage_id: StageId,
40 pub score: f32,
41 pub signature_score: f32,
42 pub semantic_score: f32,
43 pub example_score: f32,
44}
45
46pub struct SemanticIndex {
48 provider: Box<dyn EmbeddingProvider>,
49 signature_index: SubIndex,
50 semantic_index: SubIndex,
51 example_index: SubIndex,
52 config: IndexConfig,
53 tag_map: HashMap<String, Vec<StageId>>,
55}
56
57impl SemanticIndex {
58 pub fn from_stages(
61 stages: Vec<Stage>,
62 provider: Box<dyn EmbeddingProvider>,
63 config: IndexConfig,
64 ) -> Result<Self, EmbeddingError> {
65 let mut index = Self {
66 provider,
67 signature_index: SubIndex::new(),
68 semantic_index: SubIndex::new(),
69 example_index: SubIndex::new(),
70 config,
71 tag_map: HashMap::new(),
72 };
73 for stage in &stages {
74 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
75 continue;
76 }
77 index.add_stage(stage)?;
78 }
79 Ok(index)
80 }
81
82 pub fn build(
84 store: &dyn StageStore,
85 provider: Box<dyn EmbeddingProvider>,
86 config: IndexConfig,
87 ) -> Result<Self, EmbeddingError> {
88 let mut index = Self {
89 provider,
90 signature_index: SubIndex::new(),
91 semantic_index: SubIndex::new(),
92 example_index: SubIndex::new(),
93 config,
94 tag_map: HashMap::new(),
95 };
96 for stage in store.list(None) {
97 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
98 continue;
99 }
100 index.add_stage(stage)?;
101 }
102 Ok(index)
103 }
104
105 pub fn from_stages_batched(
112 stages: Vec<Stage>,
113 cached_provider: cache::CachedEmbeddingProvider,
114 config: IndexConfig,
115 chunk_size: usize,
116 ) -> Result<Self, EmbeddingError> {
117 Self::from_stages_batched_paced(
118 stages,
119 cached_provider,
120 config,
121 chunk_size,
122 std::time::Duration::ZERO,
123 )
124 }
125
126 pub fn from_stages_batched_paced(
131 stages: Vec<Stage>,
132 mut cached_provider: cache::CachedEmbeddingProvider,
133 config: IndexConfig,
134 chunk_size: usize,
135 inter_batch_delay: std::time::Duration,
136 ) -> Result<Self, EmbeddingError> {
137 let active: Vec<&Stage> = stages
139 .iter()
140 .filter(|s| !matches!(s.lifecycle, StageLifecycle::Tombstone))
141 .collect();
142
143 let mut all_texts: Vec<String> = Vec::with_capacity(active.len() * 3);
144 for s in &active {
145 all_texts.push(text::signature_text(s));
146 all_texts.push(text::description_text(s));
147 all_texts.push(text::examples_text(s));
148 }
149 let text_refs: Vec<&str> = all_texts.iter().map(|s| s.as_str()).collect();
150 let embeddings =
151 cached_provider.embed_batch_cached_paced(&text_refs, chunk_size, inter_batch_delay)?;
152 cached_provider.flush();
153
154 let mut signature_index = SubIndex::new();
156 let mut semantic_index = SubIndex::new();
157 let mut example_index = SubIndex::new();
158 let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
159
160 for (i, s) in active.iter().enumerate() {
161 signature_index.add(s.id.clone(), embeddings[i * 3].clone());
162 semantic_index.add(s.id.clone(), embeddings[i * 3 + 1].clone());
163 example_index.add(s.id.clone(), embeddings[i * 3 + 2].clone());
164 for tag in &s.tags {
165 tag_map.entry(tag.clone()).or_default().push(s.id.clone());
166 }
167 }
168
169 Ok(Self {
170 provider: Box::new(cached_provider),
171 signature_index,
172 semantic_index,
173 example_index,
174 config,
175 tag_map,
176 })
177 }
178
179 pub fn build_cached(
181 store: &dyn StageStore,
182 mut cached_provider: cache::CachedEmbeddingProvider,
183 config: IndexConfig,
184 ) -> Result<Self, EmbeddingError> {
185 let mut signature_index = SubIndex::new();
186 let mut semantic_index = SubIndex::new();
187 let mut example_index = SubIndex::new();
188 let mut tag_map: HashMap<String, Vec<StageId>> = HashMap::new();
189
190 for stage in store.list(None) {
191 if matches!(stage.lifecycle, StageLifecycle::Tombstone) {
192 continue;
193 }
194 let sig_emb = cached_provider.embed_cached(&text::signature_text(stage))?;
195 let desc_emb = cached_provider.embed_cached(&text::description_text(stage))?;
196 let ex_emb = cached_provider.embed_cached(&text::examples_text(stage))?;
197
198 signature_index.add(stage.id.clone(), sig_emb);
199 semantic_index.add(stage.id.clone(), desc_emb);
200 example_index.add(stage.id.clone(), ex_emb);
201
202 for tag in &stage.tags {
203 tag_map
204 .entry(tag.clone())
205 .or_default()
206 .push(stage.id.clone());
207 }
208 }
209
210 cached_provider.flush();
211
212 let provider: Box<dyn EmbeddingProvider> = Box::new(cached_provider);
214
215 Ok(Self {
216 provider,
217 signature_index,
218 semantic_index,
219 example_index,
220 config,
221 tag_map,
222 })
223 }
224
225 pub fn add_stage(&mut self, stage: &Stage) -> Result<(), EmbeddingError> {
227 let sig_text = text::signature_text(stage);
228 let desc_text = text::description_text(stage);
229 let ex_text = text::examples_text(stage);
230
231 let sig_emb = self.provider.embed(&sig_text)?;
232 let desc_emb = self.provider.embed(&desc_text)?;
233 let ex_emb = self.provider.embed(&ex_text)?;
234
235 self.signature_index.add(stage.id.clone(), sig_emb);
236 self.semantic_index.add(stage.id.clone(), desc_emb);
237 self.example_index.add(stage.id.clone(), ex_emb);
238
239 for tag in &stage.tags {
240 self.tag_map
241 .entry(tag.clone())
242 .or_default()
243 .push(stage.id.clone());
244 }
245
246 Ok(())
247 }
248
249 pub fn remove_stage(&mut self, stage_id: &StageId) {
251 self.signature_index.remove(stage_id);
252 self.semantic_index.remove(stage_id);
253 self.example_index.remove(stage_id);
254
255 for ids in self.tag_map.values_mut() {
256 ids.retain(|id| id != stage_id);
257 }
258 self.tag_map.retain(|_, ids| !ids.is_empty());
259 }
260
261 pub fn len(&self) -> usize {
263 self.signature_index.len()
264 }
265
266 pub fn is_empty(&self) -> bool {
267 self.len() == 0
268 }
269
270 pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, EmbeddingError> {
272 self.search_filtered(query, top_k, None)
273 }
274
275 pub fn search_filtered(
278 &self,
279 query: &str,
280 top_k: usize,
281 tag: Option<&str>,
282 ) -> Result<Vec<SearchResult>, EmbeddingError> {
283 let query_emb = self.provider.embed(query)?;
284 let fetch_k = top_k * 2;
285
286 let sig_results = self.signature_index.search(&query_emb, fetch_k);
287 let sem_results = self.semantic_index.search(&query_emb, fetch_k);
288 let ex_results = self.example_index.search(&query_emb, fetch_k);
289
290 let allowed: Option<std::collections::BTreeSet<&str>> = tag.map(|t| {
292 self.tag_map
293 .get(t)
294 .map(|ids| ids.iter().map(|id| id.0.as_str()).collect())
295 .unwrap_or_default()
296 });
297
298 let mut scores: BTreeMap<String, (f32, f32, f32)> = BTreeMap::new();
300 for r in &sig_results {
301 scores.entry(r.stage_id.0.clone()).or_default().0 = r.score;
302 }
303 for r in &sem_results {
304 scores.entry(r.stage_id.0.clone()).or_default().1 = r.score;
305 }
306 for r in &ex_results {
307 scores.entry(r.stage_id.0.clone()).or_default().2 = r.score;
308 }
309
310 let mut results: Vec<SearchResult> = scores
312 .into_iter()
313 .filter(|(id, _)| {
314 allowed
315 .as_ref()
316 .map(|a| a.contains(id.as_str()))
317 .unwrap_or(true)
318 })
319 .map(|(id, (sig, sem, ex))| {
320 let fused = self.config.signature_weight * sig.max(0.0)
321 + self.config.semantic_weight * sem.max(0.0)
322 + self.config.example_weight * ex.max(0.0);
323 SearchResult {
324 stage_id: StageId(id),
325 score: fused,
326 signature_score: sig,
327 semantic_score: sem,
328 example_score: ex,
329 }
330 })
331 .collect();
332
333 results.sort_by(|a, b| {
334 b.score
335 .partial_cmp(&a.score)
336 .unwrap_or(std::cmp::Ordering::Equal)
337 });
338 results.truncate(top_k);
339 Ok(results)
340 }
341
342 pub fn search_by_tag(&self, tag: &str) -> Vec<StageId> {
344 self.tag_map.get(tag).cloned().unwrap_or_default()
345 }
346
347 pub fn all_tags(&self) -> Vec<String> {
349 let mut tags: Vec<String> = self.tag_map.keys().cloned().collect();
350 tags.sort();
351 tags
352 }
353
354 pub fn check_duplicate_before_insert(
359 &self,
360 description: &str,
361 threshold: f32,
362 ) -> Result<Option<(StageId, f32)>, EmbeddingError> {
363 let emb = self.provider.embed(description)?;
364 let results = self.semantic_index.search(&emb, 1);
365 if let Some(top) = results.first() {
366 if top.score >= threshold {
367 return Ok(Some((top.stage_id.clone(), top.score)));
368 }
369 }
370 Ok(None)
371 }
372
373 pub fn find_near_duplicates(&self, threshold: f32) -> Vec<(StageId, StageId, f32)> {
378 use search::cosine_similarity;
379
380 let entries = self.semantic_index.entries().to_vec();
381 let mut pairs: Vec<(StageId, StageId, f32)> = Vec::new();
382
383 for i in 0..entries.len() {
384 for j in (i + 1)..entries.len() {
385 let sim = cosine_similarity(&entries[i].embedding, &entries[j].embedding);
386 if sim >= threshold {
387 let (a, b) = if entries[i].stage_id.0 < entries[j].stage_id.0 {
388 (entries[i].stage_id.clone(), entries[j].stage_id.clone())
389 } else {
390 (entries[j].stage_id.clone(), entries[i].stage_id.clone())
391 };
392 pairs.push((a, b, sim));
393 }
394 }
395 }
396
397 pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
399 pairs
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use embedding::MockEmbeddingProvider;
407 use noether_core::effects::EffectSet;
408 use noether_core::stage::{CostEstimate, StageSignature};
409 use noether_core::types::NType;
410 use noether_store::MemoryStore;
411 use std::collections::BTreeSet;
412
413 fn make_stage(id: &str, desc: &str, input: NType, output: NType) -> Stage {
414 Stage {
415 id: StageId(id.into()),
416 signature_id: None,
417 signature: StageSignature {
418 input,
419 output,
420 effects: EffectSet::pure(),
421 implementation_hash: format!("impl_{id}"),
422 },
423 capabilities: BTreeSet::new(),
424 cost: CostEstimate {
425 time_ms_p50: None,
426 tokens_est: None,
427 memory_mb: None,
428 },
429 description: desc.into(),
430 examples: vec![],
431 lifecycle: StageLifecycle::Active,
432 ed25519_signature: None,
433 signer_public_key: None,
434 implementation_code: None,
435 implementation_language: None,
436 ui_style: None,
437 tags: vec![],
438 aliases: vec![],
439 name: None,
440 properties: Vec::new(),
441 }
442 }
443
444 fn test_store() -> MemoryStore {
445 let mut store = MemoryStore::new();
446 store
447 .put(make_stage(
448 "s1",
449 "convert text to number",
450 NType::Text,
451 NType::Number,
452 ))
453 .unwrap();
454 store
455 .put(make_stage(
456 "s2",
457 "make http request",
458 NType::Text,
459 NType::Text,
460 ))
461 .unwrap();
462 store
463 .put(make_stage(
464 "s3",
465 "sort a list of items",
466 NType::List(Box::new(NType::Any)),
467 NType::List(Box::new(NType::Any)),
468 ))
469 .unwrap();
470 store
471 }
472
473 #[test]
474 fn build_indexes_all_stages() {
475 let store = test_store();
476 let index = SemanticIndex::build(
477 &store,
478 Box::new(MockEmbeddingProvider::new(32)),
479 IndexConfig::default(),
480 )
481 .unwrap();
482 assert_eq!(index.len(), 3);
483 }
484
485 #[test]
486 fn add_stage_increments_count() {
487 let store = test_store();
488 let mut index = SemanticIndex::build(
489 &store,
490 Box::new(MockEmbeddingProvider::new(32)),
491 IndexConfig::default(),
492 )
493 .unwrap();
494 assert_eq!(index.len(), 3);
495 index
496 .add_stage(&make_stage("s4", "new stage", NType::Bool, NType::Text))
497 .unwrap();
498 assert_eq!(index.len(), 4);
499 }
500
501 #[test]
502 fn remove_stage_decrements_count() {
503 let store = test_store();
504 let mut index = SemanticIndex::build(
505 &store,
506 Box::new(MockEmbeddingProvider::new(32)),
507 IndexConfig::default(),
508 )
509 .unwrap();
510 index.remove_stage(&StageId("s1".into()));
511 assert_eq!(index.len(), 2);
512 }
513
514 #[test]
515 fn search_returns_results() {
516 let store = test_store();
517 let index = SemanticIndex::build(
518 &store,
519 Box::new(MockEmbeddingProvider::new(32)),
520 IndexConfig::default(),
521 )
522 .unwrap();
523 let results = index.search("convert text", 10).unwrap();
524 assert!(!results.is_empty());
525 }
526
527 #[test]
528 fn search_respects_top_k() {
529 let store = test_store();
530 let index = SemanticIndex::build(
531 &store,
532 Box::new(MockEmbeddingProvider::new(32)),
533 IndexConfig::default(),
534 )
535 .unwrap();
536 let results = index.search("anything", 2).unwrap();
537 assert!(results.len() <= 2);
538 }
539
540 #[test]
541 fn search_self_is_top_result() {
542 let store = test_store();
543 let index = SemanticIndex::build(
544 &store,
545 Box::new(MockEmbeddingProvider::new(128)),
546 IndexConfig::default(),
547 )
548 .unwrap();
549 let results = index.search("convert text to number", 3).unwrap();
551 assert!(!results.is_empty());
552 let top = &results[0];
555 assert!(
556 top.semantic_score > 0.9,
557 "Expected high semantic score for exact match, got {}",
558 top.semantic_score
559 );
560 }
561
562 #[test]
563 fn tombstoned_stages_not_indexed() {
564 let mut store = MemoryStore::new();
565 let mut s = make_stage("s1", "active stage", NType::Text, NType::Text);
566 store.put(s.clone()).unwrap();
567 s.id = StageId("s2".into());
568 s.description = "tombstoned stage".into();
569 s.lifecycle = StageLifecycle::Tombstone;
570 store.put(s).unwrap();
571
572 let index = SemanticIndex::build(
573 &store,
574 Box::new(MockEmbeddingProvider::new(32)),
575 IndexConfig::default(),
576 )
577 .unwrap();
578 assert_eq!(index.len(), 1);
579 }
580
581 #[test]
582 fn search_by_tag_returns_matching_stages() {
583 let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
584 s1.tags = vec!["network".into(), "io".into()];
585 let mut s2 = make_stage("s2", "text length", NType::Text, NType::Number);
586 s2.tags = vec!["text".into(), "pure".into()];
587
588 let stages = vec![s1, s2];
589 let index = SemanticIndex::from_stages(
590 stages,
591 Box::new(MockEmbeddingProvider::new(32)),
592 IndexConfig::default(),
593 )
594 .unwrap();
595
596 let network_ids = index.search_by_tag("network");
597 assert_eq!(network_ids.len(), 1);
598 assert_eq!(network_ids[0], StageId("s1".into()));
599
600 let pure_ids = index.search_by_tag("pure");
601 assert_eq!(pure_ids.len(), 1);
602 assert_eq!(pure_ids[0], StageId("s2".into()));
603
604 let missing = index.search_by_tag("nonexistent");
605 assert!(missing.is_empty());
606 }
607
608 #[test]
609 fn all_tags_returns_sorted_set() {
610 let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
611 s1.tags = vec!["zebra".into(), "apple".into()];
612 let index = SemanticIndex::from_stages(
613 vec![s1],
614 Box::new(MockEmbeddingProvider::new(32)),
615 IndexConfig::default(),
616 )
617 .unwrap();
618 let tags = index.all_tags();
619 assert_eq!(tags, vec!["apple", "zebra"]);
620 }
621
622 #[test]
623 fn search_filtered_restricts_to_tag() {
624 let mut s1 = make_stage("s1", "http get request", NType::Text, NType::Text);
625 s1.tags = vec!["network".into()];
626 let s2 = make_stage("s2", "sort list", NType::Text, NType::Text);
627
628 let stages = vec![s1, s2];
629 let index = SemanticIndex::from_stages(
630 stages,
631 Box::new(MockEmbeddingProvider::new(32)),
632 IndexConfig::default(),
633 )
634 .unwrap();
635
636 let filtered = index
637 .search_filtered("anything", 10, Some("network"))
638 .unwrap();
639 assert!(filtered.iter().all(|r| r.stage_id == StageId("s1".into())));
640
641 let all = index.search_filtered("anything", 10, None).unwrap();
642 assert_eq!(all.len(), 2);
643 }
644
645 #[test]
646 fn remove_stage_cleans_tag_map() {
647 let mut s1 = make_stage("s1", "a", NType::Text, NType::Text);
648 s1.tags = vec!["mytag".into()];
649 let mut index = SemanticIndex::from_stages(
650 vec![s1],
651 Box::new(MockEmbeddingProvider::new(32)),
652 IndexConfig::default(),
653 )
654 .unwrap();
655 assert_eq!(index.search_by_tag("mytag").len(), 1);
656 index.remove_stage(&StageId("s1".into()));
657 assert!(index.search_by_tag("mytag").is_empty());
658 assert!(index.all_tags().is_empty());
659 }
660}