use super::GroundedDocument;
use super::super::confidence::Confidence;
use super::super::types::{SignalId, TrackId, TypeLabel};
use super::identity::IdentityId;
use super::signal::SignalRef;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TrackRef {
pub doc_id: String,
pub track_id: TrackId,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Track {
pub id: TrackId,
pub signals: Vec<SignalRef>,
pub entity_type: Option<TypeLabel>,
pub canonical_surface: String,
pub identity_id: Option<IdentityId>,
pub cluster_confidence: Confidence,
pub embedding: Option<Vec<f32>>,
}
impl Track {
#[must_use]
pub fn new(id: impl Into<TrackId>, canonical_surface: impl Into<String>) -> Self {
Self {
id: id.into(),
signals: Vec::new(),
entity_type: None,
canonical_surface: canonical_surface.into(),
identity_id: None,
cluster_confidence: Confidence::ONE,
embedding: None,
}
}
pub fn add_signal(&mut self, signal_id: impl Into<SignalId>, position: u32) {
let signal_id = signal_id.into();
self.signals.push(SignalRef {
signal_id,
position,
});
}
#[must_use]
pub fn len(&self) -> usize {
self.signals.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.signals.is_empty()
}
#[must_use]
pub fn is_singleton(&self) -> bool {
self.signals.len() == 1
}
#[must_use]
pub const fn id(&self) -> TrackId {
self.id
}
#[must_use]
pub fn signals(&self) -> &[SignalRef] {
&self.signals
}
#[must_use]
pub fn canonical_surface(&self) -> &str {
&self.canonical_surface
}
#[must_use]
pub const fn identity_id(&self) -> Option<IdentityId> {
self.identity_id
}
#[must_use]
pub const fn cluster_confidence(&self) -> Confidence {
self.cluster_confidence
}
pub fn set_cluster_confidence(&mut self, confidence: f32) {
self.cluster_confidence = Confidence::new(confidence as f64);
}
pub fn set_identity_id(&mut self, identity_id: IdentityId) {
self.identity_id = Some(identity_id);
}
pub fn clear_identity_id(&mut self) {
self.identity_id = None;
}
#[must_use]
pub fn with_identity(mut self, identity_id: IdentityId) -> Self {
self.identity_id = Some(identity_id);
self
}
#[must_use]
pub fn with_type(mut self, entity_type: impl Into<String>) -> Self {
let s = entity_type.into();
self.entity_type = Some(TypeLabel::from(s.as_str()));
self
}
#[must_use]
pub fn with_type_label(mut self, label: TypeLabel) -> Self {
self.entity_type = Some(label);
self
}
#[must_use]
pub fn type_label(&self) -> Option<TypeLabel> {
self.entity_type.clone()
}
#[must_use]
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
pub fn compute_spread(&self, doc: &GroundedDocument) -> Option<usize> {
if self.signals.is_empty() {
return Some(0);
}
let positions: Vec<usize> = self
.signals
.iter()
.filter_map(|sr| {
doc.signals
.iter()
.find(|s| s.id == sr.signal_id)
.and_then(|s| s.location.text_offsets())
.map(|(start, _)| start)
})
.collect();
if positions.is_empty() {
return None;
}
let min_pos = *positions.iter().min().expect("positions non-empty");
let max_pos = *positions.iter().max().expect("positions non-empty");
Some(max_pos.saturating_sub(min_pos))
}
pub fn collect_variations(&self, doc: &GroundedDocument) -> Vec<String> {
let mut variations: std::collections::HashSet<String> = std::collections::HashSet::new();
for sr in &self.signals {
if let Some(signal) = doc.signals.iter().find(|s| s.id == sr.signal_id) {
variations.insert(signal.surface.clone());
}
}
variations.into_iter().collect()
}
pub fn confidence_stats(&self, doc: &GroundedDocument) -> Option<(f32, f32, f32)> {
let confidences: Vec<f32> = self
.signals
.iter()
.filter_map(|sr| {
doc.signals
.iter()
.find(|s| s.id == sr.signal_id)
.map(|s| s.confidence.value() as f32)
})
.collect();
if confidences.is_empty() {
return None;
}
let min = confidences.iter().cloned().fold(f32::INFINITY, f32::min);
let max = confidences
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mean = confidences.iter().sum::<f32>() / confidences.len() as f32;
Some((min, max, mean))
}
pub fn compute_stats(&self, doc: &GroundedDocument, text_len: usize) -> TrackStats {
let chain_length = self.signals.len();
let spread = self.compute_spread(doc).unwrap_or(0);
let variations = self.collect_variations(doc);
let (min_conf, max_conf, mean_conf) = self.confidence_stats(doc).unwrap_or((0.0, 0.0, 0.0));
let positions: Vec<usize> = self
.signals
.iter()
.filter_map(|sr| {
doc.signals
.iter()
.find(|s| s.id == sr.signal_id)
.and_then(|s| s.location.text_offsets())
.map(|(start, _)| start)
})
.collect();
let first_position = positions.iter().min().copied().unwrap_or(0);
let last_position = positions.iter().max().copied().unwrap_or(0);
let relative_spread = if text_len > 0 {
spread as f64 / text_len as f64
} else {
0.0
};
TrackStats {
chain_length,
variation_count: variations.len(),
variations,
spread,
relative_spread,
first_position,
last_position,
min_confidence: Confidence::new(min_conf as f64),
max_confidence: Confidence::new(max_conf as f64),
mean_confidence: Confidence::new(mean_conf as f64),
has_embedding: self.embedding.is_some(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TrackStats {
pub chain_length: usize,
pub variation_count: usize,
pub variations: Vec<String>,
pub spread: usize,
pub relative_spread: f64,
pub first_position: usize,
pub last_position: usize,
pub min_confidence: Confidence,
pub max_confidence: Confidence,
pub mean_confidence: Confidence,
pub has_embedding: bool,
}