Skip to main content

crossref_lib/
client.rs

1use std::sync::Arc;
2
3use crate::config::Config;
4use crate::error::{CrossrefError, Result};
5use crate::models::{SearchQuery, SearchResult, WorkMeta};
6use crate::utils::normalise_doi;
7
8/// Unpaywall OA record for a single DOI.
9#[derive(Debug, serde::Deserialize)]
10pub struct UnpaywallRecord {
11    pub is_oa: bool,
12    pub oa_status: String,
13    /// Best available open-access PDF URL, if any.
14    pub best_oa_location: Option<UnpaywallLocation>,
15}
16
17#[derive(Debug, serde::Deserialize)]
18pub struct UnpaywallLocation {
19    pub url_for_pdf: Option<String>,
20}
21
22/// High-level API client.
23///
24/// Wraps the synchronous [`crossref::Crossref`] client (created fresh per
25/// blocking call) and a shared `reqwest::Client` for async Unpaywall queries.
26pub struct CrossrefClient {
27    config: Arc<Config>,
28    http: reqwest::Client,
29    /// Optional override for the Crossref API base URL (used in tests).
30    crossref_base_url: Option<String>,
31    /// Optional override for the Unpaywall API base URL (used in tests).
32    unpaywall_base_url: Option<String>,
33}
34
35impl CrossrefClient {
36    /// Construct a new client from the resolved configuration.
37    pub fn new(config: Arc<Config>) -> Result<Self> {
38        let http = reqwest::Client::builder()
39            .user_agent(build_user_agent(&config))
40            .build()?;
41        Ok(Self { config, http, crossref_base_url: None, unpaywall_base_url: None })
42    }
43
44    /// Construct a client with custom base URLs (for testing and integration environments).
45    pub fn new_with_base_urls(
46        config: Arc<Config>,
47        crossref_url: Option<String>,
48        unpaywall_url: Option<String>,
49    ) -> Result<Self> {
50        let http = reqwest::Client::builder()
51            .user_agent(build_user_agent(&config))
52            .build()?;
53        Ok(Self {
54            config,
55            http,
56            crossref_base_url: crossref_url,
57            unpaywall_base_url: unpaywall_url,
58        })
59    }
60
61    /// Alias for tests.
62    #[cfg(test)]
63    pub fn new_for_test(
64        config: Arc<Config>,
65        crossref_url: Option<String>,
66        unpaywall_url: Option<String>,
67    ) -> Result<Self> {
68        Self::new_with_base_urls(config, crossref_url, unpaywall_url)
69    }
70
71    // ─── Crossref API ───────────────────────────────────────────────────────
72
73    /// Fetch metadata for a single DOI, then enrich with Unpaywall OA data.
74    pub async fn fetch_work(&self, doi: &str) -> Result<WorkMeta> {
75        let doi = normalise_doi(doi);
76        let email = self.config.email.clone();
77        let base_url = self.crossref_base_url.clone();
78
79        // `crossref::Crossref` is !Send (uses Rc<Client>), so build it inside
80        // the blocking thread each time.
81        let work = tokio::task::spawn_blocking(move || {
82            let mut builder = crossref::Crossref::builder();
83            if let Some(ref e) = email {
84                builder = builder.polite(e.as_str());
85            }
86            let mut client = builder
87                .build()
88                .map_err(|e| CrossrefError::Api(e.to_string()))?;
89            // Allow overriding the base URL for tests
90            if let Some(url) = base_url {
91                client.base_url = url;
92            }
93            client
94                .work(&doi)
95                .map_err(|e| CrossrefError::Api(e.to_string()))
96        })
97        .await
98        .map_err(|e| CrossrefError::Api(e.to_string()))??;
99
100        let mut meta = map_work(work);
101
102        // Auto-enrich with Unpaywall OA data; failures are non-fatal
103        match self.fetch_unpaywall(&meta.doi).await {
104            Ok(oa) => {
105                meta.is_oa = Some(oa.is_oa);
106                meta.oa_status = Some(oa.oa_status);
107                meta.pdf_url = oa.best_oa_location.and_then(|loc| loc.url_for_pdf);
108            }
109            Err(e) => {
110                eprintln!("warning: Unpaywall enrichment failed: {e}");
111            }
112        }
113
114        Ok(meta)
115    }
116
117    /// Fetch metadata for multiple DOIs, returning results in order.
118    pub async fn fetch_works(&self, dois: &[&str]) -> Vec<Result<WorkMeta>> {
119        let mut results = Vec::with_capacity(dois.len());
120        for doi in dois {
121            results.push(self.fetch_work(doi).await);
122        }
123        results
124    }
125
126    /// Execute a search query and return a page of results.
127    pub async fn search(&self, query: &SearchQuery) -> Result<SearchResult> {
128        let query = query.clone();
129        let email = self.config.email.clone();
130        let base_url = self.crossref_base_url.clone();
131
132        let work_list = tokio::task::spawn_blocking(move || {
133            let mut builder = crossref::Crossref::builder();
134            if let Some(ref e) = email {
135                builder = builder.polite(e.as_str());
136            }
137            let mut client = builder
138                .build()
139                .map_err(|e| CrossrefError::Api(e.to_string()))?;
140            if let Some(url) = base_url {
141                client.base_url = url;
142            }
143
144            let wq = build_works_query(&query);
145            let result = client
146                .works(wq)
147                .map_err(|e| CrossrefError::Api(e.to_string()))?;
148            Ok::<_, CrossrefError>(result)
149        })
150        .await
151        .map_err(|e| CrossrefError::Api(e.to_string()))??;
152
153        let total_results = work_list.total_results as u64;
154        let items = work_list.items.into_iter().map(map_work).collect();
155        Ok(SearchResult { items, total_results })
156    }
157
158    // ─── Unpaywall API ──────────────────────────────────────────────────────
159
160    /// Query Unpaywall for OA information about a DOI.
161    pub async fn fetch_unpaywall(&self, doi: &str) -> Result<UnpaywallRecord> {
162        let doi = normalise_doi(doi);
163        let email = self
164            .config
165            .email
166            .as_deref()
167            .unwrap_or("anonymous@example.com")
168            .to_string();
169        let base = self
170            .unpaywall_base_url
171            .as_deref()
172            .unwrap_or("https://api.unpaywall.org/v2");
173        let url = format!("{base}/{doi}?email={email}");
174        let record: UnpaywallRecord = self
175            .http
176            .get(&url)
177            .send()
178            .await?
179            .json()
180            .await
181            .map_err(|e| CrossrefError::Unpaywall(e.to_string()))?;
182        Ok(record)
183    }
184
185    /// Download the best OA PDF to `dest_dir` / `<DOI>.pdf`.
186    ///
187    /// Falls back to EZproxy if the direct URL returns a non-200 status,
188    /// and finally returns a `https://doi.org/{doi}` link if no PDF is
189    /// accessible.
190    ///
191    /// Returns the path where the file was written, or the best-effort URL
192    /// if the PDF was not downloaded.
193    pub async fn download_pdf(
194        &self,
195        doi: &str,
196        dest_dir: &std::path::Path,
197    ) -> Result<std::path::PathBuf> {
198        let norm_doi = normalise_doi(doi);
199
200        let record = self.fetch_unpaywall(&norm_doi).await?;
201        let pdf_url = record
202            .best_oa_location
203            .and_then(|loc| loc.url_for_pdf);
204
205        let safe_doi = norm_doi.replace('/', "_");
206        let dest = dest_dir.join(format!("{safe_doi}.pdf"));
207
208        // Try direct PDF URL
209        if let Some(ref url) = pdf_url {
210            if let Ok(resp) = self.http.get(url).send().await {
211                if resp.status().is_success() {
212                    if let Ok(bytes) = resp.bytes().await {
213                        if is_pdf(&bytes) {
214                            std::fs::write(&dest, &bytes)?;
215                            return Ok(dest);
216                        }
217                        // Response was HTML (landing page / paywall) — try fallbacks
218                    }
219                }
220            }
221
222            // EZproxy fallback
223            if let Some(ref proxy) = self.config.proxy {
224                let proxy_url = format!("https://{proxy}/doi/{norm_doi}");
225                if let Ok(resp) = self.http.get(&proxy_url).send().await {
226                    if resp.status().is_success() {
227                        if let Ok(bytes) = resp.bytes().await {
228                            if is_pdf(&bytes) {
229                                std::fs::write(&dest, &bytes)?;
230                                return Ok(dest);
231                            }
232                        }
233                    }
234                }
235            }
236        } else if let Some(ref proxy) = self.config.proxy {
237            // No direct URL but proxy is configured — try proxy
238            let proxy_url = format!("https://{proxy}/doi/{norm_doi}");
239            if let Ok(resp) = self.http.get(&proxy_url).send().await {
240                if resp.status().is_success() {
241                    if let Ok(bytes) = resp.bytes().await {
242                        if is_pdf(&bytes) {
243                            std::fs::write(&dest, &bytes)?;
244                            return Ok(dest);
245                        }
246                    }
247                }
248            }
249        }
250
251        // Best-effort: return doi.org link as a path placeholder
252        Err(CrossrefError::PdfDownload(format!(
253            "no downloadable PDF found; try https://doi.org/{norm_doi}"
254        )))
255    }
256}
257
258// ─── Helper functions ────────────────────────────────────────────────────────
259
260/// Return `true` if `bytes` begins with the PDF magic number `%PDF-`.
261///
262/// Publishers often serve HTML landing pages (with 200 OK) at a URL that
263/// is labelled as a PDF link; checking the magic bytes avoids saving those.
264fn is_pdf(bytes: &[u8]) -> bool {
265    bytes.starts_with(b"%PDF-")
266}
267
268/// Build an appropriate `User-Agent` string.
269fn build_user_agent(config: &Config) -> String {
270    let version = env!("CARGO_PKG_VERSION");
271    match &config.email {
272        Some(email) => format!("crossref-rs/{version} (mailto:{email})"),
273        None => format!("crossref-rs/{version}"),
274    }
275}
276
277/// Map a `crossref::Work` into our `WorkMeta` model.
278fn map_work(w: crossref::Work) -> WorkMeta {
279    let title = w.title.into_iter().next();
280
281    let authors: Vec<String> = w
282        .author
283        .unwrap_or_default()
284        .into_iter()
285        .map(|c| match c.given {
286            Some(given) => format!("{}, {}", c.family, given),
287            None => c.family,
288        })
289        .collect();
290
291    // DateField is not in crossref's public API; extract year directly from the
292    // raw date-parts nested array: [[year, month?, day?], …]
293    let year = w
294        .issued
295        .date_parts
296        .0
297        .first()
298        .and_then(|parts| parts.first())
299        .and_then(|opt_y| *opt_y)
300        .map(|y| y as i32);
301
302    let journal = w
303        .container_title
304        .and_then(|v| v.into_iter().next());
305
306    WorkMeta {
307        doi: w.doi,
308        title,
309        authors,
310        year,
311        journal,
312        volume: w.volume,
313        issue: w.issue,
314        pages: w.page,
315        publisher: Some(w.publisher),
316        work_type: Some(w.type_),
317        is_oa: None,
318        oa_status: None,
319        pdf_url: None,
320    }
321}
322
323/// Translate our `SearchQuery` into a `crossref::WorksQuery`.
324fn build_works_query(q: &SearchQuery) -> crossref::WorksQuery {
325    use chrono::NaiveDate;
326
327    // The crossref crate's FieldQuery serialises as "title=value" instead of the
328    // correct "query.title=value", causing a validation-failure from the REST API.
329    // Combine all text inputs into the free-form query= parameter instead.
330    let mut term_parts: Vec<&str> = Vec::new();
331    if let Some(ref t) = q.query  { term_parts.push(t.as_str()); }
332    if let Some(ref t) = q.title  { term_parts.push(t.as_str()); }
333    if let Some(ref a) = q.author { term_parts.push(a.as_str()); }
334    let term = term_parts.join(" ");
335
336    let mut wq = if term.is_empty() {
337        crossref::WorksQuery::empty()
338    } else {
339        crossref::WorksQuery::new(term)
340    };
341
342    wq = wq.result_control(crossref::WorkResultControl::Standard(
343        crossref::query::ResultControl::Rows(q.rows as usize),
344    ));
345
346    if let Some(ref sort) = q.sort {
347        let sort_val = match sort.as_str() {
348            "score" => crossref::Sort::Score,
349            "updated" => crossref::Sort::Updated,
350            "deposited" => crossref::Sort::Deposited,
351            "indexed" => crossref::Sort::Indexed,
352            "published" => crossref::Sort::Published,
353            _ => crossref::Sort::Score,
354        };
355        wq = wq.sort(sort_val);
356    }
357
358    // Date range filters
359    if let Some(year) = q.year_from {
360        if let Some(date) = NaiveDate::from_ymd_opt(year, 1, 1) {
361            wq = wq.filter(crossref::WorksFilter::FromPubDate(date));
362        }
363    }
364    if let Some(year) = q.year_to {
365        if let Some(date) = NaiveDate::from_ymd_opt(year, 12, 31) {
366            wq = wq.filter(crossref::WorksFilter::UntilPubDate(date));
367        }
368    }
369
370    // Work type filter — parse to the Type enum; fall back to TypeName on unknown strings
371    if let Some(ref work_type) = q.work_type {
372        use std::str::FromStr;
373        if let Ok(t) = crossref::Type::from_str(work_type.as_str()) {
374            wq = wq.filter(crossref::WorksFilter::Type(t));
375        } else {
376            // Unknown type string: pass as-is via TypeName (best-effort)
377            wq = wq.filter(crossref::WorksFilter::TypeName(work_type.clone()));
378        }
379    }
380
381    // Open-access filter (proxy: has-license)
382    if q.open_access {
383        wq = wq.filter(crossref::WorksFilter::HasLicense);
384    }
385
386    wq
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::models::SearchQueryBuilder;
393
394    #[test]
395    fn test_build_works_query_filters() {
396        // Ensure that filter methods on SearchQuery are forwarded to WorksQuery
397        // We verify this by building a query and checking no panic occurs.
398        let q = SearchQueryBuilder::default()
399            .query(Some("machine learning".to_string()))
400            .year_from(Some(2020))
401            .year_to(Some(2023))
402            .work_type(Some("journal-article".to_string()))
403            .open_access(true)
404            .rows(5u32)
405            .build()
406            .unwrap();
407
408        // build_works_query should not panic
409        let _wq = build_works_query(&q);
410    }
411
412    #[test]
413    fn test_build_works_query_no_filters() {
414        let q = SearchQueryBuilder::default()
415            .query(Some("test".to_string()))
416            .rows(10u32)
417            .build()
418            .unwrap();
419        let _wq = build_works_query(&q);
420    }
421}