use std::collections::HashMap;
use super::*;
pub mod json;
pub mod xml;
#[derive(Default, Debug, Clone)]
pub struct Retriever {
configs: HashMap<String, RetrieverConfig>,
}
impl Retriever {
pub fn is_empty(&self) -> bool { self.configs.is_empty() }
}
#[derive(Debug, Clone, Deserialize)]
pub struct RetrieverConfig {
pub name: String,
pub base_url: String,
#[serde(deserialize_with = "deserialize_regex")]
pub pattern: Regex,
pub source: String,
pub endpoint_template: String,
pub response_format: ResponseFormat,
#[serde(default)]
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "xml")]
Xml(xml::XmlConfig),
#[serde(rename = "json")]
Json(json::JsonConfig),
}
#[derive(Debug, Clone, Deserialize)]
pub struct FieldMap {
pub path: String,
#[serde(default)]
pub transform: Option<Transform>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum Transform {
Replace {
pattern: String,
replacement: String,
},
Date {
from_format: String,
to_format: String,
},
Url {
base: String,
suffix: Option<String>,
},
}
#[async_trait]
pub trait ResponseProcessor: Send + Sync {
async fn process_response(&self, data: &[u8]) -> Result<Paper>;
}
impl Retriever {
pub fn new() -> Self { Self::default() }
pub fn with_config(mut self, config: RetrieverConfig) {
self.configs.insert(config.name.clone(), config);
}
pub fn with_config_str(mut self, toml_str: &str) -> Result<Self> {
let config: RetrieverConfig = toml::from_str(toml_str)?;
self.configs.insert(config.name.clone(), config);
Ok(self)
}
pub fn with_config_file(self, path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
self.with_config_str(&content)
}
pub fn with_config_dir(self, dir: impl AsRef<Path>) -> Result<Self> {
let dir = dir.as_ref();
if !dir.is_dir() {
return Err(LearnerError::Path(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Config directory not found",
)));
}
let mut retriever = self;
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "toml") {
retriever = retriever.with_config_file(path)?;
}
}
Ok(retriever)
}
pub async fn get_paper(&self, input: &str) -> Result<Paper> {
let mut matches = Vec::new();
for config in self.configs.values() {
if config.pattern.is_match(input) {
matches.push(config);
}
}
match matches.len() {
0 => Err(LearnerError::InvalidIdentifier),
1 => matches[0].retrieve_paper(input).await,
_ => Err(LearnerError::AmbiguousIdentifier(
matches.into_iter().map(|c| c.name.clone()).collect(),
)),
}
}
pub fn sanitize_identifier(&self, input: &str) -> Result<(String, String)> {
let mut matches = Vec::new();
for config in self.configs.values() {
if config.pattern.is_match(input) {
matches.push((config.source.clone(), config.extract_identifier(input)?.to_string()));
}
}
match matches.len() {
0 => Err(LearnerError::InvalidIdentifier),
1 => Ok(matches.remove(0)),
_ => Err(LearnerError::AmbiguousIdentifier(
matches.into_iter().map(|(source, _)| source).collect(),
)),
}
}
}
impl RetrieverConfig {
pub fn extract_identifier<'a>(&self, input: &'a str) -> Result<&'a str> {
self
.pattern
.captures(input)
.and_then(|cap| cap.get(1))
.map(|m| m.as_str())
.ok_or(LearnerError::InvalidIdentifier)
}
pub async fn retrieve_paper(&self, input: &str) -> Result<Paper> {
let identifier = self.extract_identifier(input)?;
let url = self.endpoint_template.replace("{identifier}", identifier);
debug!("Fetching from {} via: {}", self.name, url);
let client = reqwest::Client::new();
let mut request = client.get(&url);
for (key, value) in &self.headers {
request = request.header(key, value);
}
let response = request.send().await?;
let data = response.bytes().await?;
trace!("{} response: {}", self.name, String::from_utf8_lossy(&data));
let response_processor = match &self.response_format {
ResponseFormat::Xml(config) => config as &dyn ResponseProcessor,
ResponseFormat::Json(config) => config as &dyn ResponseProcessor,
};
let mut paper = response_processor.process_response(&data).await?;
paper.source = self.source.clone();
paper.source_identifier = identifier.to_string();
Ok(paper)
}
}
fn deserialize_regex<'de, D>(deserializer: D) -> std::result::Result<Regex, D::Error>
where D: serde::Deserializer<'de> {
let s: String = String::deserialize(deserializer)?;
Regex::new(&s).map_err(serde::de::Error::custom)
}
fn apply_transform(value: &str, transform: &Transform) -> Result<String> {
match transform {
Transform::Replace { pattern, replacement } => Regex::new(pattern)
.map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e)))
.map(|re| re.replace_all(value, replacement.as_str()).into_owned()),
Transform::Date { from_format, to_format } =>
chrono::NaiveDateTime::parse_from_str(value, from_format)
.map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e)))
.map(|dt| dt.format(to_format).to_string()),
Transform::Url { base, suffix } =>
Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))),
}
}