use reqwest::{Client, Method, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Error, Debug)]
pub enum RestConnectorError {
#[error("HTTP error: {status} - {message}")]
HttpError { status: u16, message: String },
#[error("Request failed: {0}")]
RequestFailed(String),
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Authentication failed: {0}")]
AuthenticationFailed(String),
#[error("Rate limit exceeded")]
RateLimitExceeded,
#[error("Timeout")]
Timeout,
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Response parsing error: {0}")]
ResponseParsingError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthConfig {
None,
Bearer { token: String },
ApiKey {
key: String,
value: String,
#[serde(default)]
in_header: bool,
},
Basic { username: String, password: String },
OAuth2 {
client_id: String,
client_secret: String,
token_url: String,
#[serde(default)]
scopes: Vec<String>,
},
Custom { header: String, value: String },
}
impl AuthConfig {
pub fn bearer(token: impl Into<String>) -> Self {
Self::Bearer {
token: token.into(),
}
}
pub fn api_key(key: impl Into<String>, value: impl Into<String>) -> Self {
Self::ApiKey {
key: key.into(),
value: value.into(),
in_header: true,
}
}
pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
Self::Basic {
username: username.into(),
password: password.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window_secs: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window_secs: 60,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub retry_status_codes: Vec<u16>,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
retry_status_codes: vec![408, 429, 500, 502, 503, 504],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestConfig {
pub base_url: String,
pub auth: AuthConfig,
#[serde(default)]
pub default_headers: HashMap<String, String>,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
pub rate_limit: Option<RateLimitConfig>,
pub retry: Option<RetryConfig>,
#[serde(default)]
pub enable_cache: bool,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_secs: u64,
}
fn default_timeout() -> u64 {
30
}
fn default_cache_ttl() -> u64 {
300
}
impl RestConfig {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
auth: AuthConfig::None,
default_headers: HashMap::new(),
timeout_secs: default_timeout(),
rate_limit: None,
retry: None,
enable_cache: false,
cache_ttl_secs: default_cache_ttl(),
}
}
pub fn with_auth(mut self, auth: AuthConfig) -> Self {
self.auth = auth;
self
}
pub fn with_timeout_secs(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
pub fn with_retry(mut self, max_retries: u32) -> Self {
self.retry = Some(RetryConfig {
max_retries,
..Default::default()
});
self
}
pub fn with_rate_limit(mut self, max_requests: u32, window_secs: u64) -> Self {
self.rate_limit = Some(RateLimitConfig {
max_requests,
window_secs,
});
self
}
pub fn with_cache(mut self, ttl_secs: u64) -> Self {
self.enable_cache = true;
self.cache_ttl_secs = ttl_secs;
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.default_headers.insert(key.into(), value.into());
self
}
}
struct CacheEntry {
response: Value,
expires_at: Instant,
}
struct RateLimiterState {
request_times: Vec<Instant>,
}
impl RateLimiterState {
fn new() -> Self {
Self {
request_times: Vec::new(),
}
}
fn can_make_request(&mut self, config: &RateLimitConfig) -> bool {
let now = Instant::now();
let window = Duration::from_secs(config.window_secs);
self.request_times
.retain(|t| now.duration_since(*t) < window);
self.request_times.len() < config.max_requests as usize
}
fn record_request(&mut self) {
self.request_times.push(Instant::now());
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Value,
pub response_time_ms: u64,
pub from_cache: bool,
}
impl RestResponse {
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
pub fn get(&self, path: &str) -> Option<&Value> {
let mut current = &self.body;
for part in path.split('.') {
current = current.get(part)?;
}
Some(current)
}
}
pub struct RestConnector {
config: RestConfig,
client: Client,
cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
rate_limiter: Arc<RwLock<RateLimiterState>>,
oauth_token: Arc<RwLock<Option<(String, Instant)>>>,
}
impl RestConnector {
pub fn new(config: RestConfig) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.unwrap_or_default();
Self {
config,
client,
cache: Arc::new(RwLock::new(HashMap::new())),
rate_limiter: Arc::new(RwLock::new(RateLimiterState::new())),
oauth_token: Arc::new(RwLock::new(None)),
}
}
pub async fn get(&self, path: &str) -> Result<RestResponse, RestConnectorError> {
self.request(Method::GET, path, None, None).await
}
pub async fn get_with_params(
&self,
path: &str,
params: HashMap<String, String>,
) -> Result<RestResponse, RestConnectorError> {
self.request(Method::GET, path, None, Some(params)).await
}
pub async fn post(&self, path: &str, body: Value) -> Result<RestResponse, RestConnectorError> {
self.request(Method::POST, path, Some(body), None).await
}
pub async fn put(&self, path: &str, body: Value) -> Result<RestResponse, RestConnectorError> {
self.request(Method::PUT, path, Some(body), None).await
}
pub async fn patch(&self, path: &str, body: Value) -> Result<RestResponse, RestConnectorError> {
self.request(Method::PATCH, path, Some(body), None).await
}
pub async fn delete(&self, path: &str) -> Result<RestResponse, RestConnectorError> {
self.request(Method::DELETE, path, None, None).await
}
pub async fn request(
&self,
method: Method,
path: &str,
body: Option<Value>,
query_params: Option<HashMap<String, String>>,
) -> Result<RestResponse, RestConnectorError> {
if method == Method::GET && self.config.enable_cache {
let cache_key = self.cache_key(path, &query_params);
let cache = self.cache.read().await;
if let Some(entry) = cache.get(&cache_key) {
if entry.expires_at > Instant::now() {
return Ok(RestResponse {
status: 200,
headers: HashMap::new(),
body: entry.response.clone(),
response_time_ms: 0,
from_cache: true,
});
}
}
}
if let Some(ref rate_limit) = self.config.rate_limit {
let mut limiter = self.rate_limiter.write().await;
if !limiter.can_make_request(rate_limit) {
return Err(RestConnectorError::RateLimitExceeded);
}
limiter.record_request();
}
let retry_config = self.config.retry.clone().unwrap_or_default();
let mut last_error = None;
let mut delay = retry_config.initial_delay_ms;
for attempt in 0..=retry_config.max_retries {
if attempt > 0 {
tokio::time::sleep(Duration::from_millis(delay)).await;
delay = (delay as f64 * retry_config.backoff_multiplier) as u64;
delay = delay.min(retry_config.max_delay_ms);
}
match self
.execute_request(&method, path, body.clone(), query_params.clone())
.await
{
Ok(response) => {
if method == Method::GET && self.config.enable_cache && response.is_success() {
let cache_key = self.cache_key(path, &query_params);
let mut cache = self.cache.write().await;
cache.insert(
cache_key,
CacheEntry {
response: response.body.clone(),
expires_at: Instant::now()
+ Duration::from_secs(self.config.cache_ttl_secs),
},
);
}
return Ok(response);
}
Err(e) => {
let should_retry = match &e {
RestConnectorError::HttpError { status, .. } => {
retry_config.retry_status_codes.contains(status)
}
RestConnectorError::Timeout => true,
_ => false,
};
if !should_retry || attempt == retry_config.max_retries {
return Err(e);
}
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| RestConnectorError::RequestFailed("Unknown error".into())))
}
async fn execute_request(
&self,
method: &Method,
path: &str,
body: Option<Value>,
query_params: Option<HashMap<String, String>>,
) -> Result<RestResponse, RestConnectorError> {
let url = format!("{}{}", self.config.base_url, path);
let start = Instant::now();
let mut request = self.client.request(method.clone(), &url);
for (key, value) in &self.config.default_headers {
request = request.header(key.as_str(), value.as_str());
}
request = self.apply_auth(request).await?;
if let Some(params) = query_params {
request = request.query(¶ms);
}
if let Some(body) = body {
request = request
.header("Content-Type", "application/json")
.json(&body);
}
let response = request
.send()
.await
.map_err(|e| RestConnectorError::RequestFailed(e.to_string()))?;
let status = response.status().as_u16();
let headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let body_text = response
.text()
.await
.map_err(|e| RestConnectorError::ResponseParsingError(e.to_string()))?;
let body: Value = if body_text.is_empty() {
Value::Null
} else {
serde_json::from_str(&body_text).unwrap_or_else(|_| json!({"raw": body_text}))
};
let response_time_ms = start.elapsed().as_millis() as u64;
if !StatusCode::from_u16(status)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
.is_success()
{
return Err(RestConnectorError::HttpError {
status,
message: body
.get("message")
.or_else(|| body.get("error"))
.and_then(|v| v.as_str())
.unwrap_or("Unknown error")
.to_string(),
});
}
Ok(RestResponse {
status,
headers,
body,
response_time_ms,
from_cache: false,
})
}
async fn apply_auth(
&self,
request: reqwest::RequestBuilder,
) -> Result<reqwest::RequestBuilder, RestConnectorError> {
match &self.config.auth {
AuthConfig::None => Ok(request),
AuthConfig::Bearer { token } => Ok(request.bearer_auth(token)),
AuthConfig::ApiKey {
key,
value,
in_header,
} => {
if *in_header {
Ok(request.header(key.as_str(), value.as_str()))
} else {
Ok(request.query(&[(key.as_str(), value.as_str())]))
}
}
AuthConfig::Basic { username, password } => {
Ok(request.basic_auth(username, Some(password)))
}
AuthConfig::OAuth2 {
client_id,
client_secret,
token_url,
scopes,
} => {
{
let token_guard = self.oauth_token.read().await;
if let Some((token, expires)) = token_guard.as_ref() {
if *expires > Instant::now() {
return Ok(request.bearer_auth(token));
}
}
}
let token_response = self
.client
.post(token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", client_id),
("client_secret", client_secret),
("scope", &scopes.join(" ")),
])
.send()
.await
.map_err(|e| {
RestConnectorError::AuthenticationFailed(format!(
"Failed to get OAuth2 token: {}",
e
))
})?;
let token_data: Value = token_response.json().await.map_err(|e| {
RestConnectorError::AuthenticationFailed(format!(
"Failed to parse token response: {}",
e
))
})?;
let access_token = token_data
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| {
RestConnectorError::AuthenticationFailed(
"No access_token in response".into(),
)
})?
.to_string();
let expires_in = token_data
.get("expires_in")
.and_then(|v| v.as_u64())
.unwrap_or(3600);
{
let mut token_guard = self.oauth_token.write().await;
*token_guard = Some((
access_token.clone(),
Instant::now() + Duration::from_secs(expires_in - 60), ));
}
Ok(request.bearer_auth(access_token))
}
AuthConfig::Custom { header, value } => {
Ok(request.header(header.as_str(), value.as_str()))
}
}
}
fn cache_key(&self, path: &str, params: &Option<HashMap<String, String>>) -> String {
let mut key = format!("{}:{}", self.config.base_url, path);
if let Some(params) = params {
let mut sorted_params: Vec<_> = params.iter().collect();
sorted_params.sort_by_key(|(k, _)| *k);
for (k, v) in sorted_params {
key.push_str(&format!("&{}={}", k, v));
}
}
key
}
pub async fn clear_cache(&self) {
let mut cache = self.cache.write().await;
cache.clear();
}
pub async fn cache_stats(&self) -> (usize, usize) {
let cache = self.cache.read().await;
let total = cache.len();
let valid = cache
.values()
.filter(|e| e.expires_at > Instant::now())
.count();
(total, valid)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestTemplate {
pub method: String,
pub path: String,
#[serde(default)]
pub headers: HashMap<String, String>,
pub body: Option<Value>,
#[serde(default)]
pub query_params: HashMap<String, String>,
}
impl RequestTemplate {
pub fn get(path: impl Into<String>) -> Self {
Self {
method: "GET".to_string(),
path: path.into(),
headers: HashMap::new(),
body: None,
query_params: HashMap::new(),
}
}
pub fn post(path: impl Into<String>, body: Value) -> Self {
Self {
method: "POST".to_string(),
path: path.into(),
headers: HashMap::new(),
body: Some(body),
query_params: HashMap::new(),
}
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.query_params.insert(key.into(), value.into());
self
}
pub fn render(
&self,
variables: &HashMap<String, Value>,
) -> (String, Option<Value>, HashMap<String, String>) {
let path = substitute_variables(&self.path, variables);
let body = self
.body
.as_ref()
.map(|b| substitute_json_variables(b, variables));
let query_params: HashMap<String, String> = self
.query_params
.iter()
.map(|(k, v)| (k.clone(), substitute_variables(v, variables)))
.collect();
(path, body, query_params)
}
}
fn substitute_variables(template: &str, variables: &HashMap<String, Value>) -> String {
let mut result = template.to_string();
for (key, value) in variables {
let placeholder = format!("{{{{{}}}}}", key);
let replacement = match value {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
_ => value.to_string(),
};
result = result.replace(&placeholder, &replacement);
}
result
}
fn substitute_json_variables(value: &Value, variables: &HashMap<String, Value>) -> Value {
match value {
Value::String(s) => {
let substituted = substitute_variables(s, variables);
if s.starts_with("{{") && s.ends_with("}}") {
if let Ok(parsed) = serde_json::from_str(&substituted) {
return parsed;
}
}
Value::String(substituted)
}
Value::Array(arr) => Value::Array(
arr.iter()
.map(|v| substitute_json_variables(v, variables))
.collect(),
),
Value::Object(obj) => Value::Object(
obj.iter()
.map(|(k, v)| (k.clone(), substitute_json_variables(v, variables)))
.collect(),
),
other => other.clone(),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphQLQuery {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub operation_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub variables: Option<HashMap<String, Value>>,
}
impl GraphQLQuery {
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
operation_name: None,
variables: None,
}
}
pub fn with_operation(mut self, name: impl Into<String>) -> Self {
self.operation_name = Some(name.into());
self
}
pub fn with_variables(mut self, variables: HashMap<String, Value>) -> Self {
self.variables = Some(variables);
self
}
pub fn with_variable(mut self, key: impl Into<String>, value: Value) -> Self {
self.variables
.get_or_insert_with(HashMap::new)
.insert(key.into(), value);
self
}
}
impl RestConnector {
pub async fn graphql(
&self,
endpoint: impl AsRef<str>,
query: GraphQLQuery,
) -> Result<RestResponse, RestConnectorError> {
let body = serde_json::to_value(&query)
.map_err(|e| RestConnectorError::SerializationError(e.to_string()))?;
self.post(endpoint.as_ref(), body).await
}
pub async fn graphql_mutation(
&self,
endpoint: impl AsRef<str>,
mutation: GraphQLQuery,
) -> Result<RestResponse, RestConnectorError> {
self.graphql(endpoint, mutation).await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout_secs: u64,
pub window_secs: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout_secs: 60,
window_secs: 60,
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
}
#[derive(Debug)]
struct CircuitBreakerState {
current_state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<Instant>,
last_state_change: Instant,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(RwLock::new(CircuitBreakerState {
current_state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_state_change: Instant::now(),
})),
}
}
pub async fn is_request_allowed(&self) -> bool {
let mut state = self.state.write().await;
match state.current_state {
CircuitState::Closed => true,
CircuitState::Open => {
let elapsed = state.last_state_change.elapsed().as_secs();
if elapsed >= self.config.timeout_secs {
state.current_state = CircuitState::HalfOpen;
state.success_count = 0;
state.last_state_change = Instant::now();
true
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
pub async fn record_success(&self) {
let mut state = self.state.write().await;
match state.current_state {
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
state.current_state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.last_state_change = Instant::now();
}
}
CircuitState::Closed => {
state.failure_count = 0;
}
CircuitState::Open => {}
}
}
pub async fn record_failure(&self) {
let mut state = self.state.write().await;
let now = Instant::now();
if let Some(last_failure) = state.last_failure_time {
if now.duration_since(last_failure).as_secs() > self.config.window_secs {
state.failure_count = 0;
}
}
state.last_failure_time = Some(now);
state.failure_count += 1;
match state.current_state {
CircuitState::Closed => {
if state.failure_count >= self.config.failure_threshold {
state.current_state = CircuitState::Open;
state.last_state_change = now;
}
}
CircuitState::HalfOpen => {
state.current_state = CircuitState::Open;
state.success_count = 0;
state.last_state_change = now;
}
CircuitState::Open => {}
}
}
pub async fn get_state(&self) -> CircuitState {
self.state.read().await.current_state
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
state.current_state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.last_failure_time = None;
state.last_state_change = Instant::now();
}
}
#[allow(dead_code)]
pub trait RequestInterceptor: Send + Sync {
fn intercept(
&self,
method: &Method,
url: &str,
headers: &mut HashMap<String, String>,
body: &mut Option<Value>,
) -> Result<(), RestConnectorError>;
}
#[allow(dead_code)]
pub trait ResponseInterceptor: Send + Sync {
fn intercept(&self, response: &mut RestResponse) -> Result<(), RestConnectorError>;
}
#[derive(Debug, Clone)]
pub struct LoggingInterceptor {
pub log_request: bool,
pub log_response: bool,
}
impl LoggingInterceptor {
pub fn new() -> Self {
Self {
log_request: true,
log_response: true,
}
}
}
impl Default for LoggingInterceptor {
fn default() -> Self {
Self::new()
}
}
impl RequestInterceptor for LoggingInterceptor {
fn intercept(
&self,
method: &Method,
url: &str,
_headers: &mut HashMap<String, String>,
body: &mut Option<Value>,
) -> Result<(), RestConnectorError> {
if self.log_request {
eprintln!("[REST] {} {}", method, url);
if let Some(b) = body {
eprintln!("[REST] Body: {}", b);
}
}
Ok(())
}
}
impl ResponseInterceptor for LoggingInterceptor {
fn intercept(&self, response: &mut RestResponse) -> Result<(), RestConnectorError> {
if self.log_response {
eprintln!(
"[REST] Response: {} ({}ms)",
response.status, response.response_time_ms
);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct HeaderInjectionInterceptor {
headers: HashMap<String, String>,
}
impl HeaderInjectionInterceptor {
pub fn new(headers: HashMap<String, String>) -> Self {
Self { headers }
}
}
impl RequestInterceptor for HeaderInjectionInterceptor {
fn intercept(
&self,
_method: &Method,
_url: &str,
headers: &mut HashMap<String, String>,
_body: &mut Option<Value>,
) -> Result<(), RestConnectorError> {
headers.extend(self.headers.clone());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rest_config_builder() {
let config = RestConfig::new("https://api.example.com")
.with_auth(AuthConfig::bearer("token123"))
.with_timeout_secs(60)
.with_retry(3)
.with_rate_limit(100, 60)
.with_cache(300);
assert_eq!(config.base_url, "https://api.example.com");
assert_eq!(config.timeout_secs, 60);
assert!(config.retry.is_some());
assert!(config.rate_limit.is_some());
assert!(config.enable_cache);
}
#[test]
fn test_auth_config_variants() {
let bearer = AuthConfig::bearer("token");
assert!(matches!(bearer, AuthConfig::Bearer { .. }));
let api_key = AuthConfig::api_key("X-API-Key", "secret");
assert!(matches!(api_key, AuthConfig::ApiKey { .. }));
let basic = AuthConfig::basic("user", "pass");
assert!(matches!(basic, AuthConfig::Basic { .. }));
}
#[test]
fn test_request_template() {
let template = RequestTemplate::get("/users/{{user_id}}").with_query("limit", "10");
let mut vars = HashMap::new();
vars.insert("user_id".to_string(), json!("123"));
let (path, _, _) = template.render(&vars);
assert_eq!(path, "/users/123");
}
#[test]
fn test_variable_substitution() {
let mut vars = HashMap::new();
vars.insert("name".to_string(), json!("John"));
vars.insert("age".to_string(), json!(30));
let template = "Hello {{name}}, you are {{age}} years old";
let result = substitute_variables(template, &vars);
assert_eq!(result, "Hello John, you are 30 years old");
}
#[test]
fn test_json_variable_substitution() {
let mut vars = HashMap::new();
vars.insert("name".to_string(), json!("John"));
vars.insert("age".to_string(), json!(30));
let body = json!({
"name": "{{name}}",
"age": "{{age}}",
"items": ["{{name}}", "test"]
});
let result = substitute_json_variables(&body, &vars);
assert_eq!(result["name"], "John");
assert_eq!(result["items"][0], "John");
}
#[test]
fn test_rate_limiter() {
let config = RateLimitConfig {
max_requests: 2,
window_secs: 60,
};
let mut limiter = RateLimiterState::new();
assert!(limiter.can_make_request(&config));
limiter.record_request();
assert!(limiter.can_make_request(&config));
limiter.record_request();
assert!(!limiter.can_make_request(&config));
}
#[test]
fn test_rest_response() {
let response = RestResponse {
status: 200,
headers: HashMap::new(),
body: json!({
"user": {
"name": "John",
"email": "john@example.com"
}
}),
response_time_ms: 100,
from_cache: false,
};
assert!(response.is_success());
assert_eq!(response.get("user.name").unwrap(), &json!("John"));
}
#[test]
fn test_cache_key_generation() {
let config = RestConfig::new("https://api.example.com");
let connector = RestConnector::new(config);
let key1 = connector.cache_key("/users", &None);
assert_eq!(key1, "https://api.example.com:/users");
let mut params = HashMap::new();
params.insert("page".to_string(), "1".to_string());
params.insert("limit".to_string(), "10".to_string());
let key2 = connector.cache_key("/users", &Some(params));
assert!(key2.contains("limit=10"));
assert!(key2.contains("page=1"));
}
#[test]
fn test_graphql_query_builder() {
let query = GraphQLQuery::new("query { users { id name } }")
.with_operation("GetUsers")
.with_variable("limit", json!(10));
assert_eq!(query.query, "query { users { id name } }");
assert_eq!(query.operation_name, Some("GetUsers".to_string()));
assert!(query.variables.is_some());
assert_eq!(query.variables.unwrap().get("limit"), Some(&json!(10)));
}
#[test]
fn test_graphql_query_with_variables() {
let mut vars = HashMap::new();
vars.insert("id".to_string(), json!("123"));
vars.insert("status".to_string(), json!("active"));
let query = GraphQLQuery::new(
"query($id: ID!, $status: String) { user(id: $id, status: $status) { name } }",
)
.with_variables(vars.clone());
assert!(query.variables.is_some());
assert_eq!(query.variables.as_ref().unwrap().len(), 2);
}
#[tokio::test]
async fn test_circuit_breaker_closed_state() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
success_threshold: 2,
timeout_secs: 60,
window_secs: 60,
};
let breaker = CircuitBreaker::new(config);
assert!(breaker.is_request_allowed().await);
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
success_threshold: 2,
timeout_secs: 1,
window_secs: 60,
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
assert!(!breaker.is_request_allowed().await);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout_secs: 1, window_secs: 60,
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(breaker.is_request_allowed().await);
assert_eq!(breaker.get_state().await, CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_breaker_closes_after_successes() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout_secs: 1,
window_secs: 60,
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(breaker.is_request_allowed().await);
breaker.record_success().await;
breaker.record_success().await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new(config);
for _ in 0..5 {
breaker.record_failure().await;
}
assert_eq!(breaker.get_state().await, CircuitState::Open);
breaker.reset().await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
assert!(breaker.is_request_allowed().await);
}
#[test]
fn test_logging_interceptor() {
let interceptor = LoggingInterceptor::new();
assert!(interceptor.log_request);
assert!(interceptor.log_response);
}
#[test]
fn test_header_injection_interceptor() {
let mut headers_to_inject = HashMap::new();
headers_to_inject.insert("X-Custom-Header".to_string(), "value123".to_string());
let interceptor = HeaderInjectionInterceptor::new(headers_to_inject);
let mut request_headers = HashMap::new();
let mut body = None;
interceptor
.intercept(
&Method::GET,
"https://example.com",
&mut request_headers,
&mut body,
)
.unwrap();
assert_eq!(
request_headers.get("X-Custom-Header"),
Some(&"value123".to_string())
);
}
}