use burn::tensor::{backend::Backend, Tensor, Int};
use crate::model::sensorlm::SensorLMModel;
use crate::loss::recall_at_k;
#[derive(Debug, Clone)]
pub struct RetrievalResult {
pub index: usize,
pub score: f32,
}
pub struct RetrievalEngine<B: Backend> {
model: SensorLMModel<B>,
sensor_embeddings: Option<Tensor<B, 2>>,
text_embeddings: Option<Tensor<B, 2>>,
}
impl<B: Backend> RetrievalEngine<B> {
pub fn new(model: SensorLMModel<B>) -> Self {
Self {
model,
sensor_embeddings: None,
text_embeddings: None,
}
}
pub fn index_sensor(&mut self, sensor_batches: impl IntoIterator<Item = Tensor<B, 3>>) {
let embeddings: Vec<Tensor<B, 2>> = sensor_batches
.into_iter()
.map(|batch| self.model.encode_sensor(batch))
.collect();
self.sensor_embeddings = Some(Tensor::cat(embeddings, 0));
}
pub fn index_text(
&mut self,
text_batches: impl IntoIterator<Item = (Tensor<B, 2, Int>, Tensor<B, 2, Int>)>,
) {
let embeddings: Vec<Tensor<B, 2>> = text_batches
.into_iter()
.map(|(ids, mask)| self.model.encode_text(ids, mask))
.collect();
self.text_embeddings = Some(Tensor::cat(embeddings, 0));
}
pub fn sensor_to_text(
&self,
query_sensor: Tensor<B, 3>,
top_k: usize,
) -> Vec<Vec<RetrievalResult>> {
let db = self
.text_embeddings
.as_ref()
.expect("No text embeddings indexed");
let q_emb = self.model.encode_sensor(query_sensor); self.top_k_search(q_emb, db.clone(), top_k)
}
pub fn text_to_sensor(
&self,
query_ids: Tensor<B, 2, Int>,
query_mask: Tensor<B, 2, Int>,
top_k: usize,
) -> Vec<Vec<RetrievalResult>> {
let db = self
.sensor_embeddings
.as_ref()
.expect("No sensor embeddings indexed");
let q_emb = self.model.encode_text(query_ids, query_mask); self.top_k_search(q_emb, db.clone(), top_k)
}
fn top_k_search(
&self,
queries: Tensor<B, 2>,
database: Tensor<B, 2>,
top_k: usize,
) -> Vec<Vec<RetrievalResult>> {
let q = queries.dims()[0];
let n = database.dims()[0];
let sim = queries.matmul(database.transpose());
let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
(0..q)
.map(|qi| {
let row = &data[qi * n..(qi + 1) * n];
let mut indexed: Vec<(usize, f32)> =
row.iter().copied().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
indexed
.into_iter()
.take(top_k)
.map(|(i, s)| RetrievalResult { index: i, score: s })
.collect()
})
.collect()
}
pub fn evaluate_recall(
&self,
sensor: Tensor<B, 3>,
ids: Tensor<B, 2, Int>,
mask: Tensor<B, 2, Int>,
k: usize,
) -> (f32, f32) {
let z_s = self.model.encode_sensor(sensor); let z_t = self.model.encode_text(ids, mask);
let logits = z_s.matmul(z_t.clone().transpose()); let r_s2t = recall_at_k(logits.clone(), k);
let r_t2s = recall_at_k(logits.transpose(), k);
(r_s2t, r_t2s)
}
}