1use std::collections::HashMap;
41
42use crate::category::BridgeClassification;
43use crate::config::{PipelineConfig, ProjectionKind};
44use crate::configured_projection::ConfiguredProjection;
45use crate::corpus_quality::{CorpusQuality, CorpusQualityBreakdown};
46use crate::navigator::curvature_analysis;
47use crate::pipeline::{PipelineInput, SphereQLPipeline};
48use crate::projection::PcaProjection;
49use crate::quality_metric::QualityMetric;
50use crate::types::{Embedding, RadialStrategy};
51
52const QUALITY_WEIGHT_FLOOR: f64 = 0.05;
58
59#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
65pub struct TunableConcept {
66 pub label: String,
67 pub category: String,
68 pub features: Vec<(usize, f64)>,
69 pub quality: f64,
70 pub axis_coherence: f64,
71 pub bridge_degree: u8,
72 pub source_confidence: f64,
73 pub home_affinity: f64,
74 pub source: Option<String>,
75 pub openalex_id: Option<String>,
76}
77
78#[derive(Debug, Clone)]
84pub struct SelfTuneIteration {
85 pub iteration: usize,
86 pub n_concepts: usize,
87 pub composite_score: f64,
92 pub breakdown: CorpusQualityBreakdown,
93 pub n_pruned: usize,
94 pub mean_quality: f64,
95 pub mean_quality_delta: f64,
96}
97
98#[derive(Debug, Clone, Copy)]
100pub enum StopReason {
101 Plateau,
103 MaxIterations,
105 PruneFloorHit,
109}
110
111#[derive(Debug, Clone)]
113pub struct SelfTuneReport {
114 pub iterations: Vec<SelfTuneIteration>,
115 pub stopped_reason: StopReason,
116 pub final_composite: Option<f64>,
122}
123
124#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
126#[serde(default)]
127pub struct SelfTuneConfig {
128 pub max_iterations: usize,
129 pub plateau_epsilon: f64,
130 pub min_quality_to_keep: f64,
131 pub min_concepts_per_category: usize,
132 pub bridge_genuine_boost: f64,
133 pub bridge_artifact_penalty: f64,
134 pub curvature_outlier_penalty: f64,
135 pub curvature_z_threshold: f64,
136 pub home_affinity_smoothing: f64,
137 pub source_confidence_smoothing: f64,
138}
139
140impl Default for SelfTuneConfig {
141 fn default() -> Self {
142 Self {
143 max_iterations: 10,
144 plateau_epsilon: 0.001,
145 min_quality_to_keep: 0.3,
146 min_concepts_per_category: 50,
147 bridge_genuine_boost: 1.05,
148 bridge_artifact_penalty: 0.85,
149 curvature_outlier_penalty: 0.9,
150 curvature_z_threshold: 1.5,
151 home_affinity_smoothing: 0.5,
152 source_confidence_smoothing: 0.8,
153 }
154 }
155}
156
157impl SelfTuneConfig {
158 pub fn validate(&self) -> Result<(), String> {
161 fn unit(name: &str, v: f64) -> Result<(), String> {
162 if (0.0..=1.0).contains(&v) {
163 Ok(())
164 } else {
165 Err(format!("{name} must be in [0, 1], got {v}"))
166 }
167 }
168 unit("home_affinity_smoothing", self.home_affinity_smoothing)?;
169 unit(
170 "source_confidence_smoothing",
171 self.source_confidence_smoothing,
172 )?;
173 unit("bridge_artifact_penalty", self.bridge_artifact_penalty)?;
174 unit("curvature_outlier_penalty", self.curvature_outlier_penalty)?;
175 if !self.bridge_genuine_boost.is_finite() || self.bridge_genuine_boost < 1.0 {
176 return Err(format!(
177 "bridge_genuine_boost must be >= 1.0, got {}",
178 self.bridge_genuine_boost
179 ));
180 }
181 if !self.plateau_epsilon.is_finite() || self.plateau_epsilon < 0.0 {
182 return Err(format!(
183 "plateau_epsilon must be finite and >= 0.0, got {}",
184 self.plateau_epsilon
185 ));
186 }
187 if self.max_iterations < 1 {
188 return Err("max_iterations must be >= 1".into());
189 }
190 Ok(())
191 }
192}
193
194pub fn run_self_tune<F>(
207 mut corpus: Vec<TunableConcept>,
208 embed_fn: F,
209 base_pipeline_config: PipelineConfig,
210 quality: &CorpusQuality,
211 cfg: &SelfTuneConfig,
212) -> Result<(Vec<TunableConcept>, SelfTuneReport), String>
213where
214 F: Fn(&[(usize, f64)]) -> Vec<f64>,
215{
216 cfg.validate()?;
217
218 let mut bases: Vec<f64> = corpus.iter().map(|c| c.quality).collect();
223
224 let mut iterations: Vec<SelfTuneIteration> = Vec::new();
225 let mut stopped = StopReason::MaxIterations;
226
227 for iter in 0..cfg.max_iterations {
228 if corpus.is_empty() {
229 stopped = StopReason::PruneFloorHit;
230 break;
231 }
232
233 let pipeline = match build_pipeline(&corpus, &embed_fn, &base_pipeline_config) {
234 Some(p) => p,
235 None => {
236 stopped = StopReason::PruneFloorHit;
241 break;
242 }
243 };
244
245 let composite = quality.score(&pipeline);
246 let breakdown = quality
247 .last_breakdown()
248 .expect("CorpusQuality::score populates last_breakdown");
249
250 let n_before = corpus.len();
251 let pre_mean_q: f64 = if n_before == 0 {
252 0.0
253 } else {
254 corpus.iter().map(|c| c.quality).sum::<f64>() / n_before as f64
255 };
256
257 if iter >= 1 {
262 let prev = iterations[iter - 1].composite_score;
263 if (composite - prev).abs() < cfg.plateau_epsilon {
264 iterations.push(SelfTuneIteration {
265 iteration: iter,
266 n_concepts: n_before,
267 composite_score: composite,
268 breakdown,
269 n_pruned: 0,
270 mean_quality: pre_mean_q,
271 mean_quality_delta: 0.0,
272 });
273 stopped = StopReason::Plateau;
274 break;
275 }
276 }
277
278 reweight_from_base(&mut corpus, &bases, &pipeline, cfg);
279 let n_pruned = prune_below_floor_synced(&mut corpus, &mut bases, cfg);
280
281 let n_after = corpus.len().max(1) as f64;
282 let post_mean_q: f64 = corpus.iter().map(|c| c.quality).sum::<f64>() / n_after;
283
284 iterations.push(SelfTuneIteration {
285 iteration: iter,
286 n_concepts: n_before,
287 composite_score: composite,
288 breakdown,
289 n_pruned,
290 mean_quality: post_mean_q,
291 mean_quality_delta: post_mean_q - pre_mean_q,
292 });
293 }
294
295 let final_composite =
299 build_pipeline(&corpus, &embed_fn, &base_pipeline_config).map(|p| quality.score(&p));
300
301 Ok((
302 corpus,
303 SelfTuneReport {
304 iterations,
305 stopped_reason: stopped,
306 final_composite,
307 },
308 ))
309}
310
311fn build_pipeline<F>(
314 corpus: &[TunableConcept],
315 embed_fn: &F,
316 config: &PipelineConfig,
317) -> Option<SphereQLPipeline>
318where
319 F: Fn(&[(usize, f64)]) -> Vec<f64>,
320{
321 if corpus.len() < 3 {
322 return None;
323 }
324 let categories: Vec<String> = corpus.iter().map(|c| c.category.clone()).collect();
325 let embeddings: Vec<Embedding> = corpus
326 .iter()
327 .map(|c| Embedding::new(embed_fn(&c.features)))
328 .collect();
329
330 if config.projection_kind == ProjectionKind::Pca {
331 let mut cat_counts: HashMap<&str, usize> = HashMap::new();
339 for c in corpus {
340 *cat_counts.entry(c.category.as_str()).or_default() += 1;
341 }
342 let weights: Vec<f64> = corpus
343 .iter()
344 .map(|c| {
345 c.quality.max(QUALITY_WEIGHT_FLOOR)
346 / (cat_counts[c.category.as_str()] as f64).sqrt()
347 })
348 .collect();
349 let pca = PcaProjection::fit_weighted(&embeddings, &weights, RadialStrategy::Magnitude)
350 .ok()?
351 .with_volumetric(true);
352 SphereQLPipeline::with_configured_projection_and_config(
353 categories,
354 embeddings,
355 ConfiguredProjection::Pca(pca),
356 config.clone(),
357 )
358 .ok()
359 } else {
360 let raw: Vec<Vec<f64>> = embeddings.into_iter().map(|e| e.values).collect();
363 SphereQLPipeline::new_with_config(
364 PipelineInput {
365 categories,
366 embeddings: raw,
367 },
368 config.clone(),
369 )
370 .ok()
371 }
372}
373
374pub fn reweight_in_place(
381 corpus: &mut [TunableConcept],
382 pipeline: &SphereQLPipeline,
383 cfg: &SelfTuneConfig,
384) {
385 let bases: Vec<f64> = corpus.iter().map(|c| c.quality).collect();
386 reweight_from_base(corpus, &bases, pipeline, cfg);
387}
388
389fn reweight_from_base(
394 corpus: &mut [TunableConcept],
395 bases: &[f64],
396 pipeline: &SphereQLPipeline,
397 cfg: &SelfTuneConfig,
398) {
399 debug_assert_eq!(
400 corpus.len(),
401 bases.len(),
402 "bases must stay index-parallel to corpus"
403 );
404 let bridge_map = build_bridge_map(pipeline);
405 let curvature_map = build_curvature_map(pipeline);
406
407 for (i, concept) in corpus.iter_mut().enumerate() {
408 let mut q = bases[i];
409
410 if let Some(cls) = bridge_map.get(&i) {
412 match cls {
413 BridgeClassification::Genuine => q *= cfg.bridge_genuine_boost,
414 BridgeClassification::OverlapArtifact | BridgeClassification::Weak => {
415 q *= cfg.bridge_artifact_penalty;
416 }
417 }
418 }
419
420 if let Some(z) = curvature_map.get(concept.category.as_str())
424 && z.abs() > cfg.curvature_z_threshold
425 {
426 q *= cfg.curvature_outlier_penalty;
427 }
428
429 q *= cfg.home_affinity_smoothing
431 + (1.0 - cfg.home_affinity_smoothing) * concept.home_affinity;
432
433 q *= cfg.source_confidence_smoothing
435 + (1.0 - cfg.source_confidence_smoothing) * concept.source_confidence;
436
437 concept.quality = q.clamp(0.0, 1.0);
438 }
439}
440
441fn build_bridge_map(pipeline: &SphereQLPipeline) -> HashMap<usize, BridgeClassification> {
442 let layer = pipeline.category_layer();
443 let mut out = HashMap::new();
444 for bridges in layer.graph.bridges.values() {
445 for b in bridges {
446 out.insert(b.item_index, b.classification);
447 }
448 }
449 out
450}
451
452fn build_curvature_map(pipeline: &SphereQLPipeline) -> HashMap<String, f64> {
453 let layer = pipeline.category_layer();
454 if layer.num_categories() < 3 {
455 return HashMap::new();
456 }
457 let report = curvature_analysis(layer, 0);
458 report
459 .signatures
460 .into_iter()
461 .map(|s| (s.category_name, s.mean_excess_z))
462 .collect()
463}
464
465fn prune_mask(corpus: &[TunableConcept], cfg: &SelfTuneConfig) -> (Vec<bool>, usize) {
469 let mut indices: Vec<usize> = (0..corpus.len()).collect();
470 indices.sort_by(|a, b| corpus[*a].quality.total_cmp(&corpus[*b].quality));
471
472 let mut counts: HashMap<String, usize> = HashMap::new();
473 for c in corpus.iter() {
474 *counts.entry(c.category.clone()).or_insert(0) += 1;
475 }
476
477 let mut to_remove: Vec<bool> = vec![false; corpus.len()];
478 let mut removed = 0usize;
479 for i in indices {
480 let c = &corpus[i];
481 if c.quality >= cfg.min_quality_to_keep {
482 break;
483 }
484 let count = *counts.get(c.category.as_str()).unwrap_or(&0);
485 if count <= cfg.min_concepts_per_category {
486 continue;
487 }
488 to_remove[i] = true;
489 counts.insert(c.category.clone(), count - 1);
490 removed += 1;
491 }
492 (to_remove, removed)
493}
494
495fn apply_mask<T>(v: &mut Vec<T>, mask: &[bool]) {
497 let mut i = 0;
498 v.retain(|_| {
499 let rm = mask[i];
500 i += 1;
501 !rm
502 });
503}
504
505pub fn prune_below_floor(corpus: &mut Vec<TunableConcept>, cfg: &SelfTuneConfig) -> usize {
509 if corpus.is_empty() {
510 return 0;
511 }
512 let (mask, removed) = prune_mask(corpus, cfg);
513 if removed == 0 {
514 return 0;
515 }
516 apply_mask(corpus, &mask);
517 removed
518}
519
520fn prune_below_floor_synced(
523 corpus: &mut Vec<TunableConcept>,
524 bases: &mut Vec<f64>,
525 cfg: &SelfTuneConfig,
526) -> usize {
527 if corpus.is_empty() {
528 return 0;
529 }
530 let (mask, removed) = prune_mask(corpus, cfg);
531 if removed == 0 {
532 return 0;
533 }
534 apply_mask(corpus, &mask);
535 apply_mask(bases, &mask);
536 removed
537}
538
539#[cfg(test)]
542mod tests {
543 use super::*;
544
545 fn synthetic_concept(
546 label: &str,
547 category: &str,
548 quality: f64,
549 home_affinity: f64,
550 source_confidence: f64,
551 ) -> TunableConcept {
552 TunableConcept {
553 label: label.into(),
554 category: category.into(),
555 features: vec![(0, 1.0), (1, 0.5)],
556 quality,
557 axis_coherence: 0.7,
558 bridge_degree: 1,
559 source_confidence,
560 home_affinity,
561 source: Some("synthetic".into()),
562 openalex_id: None,
563 }
564 }
565
566 fn synthetic_corpus(n_cats: usize, n_per: usize, dim: usize) -> Vec<TunableConcept> {
569 let mut corpus = Vec::with_capacity(n_per * n_cats);
570 for c in 0..n_cats {
571 for r in 0..n_per {
572 corpus.push(TunableConcept {
573 label: format!("c{c}_r{r}"),
574 category: format!("cat_{c}"),
575 features: vec![(c % dim, 1.0)],
576 quality: 0.8,
577 axis_coherence: 0.7,
578 bridge_degree: 1,
579 source_confidence: 0.6,
580 home_affinity: 0.8,
581 source: Some("synthetic".into()),
582 openalex_id: None,
583 });
584 }
585 }
586 corpus
587 }
588
589 fn dense_embed(dim: usize) -> impl Fn(&[(usize, f64)]) -> Vec<f64> {
590 move |feats: &[(usize, f64)]| -> Vec<f64> {
591 let mut v = vec![0.0_f64; dim];
592 for &(axis, w) in feats {
593 if axis < dim {
594 v[axis] = w;
595 }
596 }
597 v
598 }
599 }
600
601 #[test]
602 fn prune_respects_category_floor() {
603 let mut corpus: Vec<TunableConcept> = (0..60)
607 .map(|i| synthetic_concept(&format!("a{i}"), "x", 0.1, 0.5, 0.5))
608 .collect();
609 corpus.extend((0..60).map(|i| synthetic_concept(&format!("b{i}"), "y", 0.9, 0.9, 0.9)));
610 let cfg = SelfTuneConfig {
611 min_quality_to_keep: 0.5,
612 min_concepts_per_category: 50,
613 ..Default::default()
614 };
615 let pruned = prune_below_floor(&mut corpus, &cfg);
616 let counts: HashMap<String, usize> = corpus.iter().fold(HashMap::new(), |mut acc, c| {
617 *acc.entry(c.category.clone()).or_insert(0) += 1;
618 acc
619 });
620 assert_eq!(counts["x"], 50);
621 assert_eq!(counts["y"], 60);
622 assert_eq!(pruned, 10);
623 }
624
625 #[test]
626 fn prune_skips_when_quality_above_floor() {
627 let mut corpus: Vec<TunableConcept> = (0..100)
628 .map(|i| synthetic_concept(&format!("a{i}"), "x", 0.9, 0.9, 0.9))
629 .collect();
630 let cfg = SelfTuneConfig::default();
631 let pruned = prune_below_floor(&mut corpus, &cfg);
632 assert_eq!(pruned, 0);
633 assert_eq!(corpus.len(), 100);
634 }
635
636 #[test]
637 fn prune_synced_keeps_bases_aligned() {
638 let mut corpus: Vec<TunableConcept> = (0..60)
639 .map(|i| synthetic_concept(&format!("a{i}"), "x", 0.1, 0.5, 0.5))
640 .collect();
641 corpus.extend((0..60).map(|i| synthetic_concept(&format!("b{i}"), "y", 0.9, 0.9, 0.9)));
642 let mut bases: Vec<f64> = (0..corpus.len()).map(|i| i as f64).collect();
644 let cfg = SelfTuneConfig {
645 min_quality_to_keep: 0.5,
646 min_concepts_per_category: 50,
647 ..Default::default()
648 };
649 let pruned = prune_below_floor_synced(&mut corpus, &mut bases, &cfg);
650 assert_eq!(pruned, 10);
651 assert_eq!(corpus.len(), bases.len());
652 for (c, &b) in corpus.iter().zip(bases.iter()) {
655 if c.category == "y" {
656 assert!((60.0..120.0).contains(&b), "base {b} misaligned");
657 }
658 }
659 }
660
661 #[test]
662 fn home_affinity_zero_halves_quality() {
663 let cfg = SelfTuneConfig::default();
665 let pre = 1.0_f64;
666 let post =
667 pre * (cfg.home_affinity_smoothing + (1.0 - cfg.home_affinity_smoothing) * 0.0_f64);
668 assert!((post - cfg.home_affinity_smoothing).abs() < 1e-12);
669 }
670
671 #[test]
672 fn source_confidence_zero_attenuates_to_smoothing() {
673 let cfg = SelfTuneConfig::default();
674 let pre = 1.0_f64;
675 let post = pre
676 * (cfg.source_confidence_smoothing + (1.0 - cfg.source_confidence_smoothing) * 0.0_f64);
677 assert!((post - cfg.source_confidence_smoothing).abs() < 1e-12);
678 }
679
680 #[test]
681 fn reweight_from_base_is_idempotent() {
682 let dim = 16usize;
688 let mut corpus = synthetic_corpus(6, 8, dim);
689 let embed_fn = dense_embed(dim);
690 let pipeline = build_pipeline(&corpus, &embed_fn, &PipelineConfig::default())
691 .expect("pipeline should build");
692 let cfg = SelfTuneConfig::default();
693 let bases: Vec<f64> = corpus.iter().map(|c| c.quality).collect();
694
695 reweight_from_base(&mut corpus, &bases, &pipeline, &cfg);
696 let after_once: Vec<f64> = corpus.iter().map(|c| c.quality).collect();
697 reweight_from_base(&mut corpus, &bases, &pipeline, &cfg);
698 let after_twice: Vec<f64> = corpus.iter().map(|c| c.quality).collect();
699
700 assert_eq!(after_once, after_twice, "reweight must be idempotent");
701 assert!(after_once.iter().zip(bases.iter()).any(|(a, b)| a < b));
704 }
705
706 #[test]
707 fn build_pipeline_handles_zero_quality_floor() {
708 let dim = 16usize;
711 let mut corpus = synthetic_corpus(6, 8, dim);
712 for c in corpus.iter_mut() {
713 c.quality = 0.0;
714 }
715 let embed_fn = dense_embed(dim);
716 let pipeline = build_pipeline(&corpus, &embed_fn, &PipelineConfig::default());
717 assert!(pipeline.is_some(), "floored weights must keep fit viable");
718 }
719
720 #[test]
721 fn run_self_tune_returns_mutated_corpus_and_report() {
722 let dim = 16usize;
723 let corpus = synthetic_corpus(6, 8, dim);
724 let n_total = corpus.len();
725 let cfg = SelfTuneConfig {
726 max_iterations: 3,
727 min_quality_to_keep: 0.0,
730 min_concepts_per_category: 1,
731 ..Default::default()
732 };
733 let metric = CorpusQuality::default();
734 let embed_fn = dense_embed(dim);
735
736 let (out, report) =
737 run_self_tune(corpus, embed_fn, PipelineConfig::default(), &metric, &cfg)
738 .expect("default-derived config is valid");
739
740 assert!(!report.iterations.is_empty());
741 assert_eq!(out.len(), n_total);
742 for it in &report.iterations {
743 assert!((0.0..=1.0).contains(&it.composite_score));
744 assert!((0.0..=1.0).contains(&it.mean_quality));
745 assert!((0.0..=1.0).contains(&it.breakdown.evr));
746 }
747 let final_score = report.final_composite.expect("final corpus is buildable");
750 assert!((0.0..=1.0).contains(&final_score));
751 }
752
753 #[test]
754 fn run_self_tune_quality_does_not_collapse_across_iterations() {
755 let dim = 16usize;
762 let corpus = synthetic_corpus(6, 8, dim);
763 let cfg = SelfTuneConfig {
764 max_iterations: 4,
765 min_quality_to_keep: 0.0,
766 min_concepts_per_category: 1,
767 plateau_epsilon: 0.0,
769 ..Default::default()
770 };
771 let metric = CorpusQuality::default();
772 let embed_fn = dense_embed(dim);
773
774 let (_, report) = run_self_tune(corpus, embed_fn, PipelineConfig::default(), &metric, &cfg)
775 .expect("default-derived config is valid");
776
777 for it in report.iterations.iter().skip(1) {
781 assert!(
782 it.mean_quality_delta.abs() < 0.02,
783 "iteration {} mean_quality_delta {} suggests compounding",
784 it.iteration,
785 it.mean_quality_delta
786 );
787 }
788 }
789
790 #[test]
791 fn validate_rejects_out_of_range_smoothing() {
792 let low = SelfTuneConfig {
793 home_affinity_smoothing: -0.1,
794 ..Default::default()
795 };
796 assert!(low.validate().is_err());
797
798 let high = SelfTuneConfig {
799 home_affinity_smoothing: 1.5,
800 ..Default::default()
801 };
802 assert!(high.validate().is_err());
803
804 assert!(SelfTuneConfig::default().validate().is_ok());
805 }
806
807 #[test]
808 fn run_self_tune_surfaces_invalid_config() {
809 let dim = 16usize;
810 let corpus = synthetic_corpus(6, 8, dim);
811 let cfg = SelfTuneConfig {
812 home_affinity_smoothing: 1.5,
813 ..Default::default()
814 };
815 let metric = CorpusQuality::default();
816 let err = run_self_tune(
817 corpus,
818 dense_embed(dim),
819 PipelineConfig::default(),
820 &metric,
821 &cfg,
822 )
823 .expect_err("out-of-range smoothing must be rejected");
824 assert!(err.contains("home_affinity_smoothing"));
825 }
826
827 #[test]
828 fn plateau_iteration_does_not_mutate_corpus() {
829 let dim = 16usize;
834 let corpus = synthetic_corpus(6, 8, dim);
835 let embed_fn = dense_embed(dim);
836 let cfg = SelfTuneConfig {
837 max_iterations: 5,
838 min_quality_to_keep: 0.0,
839 min_concepts_per_category: 1,
840 plateau_epsilon: 1.0,
841 ..Default::default()
842 };
843 let metric = CorpusQuality::default();
844
845 let mut expected = corpus.clone();
846 let pipeline = build_pipeline(&expected, &embed_fn, &PipelineConfig::default())
847 .expect("pipeline should build");
848 let bases: Vec<f64> = expected.iter().map(|c| c.quality).collect();
849 reweight_from_base(&mut expected, &bases, &pipeline, &cfg);
850
851 let (out, report) =
852 run_self_tune(corpus, embed_fn, PipelineConfig::default(), &metric, &cfg)
853 .expect("config is valid");
854
855 assert!(matches!(report.stopped_reason, StopReason::Plateau));
856 assert_eq!(report.iterations.len(), 2);
857 let plateau_it = report.iterations.last().unwrap();
858 assert_eq!(plateau_it.n_pruned, 0);
859 assert_eq!(plateau_it.mean_quality_delta, 0.0);
860
861 let got: Vec<f64> = out.iter().map(|c| c.quality).collect();
862 let want: Vec<f64> = expected.iter().map(|c| c.quality).collect();
863 assert_eq!(
864 got, want,
865 "plateau iteration must leave qualities untouched"
866 );
867 }
868}