use async_trait::async_trait;
use nexo_tool_meta::marketing::EnrichmentResult;
#[derive(Debug, Clone, PartialEq)]
pub enum FallbackOutcome {
Hit {
result: EnrichmentResult,
source_index: usize,
},
AllExhausted {
attempts: Vec<EnrichmentResult>,
},
}
#[async_trait]
pub trait EnrichmentSource: Send + Sync {
fn name(&self) -> &str;
fn cost_estimate(&self) -> SourceCost;
async fn extract(
&self,
input: &EnrichmentInput<'_>,
) -> Result<Option<EnrichmentResult>, EnrichmentSourceError>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SourceCost {
Free,
Cheap,
Moderate,
Paid,
}
#[derive(Debug, Clone, Copy)]
pub struct EnrichmentInput<'a> {
pub from_email: &'a str,
pub from_display_name: Option<&'a str>,
pub subject: &'a str,
pub body_excerpt: &'a str,
pub reply_to: Option<&'a str>,
}
#[derive(Debug, thiserror::Error)]
pub enum EnrichmentSourceError {
#[error("source {source_name:?} unavailable: {reason}")]
SourceUnavailable { source_name: String, reason: String },
#[error("invalid input: {0}")]
InvalidInput(String),
}
pub struct FallbackChain {
sources: Vec<Box<dyn EnrichmentSource>>,
confidence_threshold: f32,
}
impl FallbackChain {
pub fn new(sources: Vec<Box<dyn EnrichmentSource>>, confidence_threshold: f32) -> Self {
Self {
sources,
confidence_threshold,
}
}
pub fn confidence_threshold(&self) -> f32 {
self.confidence_threshold
}
pub fn source_names(&self) -> Vec<&str> {
self.sources.iter().map(|s| s.name()).collect()
}
pub async fn run(&self, input: &EnrichmentInput<'_>) -> FallbackOutcome {
let mut attempts: Vec<EnrichmentResult> = Vec::new();
for (idx, source) in self.sources.iter().enumerate() {
match source.extract(input).await {
Ok(Some(result)) => {
if result.confidence >= self.confidence_threshold {
return FallbackOutcome::Hit {
result,
source_index: idx,
};
}
attempts.push(result);
}
Ok(None) => {
}
Err(e) => {
tracing::warn!(
source = source.name(),
error = %e,
"enrichment source failed; continuing"
);
}
}
}
FallbackOutcome::AllExhausted { attempts }
}
}
#[cfg(test)]
mod tests {
use super::*;
struct StubSource {
name: &'static str,
result: Option<EnrichmentResult>,
error: Option<String>,
}
impl StubSource {
fn hit(name: &'static str, conf: f32) -> Self {
Self {
name,
result: Some(EnrichmentResult {
source: name.into(),
confidence: conf,
person_inferred: None,
company_inferred: None,
note: Some(format!("from {name}")),
}),
error: None,
}
}
fn miss(name: &'static str) -> Self {
Self {
name,
result: None,
error: None,
}
}
fn errs(name: &'static str, reason: &'static str) -> Self {
Self {
name,
result: None,
error: Some(reason.to_string()),
}
}
}
#[async_trait]
impl EnrichmentSource for StubSource {
fn name(&self) -> &str {
self.name
}
fn cost_estimate(&self) -> SourceCost {
SourceCost::Free
}
async fn extract(
&self,
_input: &EnrichmentInput<'_>,
) -> Result<Option<EnrichmentResult>, EnrichmentSourceError> {
if let Some(reason) = &self.error {
return Err(EnrichmentSourceError::SourceUnavailable {
source_name: self.name.to_string(),
reason: reason.clone(),
});
}
Ok(self.result.clone())
}
}
fn input<'a>() -> EnrichmentInput<'a> {
EnrichmentInput {
from_email: "juan@gmail.com",
from_display_name: Some("Juan García (Acme)"),
subject: "Hi",
body_excerpt: "Hola",
reply_to: None,
}
}
#[tokio::test]
async fn first_hit_above_threshold_wins() {
let chain = FallbackChain::new(
vec![
Box::new(StubSource::miss("display_name")),
Box::new(StubSource::hit("signature", 0.85)),
Box::new(StubSource::hit("llm", 0.95)), ],
0.7,
);
let out = chain.run(&input()).await;
match out {
FallbackOutcome::Hit {
result,
source_index,
} => {
assert_eq!(result.source, "signature");
assert_eq!(source_index, 1);
}
_ => panic!("expected Hit"),
}
}
#[tokio::test]
async fn below_threshold_continues_to_next_source() {
let chain = FallbackChain::new(
vec![
Box::new(StubSource::hit("signature", 0.50)),
Box::new(StubSource::hit("llm", 0.85)),
],
0.7,
);
let out = chain.run(&input()).await;
match out {
FallbackOutcome::Hit {
result,
source_index,
} => {
assert_eq!(result.source, "llm");
assert_eq!(source_index, 1);
}
_ => panic!("expected Hit"),
}
}
#[tokio::test]
async fn all_misses_returns_exhausted() {
let chain = FallbackChain::new(
vec![
Box::new(StubSource::miss("a")),
Box::new(StubSource::miss("b")),
],
0.7,
);
let out = chain.run(&input()).await;
assert!(
matches!(out, FallbackOutcome::AllExhausted { ref attempts } if attempts.is_empty())
);
}
#[tokio::test]
async fn below_threshold_attempts_collected_in_exhausted() {
let chain = FallbackChain::new(
vec![
Box::new(StubSource::hit("a", 0.40)),
Box::new(StubSource::hit("b", 0.55)),
],
0.7,
);
let out = chain.run(&input()).await;
match out {
FallbackOutcome::AllExhausted { attempts } => {
assert_eq!(attempts.len(), 2);
assert_eq!(attempts[0].source, "a");
assert_eq!(attempts[1].source, "b");
}
_ => panic!("expected AllExhausted"),
}
}
#[tokio::test]
async fn hard_error_skipped_chain_continues() {
let chain = FallbackChain::new(
vec![
Box::new(StubSource::errs("flaky-llm", "timeout")),
Box::new(StubSource::hit("signature", 0.85)),
],
0.7,
);
let out = chain.run(&input()).await;
assert!(matches!(
out,
FallbackOutcome::Hit {
source_index: 1,
..
}
));
}
#[tokio::test]
async fn empty_chain_returns_exhausted() {
let chain = FallbackChain::new(vec![], 0.7);
let out = chain.run(&input()).await;
assert!(matches!(out, FallbackOutcome::AllExhausted { .. }));
}
}