use std::sync::Arc;
use tokio::task::JoinSet;
use super::{Provider, ProviderError};
use crate::models::PullRequest;
#[derive(Clone)]
pub struct ProviderRegistry {
providers: Vec<Arc<dyn Provider>>,
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self { providers: Vec::new() }
}
}
impl ProviderRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, provider: Arc<dyn Provider>) {
self.providers.push(provider);
}
pub async fn list_all_prs(&self) -> Vec<(&'static str, Result<Vec<PullRequest>, ProviderError>)> {
let mut set: JoinSet<(&'static str, Result<Vec<PullRequest>, ProviderError>)> = JoinSet::new();
for provider in &self.providers {
let p = Arc::clone(provider);
set.spawn(async move {
let name = p.name();
let result = p.list_prs().await;
(name, result)
});
}
let mut results = Vec::new();
while let Some(join_result) = set.join_next().await {
match join_result {
Ok(provider_result) => results.push(provider_result),
Err(e) => {
tracing::error!("Provider task panicked: {}", e);
}
}
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::mock::MockProvider;
#[tokio::test]
async fn registry_with_mock() {
let mut registry = ProviderRegistry::new();
registry.register(Arc::new(MockProvider));
let results = registry.list_all_prs().await;
assert_eq!(results.len(), 1);
let (name, prs) = &results[0];
assert_eq!(*name, "mock");
assert!(prs.as_ref().unwrap().len() >= 1);
}
use mockall::mock;
use crate::models::{PrIdentifier, PrState, ReviewStatus, User};
mock! {
pub FakeProvider {}
#[async_trait::async_trait]
impl Provider for FakeProvider {
fn name(&self) -> &'static str;
fn display_name(&self) -> &'static str;
async fn check_auth(&self) -> super::super::AuthStatus;
async fn list_prs(&self) -> Result<Vec<PullRequest>, ProviderError>;
async fn get_pr_details(&self, pr_id: &PrIdentifier) -> Result<PullRequest, ProviderError>;
async fn get_pr_diff(&self, pr_id: &PrIdentifier) -> Result<String, ProviderError>;
}
}
fn minimal_pull_request() -> PullRequest {
PullRequest {
id: PrIdentifier {
provider: "ok_provider".to_string(),
owner: "o".to_string(),
repo: "r".to_string(),
number: 1,
},
number: 1,
title: "test PR".to_string(),
url: "https://github.com/o/r/pull/1".to_string(),
author: User {
login: "alice".to_string(),
display_name: None,
avatar_url: None,
},
reviewers: vec![],
repo_full_name: "o/r".to_string(),
provider: "ok_provider".to_string(),
head_branch: String::new(),
base_branch: String::new(),
state: PrState::Open,
review_status: ReviewStatus::NeedsReview,
ci_status: None,
draft: false,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
labels: vec![],
comment_count: 0,
additions: None,
deletions: None,
}
}
#[tokio::test]
async fn registry_partial_failure_github() {
let mut provider_ok = MockFakeProvider::new();
provider_ok.expect_name().returning(|| "ok_provider");
provider_ok
.expect_check_auth()
.returning(|| super::super::AuthStatus::Available);
provider_ok
.expect_list_prs()
.returning(|| Ok(vec![minimal_pull_request()]));
let mut provider_fail = MockFakeProvider::new();
provider_fail.expect_name().returning(|| "github");
provider_fail
.expect_check_auth()
.returning(|| super::super::AuthStatus::Available);
provider_fail.expect_list_prs().returning(|| {
Err(ProviderError::ApiError {
provider: "github".to_string(),
status: 500,
message: "simulated GitHub API failure".to_string(),
})
});
let mut registry = ProviderRegistry::new();
registry.register(Arc::new(provider_ok));
registry.register(Arc::new(provider_fail));
let results = registry.list_all_prs().await;
assert_eq!(results.len(), 2, "Expected 2 results (one per provider)");
let ok_result = results
.iter()
.find(|(name, _)| *name == "ok_provider")
.expect("ok_provider result missing");
let ok_prs = ok_result.1.as_ref().expect("ok_provider should succeed");
assert_eq!(ok_prs.len(), 1, "ok_provider should return 1 PR");
assert_eq!(ok_prs[0].number, 1);
let fail_result = results
.iter()
.find(|(name, _)| *name == "github")
.expect("github result missing");
assert!(
fail_result.1.is_err(),
"github provider should return an error"
);
if let Err(ProviderError::ApiError { provider, status, .. }) = &fail_result.1 {
assert_eq!(provider, "github");
assert_eq!(*status, 500);
} else {
panic!("Expected ApiError for github provider");
}
}
}