#![cfg(feature = "citation")]
use std::collections::{HashSet, VecDeque};
use serde::Serialize;
use url::Url;
use crate::provenance::{Capability, LogError, LogEvent, LogResult, RowInput};
use crate::source::{FetchContext, FetchError, Source};
use crate::sources::openalex::OpenalexSource;
use crate::{CapabilityProfile, Doi, Ref};
#[derive(Debug, Clone, Copy)]
pub struct GraphCaps {
pub depth: usize,
pub total: usize,
pub per_paper: usize,
}
impl GraphCaps {
pub const MAX_DEPTH: usize = 3;
pub const MAX_TOTAL: usize = 100;
pub const MAX_PER_PAPER: usize = 20;
#[must_use]
pub fn clamped(self) -> Self {
Self {
depth: self.depth.min(Self::MAX_DEPTH),
total: self.total.min(Self::MAX_TOTAL),
per_paper: self.per_paper.min(Self::MAX_PER_PAPER),
}
}
}
impl Default for GraphCaps {
fn default() -> Self {
Self {
depth: Self::MAX_DEPTH,
total: Self::MAX_TOTAL,
per_paper: Self::MAX_PER_PAPER,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct GraphNode {
pub id: String,
pub depth: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct GraphEdge {
pub from: String,
pub to: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct GraphResult {
pub seed_work_id: String,
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
pub truncated: bool,
pub total_visited: usize,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum GraphError {
#[error("openalex source error: {0}")]
Source(#[from] FetchError),
#[error("provenance log error during graph expansion: {0}")]
Log(#[from] LogError),
#[error("seed DOI not indexed by OpenAlex (no `id` field in response)")]
SeedNotIndexed,
#[error("citation graph requires DOIGET_ENABLE_OPENALEX + --features metadata")]
CapabilityDenied,
}
pub async fn expand(
seed_doi: &Doi,
caps: GraphCaps,
source: &OpenalexSource,
profile: &CapabilityProfile,
ctx: &FetchContext,
) -> Result<GraphResult, GraphError> {
if !profile.metadata.openalex {
return Err(GraphError::CapabilityDenied);
}
let caps = caps.clamped();
let ref_ = Ref::Doi(seed_doi.clone());
let seed_result = source.fetch(&ref_, profile, ctx).await?;
let seed_work = seed_result
.metadata_json
.ok_or(GraphError::SeedNotIndexed)?;
let seed_work_id = extract_work_id(&seed_work).ok_or(GraphError::SeedNotIndexed)?;
let seed_refs = extract_referenced_works(&seed_work);
let mut nodes = vec![GraphNode {
id: seed_work_id.clone(),
depth: 0,
}];
let mut edges: Vec<GraphEdge> = Vec::new();
let mut visited: HashSet<String> = HashSet::new();
visited.insert(seed_work_id.clone());
let mut truncated = false;
let mut queue: VecDeque<(String, usize)> = VecDeque::new();
enqueue_children(
&seed_work_id,
&seed_refs,
1,
&caps,
&mut nodes,
&mut edges,
&mut visited,
&mut queue,
&mut truncated,
);
while let Some((work_id, depth)) = queue.pop_front() {
if depth >= caps.depth {
truncated = true;
continue;
}
if nodes.len() >= caps.total {
truncated = true;
break;
}
let work_url = match build_openalex_work_url(&work_id) {
Ok(u) => u,
Err(_) => {
truncated = true;
continue;
}
};
let body = match ctx.http.fetch_bytes("openalex", work_url).await {
Ok((b, _final)) => b,
Err(e) => {
let _ = ctx.log.append(RowInput {
event: LogEvent::Fetch,
result: LogResult::Err,
capability: Capability::Metadata,
ref_: Some(work_id.as_str()),
source: Some("openalex"),
error_code: None,
size_bytes: None,
license: None,
store_path: None,
canonical_digest: None,
});
tracing::warn!(work_id = %work_id, error = %e, "citation-graph step failed");
truncated = true;
continue;
}
};
ctx.log.append(RowInput {
event: LogEvent::Fetch,
result: LogResult::Ok,
capability: Capability::Metadata,
ref_: Some(work_id.as_str()),
source: Some("openalex"),
error_code: None,
size_bytes: Some(body.len() as u64),
license: None,
store_path: None,
canonical_digest: None,
})?;
let work: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(_) => {
truncated = true;
continue;
}
};
let refs = extract_referenced_works(&work);
enqueue_children(
&work_id,
&refs,
depth + 1,
&caps,
&mut nodes,
&mut edges,
&mut visited,
&mut queue,
&mut truncated,
);
}
Ok(GraphResult {
seed_work_id,
total_visited: nodes.len(),
nodes,
edges,
truncated,
})
}
#[allow(clippy::too_many_arguments)]
fn enqueue_children(
parent_id: &str,
refs: &[String],
child_depth: usize,
caps: &GraphCaps,
nodes: &mut Vec<GraphNode>,
edges: &mut Vec<GraphEdge>,
visited: &mut HashSet<String>,
queue: &mut VecDeque<(String, usize)>,
truncated: &mut bool,
) {
if refs.len() > caps.per_paper {
*truncated = true;
}
for child in refs.iter().take(caps.per_paper) {
if visited.contains(child) {
edges.push(GraphEdge {
from: parent_id.to_string(),
to: child.to_string(),
});
continue;
}
if nodes.len() >= caps.total {
*truncated = true;
return;
}
visited.insert(child.clone());
nodes.push(GraphNode {
id: child.clone(),
depth: child_depth,
});
edges.push(GraphEdge {
from: parent_id.to_string(),
to: child.clone(),
});
queue.push_back((child.clone(), child_depth));
}
}
fn extract_work_id(work: &serde_json::Value) -> Option<String> {
let full = work.get("id")?.as_str()?;
Some(full.rsplit('/').next().unwrap_or(full).to_string())
}
fn extract_referenced_works(work: &serde_json::Value) -> Vec<String> {
work.get("referenced_works")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|s| s.as_str())
.map(|s| s.rsplit('/').next().unwrap_or(s).to_string())
.collect()
})
.unwrap_or_default()
}
fn build_openalex_work_url(work_id: &str) -> Result<Url, FetchError> {
let base_str =
std::env::var("DOIGET_OPENALEX_BASE").unwrap_or_else(|_| "https://api.openalex.org".into());
let base = Url::parse(&base_str).map_err(|e| FetchError::SourceSchema {
hint: format!("openalex base URL invalid: {e}"),
})?;
base.join(&format!("/works/{}", work_id))
.map_err(|e| FetchError::SourceSchema {
hint: format!("openalex Work URL construction failed: {e}"),
})
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
use camino::Utf8PathBuf;
use tempfile::TempDir;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use crate::http::HttpClient;
use crate::provenance::ProvenanceLog;
use crate::rate_limiter::RateLimiter;
use crate::{CapabilityProfile, Doi, MetadataAccess, RateLimits};
const SEED_WORK: &str = r#"{
"id": "https://openalex.org/W0001",
"doi": "https://doi.org/10.1234/seed",
"display_name": "Seed Paper",
"referenced_works": [
"https://openalex.org/W0002",
"https://openalex.org/W0003"
]
}"#;
const HOP_W0002: &str = r#"{
"id": "https://openalex.org/W0002",
"referenced_works": ["https://openalex.org/W0004"]
}"#;
const HOP_W0003: &str = r#"{
"id": "https://openalex.org/W0003",
"referenced_works": []
}"#;
const HOP_W0004: &str = r#"{
"id": "https://openalex.org/W0004",
"referenced_works": []
}"#;
fn build_test_context(wiremock_host: &str) -> (TempDir, FetchContext) {
let td = TempDir::new().expect("tempdir");
let log_dir =
Utf8PathBuf::try_from(td.path().to_path_buf()).expect("temp dir path must be UTF-8");
let log_path = log_dir.join("test.jsonl");
let http = Arc::new(HttpClient::new_for_tests_allow_http(
"openalex",
wiremock_host,
));
let rate_limiter = Arc::new(RateLimiter::new(RateLimits::HARD_CODED));
let session_id = "01J0000000000000000000TEST".to_string();
let log = Arc::new(
ProvenanceLog::open(log_path, session_id.clone()).expect("provenance log opens"),
);
let ctx = FetchContext {
http,
rate_limiter,
log,
session_id,
};
(td, ctx)
}
fn profile_with_openalex_enabled() -> CapabilityProfile {
let mut p = CapabilityProfile::from_env().expect("clean env never errors");
p.metadata = MetadataAccess {
openalex: true,
semantic_scholar: false,
doaj: false,
};
p
}
#[tokio::test]
async fn caps_clamps_to_adr_0010_maxima() {
let huge = GraphCaps {
depth: 99,
total: 99999,
per_paper: 999,
};
let c = huge.clamped();
assert_eq!(c.depth, GraphCaps::MAX_DEPTH);
assert_eq!(c.total, GraphCaps::MAX_TOTAL);
assert_eq!(c.per_paper, GraphCaps::MAX_PER_PAPER);
let small = GraphCaps {
depth: 1,
total: 5,
per_paper: 2,
};
let c2 = small.clamped();
assert_eq!(c2.depth, 1);
assert_eq!(c2.total, 5);
assert_eq!(c2.per_paper, 2);
}
#[tokio::test]
#[serial_test::serial]
async fn expand_walks_depth_2_graph() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/works/10.1234/seed"))
.respond_with(ResponseTemplate::new(200).set_body_string(SEED_WORK))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/works/W0002"))
.respond_with(ResponseTemplate::new(200).set_body_string(HOP_W0002))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/works/W0003"))
.respond_with(ResponseTemplate::new(200).set_body_string(HOP_W0003))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/works/W0004"))
.respond_with(ResponseTemplate::new(200).set_body_string(HOP_W0004))
.mount(&server)
.await;
let prev = std::env::var("DOIGET_OPENALEX_BASE").ok();
std::env::set_var("DOIGET_OPENALEX_BASE", server.uri());
let (_td, ctx) = build_test_context(&server.uri());
let src = OpenalexSource::with_base(
Url::parse(&server.uri()).expect("wiremock URI parses"),
"doiget@localhost".to_string(),
);
let profile = profile_with_openalex_enabled();
let seed = Doi::parse("10.1234/seed").expect("DOI parses");
let result = expand(&seed, GraphCaps::default(), &src, &profile, &ctx).await;
match prev {
Some(v) => std::env::set_var("DOIGET_OPENALEX_BASE", v),
None => std::env::remove_var("DOIGET_OPENALEX_BASE"),
}
let result = result.expect("expand ok");
assert_eq!(result.seed_work_id, "W0001");
assert_eq!(result.total_visited, 4, "nodes: {:?}", result.nodes);
assert_eq!(result.nodes[0].id, "W0001");
assert_eq!(result.nodes[0].depth, 0);
assert_eq!(result.edges.len(), 3);
assert!(!result.truncated, "graph should be complete");
}
#[tokio::test]
async fn expand_without_capability_flag_errors() {
let (_td, ctx) = build_test_context("http://127.0.0.1:1");
let src = OpenalexSource::with_base(
Url::parse("http://127.0.0.1:1").expect("URI parses"),
"doiget@localhost".to_string(),
);
let profile = CapabilityProfile::from_env().expect("clean env never errors");
let seed = Doi::parse("10.1234/seed").expect("DOI parses");
let err = expand(&seed, GraphCaps::default(), &src, &profile, &ctx)
.await
.expect_err("missing openalex capability must error");
assert!(matches!(err, GraphError::CapabilityDenied));
}
}