use crate::{Error, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum HttpMethod {
#[default]
GET,
POST,
PUT,
PATCH,
DELETE,
}
impl std::fmt::Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpMethod::GET => write!(f, "GET"),
HttpMethod::POST => write!(f, "POST"),
HttpMethod::PUT => write!(f, "PUT"),
HttpMethod::PATCH => write!(f, "PATCH"),
HttpMethod::DELETE => write!(f, "DELETE"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RestFieldType {
String,
Int,
Float,
Boolean,
Object(String),
List(Box<RestFieldType>),
}
impl RestFieldType {
pub fn to_type_ref(&self) -> String {
match self {
RestFieldType::String => "String".to_string(),
RestFieldType::Int => "Int".to_string(),
RestFieldType::Float => "Float".to_string(),
RestFieldType::Boolean => "Boolean".to_string(),
RestFieldType::Object(name) => name.clone(),
RestFieldType::List(inner) => format!("[{}]", inner.to_type_ref()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestResponseField {
pub name: String,
pub field_type: RestFieldType,
pub nullable: bool,
pub description: Option<String>,
}
impl RestResponseField {
pub fn string(name: impl Into<String>) -> Self {
Self {
name: name.into(),
field_type: RestFieldType::String,
nullable: false,
description: None,
}
}
pub fn int(name: impl Into<String>) -> Self {
Self {
name: name.into(),
field_type: RestFieldType::Int,
nullable: false,
description: None,
}
}
pub fn float(name: impl Into<String>) -> Self {
Self {
name: name.into(),
field_type: RestFieldType::Float,
nullable: false,
description: None,
}
}
pub fn boolean(name: impl Into<String>) -> Self {
Self {
name: name.into(),
field_type: RestFieldType::Boolean,
nullable: false,
description: None,
}
}
pub fn object(name: impl Into<String>, type_name: impl Into<String>) -> Self {
Self {
name: name.into(),
field_type: RestFieldType::Object(type_name.into()),
nullable: false,
description: None,
}
}
pub fn list(name: impl Into<String>, item_type: RestFieldType) -> Self {
Self {
name: name.into(),
field_type: RestFieldType::List(Box::new(item_type)),
nullable: false,
description: None,
}
}
pub fn nullable(mut self) -> Self {
self.nullable = true;
self
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestResponseSchema {
pub type_name: String,
pub fields: Vec<RestResponseField>,
pub description: Option<String>,
}
impl RestResponseSchema {
pub fn new(type_name: impl Into<String>) -> Self {
Self {
type_name: type_name.into(),
fields: Vec::new(),
description: None,
}
}
pub fn field(mut self, field: RestResponseField) -> Self {
self.fields.push(field);
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestEndpoint {
pub name: String,
pub path: String,
pub method: HttpMethod,
pub body_template: Option<String>,
pub response_path: Option<String>,
pub headers: HashMap<String, String>,
pub query_params: HashMap<String, String>,
pub timeout: Option<Duration>,
pub is_mutation: Option<bool>,
pub description: Option<String>,
pub return_type: Option<String>,
pub response_schema: Option<RestResponseSchema>,
}
impl RestEndpoint {
pub fn new(name: impl Into<String>, path: impl Into<String>) -> Self {
Self {
name: name.into(),
path: path.into(),
method: HttpMethod::GET,
body_template: None,
response_path: None,
headers: HashMap::new(),
query_params: HashMap::new(),
timeout: None,
is_mutation: None,
description: None,
return_type: None,
response_schema: None,
}
}
pub fn method(mut self, method: HttpMethod) -> Self {
self.method = method;
self
}
pub fn body_template(mut self, template: impl Into<String>) -> Self {
self.body_template = Some(template.into());
self
}
pub fn response_path(mut self, path: impl Into<String>) -> Self {
self.response_path = Some(path.into());
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn query_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.query_params.insert(key.into(), value.into());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn as_mutation(mut self) -> Self {
self.is_mutation = Some(true);
self
}
pub fn as_query(mut self) -> Self {
self.is_mutation = Some(false);
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn return_type(mut self, type_name: impl Into<String>) -> Self {
self.return_type = Some(type_name.into());
self
}
pub fn with_response_schema(mut self, schema: RestResponseSchema) -> Self {
self.response_schema = Some(schema);
self
}
pub fn is_mutation(&self) -> bool {
self.is_mutation.unwrap_or({
matches!(
self.method,
HttpMethod::POST | HttpMethod::PUT | HttpMethod::PATCH | HttpMethod::DELETE
)
})
}
}
#[derive(Debug, Clone)]
pub struct RestConnectorConfig {
pub base_url: String,
pub timeout: Duration,
pub default_headers: HashMap<String, String>,
pub retry: RetryConfig,
pub log_bodies: bool,
}
impl Default for RestConnectorConfig {
fn default() -> Self {
Self {
base_url: String::new(),
timeout: Duration::from_secs(30),
default_headers: HashMap::new(),
retry: RetryConfig::default(),
log_bodies: false,
}
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub multiplier: f64,
pub retry_statuses: Vec<u16>,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
multiplier: 2.0,
retry_statuses: vec![429, 500, 502, 503, 504],
}
}
}
impl RetryConfig {
pub fn disabled() -> Self {
Self {
max_retries: 0,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
max_retries: 5,
initial_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(30),
multiplier: 2.0,
retry_statuses: vec![408, 429, 500, 502, 503, 504],
}
}
}
#[derive(Debug, Clone)]
pub struct RestResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: JsonValue,
pub duration: Duration,
}
impl RestResponse {
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
pub fn extract(&self, path: &str) -> Option<JsonValue> {
extract_json_path(&self.body, path)
}
}
#[async_trait]
pub trait ResponseTransformer: Send + Sync {
async fn transform(&self, endpoint: &str, response: RestResponse) -> Result<JsonValue>;
}
#[derive(Default)]
pub struct DefaultTransformer;
#[async_trait]
impl ResponseTransformer for DefaultTransformer {
async fn transform(&self, _endpoint: &str, response: RestResponse) -> Result<JsonValue> {
Ok(response.body)
}
}
#[async_trait]
pub trait RequestInterceptor: Send + Sync {
async fn intercept(&self, request: &mut RestRequest) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct RestRequest {
pub url: String,
pub method: HttpMethod,
pub headers: HashMap<String, String>,
pub body: Option<String>,
pub timeout: Duration,
}
pub struct RestConnector {
config: RestConnectorConfig,
endpoints: HashMap<String, RestEndpoint>,
client: reqwest::Client,
transformer: Arc<dyn ResponseTransformer>,
interceptors: Vec<Arc<dyn RequestInterceptor>>,
cache: Option<Arc<RwLock<RestCache>>>,
}
impl std::fmt::Debug for RestConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RestConnector")
.field("config", &self.config)
.field("endpoints", &self.endpoints.keys().collect::<Vec<_>>())
.finish()
}
}
impl Clone for RestConnector {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
endpoints: self.endpoints.clone(),
client: self.client.clone(),
transformer: self.transformer.clone(),
interceptors: self.interceptors.clone(),
cache: self.cache.clone(),
}
}
}
impl RestConnector {
pub fn builder() -> RestConnectorBuilder {
RestConnectorBuilder::default()
}
pub fn base_url(&self) -> &str {
&self.config.base_url
}
pub fn endpoints(&self) -> &HashMap<String, RestEndpoint> {
&self.endpoints
}
pub fn get_endpoint(&self, name: &str) -> Option<&RestEndpoint> {
self.endpoints.get(name)
}
#[instrument(skip(self, args), fields(endpoint = %endpoint_name))]
pub async fn execute(
&self,
endpoint_name: &str,
args: HashMap<String, JsonValue>,
) -> Result<JsonValue> {
let endpoint = self
.endpoints
.get(endpoint_name)
.ok_or_else(|| Error::Schema(format!("Unknown REST endpoint: {}", endpoint_name)))?;
let mut request = self.build_request(endpoint, &args)?;
for interceptor in &self.interceptors {
interceptor.intercept(&mut request).await?;
}
if endpoint.method == HttpMethod::GET {
if let Some(cache) = &self.cache {
let cache_key = format!("{}:{}", endpoint_name, serde_json::to_string(&args)?);
let cache_read = cache.read().await;
if let Some(cached) = cache_read.get(&cache_key) {
debug!("REST cache hit for {}", endpoint_name);
return Ok(cached.clone());
}
}
}
let response = self.execute_with_retry(&request, endpoint).await?;
let data = self.transformer.transform(endpoint_name, response).await?;
if endpoint.method == HttpMethod::GET {
if let Some(cache) = &self.cache {
let cache_key = format!("{}:{}", endpoint_name, serde_json::to_string(&args)?);
let mut cache_write = cache.write().await;
cache_write.insert(cache_key, data.clone());
}
}
Ok(data)
}
fn build_request(
&self,
endpoint: &RestEndpoint,
args: &HashMap<String, JsonValue>,
) -> Result<RestRequest> {
let mut path = endpoint.path.clone();
for (key, value) in args {
let placeholder = format!("{{{}}}", key);
if !path.contains(&placeholder) {
continue;
}
let value_str = json_value_to_string(value);
if value_str.contains("..")
|| value_str.contains("://")
|| value_str.contains('@')
|| value_str.contains('\0')
|| value_str.contains('\n')
|| value_str.contains('\r')
|| value_str.contains("%00")
|| value_str.contains("/../")
|| value_str.contains("/./")
{
return Err(Error::InvalidRequest(format!(
"Invalid characters in path parameter '{}': potential path traversal or URL injection",
key
)));
}
let safe_value = urlencoding::encode(&value_str);
path = path.replace(&placeholder, &safe_value);
}
let mut query_parts = Vec::new();
for (key, template) in &endpoint.query_params {
let mut value = template.clone();
for (arg_key, arg_value) in args {
let placeholder = format!("{{{}}}", arg_key);
let value_str = json_value_to_string(arg_value);
value = value.replace(&placeholder, &value_str);
}
if !value.contains('{') {
query_parts.push(format!("{}={}", key, urlencoding::encode(&value)));
}
}
let url = if query_parts.is_empty() {
format!("{}{}", self.config.base_url, path)
} else {
format!("{}{}?{}", self.config.base_url, path, query_parts.join("&"))
};
if !url.starts_with(&self.config.base_url) {
return Err(Error::InvalidRequest(
"URL manipulation detected: final URL does not match base URL".to_string(),
));
}
let mut headers = self.config.default_headers.clone();
headers.extend(endpoint.headers.clone());
let body = if let Some(ref template) = endpoint.body_template {
let mut body = template.clone();
for (key, value) in args {
let quoted_placeholder = format!("\"{{{}}}\"", key);
let json_repr = serde_json::to_string(value)
.unwrap_or_else(|_| serde_json::json!(null).to_string());
if body.contains("ed_placeholder) {
body = body.replace("ed_placeholder, &json_repr);
} else {
let bare_placeholder = format!("{{{}}}", key);
body = body.replace(&bare_placeholder, &json_repr);
}
}
Some(body)
} else if matches!(
endpoint.method,
HttpMethod::POST | HttpMethod::PUT | HttpMethod::PATCH
) {
Some(serde_json::to_string(args)?)
} else {
None
};
let timeout = endpoint.timeout.unwrap_or(self.config.timeout);
Ok(RestRequest {
url,
method: endpoint.method,
headers,
body,
timeout,
})
}
async fn execute_with_retry(
&self,
request: &RestRequest,
_endpoint: &RestEndpoint,
) -> Result<RestResponse> {
let retry_config = &self.config.retry;
let mut attempts = 0;
let mut backoff = retry_config.initial_backoff;
loop {
attempts += 1;
match self.execute_request(request).await {
Ok(response) => {
if response.is_success() {
return Ok(response);
}
if retry_config.retry_statuses.contains(&response.status)
&& attempts <= retry_config.max_retries
{
warn!(
"REST {} {} returned {}, retrying in {:?} (attempt {}/{})",
request.method,
request.url,
response.status,
backoff,
attempts,
retry_config.max_retries
);
tokio::time::sleep(backoff).await;
backoff = std::cmp::min(
Duration::from_secs_f64(
backoff.as_secs_f64() * retry_config.multiplier,
),
retry_config.max_backoff,
);
continue;
}
return Err(Error::Schema(format!(
"REST {} {} failed with status {}: {}",
request.method, request.url, response.status, response.body
)));
}
Err(e) => {
if attempts <= retry_config.max_retries {
warn!(
"REST {} {} failed: {}, retrying in {:?} (attempt {}/{})",
request.method,
request.url,
e,
backoff,
attempts,
retry_config.max_retries
);
tokio::time::sleep(backoff).await;
backoff = std::cmp::min(
Duration::from_secs_f64(
backoff.as_secs_f64() * retry_config.multiplier,
),
retry_config.max_backoff,
);
continue;
}
return Err(e);
}
}
}
}
async fn execute_request(&self, request: &RestRequest) -> Result<RestResponse> {
let start = std::time::Instant::now();
let mut req_builder = match request.method {
HttpMethod::GET => self.client.get(&request.url),
HttpMethod::POST => self.client.post(&request.url),
HttpMethod::PUT => self.client.put(&request.url),
HttpMethod::PATCH => self.client.patch(&request.url),
HttpMethod::DELETE => self.client.delete(&request.url),
};
for (key, value) in &request.headers {
req_builder = req_builder.header(key, value);
}
if let Some(ref body) = request.body {
req_builder = req_builder
.header("Content-Type", "application/json")
.body(body.clone());
}
req_builder = req_builder.timeout(request.timeout);
if self.config.log_bodies {
debug!(
"REST {} {} body={:?}",
request.method, request.url, request.body
);
}
let response = req_builder.send().await.map_err(|e| {
error!("REST request failed: {}", e);
Error::Schema(format!("REST request failed: {}", e))
})?;
let status = response.status().as_u16();
let headers: HashMap<String, String> = response
.headers()
.iter()
.filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
.collect();
let is_gbp = headers
.get("content-type")
.map(|v| v.contains("application/x-gbp"))
.unwrap_or(false);
let body: JsonValue = if is_gbp {
let bytes = response.bytes().await.map_err(|e| {
error!("Failed to read REST response bytes: {}", e);
Error::Schema(format!("Failed to read REST response: {}", e))
})?;
let mut decoder = crate::gbp::GbpDecoder::new();
decoder.decode_lz4(&bytes).map_err(|e| {
error!("Failed to decode GBP response: {}", e);
Error::Schema(format!("Failed to decode GBP response: {}", e))
})?
} else {
let body_text = response.text().await.map_err(|e| {
error!("Failed to read REST response: {}", e);
Error::Schema(format!("Failed to read REST response: {}", e))
})?;
serde_json::from_str(&body_text).unwrap_or(JsonValue::String(body_text))
};
let duration = start.elapsed();
if self.config.log_bodies {
debug!(
"REST {} {} -> {} ({:?}) body={:?}",
request.method, request.url, status, duration, body
);
} else {
debug!(
"REST {} {} -> {} ({:?})",
request.method, request.url, status, duration
);
}
Ok(RestResponse {
status,
headers,
body,
duration,
})
}
pub async fn clear_cache(&self) {
if let Some(cache) = &self.cache {
cache.write().await.clear();
}
}
}
#[derive(Default)]
pub struct RestConnectorBuilder {
config: RestConnectorConfig,
endpoints: HashMap<String, RestEndpoint>,
transformer: Option<Arc<dyn ResponseTransformer>>,
interceptors: Vec<Arc<dyn RequestInterceptor>>,
enable_cache: bool,
cache_size: usize,
custom_client: Option<reqwest::Client>,
}
impl RestConnectorBuilder {
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.config.base_url = url.into().trim_end_matches('/').to_string();
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout;
self
}
pub fn default_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.config.default_headers.insert(key.into(), value.into());
self
}
pub fn retry(mut self, retry: RetryConfig) -> Self {
self.config.retry = retry;
self
}
pub fn no_retry(mut self) -> Self {
self.config.retry = RetryConfig::disabled();
self
}
pub fn log_bodies(mut self, enabled: bool) -> Self {
self.config.log_bodies = enabled;
self
}
pub fn add_endpoint(mut self, endpoint: RestEndpoint) -> Self {
self.endpoints.insert(endpoint.name.clone(), endpoint);
self
}
pub fn add_endpoints(mut self, endpoints: Vec<RestEndpoint>) -> Self {
for endpoint in endpoints {
self.endpoints.insert(endpoint.name.clone(), endpoint);
}
self
}
pub fn transformer(mut self, transformer: Arc<dyn ResponseTransformer>) -> Self {
self.transformer = Some(transformer);
self
}
pub fn interceptor(mut self, interceptor: Arc<dyn RequestInterceptor>) -> Self {
self.interceptors.push(interceptor);
self
}
pub fn with_cache(mut self, max_entries: usize) -> Self {
self.enable_cache = true;
self.cache_size = max_entries;
self
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.custom_client = Some(client);
self
}
pub fn build(self) -> Result<RestConnector> {
if self.config.base_url.is_empty() {
return Err(Error::Schema("REST connector requires a base_url".into()));
}
let client = if let Some(client) = self.custom_client {
info!("REST connector using pre-configured HTTP client (mTLS)",);
client
} else {
reqwest::Client::builder()
.timeout(self.config.timeout)
.build()
.map_err(|e| Error::Schema(format!("Failed to create HTTP client: {}", e)))?
};
let cache = if self.enable_cache {
Some(Arc::new(RwLock::new(RestCache::new(self.cache_size))))
} else {
None
};
info!(
"REST connector configured with {} endpoints at {}",
self.endpoints.len(),
self.config.base_url
);
Ok(RestConnector {
config: self.config,
endpoints: self.endpoints,
client,
transformer: self
.transformer
.unwrap_or_else(|| Arc::new(DefaultTransformer)),
interceptors: self.interceptors,
cache,
})
}
}
struct RestCache {
entries: HashMap<String, JsonValue>,
order: Vec<String>,
max_size: usize,
}
impl RestCache {
fn new(max_size: usize) -> Self {
Self {
entries: HashMap::new(),
order: Vec::new(),
max_size,
}
}
fn get(&self, key: &str) -> Option<&JsonValue> {
self.entries.get(key)
}
fn insert(&mut self, key: String, value: JsonValue) {
if self.entries.contains_key(&key) {
self.order.retain(|k| k != &key);
self.order.push(key.clone());
self.entries.insert(key, value);
return;
}
while self.entries.len() >= self.max_size && !self.order.is_empty() {
if let Some(oldest) = self.order.first().cloned() {
self.order.remove(0);
self.entries.remove(&oldest);
}
}
self.entries.insert(key.clone(), value);
self.order.push(key);
}
fn clear(&mut self) {
self.entries.clear();
self.order.clear();
}
}
fn extract_json_path(value: &JsonValue, path: &str) -> Option<JsonValue> {
if path.is_empty() || path == "$" {
return Some(value.clone());
}
let path = path.strip_prefix("$.").or_else(|| path.strip_prefix("$"))?;
let mut current = value.clone();
for segment in path.split('.') {
if let Some(bracket_pos) = segment.find('[') {
let field_name = &segment[..bracket_pos];
let index_str = segment[bracket_pos + 1..].trim_end_matches(']');
if !field_name.is_empty() {
current = current.get(field_name)?.clone();
}
if let Ok(index) = index_str.parse::<usize>() {
current = current.get(index)?.clone();
}
} else if !segment.is_empty() {
current = current.get(segment)?.clone();
}
}
Some(current)
}
fn json_value_to_string(value: &JsonValue) -> String {
match value {
JsonValue::String(s) => s.clone(),
JsonValue::Number(n) => n.to_string(),
JsonValue::Bool(b) => b.to_string(),
JsonValue::Null => "null".to_string(),
_ => value.to_string(),
}
}
pub struct BearerAuthInterceptor {
token: Arc<RwLock<String>>,
}
impl BearerAuthInterceptor {
pub fn new(token: impl Into<String>) -> Self {
Self {
token: Arc::new(RwLock::new(token.into())),
}
}
pub async fn set_token(&self, token: impl Into<String>) {
*self.token.write().await = token.into();
}
}
#[async_trait]
impl RequestInterceptor for BearerAuthInterceptor {
async fn intercept(&self, request: &mut RestRequest) -> Result<()> {
let token = self.token.read().await;
request
.headers
.insert("Authorization".to_string(), format!("Bearer {}", *token));
Ok(())
}
}
pub struct ApiKeyInterceptor {
header_name: String,
api_key: String,
}
impl ApiKeyInterceptor {
pub fn new(header_name: impl Into<String>, api_key: impl Into<String>) -> Self {
Self {
header_name: header_name.into(),
api_key: api_key.into(),
}
}
pub fn x_api_key(api_key: impl Into<String>) -> Self {
Self::new("X-API-Key", api_key)
}
}
#[async_trait]
impl RequestInterceptor for ApiKeyInterceptor {
async fn intercept(&self, request: &mut RestRequest) -> Result<()> {
request
.headers
.insert(self.header_name.clone(), self.api_key.clone());
Ok(())
}
}
#[derive(Default)]
pub struct RestConnectorRegistry {
connectors: HashMap<String, Arc<RestConnector>>,
}
impl RestConnectorRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, name: impl Into<String>, connector: RestConnector) {
self.connectors.insert(name.into(), Arc::new(connector));
}
pub fn get(&self, name: &str) -> Option<Arc<RestConnector>> {
self.connectors.get(name).cloned()
}
pub fn names(&self) -> Vec<&String> {
self.connectors.keys().collect()
}
pub fn connectors(&self) -> &HashMap<String, Arc<RestConnector>> {
&self.connectors
}
pub fn is_empty(&self) -> bool {
self.connectors.is_empty()
}
pub async fn execute(
&self,
connector_name: &str,
endpoint_name: &str,
args: HashMap<String, JsonValue>,
) -> Result<JsonValue> {
let connector = self
.get(connector_name)
.ok_or_else(|| Error::Schema(format!("Unknown REST connector: {}", connector_name)))?;
connector.execute(endpoint_name, args).await
}
pub fn build_graphql_fields(&self) -> (Vec<RestGraphQLField>, Vec<RestGraphQLField>) {
let mut query_fields = Vec::new();
let mut mutation_fields = Vec::new();
for (connector_name, connector) in &self.connectors {
for (endpoint_name, endpoint) in connector.endpoints() {
let field_name = endpoint_name.clone();
let description = endpoint.description.clone().unwrap_or_else(|| {
format!(
"REST {} {}{}",
endpoint.method,
connector.base_url(),
endpoint.path
)
});
let path_params = extract_path_params(&endpoint.path);
let mut all_params = path_params;
for key in endpoint.query_params.keys() {
if !all_params.contains(key) {
all_params.push(key.clone());
}
}
if let Some(ref body_template) = endpoint.body_template {
let body_params = extract_template_params(body_template);
for param in body_params {
if !all_params.contains(¶m) {
all_params.push(param);
}
}
}
let field = RestGraphQLField {
name: field_name,
description,
parameters: all_params,
connector_name: connector_name.clone(),
endpoint_name: endpoint_name.clone(),
connector: connector.clone(),
response_schema: endpoint.response_schema.clone(),
};
if endpoint.is_mutation() {
mutation_fields.push(field);
} else {
query_fields.push(field);
}
}
}
(query_fields, mutation_fields)
}
}
#[derive(Clone)]
pub struct RestGraphQLField {
pub name: String,
pub description: String,
pub parameters: Vec<String>,
pub connector_name: String,
pub endpoint_name: String,
pub connector: Arc<RestConnector>,
pub response_schema: Option<RestResponseSchema>,
}
impl std::fmt::Debug for RestGraphQLField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RestGraphQLField")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &self.parameters)
.field("connector_name", &self.connector_name)
.field("endpoint_name", &self.endpoint_name)
.finish()
}
}
impl RestGraphQLField {
pub async fn execute(&self, args: HashMap<String, JsonValue>) -> Result<JsonValue> {
self.connector.execute(&self.endpoint_name, args).await
}
}
fn extract_path_params(path: &str) -> Vec<String> {
let mut params = Vec::new();
let mut in_param = false;
let mut param_name = String::new();
for c in path.chars() {
if c == '{' {
in_param = true;
param_name.clear();
} else if c == '}' {
if in_param && !param_name.is_empty() {
params.push(param_name.clone());
}
in_param = false;
} else if in_param {
param_name.push(c);
}
}
params
}
fn extract_template_params(template: &str) -> Vec<String> {
extract_path_params(template)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_json_path_simple() {
let json: JsonValue = serde_json::json!({
"data": {
"user": {
"id": "123",
"name": "Alice"
}
}
});
assert_eq!(
extract_json_path(&json, "$.data.user.id"),
Some(JsonValue::String("123".into()))
);
assert_eq!(
extract_json_path(&json, "$.data.user"),
Some(serde_json::json!({"id": "123", "name": "Alice"}))
);
}
#[test]
fn test_extract_json_path_array() {
let json: JsonValue = serde_json::json!({
"users": [
{"id": "1", "name": "Alice"},
{"id": "2", "name": "Bob"}
]
});
assert_eq!(
extract_json_path(&json, "$.users[0].name"),
Some(JsonValue::String("Alice".into()))
);
assert_eq!(
extract_json_path(&json, "$.users[1].id"),
Some(JsonValue::String("2".into()))
);
}
#[test]
fn test_endpoint_builder() {
let endpoint = RestEndpoint::new("getUser", "/users/{id}")
.method(HttpMethod::GET)
.header("Accept", "application/json")
.query_param("include", "profile")
.response_path("$.data")
.description("Get a user by ID");
assert_eq!(endpoint.name, "getUser");
assert_eq!(endpoint.path, "/users/{id}");
assert_eq!(endpoint.method, HttpMethod::GET);
assert!(!endpoint.is_mutation());
}
#[test]
fn test_post_endpoint_is_mutation() {
let endpoint = RestEndpoint::new("createUser", "/users").method(HttpMethod::POST);
assert!(endpoint.is_mutation());
}
#[test]
fn test_explicit_query_override() {
let endpoint = RestEndpoint::new("deleteUser", "/users/{id}")
.method(HttpMethod::DELETE)
.as_query();
assert!(!endpoint.is_mutation());
}
#[tokio::test]
async fn test_rest_cache() {
let mut cache = RestCache::new(2);
cache.insert("key1".to_string(), JsonValue::String("value1".into()));
cache.insert("key2".to_string(), JsonValue::String("value2".into()));
assert!(cache.get("key1").is_some());
assert!(cache.get("key2").is_some());
cache.insert("key3".to_string(), JsonValue::String("value3".into()));
assert!(cache.get("key1").is_none());
assert!(cache.get("key2").is_some());
assert!(cache.get("key3").is_some());
}
#[test]
fn test_retry_config_defaults() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.multiplier, 2.0);
assert!(config.retry_statuses.contains(&503));
}
#[test]
fn test_retry_config_disabled() {
let config = RetryConfig::disabled();
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_retry_config_aggressive() {
let config = RetryConfig::aggressive();
assert_eq!(config.max_retries, 5);
assert!(config.retry_statuses.contains(&429));
}
#[test]
fn test_rest_response_field_creation() {
let f = RestResponseField::string("name");
assert_eq!(f.name, "name");
assert_eq!(f.field_type, RestFieldType::String);
assert!(!f.nullable);
let f = RestResponseField::int("age").nullable();
assert_eq!(f.field_type, RestFieldType::Int);
assert!(f.nullable);
}
#[test]
fn test_rest_response_schema_builder() {
let schema = RestResponseSchema::new("User")
.field(RestResponseField::string("name"))
.field(RestResponseField::int("age"))
.description("A user");
assert_eq!(schema.type_name, "User");
assert_eq!(schema.fields.len(), 2);
assert_eq!(schema.description, Some("A user".to_string()));
}
#[test]
fn test_build_request_path_substitution() {
let connector = RestConnector::builder()
.base_url("http://api.com")
.build()
.unwrap();
let endpoint = RestEndpoint::new("test", "/users/{id}");
let mut args = HashMap::new();
args.insert("id".to_string(), serde_json::json!("123"));
let req = connector
.build_request(&endpoint, &args)
.expect("Should build");
assert_eq!(req.url, "http://api.com/users/123");
}
#[test]
fn test_build_request_query_params() {
let connector = RestConnector::builder()
.base_url("http://api.com")
.build()
.unwrap();
let endpoint = RestEndpoint::new("test", "/search")
.query_param("q", "{query}")
.query_param("limit", "10");
let mut args = HashMap::new();
args.insert("query".to_string(), serde_json::json!("hello world"));
let req = connector
.build_request(&endpoint, &args)
.expect("Should build");
assert!(req.url.contains("q=hello%20world"));
assert!(req.url.contains("limit=10"));
}
#[test]
fn test_build_request_body_template() {
let connector = RestConnector::builder()
.base_url("http://api.com")
.build()
.unwrap();
let endpoint = RestEndpoint::new("create", "/users")
.method(HttpMethod::POST)
.body_template(r#"{"name": "{name}"}"#);
let mut args = HashMap::new();
args.insert("name".to_string(), serde_json::json!("Alice"));
let req = connector
.build_request(&endpoint, &args)
.expect("Should build");
assert_eq!(req.body, Some(r#"{"name": "Alice"}"#.to_string()));
}
#[test]
fn test_build_request_security_path_traversal() {
let connector = RestConnector::builder()
.base_url("http://api.com")
.build()
.unwrap();
let endpoint = RestEndpoint::new("test", "/files/{path}");
let mut args = HashMap::new();
args.insert("path".to_string(), serde_json::json!("../etc/passwd"));
let err = connector.build_request(&endpoint, &args).unwrap_err();
assert!(err.to_string().contains("Invalid characters"));
}
#[test]
fn test_build_request_security_url_injection() {
let connector = RestConnector::builder()
.base_url("http://api.com")
.build()
.unwrap();
let endpoint = RestEndpoint::new("test", "/redirect/{url}");
let mut args = HashMap::new();
args.insert("url".to_string(), serde_json::json!("http://evil.com"));
let err = connector.build_request(&endpoint, &args).unwrap_err();
assert!(err.to_string().contains("Invalid characters"));
}
#[test]
fn test_build_request_headers() {
let connector = RestConnector::builder()
.base_url("http://api.com")
.default_header("X-Global", "global")
.build()
.unwrap();
let endpoint = RestEndpoint::new("test", "/test").header("X-Endpoint", "local");
let args = HashMap::new();
let req = connector.build_request(&endpoint, &args).unwrap();
assert_eq!(
req.headers.get("X-Global").map(|s| s.as_str()),
Some("global")
);
assert_eq!(
req.headers.get("X-Endpoint").map(|s| s.as_str()),
Some("local")
);
}
}