use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use async_trait::async_trait;
use openapiv3::{OpenAPI, Operation, Parameter, ReferenceOr};
use reqwest::Client;
use serde_json::{Value, json};
use tokio::sync::RwLock;
use tracing::{debug, info};
use crate::adapters::graphql_rate_limit::{
RateLimitConfig, RateLimitStrategy, RequestRateLimit, rate_limit_acquire,
};
use crate::adapters::rest_api::{RestApiAdapter, RestApiConfig};
use crate::domain::error::{Result, ServiceError, StygianError};
use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
type SpecCache = Arc<RwLock<HashMap<String, Arc<OpenAPI>>>>;
#[derive(Debug, Clone, Default)]
pub struct OpenApiConfig {
pub rest: RestApiConfig,
}
#[derive(Clone)]
pub struct OpenApiAdapter {
inner: RestApiAdapter,
spec_client: Client,
spec_cache: SpecCache,
rate_limit: Arc<OnceLock<RequestRateLimit>>,
}
impl OpenApiAdapter {
pub fn new() -> Self {
Self::with_config(OpenApiConfig::default())
}
pub fn with_config(config: OpenApiConfig) -> Self {
#[allow(clippy::expect_used)]
let spec_client = Client::builder()
.timeout(Duration::from_secs(30))
.use_rustls_tls()
.build()
.expect("TLS backend unavailable");
Self {
inner: RestApiAdapter::with_config(config.rest),
spec_client,
spec_cache: Arc::new(RwLock::new(HashMap::new())),
rate_limit: Arc::new(OnceLock::new()),
}
}
}
impl Default for OpenApiAdapter {
fn default() -> Self {
Self::new()
}
}
fn svc_err(msg: impl Into<String>) -> StygianError {
StygianError::from(ServiceError::Unavailable(msg.into()))
}
async fn fetch_spec(client: &Client, url: &str) -> Result<Arc<OpenAPI>> {
let body = client
.get(url)
.header(
"Accept",
"application/json, application/yaml, text/yaml, */*",
)
.send()
.await
.map_err(|e| svc_err(format!("spec fetch failed: {e}")))?
.text()
.await
.map_err(|e| svc_err(format!("spec read failed: {e}")))?;
let api: OpenAPI = serde_json::from_str(&body)
.or_else(|_| serde_yaml::from_str(&body))
.map_err(|e| svc_err(format!("spec parse failed: {e}")))?;
Ok(Arc::new(api))
}
async fn resolve_spec(cache: &SpecCache, client: &Client, url: &str) -> Result<Arc<OpenAPI>> {
{
let guard = cache.read().await;
if let Some(spec) = guard.get(url) {
debug!(url, "OpenAPI spec cache hit");
return Ok(Arc::clone(spec));
}
}
let spec = fetch_spec(client, url).await?;
{
let mut guard = cache.write().await;
guard
.entry(url.to_owned())
.or_insert_with(|| Arc::clone(&spec));
}
Ok(spec)
}
fn resolve_operation<'a>(
api: &'a OpenAPI,
operation_ref: &str,
) -> Result<(String, String, &'a Operation)> {
let method_path: Option<(String, &str)> = operation_ref
.split_once(' ')
.filter(|(m, _)| {
matches!(
m.to_uppercase().as_str(),
"GET" | "POST" | "PUT" | "PATCH" | "DELETE" | "HEAD" | "OPTIONS" | "TRACE"
)
})
.map(|(m, p)| (m.to_uppercase(), p));
for (path_str, path_item_ref) in &api.paths.paths {
let item = match path_item_ref {
ReferenceOr::Item(i) => i,
ReferenceOr::Reference { .. } => continue,
};
let ops: [(&str, Option<&Operation>); 8] = [
("GET", item.get.as_ref()),
("POST", item.post.as_ref()),
("PUT", item.put.as_ref()),
("PATCH", item.patch.as_ref()),
("DELETE", item.delete.as_ref()),
("HEAD", item.head.as_ref()),
("OPTIONS", item.options.as_ref()),
("TRACE", item.trace.as_ref()),
];
for (method, maybe_op) in ops {
let Some(op) = maybe_op else { continue };
let matched = match &method_path {
Some((target_method, target_path)) => {
method == target_method.as_str() && path_str == target_path
}
None => op.operation_id.as_deref() == Some(operation_ref),
};
if matched {
return Ok((method.to_owned(), path_str.clone(), op));
}
}
}
Err(svc_err(format!(
"operation '{operation_ref}' not found in spec"
)))
}
#[allow(clippy::indexing_slicing)]
fn resolve_server(api: &OpenAPI, server_override: &Value) -> String {
if let Some(url) = server_override.as_str().filter(|s| !s.is_empty()) {
return url.trim_end_matches('/').to_owned();
}
api.servers
.first()
.map(|s| s.url.trim_end_matches('/').to_owned())
.unwrap_or_default()
}
fn classify_params(op: &Operation) -> (Vec<String>, Vec<String>) {
let mut path_params: Vec<String> = Vec::new();
let mut query_params: Vec<String> = Vec::new();
for p_ref in &op.parameters {
let p = match p_ref {
ReferenceOr::Item(p) => p,
ReferenceOr::Reference { .. } => continue,
};
match p {
Parameter::Path { parameter_data, .. } => {
path_params.push(parameter_data.name.clone());
}
Parameter::Query { parameter_data, .. } => {
query_params.push(parameter_data.name.clone());
}
Parameter::Header { .. } | Parameter::Cookie { .. } => {}
}
}
(path_params, query_params)
}
fn build_url(server_url: &str, path_template: &str, args: &HashMap<String, Value>) -> String {
let mut url = format!("{server_url}{path_template}");
for (key, val) in args {
let placeholder = format!("{{{key}}}");
if url.contains(placeholder.as_str()) {
let replacement = val.as_str().map_or_else(|| val.to_string(), str::to_owned);
url = url.replace(placeholder.as_str(), &replacement);
}
}
url
}
#[allow(clippy::indexing_slicing)]
fn build_rest_params(
method: &str,
op: &Operation,
args: &HashMap<String, Value>,
path_param_names: &[String],
query_param_names: &[String],
auth_override: &Value,
) -> Value {
let query_obj: serde_json::Map<String, Value> = query_param_names
.iter()
.filter_map(|name| {
args.get(name.as_str()).map(|val| {
let s = val.as_str().map_or_else(|| val.to_string(), str::to_owned);
(name.clone(), Value::String(s))
})
})
.collect();
let body_value = if op.request_body.is_some() {
let excluded: std::collections::HashSet<&str> = path_param_names
.iter()
.chain(query_param_names.iter())
.map(String::as_str)
.collect();
let body_args: serde_json::Map<String, Value> = args
.iter()
.filter(|(k, _)| !excluded.contains(k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if body_args.is_empty() {
Value::Null
} else {
Value::Object(body_args)
}
} else {
Value::Null
};
let mut params = json!({
"method": method,
"query": Value::Object(query_obj),
});
if !body_value.is_null() {
params["body"] = body_value;
}
if !auth_override.is_null() {
params["auth"] = auth_override.clone();
}
params
}
#[allow(clippy::indexing_slicing)]
fn parse_rate_limit_config(rl: &Value) -> RateLimitConfig {
let strategy = match rl["strategy"].as_str().unwrap_or("sliding_window") {
"token_bucket" => RateLimitStrategy::TokenBucket,
_ => RateLimitStrategy::SlidingWindow,
};
RateLimitConfig {
max_requests: rl["max_requests"]
.as_u64()
.and_then(|value| u32::try_from(value).ok())
.unwrap_or(100),
window: Duration::from_secs(rl["window_secs"].as_u64().unwrap_or(60)),
max_delay_ms: rl["max_delay_ms"].as_u64().unwrap_or(30_000),
strategy,
}
}
#[async_trait]
impl ScrapingService for OpenApiAdapter {
#[allow(clippy::indexing_slicing)]
async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
let rl_params = &input.params["rate_limit"];
if !rl_params.is_null() {
let rl = self
.rate_limit
.get_or_init(|| RequestRateLimit::new(parse_rate_limit_config(rl_params)));
rate_limit_acquire(rl).await;
}
info!(url = %input.url, "OpenAPI adapter: execute");
let api = resolve_spec(&self.spec_cache, &self.spec_client, &input.url).await?;
let operation_ref = input.params["operation"]
.as_str()
.ok_or_else(|| svc_err("params.operation is required"))?;
let (method, path_template, op) = resolve_operation(&api, operation_ref)?;
let server_url = resolve_server(&api, &input.params["server"]["url"]);
let (path_param_names, query_param_names) = classify_params(op);
let args: HashMap<String, Value> = input.params["args"]
.as_object()
.map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
.unwrap_or_default();
let final_url = build_url(&server_url, &path_template, &args);
let rest_params = build_rest_params(
&method,
op,
&args,
&path_param_names,
&query_param_names,
&input.params["auth"],
);
debug!(
%final_url, %method, path_template, operation_ref,
"OpenAPI: delegating to RestApiAdapter"
);
let inner_output = self
.inner
.execute(ServiceInput {
url: final_url.clone(),
params: rest_params,
})
.await?;
let mut metadata = inner_output.metadata;
if let Value::Object(ref mut m) = metadata {
m.insert(
"openapi_spec_url".to_owned(),
Value::String(input.url.clone()),
);
m.insert(
"operation_id".to_owned(),
Value::String(operation_ref.to_owned()),
);
m.insert("method".to_owned(), Value::String(method));
m.insert("path_template".to_owned(), Value::String(path_template));
m.insert("server_url".to_owned(), Value::String(server_url));
m.insert("resolved_url".to_owned(), Value::String(final_url));
}
Ok(ServiceOutput {
data: inner_output.data,
metadata,
})
}
fn name(&self) -> &'static str {
"openapi"
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::panic,
clippy::indexing_slicing,
clippy::expect_used
)]
mod tests {
use super::*;
use serde_json::json;
use std::time::Duration;
const MINI_SPEC: &str = r#"{
"openapi": "3.0.0",
"info": { "title": "Mini Test API", "version": "1.0" },
"servers": [{ "url": "https://api.example.com/v1" }],
"paths": {
"/pets": {
"get": {
"operationId": "listPets",
"parameters": [
{ "name": "limit", "in": "query", "schema": { "type": "integer" } },
{ "name": "status", "in": "query", "schema": { "type": "string" } }
],
"responses": { "200": { "description": "OK" } }
}
},
"/pets/{petId}": {
"get": {
"operationId": "getPet",
"parameters": [
{ "name": "petId", "in": "path", "required": true, "schema": { "type": "integer" } }
],
"responses": { "200": { "description": "OK" } }
},
"delete": {
"operationId": "deletePet",
"parameters": [
{ "name": "petId", "in": "path", "required": true, "schema": { "type": "integer" } }
],
"responses": { "204": { "description": "No content" } }
}
},
"/pets/findByStatus": {
"get": {
"operationId": "findPetsByStatus",
"parameters": [
{ "name": "status", "in": "query", "schema": { "type": "string" } }
],
"responses": { "200": { "description": "OK" } }
}
}
},
"components": {
"securitySchemes": {
"apiKeyAuth": { "type": "apiKey", "in": "header", "name": "X-Api-Key" }
}
}
}"#;
fn parse_mini() -> Arc<OpenAPI> {
Arc::new(serde_json::from_str(MINI_SPEC).expect("MINI_SPEC is valid JSON"))
}
#[test]
fn parse_petstore_spec() {
let api = parse_mini();
assert_eq!(api.paths.paths.len(), 3, "spec has 3 paths");
assert!(api.components.is_some());
}
#[test]
fn resolve_operation_by_id() {
let api = parse_mini();
let (method, path, op) = resolve_operation(&api, "listPets").unwrap();
assert_eq!(method, "GET");
assert_eq!(path, "/pets");
assert_eq!(op.operation_id.as_deref(), Some("listPets"));
}
#[test]
fn resolve_operation_by_method_path() {
let api = parse_mini();
let (method, path, op) = resolve_operation(&api, "GET /pets/findByStatus").unwrap();
assert_eq!(method, "GET");
assert_eq!(path, "/pets/findByStatus");
assert_eq!(op.operation_id.as_deref(), Some("findPetsByStatus"));
}
#[test]
fn resolve_operation_not_found() {
let api = parse_mini();
assert!(resolve_operation(&api, "nonExistentOp").is_err());
}
#[test]
fn bind_path_params() {
let args: HashMap<String, Value> = HashMap::from([("petId".to_owned(), json!(42))]);
let url = build_url("https://api.example.com/v1", "/pets/{petId}", &args);
assert_eq!(url, "https://api.example.com/v1/pets/42");
}
#[test]
fn bind_path_params_string() {
let args: HashMap<String, Value> = HashMap::from([("petId".to_owned(), json!("fluffy"))]);
let url = build_url("https://api.example.com/v1", "/pets/{petId}", &args);
assert_eq!(url, "https://api.example.com/v1/pets/fluffy");
}
#[test]
fn bind_query_params() {
let api = parse_mini();
let (_, _, op) = resolve_operation(&api, "listPets").unwrap();
let (path_names, query_names) = classify_params(op);
assert!(path_names.is_empty());
assert!(query_names.contains(&"status".to_owned()));
assert!(query_names.contains(&"limit".to_owned()));
let args: HashMap<String, Value> = [
("status".to_owned(), json!("available")),
("limit".to_owned(), json!("10")),
]
.into_iter()
.collect();
let params = build_rest_params("GET", op, &args, &path_names, &query_names, &Value::Null);
assert_eq!(params["query"]["status"], json!("available"));
assert_eq!(params["query"]["limit"], json!("10"));
}
#[test]
fn server_override() {
let api = parse_mini();
let url = resolve_server(&api, &json!("https://override.example.com/v2/"));
assert_eq!(url, "https://override.example.com/v2");
let default_url = resolve_server(&api, &Value::Null);
assert_eq!(default_url, "https://api.example.com/v1");
}
#[tokio::test]
async fn spec_cache_hit() {
let cache: SpecCache = Arc::new(RwLock::new(HashMap::new()));
let api = parse_mini();
cache
.write()
.await
.insert("http://test/spec.json".to_owned(), Arc::clone(&api));
#[allow(clippy::expect_used)]
let dummy_client = Client::builder().use_rustls_tls().build().expect("client");
let returned = resolve_spec(&cache, &dummy_client, "http://test/spec.json")
.await
.unwrap();
assert!(Arc::ptr_eq(&api, &returned));
}
#[tokio::test]
async fn rate_limit_proactive() {
use crate::adapters::graphql_rate_limit::rate_limit_acquire;
use tokio::time::Instant;
let config = RateLimitConfig {
max_requests: 3,
window: Duration::from_secs(10),
max_delay_ms: 5_000,
strategy: RateLimitStrategy::SlidingWindow,
};
let rl = RequestRateLimit::new(config);
for _ in 0..3 {
rate_limit_acquire(&rl).await;
}
let start = Instant::now();
let config_short = RateLimitConfig {
max_requests: 1,
window: Duration::from_millis(50),
max_delay_ms: 200,
strategy: RateLimitStrategy::SlidingWindow,
};
let rl_short = RequestRateLimit::new(config_short);
rate_limit_acquire(&rl_short).await; rate_limit_acquire(&rl_short).await; let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(40),
"expected ≥40 ms delay but got {elapsed:?}"
);
}
#[test]
fn parse_rate_limit_config_token_bucket() {
let rl = json!({
"max_requests": 50,
"window_secs": 30,
"strategy": "token_bucket",
});
let cfg = parse_rate_limit_config(&rl);
assert_eq!(cfg.max_requests, 50);
assert_eq!(cfg.window, Duration::from_secs(30));
assert_eq!(cfg.strategy, RateLimitStrategy::TokenBucket);
}
#[test]
fn parse_rate_limit_config_defaults() {
let cfg = parse_rate_limit_config(&json!({}));
assert_eq!(cfg.max_requests, 100);
assert_eq!(cfg.window, Duration::from_secs(60));
assert_eq!(cfg.strategy, RateLimitStrategy::SlidingWindow);
}
#[test]
fn adapter_name() {
assert_eq!(OpenApiAdapter::new().name(), "openapi");
}
}