1use std::sync::Arc;
2
3use crate::config::Config;
4use crate::error::{CrossrefError, Result};
5use crate::models::{SearchQuery, SearchResult, WorkMeta};
6use crate::utils::normalise_doi;
7
8#[derive(Debug, serde::Deserialize)]
10pub struct UnpaywallRecord {
11 pub is_oa: bool,
12 pub oa_status: String,
13 pub best_oa_location: Option<UnpaywallLocation>,
15}
16
17#[derive(Debug, serde::Deserialize)]
18pub struct UnpaywallLocation {
19 pub url_for_pdf: Option<String>,
20}
21
22pub struct CrossrefClient {
27 config: Arc<Config>,
28 http: reqwest::Client,
29 crossref_base_url: Option<String>,
31 unpaywall_base_url: Option<String>,
33}
34
35impl CrossrefClient {
36 pub fn new(config: Arc<Config>) -> Result<Self> {
38 let http = reqwest::Client::builder()
39 .user_agent(build_user_agent(&config))
40 .build()?;
41 Ok(Self { config, http, crossref_base_url: None, unpaywall_base_url: None })
42 }
43
44 pub fn new_with_base_urls(
46 config: Arc<Config>,
47 crossref_url: Option<String>,
48 unpaywall_url: Option<String>,
49 ) -> Result<Self> {
50 let http = reqwest::Client::builder()
51 .user_agent(build_user_agent(&config))
52 .build()?;
53 Ok(Self {
54 config,
55 http,
56 crossref_base_url: crossref_url,
57 unpaywall_base_url: unpaywall_url,
58 })
59 }
60
61 #[cfg(test)]
63 pub fn new_for_test(
64 config: Arc<Config>,
65 crossref_url: Option<String>,
66 unpaywall_url: Option<String>,
67 ) -> Result<Self> {
68 Self::new_with_base_urls(config, crossref_url, unpaywall_url)
69 }
70
71 pub async fn fetch_work(&self, doi: &str) -> Result<WorkMeta> {
75 let doi = normalise_doi(doi);
76 let email = self.config.email.clone();
77 let base_url = self.crossref_base_url.clone();
78
79 let work = tokio::task::spawn_blocking(move || {
82 let mut builder = crossref::Crossref::builder();
83 if let Some(ref e) = email {
84 builder = builder.polite(e.as_str());
85 }
86 let mut client = builder
87 .build()
88 .map_err(|e| CrossrefError::Api(e.to_string()))?;
89 if let Some(url) = base_url {
91 client.base_url = url;
92 }
93 client
94 .work(&doi)
95 .map_err(|e| CrossrefError::Api(e.to_string()))
96 })
97 .await
98 .map_err(|e| CrossrefError::Api(e.to_string()))??;
99
100 let mut meta = map_work(work);
101
102 match self.fetch_unpaywall(&meta.doi).await {
104 Ok(oa) => {
105 meta.is_oa = Some(oa.is_oa);
106 meta.oa_status = Some(oa.oa_status);
107 meta.pdf_url = oa.best_oa_location.and_then(|loc| loc.url_for_pdf);
108 }
109 Err(e) => {
110 eprintln!("warning: Unpaywall enrichment failed: {e}");
111 }
112 }
113
114 Ok(meta)
115 }
116
117 pub async fn fetch_works(&self, dois: &[&str]) -> Vec<Result<WorkMeta>> {
119 let mut results = Vec::with_capacity(dois.len());
120 for doi in dois {
121 results.push(self.fetch_work(doi).await);
122 }
123 results
124 }
125
126 pub async fn search(&self, query: &SearchQuery) -> Result<SearchResult> {
128 let query = query.clone();
129 let email = self.config.email.clone();
130 let base_url = self.crossref_base_url.clone();
131
132 let work_list = tokio::task::spawn_blocking(move || {
133 let mut builder = crossref::Crossref::builder();
134 if let Some(ref e) = email {
135 builder = builder.polite(e.as_str());
136 }
137 let mut client = builder
138 .build()
139 .map_err(|e| CrossrefError::Api(e.to_string()))?;
140 if let Some(url) = base_url {
141 client.base_url = url;
142 }
143
144 let wq = build_works_query(&query);
145 let result = client
146 .works(wq)
147 .map_err(|e| CrossrefError::Api(e.to_string()))?;
148 Ok::<_, CrossrefError>(result)
149 })
150 .await
151 .map_err(|e| CrossrefError::Api(e.to_string()))??;
152
153 let total_results = work_list.total_results as u64;
154 let items = work_list.items.into_iter().map(map_work).collect();
155 Ok(SearchResult { items, total_results })
156 }
157
158 pub async fn fetch_unpaywall(&self, doi: &str) -> Result<UnpaywallRecord> {
162 let doi = normalise_doi(doi);
163 let email = self
164 .config
165 .email
166 .as_deref()
167 .unwrap_or("anonymous@example.com")
168 .to_string();
169 let base = self
170 .unpaywall_base_url
171 .as_deref()
172 .unwrap_or("https://api.unpaywall.org/v2");
173 let url = format!("{base}/{doi}?email={email}");
174 let record: UnpaywallRecord = self
175 .http
176 .get(&url)
177 .send()
178 .await?
179 .json()
180 .await
181 .map_err(|e| CrossrefError::Unpaywall(e.to_string()))?;
182 Ok(record)
183 }
184
185 pub async fn download_pdf(
194 &self,
195 doi: &str,
196 dest_dir: &std::path::Path,
197 ) -> Result<std::path::PathBuf> {
198 let norm_doi = normalise_doi(doi);
199
200 let record = self.fetch_unpaywall(&norm_doi).await?;
201 let pdf_url = record
202 .best_oa_location
203 .and_then(|loc| loc.url_for_pdf);
204
205 let safe_doi = norm_doi.replace('/', "_");
206 let dest = dest_dir.join(format!("{safe_doi}.pdf"));
207
208 if let Some(ref url) = pdf_url {
210 if let Ok(resp) = self.http.get(url).send().await {
211 if resp.status().is_success() {
212 if let Ok(bytes) = resp.bytes().await {
213 if is_pdf(&bytes) {
214 std::fs::write(&dest, &bytes)?;
215 return Ok(dest);
216 }
217 }
219 }
220 }
221
222 if let Some(ref proxy) = self.config.proxy {
224 let proxy_url = format!("https://{proxy}/doi/{norm_doi}");
225 if let Ok(resp) = self.http.get(&proxy_url).send().await {
226 if resp.status().is_success() {
227 if let Ok(bytes) = resp.bytes().await {
228 if is_pdf(&bytes) {
229 std::fs::write(&dest, &bytes)?;
230 return Ok(dest);
231 }
232 }
233 }
234 }
235 }
236 } else if let Some(ref proxy) = self.config.proxy {
237 let proxy_url = format!("https://{proxy}/doi/{norm_doi}");
239 if let Ok(resp) = self.http.get(&proxy_url).send().await {
240 if resp.status().is_success() {
241 if let Ok(bytes) = resp.bytes().await {
242 if is_pdf(&bytes) {
243 std::fs::write(&dest, &bytes)?;
244 return Ok(dest);
245 }
246 }
247 }
248 }
249 }
250
251 Err(CrossrefError::PdfDownload(format!(
253 "no downloadable PDF found; try https://doi.org/{norm_doi}"
254 )))
255 }
256}
257
258fn is_pdf(bytes: &[u8]) -> bool {
265 bytes.starts_with(b"%PDF-")
266}
267
268fn build_user_agent(config: &Config) -> String {
270 let version = env!("CARGO_PKG_VERSION");
271 match &config.email {
272 Some(email) => format!("crossref-rs/{version} (mailto:{email})"),
273 None => format!("crossref-rs/{version}"),
274 }
275}
276
277fn map_work(w: crossref::Work) -> WorkMeta {
279 let title = w.title.into_iter().next();
280
281 let authors: Vec<String> = w
282 .author
283 .unwrap_or_default()
284 .into_iter()
285 .map(|c| match c.given {
286 Some(given) => format!("{}, {}", c.family, given),
287 None => c.family,
288 })
289 .collect();
290
291 let year = w
294 .issued
295 .date_parts
296 .0
297 .first()
298 .and_then(|parts| parts.first())
299 .and_then(|opt_y| *opt_y)
300 .map(|y| y as i32);
301
302 let journal = w
303 .container_title
304 .and_then(|v| v.into_iter().next());
305
306 WorkMeta {
307 doi: w.doi,
308 title,
309 authors,
310 year,
311 journal,
312 volume: w.volume,
313 issue: w.issue,
314 pages: w.page,
315 publisher: Some(w.publisher),
316 work_type: Some(w.type_),
317 is_oa: None,
318 oa_status: None,
319 pdf_url: None,
320 }
321}
322
323fn build_works_query(q: &SearchQuery) -> crossref::WorksQuery {
325 use chrono::NaiveDate;
326
327 let mut term_parts: Vec<&str> = Vec::new();
331 if let Some(ref t) = q.query { term_parts.push(t.as_str()); }
332 if let Some(ref t) = q.title { term_parts.push(t.as_str()); }
333 if let Some(ref a) = q.author { term_parts.push(a.as_str()); }
334 let term = term_parts.join(" ");
335
336 let mut wq = if term.is_empty() {
337 crossref::WorksQuery::empty()
338 } else {
339 crossref::WorksQuery::new(term)
340 };
341
342 wq = wq.result_control(crossref::WorkResultControl::Standard(
343 crossref::query::ResultControl::Rows(q.rows as usize),
344 ));
345
346 if let Some(ref sort) = q.sort {
347 let sort_val = match sort.as_str() {
348 "score" => crossref::Sort::Score,
349 "updated" => crossref::Sort::Updated,
350 "deposited" => crossref::Sort::Deposited,
351 "indexed" => crossref::Sort::Indexed,
352 "published" => crossref::Sort::Published,
353 _ => crossref::Sort::Score,
354 };
355 wq = wq.sort(sort_val);
356 }
357
358 if let Some(year) = q.year_from {
360 if let Some(date) = NaiveDate::from_ymd_opt(year, 1, 1) {
361 wq = wq.filter(crossref::WorksFilter::FromPubDate(date));
362 }
363 }
364 if let Some(year) = q.year_to {
365 if let Some(date) = NaiveDate::from_ymd_opt(year, 12, 31) {
366 wq = wq.filter(crossref::WorksFilter::UntilPubDate(date));
367 }
368 }
369
370 if let Some(ref work_type) = q.work_type {
372 use std::str::FromStr;
373 if let Ok(t) = crossref::Type::from_str(work_type.as_str()) {
374 wq = wq.filter(crossref::WorksFilter::Type(t));
375 } else {
376 wq = wq.filter(crossref::WorksFilter::TypeName(work_type.clone()));
378 }
379 }
380
381 if q.open_access {
383 wq = wq.filter(crossref::WorksFilter::HasLicense);
384 }
385
386 wq
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::models::SearchQueryBuilder;
393
394 #[test]
395 fn test_build_works_query_filters() {
396 let q = SearchQueryBuilder::default()
399 .query(Some("machine learning".to_string()))
400 .year_from(Some(2020))
401 .year_to(Some(2023))
402 .work_type(Some("journal-article".to_string()))
403 .open_access(true)
404 .rows(5u32)
405 .build()
406 .unwrap();
407
408 let _wq = build_works_query(&q);
410 }
411
412 #[test]
413 fn test_build_works_query_no_filters() {
414 let q = SearchQueryBuilder::default()
415 .query(Some("test".to_string()))
416 .rows(10u32)
417 .build()
418 .unwrap();
419 let _wq = build_works_query(&q);
420 }
421}