use super::{CodeReference, MiningError, MiningQuery, ReferenceMiner};
use async_trait::async_trait;
use std::collections::HashSet;
use std::sync::Arc;
pub struct MultiSource {
name: String,
sources: Vec<Arc<dyn ReferenceMiner>>,
fail_if_all_fail: bool,
}
impl MultiSource {
pub fn new(sources: Vec<Arc<dyn ReferenceMiner>>) -> Self {
Self {
name: "multi".to_string(),
sources,
fail_if_all_fail: true,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn tolerate_total_failure(mut self) -> Self {
self.fail_if_all_fail = false;
self
}
}
#[async_trait]
impl ReferenceMiner for MultiSource {
fn name(&self) -> &str {
&self.name
}
async fn search(&self, query: &MiningQuery) -> Result<Vec<CodeReference>, MiningError> {
if self.sources.is_empty() {
return Ok(Vec::new());
}
let futures_iter = self.sources.iter().map(|s| {
let s = Arc::clone(s);
let q = query.clone();
async move {
let name = s.name().to_string();
let result = s.search(&q).await;
(name, result)
}
});
let outcomes = futures::future::join_all(futures_iter).await;
let mut merged: Vec<CodeReference> = Vec::new();
let mut errors: Vec<MiningError> = Vec::new();
let mut any_success = false;
for (name, outcome) in outcomes {
match outcome {
Ok(hits) => {
any_success = true;
tracing::debug!(source = %name, hits = hits.len(), "reference_miner source returned");
merged.extend(hits);
}
Err(e) => {
tracing::warn!(source = %name, error = %e, "reference_miner source failed");
errors.push(e);
}
}
}
if !any_success && self.fail_if_all_fail {
return Err(errors
.into_iter()
.next()
.unwrap_or(MiningError::Unavailable("no sources".into())));
}
merged.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut seen: HashSet<(String, String, String)> = HashSet::new();
merged.retain(|r| seen.insert((r.repo.clone(), r.path.clone(), r.commit.clone())));
let max = query.filters.max_results;
if max > 0 {
merged.truncate(max);
}
Ok(merged)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reference_miner::{MiningFilters, MiningScope};
struct FakeSource {
name: &'static str,
hits: Vec<CodeReference>,
err: Option<MiningError>,
}
#[async_trait]
impl ReferenceMiner for FakeSource {
fn name(&self) -> &str {
self.name
}
async fn search(&self, _q: &MiningQuery) -> Result<Vec<CodeReference>, MiningError> {
if let Some(ref e) = self.err {
return Err(MiningError::Unavailable(format!("{e}")));
}
Ok(self.hits.clone())
}
}
fn cr(repo: &str, path: &str, commit: &str, score: f32) -> CodeReference {
CodeReference {
repo: repo.into(),
commit: commit.into(),
path: path.into(),
snippet: String::new(),
score,
license: Some("MIT".into()),
why_relevant: String::new(),
}
}
fn query() -> MiningQuery {
MiningQuery {
query: "q".into(),
scope: MiningScope::All,
filters: MiningFilters {
max_results: 10,
..Default::default()
},
}
}
#[tokio::test]
async fn merges_and_dedupes() {
let a = Arc::new(FakeSource {
name: "a",
hits: vec![cr("r", "p", "c", 0.9), cr("r", "p2", "c", 0.5)],
err: None,
});
let b = Arc::new(FakeSource {
name: "b",
hits: vec![cr("r", "p", "c", 0.1), cr("r2", "p", "c", 0.7)],
err: None,
});
let multi = MultiSource::new(vec![a, b]);
let out = multi.search(&query()).await.unwrap();
assert_eq!(out.len(), 3);
assert!((out[0].score - 0.9).abs() < 1e-6);
assert_eq!(out[0].repo, "r");
}
#[tokio::test]
async fn partial_failure_returns_successes() {
let a: Arc<dyn ReferenceMiner> = Arc::new(FakeSource {
name: "a",
hits: vec![cr("r", "p", "c", 0.9)],
err: None,
});
let b: Arc<dyn ReferenceMiner> = Arc::new(FakeSource {
name: "b",
hits: vec![],
err: Some(MiningError::RateLimited("fake".into())),
});
let multi = MultiSource::new(vec![a, b]);
let out = multi.search(&query()).await.unwrap();
assert_eq!(out.len(), 1);
}
#[tokio::test]
async fn total_failure_propagates() {
let a: Arc<dyn ReferenceMiner> = Arc::new(FakeSource {
name: "a",
hits: vec![],
err: Some(MiningError::Unavailable("x".into())),
});
let multi = MultiSource::new(vec![a]);
assert!(multi.search(&query()).await.is_err());
}
#[tokio::test]
async fn total_failure_tolerated_returns_empty() {
let a: Arc<dyn ReferenceMiner> = Arc::new(FakeSource {
name: "a",
hits: vec![],
err: Some(MiningError::Unavailable("x".into())),
});
let multi = MultiSource::new(vec![a]).tolerate_total_failure();
let out = multi.search(&query()).await.unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn truncates_to_max_results() {
let a: Arc<dyn ReferenceMiner> = Arc::new(FakeSource {
name: "a",
hits: (0..5)
.map(|i| cr("r", &format!("p{i}"), "c", 0.1 * i as f32))
.collect(),
err: None,
});
let multi = MultiSource::new(vec![a]);
let mut q = query();
q.filters.max_results = 2;
let out = multi.search(&q).await.unwrap();
assert_eq!(out.len(), 2);
}
}