use std::future::Future;
use rig::vector_store::{VectorSearchRequest, VectorStoreIndexDyn, request::Filter};
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use super::types::Document;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Proposition {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub text: String,
}
impl Proposition {
pub fn new(text: impl Into<String>) -> Self {
Self {
id: None,
text: text.into(),
}
}
pub fn with_id(id: impl Into<String>, text: impl Into<String>) -> Self {
Self {
id: Some(id.into()),
text: text.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RedundancyVerdict {
pub is_redundant: bool,
pub similarity: f64,
}
pub trait PropositionExtractor: Send + Sync {
fn extract(&self, doc: &Document) -> impl Future<Output = Result<Vec<Proposition>>> + Send;
}
pub trait RedundancyCheck: Send + Sync {
fn check(
&self,
proposition: &Proposition,
) -> impl Future<Output = Result<RedundancyVerdict>> + Send;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct StubPropositionExtractor;
impl StubPropositionExtractor {
pub fn new() -> Self {
Self
}
fn split(&self, text: &str, out: &mut Vec<Proposition>) {
let mut buf = String::new();
for ch in text.chars() {
buf.push(ch);
if matches!(ch, '.' | '!' | '?') {
let trimmed = buf.trim();
if trimmed.len() > 1 {
out.push(Proposition::new(trimmed));
}
buf.clear();
}
}
let trimmed = buf.trim();
if !trimmed.is_empty() {
out.push(Proposition::new(trimmed));
}
}
}
impl PropositionExtractor for StubPropositionExtractor {
fn extract(&self, doc: &Document) -> impl Future<Output = Result<Vec<Proposition>>> + Send {
let mut out = Vec::new();
self.split(&doc.text, &mut out);
for section in &doc.sections {
self.split(§ion.text, &mut out);
}
async move { Ok(out) }
}
}
pub struct VectorStoreRedundancyCheck<'s> {
store: &'s dyn VectorStoreIndexDyn,
threshold: f64,
samples: u64,
}
impl<'s> VectorStoreRedundancyCheck<'s> {
pub const DEFAULT_THRESHOLD: f64 = 0.90;
pub fn new(store: &'s dyn VectorStoreIndexDyn, threshold: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&threshold) {
return Err(Error::Config(format!(
"RedundancyCheck threshold must be in [0.0, 1.0], got {threshold}"
)));
}
Ok(Self {
store,
threshold,
samples: 1,
})
}
pub fn with_default_threshold(store: &'s dyn VectorStoreIndexDyn) -> Self {
Self {
store,
threshold: Self::DEFAULT_THRESHOLD,
samples: 1,
}
}
pub fn with_samples(mut self, samples: u64) -> Self {
self.samples = samples.max(1);
self
}
pub fn threshold(&self) -> f64 {
self.threshold
}
}
impl<'s> RedundancyCheck for VectorStoreRedundancyCheck<'s> {
fn check(
&self,
proposition: &Proposition,
) -> impl Future<Output = Result<RedundancyVerdict>> + Send {
let req: VectorSearchRequest<Filter<serde_json::Value>> = VectorSearchRequest::builder()
.query(proposition.text.clone())
.samples(self.samples)
.build();
let threshold = self.threshold;
async move {
let hits = self.store.top_n_ids(req).await?;
let similarity = hits.first().map(|(score, _)| *score).unwrap_or(0.0);
Ok(RedundancyVerdict {
is_redundant: similarity >= threshold,
similarity,
})
}
}
}