use crate::error::{CrvError, CrvResult};
use crate::types::{CrossReference, CrvConfig, SignalLineProbe, StageVData};
use ruvector_gnn::search::{cosine_similarity, differentiable_search};
#[derive(Debug, Clone)]
pub struct StageVEngine {
dim: usize,
temperature: f32,
}
impl StageVEngine {
pub fn new(config: &CrvConfig) -> Self {
Self {
dim: config.dimensions,
temperature: config.search_temperature,
}
}
pub fn probe(
&self,
query_embedding: &[f32],
candidates: &[Vec<f32>],
k: usize,
) -> CrvResult<SignalLineProbe> {
if candidates.is_empty() {
return Err(CrvError::EmptyInput(
"No candidates for probing".to_string(),
));
}
let (top_candidates, attention_weights) =
differentiable_search(query_embedding, candidates, k, self.temperature);
Ok(SignalLineProbe {
query: String::new(), target_stage: 0, attention_weights,
top_candidates,
})
}
pub fn cross_reference(
&self,
from_stage: u8,
from_entries: &[Vec<f32>],
to_stage: u8,
to_entries: &[Vec<f32>],
threshold: f32,
) -> Vec<CrossReference> {
let mut refs = Vec::new();
for (from_idx, from_emb) in from_entries.iter().enumerate() {
for (to_idx, to_emb) in to_entries.iter().enumerate() {
if from_emb.len() == to_emb.len() {
let score = cosine_similarity(from_emb, to_emb);
if score >= threshold {
refs.push(CrossReference {
from_stage,
from_entry: from_idx,
to_stage,
to_entry: to_idx,
score,
});
}
}
}
}
refs.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
refs
}
pub fn encode(&self, data: &StageVData, all_embeddings: &[Vec<f32>]) -> CrvResult<Vec<f32>> {
if data.probes.is_empty() {
return Err(CrvError::EmptyInput("No probes in Stage V data".to_string()));
}
let mut embedding = vec![0.0f32; self.dim];
for probe in &data.probes {
for (&candidate_idx, &weight) in probe
.top_candidates
.iter()
.zip(probe.attention_weights.iter())
{
if candidate_idx < all_embeddings.len() {
let emb = &all_embeddings[candidate_idx];
for (i, &v) in emb.iter().enumerate() {
if i < self.dim {
embedding[i] += v * weight;
}
}
}
}
}
let num_probes = data.probes.len() as f32;
for v in &mut embedding {
*v /= num_probes;
}
Ok(embedding)
}
pub fn signal_strength(&self, embedding: &[f32]) -> f32 {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
norm
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> CrvConfig {
CrvConfig {
dimensions: 8,
search_temperature: 1.0,
..CrvConfig::default()
}
}
#[test]
fn test_engine_creation() {
let config = test_config();
let engine = StageVEngine::new(&config);
assert_eq!(engine.dim, 8);
}
#[test]
fn test_probe() {
let config = test_config();
let engine = StageVEngine::new(&config);
let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let candidates = vec![
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], vec![0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ];
let probe = engine.probe(&query, &candidates, 2).unwrap();
assert_eq!(probe.top_candidates.len(), 2);
assert_eq!(probe.attention_weights.len(), 2);
assert_eq!(probe.top_candidates[0], 0);
}
#[test]
fn test_cross_reference() {
let config = test_config();
let engine = StageVEngine::new(&config);
let from = vec![
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
];
let to = vec![
vec![0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], ];
let refs = engine.cross_reference(1, &from, 2, &to, 0.5);
assert!(!refs.is_empty());
assert_eq!(refs[0].from_stage, 1);
assert_eq!(refs[0].to_stage, 2);
assert!(refs[0].score > 0.5);
}
#[test]
fn test_empty_probe() {
let config = test_config();
let engine = StageVEngine::new(&config);
let query = vec![1.0; 8];
let candidates: Vec<Vec<f32>> = vec![];
assert!(engine.probe(&query, &candidates, 5).is_err());
}
}