use super::Location;
use super::super::confidence::Confidence;
use super::super::entity::{Entity, HierarchicalConfidence, Provenance};
use super::super::types::{SignalId, TypeLabel};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum Modality {
Iconic,
#[default]
Symbolic,
Hybrid,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Signal<L = Location> {
pub id: SignalId,
pub location: L,
pub surface: String,
pub label: TypeLabel,
pub confidence: Confidence,
pub hierarchical: Option<HierarchicalConfidence>,
pub provenance: Option<Provenance>,
pub modality: Modality,
pub normalized: Option<String>,
pub negated: bool,
pub quantifier: Option<Quantifier>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Quantifier {
Universal,
Existential,
None,
Definite,
Approximate,
MinBound,
MaxBound,
Bare,
}
impl<L> Signal<L> {
#[must_use]
pub fn new(
id: impl Into<SignalId>,
location: L,
surface: impl Into<String>,
label: impl Into<TypeLabel>,
confidence: f32,
) -> Self {
Self {
id: id.into(),
location,
surface: surface.into(),
label: label.into(),
confidence: Confidence::new(confidence as f64),
hierarchical: None,
provenance: None,
modality: Modality::default(),
normalized: None,
negated: false,
quantifier: None,
}
}
#[must_use]
pub fn label(&self) -> &str {
self.label.as_str()
}
#[must_use]
pub fn type_label(&self) -> TypeLabel {
self.label.clone()
}
#[must_use]
pub fn surface(&self) -> &str {
&self.surface
}
#[must_use]
pub fn is_confident(&self, threshold: Confidence) -> bool {
self.confidence >= threshold
}
#[must_use]
pub fn with_modality(mut self, modality: Modality) -> Self {
self.modality = modality;
self
}
#[must_use]
pub fn negated(mut self) -> Self {
self.negated = true;
self
}
#[must_use]
pub fn with_quantifier(mut self, q: Quantifier) -> Self {
self.quantifier = Some(q);
self
}
#[must_use]
pub fn with_provenance(mut self, p: Provenance) -> Self {
self.provenance = Some(p);
self
}
}
impl Signal<Location> {
#[must_use]
pub fn text_offsets(&self) -> Option<(usize, usize)> {
self.location.text_offsets()
}
#[must_use]
pub fn validate_against(&self, source_text: &str) -> Option<SignalValidationError> {
let (start, end) = self.location.text_offsets()?;
let char_count = source_text.chars().count();
if end > char_count {
return Some(SignalValidationError::OutOfBounds {
signal_id: self.id,
end,
text_len: char_count,
});
}
if start >= end {
return Some(SignalValidationError::InvalidSpan {
signal_id: self.id,
start,
end,
});
}
let actual: String = source_text.chars().skip(start).take(end - start).collect();
if actual != self.surface {
return Some(SignalValidationError::TextMismatch {
signal_id: self.id,
expected: self.surface.clone(),
actual,
start,
end,
});
}
None
}
#[must_use]
pub fn is_valid(&self, source_text: &str) -> bool {
self.validate_against(source_text).is_none()
}
#[must_use]
pub fn from_text(
source: &str,
surface: &str,
label: impl Into<TypeLabel>,
confidence: f32,
) -> Option<Self> {
Self::from_text_nth(source, surface, label, confidence, 0)
}
#[must_use]
pub fn from_text_nth(
source: &str,
surface: &str,
label: impl Into<TypeLabel>,
confidence: f32,
occurrence: usize,
) -> Option<Self> {
for (count, (byte_idx, _)) in source.match_indices(surface).enumerate() {
if count == occurrence {
let start = source[..byte_idx].chars().count();
let end = start + surface.chars().count();
return Some(Self::new(
SignalId::ZERO,
Location::text(start, end),
surface,
label,
confidence,
));
}
}
None
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SignalValidationError {
OutOfBounds {
signal_id: SignalId,
end: usize,
text_len: usize,
},
InvalidSpan {
signal_id: SignalId,
start: usize,
end: usize,
},
TextMismatch {
signal_id: SignalId,
expected: String,
actual: String,
start: usize,
end: usize,
},
}
impl std::fmt::Display for SignalValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::OutOfBounds {
signal_id,
end,
text_len,
} => {
write!(
f,
"S{}: end offset {} exceeds text length {}",
signal_id, end, text_len
)
}
Self::InvalidSpan {
signal_id,
start,
end,
} => {
write!(f, "S{}: invalid span [{}, {})", signal_id, start, end)
}
Self::TextMismatch {
signal_id,
expected,
actual,
start,
end,
} => {
write!(
f,
"S{}: text mismatch at [{}, {}): expected '{}', found '{}'",
signal_id, start, end, expected, actual
)
}
}
}
}
impl std::error::Error for SignalValidationError {}
impl From<&Entity> for Signal<Location> {
fn from(e: &Entity) -> Self {
let mut signal = Signal::new(
SignalId::ZERO,
Location::text(e.start(), e.end()),
&e.text,
e.entity_type.as_label(),
f32::from(e.confidence),
);
signal.normalized = e.normalized.clone();
signal.provenance = e.provenance.clone();
signal.hierarchical = e.hierarchical_confidence;
signal
}
}