use std::future::Future;
use crate::error::Result;
use super::graph::{GraphBaseline, TripleExtractor};
use super::ioc::{IocBaseline, IocExtractor};
use super::proposition::{PropositionExtractor, RedundancyCheck};
use super::types::{Document, Dropped, DroppedItem, DroppedReason, IngestionDelta};
pub trait PropositionTrack: Send + Sync {
fn run<'a>(
&'a self,
doc: &'a Document,
delta: &'a mut IngestionDelta,
) -> impl Future<Output = Result<()>> + Send + 'a;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoPropositionTrack;
impl PropositionTrack for NoPropositionTrack {
async fn run<'a>(&'a self, _doc: &'a Document, _delta: &'a mut IngestionDelta) -> Result<()> {
Ok(())
}
}
#[derive(Debug)]
pub struct ActivePropositionTrack<P, R> {
extractor: P,
redundancy: R,
}
impl<P, R> ActivePropositionTrack<P, R> {
pub fn extractor(&self) -> &P {
&self.extractor
}
pub fn redundancy(&self) -> &R {
&self.redundancy
}
}
impl<P, R> PropositionTrack for ActivePropositionTrack<P, R>
where
P: PropositionExtractor,
R: RedundancyCheck,
{
async fn run<'a>(&'a self, doc: &'a Document, delta: &'a mut IngestionDelta) -> Result<()> {
let candidates = self.extractor.extract(doc).await?;
tracing::debug!(
target: "rig_retrieval_evals::ingestion::pipeline",
doc_id = %doc.id,
candidate_count = candidates.len(),
"track 3: candidate propositions extracted"
);
for prop in candidates {
let verdict = self.redundancy.check(&prop).await?;
if verdict.is_redundant {
delta.dropped.push(Dropped {
item: DroppedItem::Proposition(prop),
reason: DroppedReason::Redundant {
similarity: verdict.similarity,
},
});
} else {
delta.propositions.push(prop);
}
}
tracing::debug!(
target: "rig_retrieval_evals::ingestion::pipeline",
doc_id = %doc.id,
new = delta.propositions.len(),
"track 3: delta updated"
);
Ok(())
}
}
#[derive(Debug)]
pub struct DistillationPipeline<E, B, T = NoPropositionTrack, G = NoGraphTrack> {
extractor: E,
baseline: B,
propositions: T,
graph: G,
}
impl<E, B> DistillationPipeline<E, B, NoPropositionTrack, NoGraphTrack> {
pub fn new(extractor: E, baseline: B) -> Self {
Self {
extractor,
baseline,
propositions: NoPropositionTrack,
graph: NoGraphTrack,
}
}
}
impl<E, B, T, G> DistillationPipeline<E, B, T, G> {
pub fn with_propositions<P, R>(
self,
extractor: P,
redundancy: R,
) -> DistillationPipeline<E, B, ActivePropositionTrack<P, R>, G>
where
P: PropositionExtractor,
R: RedundancyCheck,
{
DistillationPipeline {
extractor: self.extractor,
baseline: self.baseline,
propositions: ActivePropositionTrack {
extractor,
redundancy,
},
graph: self.graph,
}
}
pub fn with_graph<X, GB>(
self,
extractor: X,
baseline: GB,
) -> DistillationPipeline<E, B, T, ActiveGraphTrack<X, GB>>
where
X: TripleExtractor,
GB: GraphBaseline,
{
DistillationPipeline {
extractor: self.extractor,
baseline: self.baseline,
propositions: self.propositions,
graph: ActiveGraphTrack {
extractor,
baseline,
},
}
}
pub fn extractor(&self) -> &E {
&self.extractor
}
pub fn baseline(&self) -> &B {
&self.baseline
}
pub fn propositions(&self) -> &T {
&self.propositions
}
pub fn graph(&self) -> &G {
&self.graph
}
}
impl<E, B, T, G> DistillationPipeline<E, B, T, G>
where
E: IocExtractor + Send + Sync,
B: IocBaseline,
T: PropositionTrack,
G: GraphTrack,
{
pub async fn ingest(&self, doc: &Document) -> Result<IngestionDelta> {
let mut delta = IngestionDelta::new();
let candidates = self.extractor.extract(doc);
tracing::debug!(
target: "rig_retrieval_evals::ingestion::pipeline",
doc_id = %doc.id,
candidate_count = candidates.len(),
"track 1: candidate IoCs extracted"
);
for ioc in candidates {
if self.baseline.contains(&ioc).await? {
delta.dropped.push(Dropped {
item: DroppedItem::Ioc(ioc),
reason: DroppedReason::DuplicateIoc,
});
} else {
delta.iocs.push(ioc);
}
}
tracing::debug!(
target: "rig_retrieval_evals::ingestion::pipeline",
doc_id = %doc.id,
new = delta.iocs.len(),
dropped = delta.dropped.len(),
"track 1: delta computed"
);
self.graph.run(doc, &mut delta).await?;
self.propositions.run(doc, &mut delta).await?;
Ok(delta)
}
}
pub trait GraphTrack: Send + Sync {
fn run<'a>(
&'a self,
doc: &'a Document,
delta: &'a mut IngestionDelta,
) -> impl Future<Output = Result<()>> + Send + 'a;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoGraphTrack;
impl GraphTrack for NoGraphTrack {
async fn run<'a>(&'a self, _doc: &'a Document, _delta: &'a mut IngestionDelta) -> Result<()> {
Ok(())
}
}
#[derive(Debug)]
pub struct ActiveGraphTrack<X, B> {
extractor: X,
baseline: B,
}
impl<X, B> ActiveGraphTrack<X, B> {
pub fn extractor(&self) -> &X {
&self.extractor
}
pub fn baseline(&self) -> &B {
&self.baseline
}
}
impl<X, B> GraphTrack for ActiveGraphTrack<X, B>
where
X: TripleExtractor,
B: GraphBaseline,
{
async fn run<'a>(&'a self, doc: &'a Document, delta: &'a mut IngestionDelta) -> Result<()> {
let candidates = self.extractor.extract(doc).await?;
tracing::debug!(
target: "rig_retrieval_evals::ingestion::pipeline",
doc_id = %doc.id,
candidate_count = candidates.len(),
"track 2: candidate triples extracted"
);
for triple in candidates {
if self.baseline.contains(&triple).await? {
delta.dropped.push(Dropped {
item: DroppedItem::Triple(triple),
reason: DroppedReason::DuplicateEdge,
});
} else {
delta.triples.push(triple);
}
}
tracing::debug!(
target: "rig_retrieval_evals::ingestion::pipeline",
doc_id = %doc.id,
new = delta.triples.len(),
"track 2: delta updated"
);
Ok(())
}
}