use crate::buffered_eprintln;
use crate::config::Config;
use crate::github::cache::CacheConfig;
use crate::github::types::PullRequest;
use crate::scoring::{calculate_score, merge_scoring_configs, ScoreResult};
use crate::snooze::{filter_active_prs, filter_snoozed_prs, SnoozeState};
use anyhow::Result;
use futures::stream::{FuturesUnordered, StreamExt};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug)]
pub struct AuthError {
pub message: String,
}
impl fmt::Display for AuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for AuthError {}
pub async fn fetch_and_score_prs(
client: &octocrab::Octocrab,
config: &Config,
snooze_state: &SnoozeState,
cache_config: &CacheConfig,
verbose: bool,
auth_username: Option<&str>,
) -> Result<(
Vec<(PullRequest, ScoreResult)>,
Vec<(PullRequest, ScoreResult)>,
Option<u64>,
)> {
if verbose {
let cache_status = if cache_config.enabled {
"enabled"
} else {
"disabled (--no-cache)"
};
buffered_eprintln!("Cache: {}", cache_status);
}
let global_scoring = config.scoring.clone().unwrap_or_default();
let mut all_prs = Vec::new();
let mut any_succeeded = false;
let mut futures = FuturesUnordered::new();
let auth_username_owned = auth_username.map(|s| s.to_string());
for (query_index, query_config) in config.queries.iter().enumerate() {
let client = client.clone();
let query = query_config.query.clone();
let query_name = query_config.name.clone();
let auth_username_clone = auth_username_owned.clone();
let merged_scoring = merge_scoring_configs(&global_scoring, query_config.scoring.as_ref());
let exclude_patterns = merged_scoring.size.and_then(|s| s.exclude);
futures.push(async move {
let result = crate::github::search_and_enrich_prs(
&client,
&query,
auth_username_clone.as_deref(),
exclude_patterns,
)
.await;
(query_name, query, query_index, result)
});
}
while let Some((name, query, query_index, result)) = futures.next().await {
match result {
Ok(prs) => {
if verbose {
buffered_eprintln!(
" Found {} PRs for {}",
prs.len(),
name.as_deref().unwrap_or(&query)
);
}
all_prs.extend(prs.into_iter().map(|pr| (pr, query_index)));
any_succeeded = true;
}
Err(e) => {
if e.downcast_ref::<AuthError>().is_some() {
return Err(e);
}
buffered_eprintln!(
"Query failed: {} - {}",
name.as_deref().unwrap_or(&query),
e
);
}
}
}
if !any_succeeded && !config.queries.is_empty() {
anyhow::bail!("All queries failed. Check your network connection and GitHub token.");
}
let mut seen_urls = HashSet::new();
let mut pr_to_query_index = HashMap::new();
let unique_prs: Vec<_> = all_prs
.into_iter()
.filter_map(|(pr, query_idx)| {
if seen_urls.insert(pr.url.clone()) {
pr_to_query_index.insert(pr.url.clone(), query_idx);
Some(pr)
} else {
None
}
})
.collect();
if verbose {
buffered_eprintln!("After deduplication: {} unique PRs", unique_prs.len());
}
let active_prs = filter_active_prs(unique_prs.clone(), snooze_state);
let snoozed_prs = filter_snoozed_prs(unique_prs, snooze_state);
if verbose {
buffered_eprintln!(
"After filter: {} active, {} snoozed",
active_prs.len(),
snoozed_prs.len()
);
}
let mut active_scored: Vec<_> = active_prs
.into_iter()
.map(|pr| {
let query_idx = pr_to_query_index.get(&pr.url).copied().unwrap_or(0);
let scoring =
merge_scoring_configs(&global_scoring, config.queries[query_idx].scoring.as_ref());
let result = calculate_score(&pr, &scoring);
(pr, result)
})
.collect();
let mut snoozed_scored: Vec<_> = snoozed_prs
.into_iter()
.map(|pr| {
let query_idx = pr_to_query_index.get(&pr.url).copied().unwrap_or(0);
let scoring =
merge_scoring_configs(&global_scoring, config.queries[query_idx].scoring.as_ref());
let result = calculate_score(&pr, &scoring);
(pr, result)
})
.collect();
let sort_fn = |a: &(PullRequest, ScoreResult), b: &(PullRequest, ScoreResult)| {
let score_cmp =
b.1.score
.partial_cmp(&a.1.score)
.unwrap_or(std::cmp::Ordering::Equal);
if score_cmp != std::cmp::Ordering::Equal {
return score_cmp;
}
a.0.created_at.cmp(&b.0.created_at)
};
active_scored.sort_by(sort_fn);
snoozed_scored.sort_by(sort_fn);
let rate_limit_remaining = match client.ratelimit().get().await {
Ok(rate_limit) => Some(rate_limit.resources.core.remaining as u64),
Err(_) => None,
};
Ok((active_scored, snoozed_scored, rate_limit_remaining))
}