mod error;
pub use error::PubChemError;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use moka::future::Cache;
use serde::Deserialize;
use urlencoding::encode;
use crate::error::{HsPredictError, Result};
use crate::types::SubstanceIdentifier;
const BASE_URL: &str = "https://pubchem.ncbi.nlm.nih.gov/rest/pug";
const PROPERTIES: &str =
"IUPACName,CanonicalSMILES,InChIKey,InChI,MolecularFormula,MolecularWeight";
#[derive(Debug, Clone)]
pub struct PubChemCompound {
pub cid: u64,
pub iupac_name: Option<String>,
pub canonical_smiles: Option<String>,
pub inchi: Option<String>,
pub inchi_key: Option<String>,
pub molecular_formula: Option<String>,
pub molecular_weight: Option<f64>,
}
impl PubChemCompound {
pub fn apply_to(&self, id: &mut SubstanceIdentifier) {
id.cid = Some(self.cid);
if id.smiles.is_none() {
id.smiles = self.canonical_smiles.clone();
}
if id.iupac_name.is_none() {
id.iupac_name = self.iupac_name.clone();
}
if id.inchi.is_none() {
id.inchi = self.inchi.clone();
}
if id.inchi_key.is_none() {
id.inchi_key = self.inchi_key.clone();
}
}
}
#[derive(Clone)]
pub struct PubChemClient {
http: reqwest::Client,
cache: Cache<u64, Arc<PubChemCompound>>,
limiter: Arc<DefaultDirectRateLimiter>,
base_url: String,
}
impl std::fmt::Debug for PubChemClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PubChemClient")
.field("base_url", &self.base_url)
.finish_non_exhaustive()
}
}
impl Default for PubChemClient {
fn default() -> Self {
Self::new()
}
}
impl PubChemClient {
pub fn new() -> Self {
Self::builder().build()
}
pub fn builder() -> PubChemClientBuilder {
PubChemClientBuilder::default()
}
pub async fn lookup(&self, id: &SubstanceIdentifier) -> Result<PubChemCompound> {
if let Some(cid) = id.cid {
if let Some(cached) = self.cache.get(&cid).await {
return Ok((*cached).clone());
}
}
let (namespace, input) = Self::pick_namespace(id)
.ok_or(PubChemError::NoUsableIdentifier)?;
self.fetch(namespace, &input).await
}
pub async fn enrich(&self, id: &mut SubstanceIdentifier) -> Result<()> {
match self.lookup(id).await {
Ok(compound) => {
compound.apply_to(id);
Ok(())
}
Err(HsPredictError::PubChem(PubChemError::NotFound { .. }))
| Err(HsPredictError::PubChem(PubChemError::NoUsableIdentifier)) => Ok(()),
Err(e) => Err(e),
}
}
fn pick_namespace(id: &SubstanceIdentifier) -> Option<(&'static str, String)> {
if let Some(ref cas) = id.cas {
return Some(("name", cas.clone()));
}
if let Some(ref key) = id.inchi_key {
return Some(("inchikey", key.clone()));
}
if let Some(ref inchi) = id.inchi {
return Some(("inchi", inchi.clone()));
}
if let Some(ref smiles) = id.smiles {
return Some(("smiles", smiles.clone()));
}
if let Some(ref name) = id.iupac_name {
return Some(("name", name.clone()));
}
None
}
async fn fetch(&self, namespace: &str, input: &str) -> Result<PubChemCompound> {
self.limiter.until_ready().await;
let url = format!(
"{base}/compound/{ns}/{enc}/property/{props}/JSON",
base = self.base_url,
ns = namespace,
enc = encode(input),
props = PROPERTIES,
);
let resp = self
.http
.get(&url)
.send()
.await
.map_err(|e| PubChemError::Http(e.to_string()))?;
match resp.status().as_u16() {
200 => {}
404 => return Err(PubChemError::NotFound { input: input.to_string() }.into()),
429 => return Err(PubChemError::RateLimitExceeded.into()),
code => {
return Err(PubChemError::Http(format!("HTTP {code}")).into());
}
}
let body: PugPropertyResponse = resp
.json()
.await
.map_err(|e| PubChemError::Parse(e.to_string()))?;
let props = body
.property_table
.properties
.into_iter()
.next()
.ok_or_else(|| PubChemError::NotFound { input: input.to_string() })?;
let compound = Arc::new(PubChemCompound {
cid: props.cid,
iupac_name: props.iupac_name,
canonical_smiles: props.canonical_smiles,
inchi: props.in_chi,
inchi_key: props.in_chi_key,
molecular_formula: props.molecular_formula,
molecular_weight: props.molecular_weight.as_deref().and_then(|s| s.parse().ok()),
});
self.cache.insert(compound.cid, Arc::clone(&compound)).await;
Ok((*compound).clone())
}
}
pub struct PubChemClientBuilder {
requests_per_second: u32,
cache_capacity: u64,
cache_ttl: Duration,
base_url: String,
user_agent: String,
}
impl Default for PubChemClientBuilder {
fn default() -> Self {
Self {
requests_per_second: 5,
cache_capacity: 1_000,
cache_ttl: Duration::from_secs(24 * 3600),
base_url: BASE_URL.to_string(),
user_agent: format!(
"hs-predict/{} ({})",
env!("CARGO_PKG_VERSION"),
env!("CARGO_PKG_REPOSITORY")
),
}
}
}
impl PubChemClientBuilder {
pub fn requests_per_second(mut self, n: u32) -> Self {
self.requests_per_second = n.max(1);
self
}
pub fn cache_capacity(mut self, n: u64) -> Self {
self.cache_capacity = n;
self
}
pub fn cache_ttl(mut self, ttl: Duration) -> Self {
self.cache_ttl = ttl;
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn build(self) -> PubChemClient {
let quota = Quota::per_second(
NonZeroU32::new(self.requests_per_second)
.expect("requests_per_second must be ≥ 1"),
);
PubChemClient {
http: reqwest::Client::builder()
.user_agent(self.user_agent)
.build()
.expect("failed to build reqwest::Client"),
cache: Cache::builder()
.max_capacity(self.cache_capacity)
.time_to_live(self.cache_ttl)
.build(),
limiter: Arc::new(RateLimiter::direct(quota)),
base_url: self.base_url,
}
}
}
#[derive(Deserialize)]
struct PugPropertyResponse {
#[serde(rename = "PropertyTable")]
property_table: PropertyTable,
}
#[derive(Deserialize)]
struct PropertyTable {
#[serde(rename = "Properties")]
properties: Vec<CompoundProperty>,
}
#[derive(Deserialize)]
struct CompoundProperty {
#[serde(rename = "CID")]
cid: u64,
#[serde(rename = "IUPACName")]
iupac_name: Option<String>,
#[serde(rename = "CanonicalSMILES")]
canonical_smiles: Option<String>,
#[serde(rename = "InChI")]
in_chi: Option<String>,
#[serde(rename = "InChIKey")]
in_chi_key: Option<String>,
#[serde(rename = "MolecularFormula")]
molecular_formula: Option<String>,
#[serde(rename = "MolecularWeight")]
molecular_weight: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn client_builds_with_defaults() {
let client = PubChemClient::new();
assert_eq!(client.base_url, BASE_URL);
}
#[test]
fn builder_overrides_base_url() {
let client = PubChemClient::builder()
.base_url("http://localhost:8080")
.build();
assert_eq!(client.base_url, "http://localhost:8080");
}
#[test]
fn pick_namespace_cas_first() {
let id = SubstanceIdentifier {
cas: Some("1310-73-2".to_string()),
smiles: Some("[Na+].[OH-]".to_string()),
..Default::default()
};
let (ns, inp) = PubChemClient::pick_namespace(&id).unwrap();
assert_eq!(ns, "name");
assert_eq!(inp, "1310-73-2");
}
#[test]
fn pick_namespace_inchikey_when_no_cas() {
let id = SubstanceIdentifier {
inchi_key: Some("HEMHJVSKTPXQMS-UHFFFAOYSA-M".to_string()),
..Default::default()
};
let (ns, inp) = PubChemClient::pick_namespace(&id).unwrap();
assert_eq!(ns, "inchikey");
assert_eq!(inp, "HEMHJVSKTPXQMS-UHFFFAOYSA-M");
}
#[test]
fn pick_namespace_returns_none_for_empty_id() {
let id = SubstanceIdentifier::default();
assert!(PubChemClient::pick_namespace(&id).is_none());
}
#[test]
fn apply_to_fills_missing_fields_only() {
let compound = PubChemCompound {
cid: 14798,
iupac_name: Some("sodium hydroxide".to_string()),
canonical_smiles: Some("[Na+].[OH-]".to_string()),
inchi: Some("InChI=1S/Na.H2O/h;1H/q+1;/p-1".to_string()),
inchi_key: Some("HEMHJVSKTPXQMS-UHFFFAOYSA-M".to_string()),
molecular_formula: Some("HNaO".to_string()),
molecular_weight: Some(39.997),
};
let mut id = SubstanceIdentifier {
cas: Some("1310-73-2".to_string()),
smiles: Some("existing".to_string()), ..Default::default()
};
compound.apply_to(&mut id);
assert_eq!(id.cid, Some(14798));
assert_eq!(id.smiles.as_deref(), Some("existing")); assert_eq!(id.iupac_name.as_deref(), Some("sodium hydroxide")); assert_eq!(id.inchi_key.as_deref(), Some("HEMHJVSKTPXQMS-UHFFFAOYSA-M")); }
#[tokio::test]
#[ignore = "requires internet access"]
async fn integration_lookup_naoh_by_cas() {
let client = PubChemClient::new();
let id = SubstanceIdentifier::from_cas("1310-73-2");
let compound = client.lookup(&id).await.unwrap();
assert_eq!(compound.cid, 14798);
assert_eq!(
compound.canonical_smiles.as_deref(),
Some("[Na+].[OH-]")
);
assert_eq!(
compound.iupac_name.as_deref(),
Some("sodium hydroxide")
);
}
#[tokio::test]
#[ignore = "requires internet access"]
async fn integration_enrich_fills_smiles() {
let client = PubChemClient::new();
let mut id = SubstanceIdentifier::from_cas("67-64-1"); client.enrich(&mut id).await.unwrap();
assert!(id.smiles.is_some());
assert!(id.cid.is_some());
assert!(id.iupac_name.is_some());
}
}