research_master/sources/
registry.rs1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use super::{Source, SourceError};
7use crate::config::SourceConfig;
8
9#[cfg(feature = "source-acm")]
11use super::acm::AcmSource;
12#[cfg(feature = "source-arxiv")]
13use super::arxiv::ArxivSource;
14#[cfg(feature = "source-base")]
15use super::base::BaseSource;
16#[cfg(feature = "source-biorxiv")]
17use super::biorxiv::BiorxivSource;
18#[cfg(feature = "source-connected_papers")]
19use super::connected_papers::ConnectedPapersSource;
20#[cfg(feature = "source-core-repo")]
21use super::core::CoreSource;
22#[cfg(feature = "source-crossref")]
23use super::crossref::CrossRefSource;
24#[cfg(feature = "source-dblp")]
25use super::dblp::DblpSource;
26#[cfg(feature = "source-dimensions")]
27use super::dimensions::DimensionsSource;
28#[cfg(feature = "source-doaj")]
29use super::doaj::DoajSource;
30#[cfg(feature = "source-europe_pmc")]
31use super::europe_pmc::EuropePmcSource;
32#[cfg(feature = "source-google_scholar")]
33use super::google_scholar::GoogleScholarSource;
34#[cfg(feature = "source-hal")]
35use super::hal::HalSource;
36#[cfg(feature = "source-iacr")]
37use super::iacr::IacrSource;
38#[cfg(feature = "source-ieee_xplore")]
39use super::ieee_xplore::IeeeXploreSource;
40#[cfg(feature = "source-jstor")]
41use super::jstor::JstorSource;
42#[cfg(feature = "source-mdpi")]
43use super::mdpi::MdpiSource;
44#[cfg(feature = "source-openalex")]
45use super::openalex::OpenAlexSource;
46#[cfg(feature = "source-osf")]
47use super::osf::OsfSource;
48#[cfg(feature = "source-pmc")]
49use super::pmc::PmcSource;
50#[cfg(feature = "source-pubmed")]
51use super::pubmed::PubMedSource;
52#[cfg(feature = "source-scispace")]
53use super::scispace::ScispaceSource;
54#[cfg(feature = "source-semantic")]
55use super::semantic::SemanticScholarSource;
56#[cfg(feature = "source-springer")]
57use super::springer::SpringerSource;
58#[cfg(feature = "source-ssrn")]
59use super::ssrn::SsrnSource;
60#[cfg(feature = "source-unpaywall")]
61use super::unpaywall::UnpaywallSource;
62#[cfg(feature = "source-worldwidescience")]
63use super::worldwidescience::WorldWideScienceSource;
64#[cfg(feature = "source-zenodo")]
65use super::zenodo::ZenodoSource;
66
67#[derive(Debug, Clone, Default)]
69struct SourceFilter {
70 enabled: Option<HashSet<String>>,
72 disabled: Option<HashSet<String>>,
74}
75
76impl SourceFilter {
77 fn from_config(config: &SourceConfig) -> Self {
79 let enabled = config
80 .enabled_sources
81 .as_ref()
82 .filter(|s| !s.is_empty())
83 .map(|value| {
84 value
85 .split(',')
86 .map(|s| s.trim().to_lowercase())
87 .filter(|s| !s.is_empty())
88 .collect::<HashSet<_>>()
89 })
90 .filter(|set| !set.is_empty());
91
92 let disabled = config
93 .disabled_sources
94 .as_ref()
95 .filter(|s| !s.is_empty())
96 .map(|value| {
97 value
98 .split(',')
99 .map(|s| s.trim().to_lowercase())
100 .filter(|s| !s.is_empty())
101 .collect::<HashSet<_>>()
102 })
103 .filter(|set| !set.is_empty());
104
105 Self { enabled, disabled }
106 }
107
108 fn is_enabled(&self, source_id: &str) -> bool {
116 let id_lower = source_id.to_lowercase();
117
118 match (&self.enabled, &self.disabled) {
119 (Some(enabled), Some(disabled)) => {
121 enabled.contains(&id_lower) && !disabled.contains(&id_lower)
122 }
123 (Some(enabled), None) => enabled.contains(&id_lower),
125 (None, Some(disabled)) => !disabled.contains(&id_lower),
127 (None, None) => true,
129 }
130 }
131}
132
133bitflags::bitflags! {
134 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
136 pub struct SourceCapabilities: u32 {
137 const SEARCH = 1 << 0;
138 const DOWNLOAD = 1 << 1;
139 const READ = 1 << 2;
140 const CITATIONS = 1 << 3;
141 const DOI_LOOKUP = 1 << 4;
142 const AUTHOR_SEARCH = 1 << 5;
143 }
144}
145
146#[derive(Debug, Clone)]
151pub struct SourceRegistry {
152 sources: HashMap<String, Arc<dyn Source>>,
153}
154
155impl SourceRegistry {
156 pub fn new() -> Self {
158 Self::try_new().expect("Failed to initialize any sources")
159 }
160
161 pub fn try_new() -> Result<Self, SourceError> {
168 let source_config = crate::config::get_config().sources;
169 let filter = SourceFilter::from_config(&source_config);
170 let mut registry = Self {
171 sources: HashMap::new(),
172 };
173
174 macro_rules! try_register {
176 ($source:expr) => {
177 if let Ok(source) = $source {
178 let source_id = source.id().to_string();
179 if filter.is_enabled(&source_id) {
180 registry.register(Arc::new(source));
181 tracing::info!("Registered source: {}", source_id);
182 } else {
183 tracing::debug!("Source '{}' filtered out by source filter", source_id);
184 }
185 } else {
186 tracing::warn!("Skipping source: initialization failed");
187 }
188 };
189 }
190
191 #[cfg(feature = "source-arxiv")]
194 try_register!(ArxivSource::new());
195
196 #[cfg(feature = "source-pubmed")]
197 try_register!(PubMedSource::new());
198
199 #[cfg(feature = "source-biorxiv")]
200 try_register!(BiorxivSource::new());
201
202 #[cfg(feature = "source-semantic")]
203 try_register!(SemanticScholarSource::new());
204
205 #[cfg(feature = "source-openalex")]
206 try_register!(OpenAlexSource::new());
207
208 #[cfg(feature = "source-crossref")]
209 try_register!(CrossRefSource::new());
210
211 #[cfg(feature = "source-iacr")]
212 try_register!(IacrSource::new());
213
214 #[cfg(feature = "source-pmc")]
215 try_register!(PmcSource::new());
216
217 #[cfg(feature = "source-hal")]
218 try_register!(HalSource::new());
219
220 #[cfg(feature = "source-dblp")]
221 try_register!(DblpSource::new());
222
223 #[cfg(feature = "source-dimensions")]
224 try_register!(DimensionsSource::new());
225
226 #[cfg(feature = "source-ieee_xplore")]
227 try_register!(IeeeXploreSource::new());
228
229 #[cfg(feature = "source-core-repo")]
230 try_register!(CoreSource::new());
231
232 #[cfg(feature = "source-zenodo")]
233 try_register!(ZenodoSource::new());
234
235 #[cfg(feature = "source-unpaywall")]
236 try_register!(UnpaywallSource::new());
237
238 #[cfg(feature = "source-mdpi")]
239 try_register!(MdpiSource::new());
240
241 #[cfg(feature = "source-ssrn")]
242 try_register!(SsrnSource::new());
243
244 #[cfg(feature = "source-jstor")]
245 try_register!(JstorSource::new());
246
247 #[cfg(feature = "source-scispace")]
248 try_register!(ScispaceSource::new());
249
250 #[cfg(feature = "source-acm")]
251 try_register!(AcmSource::new());
252
253 #[cfg(feature = "source-connected_papers")]
254 try_register!(ConnectedPapersSource::new());
255
256 #[cfg(feature = "source-doaj")]
257 try_register!(DoajSource::new());
258
259 #[cfg(feature = "source-europe_pmc")]
260 try_register!(EuropePmcSource::new());
261
262 #[cfg(feature = "source-worldwidescience")]
263 try_register!(WorldWideScienceSource::new());
264
265 #[cfg(feature = "source-osf")]
266 try_register!(OsfSource::new());
267
268 #[cfg(feature = "source-base")]
269 try_register!(BaseSource::new());
270
271 #[cfg(feature = "source-springer")]
272 try_register!(SpringerSource::new());
273
274 #[cfg(feature = "source-google_scholar")]
275 try_register!(GoogleScholarSource::new());
276
277 if registry.is_empty() {
278 return Err(SourceError::Other(
279 "No sources could be initialized. Check configuration and API keys.".to_string(),
280 ));
281 }
282
283 tracing::info!("Initialized {} sources", registry.len());
284
285 Ok(registry)
286 }
287
288 pub fn register(&mut self, source: Arc<dyn Source>) {
290 self.sources.insert(source.id().to_string(), source);
291 }
292
293 pub fn get(&self, id: &str) -> Option<&Arc<dyn Source>> {
295 self.sources.get(id)
296 }
297
298 pub fn get_required(&self, id: &str) -> Result<&Arc<dyn Source>, SourceError> {
300 self.get(id)
301 .ok_or_else(|| SourceError::NotFound(format!("Source '{}' not found", id)))
302 }
303
304 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn Source>> {
306 self.sources.values()
307 }
308
309 pub fn ids(&self) -> impl Iterator<Item = &str> {
311 self.sources.keys().map(|s| s.as_str())
312 }
313
314 pub fn with_capability(&self, capability: SourceCapabilities) -> Vec<&Arc<dyn Source>> {
316 self.all()
317 .filter(|s| s.capabilities().contains(capability))
318 .collect()
319 }
320
321 pub fn searchable(&self) -> Vec<&Arc<dyn Source>> {
323 self.with_capability(SourceCapabilities::SEARCH)
324 }
325
326 pub fn downloadable(&self) -> Vec<&Arc<dyn Source>> {
328 self.with_capability(SourceCapabilities::DOWNLOAD)
329 }
330
331 pub fn with_citations(&self) -> Vec<&Arc<dyn Source>> {
333 self.with_capability(SourceCapabilities::CITATIONS)
334 }
335
336 pub fn has(&self, id: &str) -> bool {
338 self.sources.contains_key(id)
339 }
340
341 pub fn len(&self) -> usize {
343 self.sources.len()
344 }
345
346 pub fn is_empty(&self) -> bool {
348 self.sources.is_empty()
349 }
350}
351
352impl Default for SourceRegistry {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use std::sync::{Mutex, OnceLock};
362
363 fn env_lock() -> &'static Mutex<()> {
364 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
365 LOCK.get_or_init(|| Mutex::new(()))
366 }
367
368 #[test]
369 fn test_registry_basic() {
370 let registry = SourceRegistry::new();
371
372 assert!(!registry.is_empty());
374 }
375
376 #[test]
377 fn test_get_source() {
378 let registry = SourceRegistry::new();
379
380 let arxiv = registry.get("arxiv");
381 if let Some(arxiv) = arxiv {
383 assert_eq!(arxiv.id(), "arxiv");
384 }
385
386 let missing = registry.get("nonexistent");
387 assert!(missing.is_none());
388 }
389
390 fn with_source_env_vars<F>(enabled: Option<&str>, disabled: Option<&str>, test: F)
392 where
393 F: FnOnce(),
394 {
395 let _guard = env_lock().lock().expect("env lock poisoned");
396 let orig_enabled = std::env::var("RESEARCH_MASTER_ENABLED_SOURCES").ok();
398 let orig_disabled = std::env::var("RESEARCH_MASTER_DISABLED_SOURCES").ok();
399
400 match enabled {
402 Some(v) => std::env::set_var("RESEARCH_MASTER_ENABLED_SOURCES", v),
403 None => std::env::remove_var("RESEARCH_MASTER_ENABLED_SOURCES"),
404 }
405 match disabled {
406 Some(v) => std::env::set_var("RESEARCH_MASTER_DISABLED_SOURCES", v),
407 None => std::env::remove_var("RESEARCH_MASTER_DISABLED_SOURCES"),
408 }
409
410 test();
412
413 match orig_enabled {
415 Some(v) => std::env::set_var("RESEARCH_MASTER_ENABLED_SOURCES", v),
416 None => std::env::remove_var("RESEARCH_MASTER_ENABLED_SOURCES"),
417 }
418 match orig_disabled {
419 Some(v) => std::env::set_var("RESEARCH_MASTER_DISABLED_SOURCES", v),
420 None => std::env::remove_var("RESEARCH_MASTER_DISABLED_SOURCES"),
421 }
422 }
423
424 #[test]
425 fn test_source_filter_only_enabled() {
426 with_source_env_vars(Some("arxiv,pubmed"), None, || {
428 let config = crate::config::get_config().sources;
429 let filter = SourceFilter::from_config(&config);
430 assert!(filter.is_enabled("arxiv"));
431 assert!(filter.is_enabled("pubmed"));
432 assert!(!filter.is_enabled("semantic"));
433 assert!(filter.is_enabled("ARXIV")); });
435 }
436
437 #[test]
438 fn test_source_filter_only_disabled() {
439 with_source_env_vars(None, Some("dblp,jstor"), || {
441 let config = crate::config::get_config().sources;
442 let filter = SourceFilter::from_config(&config);
443 assert!(filter.is_enabled("arxiv"));
444 assert!(filter.is_enabled("pubmed"));
445 assert!(!filter.is_enabled("dblp"));
446 assert!(!filter.is_enabled("jstor"));
447 assert!(!filter.is_enabled("DBLP")); });
449 }
450
451 #[test]
452 fn test_source_filter_both_enabled_and_disabled() {
453 with_source_env_vars(Some("arxiv,pubmed,semantic,dblp"), Some("dblp"), || {
455 let config = crate::config::get_config().sources;
456 let filter = SourceFilter::from_config(&config);
457 assert!(filter.is_enabled("arxiv"));
458 assert!(filter.is_enabled("pubmed"));
459 assert!(filter.is_enabled("semantic"));
460 assert!(!filter.is_enabled("dblp")); });
462 }
463
464 #[test]
465 fn test_source_filter_neither() {
466 with_source_env_vars(None, None, || {
468 let config = crate::config::get_config().sources;
469 let filter = SourceFilter::from_config(&config);
470 assert!(filter.is_enabled("arxiv"));
471 assert!(filter.is_enabled("pubmed"));
472 assert!(filter.is_enabled("semantic"));
473 assert!(filter.is_enabled("dblp"));
474 });
475 }
476
477 #[test]
478 fn test_source_filter_empty_values() {
479 with_source_env_vars(Some(""), Some(""), || {
481 let config = crate::config::get_config().sources;
482 let filter = SourceFilter::from_config(&config);
483 assert!(filter.is_enabled("arxiv"));
485 assert!(filter.is_enabled("pubmed"));
486 });
487 }
488
489 #[test]
490 fn test_searchable_sources() {
491 let registry = SourceRegistry::new();
492
493 let searchable = registry.searchable();
494 assert!(!searchable.is_empty());
496 }
497
498 #[test]
499 fn test_capabilities() {
500 let registry = SourceRegistry::new();
501
502 if let Some(arxiv) = registry.get("arxiv") {
504 assert!(arxiv.capabilities().contains(SourceCapabilities::SEARCH));
505 assert!(arxiv.capabilities().contains(SourceCapabilities::DOWNLOAD));
506 assert!(arxiv.capabilities().contains(SourceCapabilities::READ));
507 }
508
509 if let Some(semantic) = registry.get("semantic") {
511 assert!(semantic
512 .capabilities()
513 .contains(SourceCapabilities::CITATIONS));
514 assert!(semantic
515 .capabilities()
516 .contains(SourceCapabilities::AUTHOR_SEARCH));
517 }
518
519 if let Some(dblp) = registry.get("dblp") {
521 assert!(dblp.capabilities().contains(SourceCapabilities::SEARCH));
522 assert!(!dblp.capabilities().contains(SourceCapabilities::DOWNLOAD));
523 }
524 }
525}