use std::collections::HashMap;
use std::sync::Arc;
use tokio::test;
use mockforge_core::chain_execution::ChainExecutionEngine;
use mockforge_core::request_chaining::{
ChainConfig, ChainDefinition, ChainLink, ChainRequest, RequestBody, RequestChainRegistry,
};
fn create_auth_chain() -> ChainDefinition {
ChainDefinition {
id: "auth-chain-test".to_string(),
name: "Authentication Chain Test".to_string(),
description: Some("Test chain for authentication flow".to_string()),
config: ChainConfig {
enabled: true,
max_chain_length: 10,
global_timeout_secs: 30,
enable_parallel_execution: false,
},
links: vec![
ChainLink {
request: ChainRequest {
id: "login".to_string(),
method: "POST".to_string(),
url: "https://httpbin.org/post".to_string(),
headers: HashMap::from([
("Content-Type".to_string(), "application/json".to_string()),
("User-Agent".to_string(), "MockForge-Test".to_string()),
]),
body: Some(RequestBody::Json(serde_json::json!({
"username": "testuser",
"password": "testpass"
}))),
depends_on: vec![],
timeout_secs: Some(10),
expected_status: Some(vec![200, 201]),
scripting: None,
},
extract: HashMap::from([
("token".to_string(), "json.access_token".to_string()),
("user_id".to_string(), "json.user.id".to_string()),
]),
store_as: Some("login_response".to_string()),
},
ChainLink {
request: ChainRequest {
id: "get_profile".to_string(),
method: "GET".to_string(),
url: "https://httpbin.org/get".to_string(),
headers: HashMap::from([
(
"Authorization".to_string(),
"Bearer {{chain.login_response.json.access_token}}".to_string(),
),
(
"X-User-ID".to_string(),
"{{chain.login_response.json.user.id}}".to_string(),
),
]),
body: None,
depends_on: vec!["login".to_string()],
timeout_secs: Some(10),
expected_status: Some(vec![200]),
scripting: None,
},
extract: HashMap::from([("profile_name".to_string(), "json.name".to_string())]),
store_as: Some("profile_response".to_string()),
},
],
variables: HashMap::new(),
tags: vec!["test".to_string(), "integration".to_string()],
}
}
#[test]
async fn test_registry_basic_functionality() {
let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
let _engine = Arc::new(ChainExecutionEngine::new(registry.clone(), ChainConfig::default()));
let chains = registry.list_chains().await;
assert_eq!(chains.len(), 0);
let chain_definition = create_auth_chain();
let chain_yaml = serde_yaml::to_string(&chain_definition).unwrap();
let chain_id = registry.register_from_yaml(&chain_yaml).await.unwrap();
assert_eq!(chain_id, "auth-chain-test");
let chains = registry.list_chains().await;
assert_eq!(chains.len(), 1);
assert_eq!(chains[0], "auth-chain-test");
let retrieved_chain = registry.get_chain(&chain_id).await;
assert!(retrieved_chain.is_some());
assert_eq!(retrieved_chain.unwrap().name, "Authentication Chain Test");
}
#[test]
async fn test_chain_validation() {
let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
let valid_chain = create_auth_chain();
let result = registry.validate_chain(&valid_chain).await;
assert!(result.is_ok());
let mut invalid_chain = valid_chain.clone();
invalid_chain.id = "invalid-chain".to_string();
invalid_chain.links[0].request.depends_on = vec!["get_profile".to_string()];
invalid_chain.links[1].request.depends_on = vec!["login".to_string()];
let result = registry.validate_chain(&invalid_chain).await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Circular dependency") || error_msg.contains("circular dependency"));
let empty_chain = ChainDefinition {
id: "empty-chain".to_string(),
name: "Empty Chain".to_string(),
description: None,
config: ChainConfig::default(),
links: vec![],
variables: HashMap::new(),
tags: vec![],
};
let result = registry.validate_chain(&empty_chain).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("must have at least one link"));
}
#[test]
async fn test_chain_with_too_many_links() {
let registry = Arc::new(RequestChainRegistry::new(ChainConfig {
enabled: true,
max_chain_length: 5,
global_timeout_secs: 300,
enable_parallel_execution: false,
}));
let mut oversized_chain = create_auth_chain();
oversized_chain.id = "oversized-chain".to_string();
oversized_chain.config.max_chain_length = 10;
let extra_links: Vec<ChainLink> = (0..6)
.map(|i| ChainLink {
request: ChainRequest {
id: format!("extra_link_{}", i),
method: "GET".to_string(),
url: "https://httpbin.org/get".to_string(),
headers: HashMap::new(),
body: None,
depends_on: vec![],
timeout_secs: None,
expected_status: None,
scripting: None,
},
extract: HashMap::new(),
store_as: Some(format!("response_{}", i)),
})
.collect();
oversized_chain.links.extend(extra_links);
let result = registry.validate_chain(&oversized_chain).await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("exceeds maximum") || error_msg.contains("chain length"));
}
#[test]
async fn test_chain_dependency_resolution() {
let chain = create_auth_chain();
assert_eq!(chain.links.len(), 2);
assert!(chain.links[0].request.depends_on.is_empty());
assert_eq!(chain.links[1].request.depends_on, vec!["login"]);
let ids: std::collections::HashSet<String> =
chain.links.iter().map(|link| link.request.id.clone()).collect();
assert_eq!(ids.len(), chain.links.len());
}
#[test]
async fn test_chain_json_round_trip() {
let chain = create_auth_chain();
let json_str = serde_json::to_string(&chain).unwrap();
let deserialized: ChainDefinition = serde_json::from_str(&json_str).unwrap();
assert_eq!(deserialized.id, chain.id);
assert_eq!(deserialized.name, chain.name);
assert_eq!(deserialized.links.len(), chain.links.len());
for (original, parsed) in chain.links.iter().zip(deserialized.links.iter()) {
assert_eq!(original.request.id, parsed.request.id);
assert_eq!(original.request.method, parsed.request.method);
assert_eq!(original.request.url, parsed.request.url);
}
}
#[test]
async fn test_chain_yaml_round_trip() {
let chain = create_auth_chain();
let yaml_str = serde_yaml::to_string(&chain).unwrap();
let deserialized: ChainDefinition = serde_yaml::from_str(&yaml_str).unwrap();
assert_eq!(deserialized.id, chain.id);
assert_eq!(deserialized.name, chain.name);
assert_eq!(deserialized.links.len(), chain.links.len());
assert_eq!(deserialized.tags, vec!["test".to_string(), "integration".to_string()]);
}
#[test]
async fn test_chain_crud_operations() {
let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
let chain = create_auth_chain();
let chain_yaml = serde_yaml::to_string(&chain).unwrap();
let chain_id = registry.register_from_yaml(&chain_yaml).await.unwrap();
let retrieved = registry.get_chain(&chain_id).await.unwrap();
assert_eq!(retrieved.id, chain_id);
registry.remove_chain(&chain_id).await.unwrap();
let retrieved_after_delete = registry.get_chain(&chain_id).await;
assert!(retrieved_after_delete.is_none());
let chains = registry.list_chains().await;
assert!(chains.is_empty());
}
#[test]
async fn test_chain_with_parallel_execution() {
let registry = Arc::new(RequestChainRegistry::new(ChainConfig {
enabled: true,
max_chain_length: 20,
global_timeout_secs: 300,
enable_parallel_execution: true,
}));
let parallel_chain = ChainDefinition {
id: "parallel-chain".to_string(),
name: "Parallel Test Chain".to_string(),
description: Some("Chain with parallel execution".to_string()),
config: ChainConfig {
enabled: true,
max_chain_length: 20,
global_timeout_secs: 300,
enable_parallel_execution: true,
},
links: vec![
ChainLink {
request: ChainRequest {
id: "independent1".to_string(),
method: "GET".to_string(),
url: "https://httpbin.org/get".to_string(),
headers: HashMap::new(),
body: None,
depends_on: vec![],
timeout_secs: None,
expected_status: None,
scripting: None,
},
extract: HashMap::new(),
store_as: Some("response1".to_string()),
},
ChainLink {
request: ChainRequest {
id: "independent2".to_string(),
method: "GET".to_string(),
url: "https://httpbin.org/get".to_string(),
headers: HashMap::new(),
body: None,
depends_on: vec![],
timeout_secs: None,
expected_status: None,
scripting: None,
},
extract: HashMap::new(),
store_as: Some("response2".to_string()),
},
ChainLink {
request: ChainRequest {
id: "dependent".to_string(),
method: "GET".to_string(),
url: "https://httpbin.org/get".to_string(),
headers: HashMap::new(),
body: None,
depends_on: vec!["independent1".to_string()],
timeout_secs: None,
expected_status: None,
scripting: None,
},
extract: HashMap::new(),
store_as: Some("response3".to_string()),
},
],
variables: HashMap::new(),
tags: vec!["parallel".to_string()],
};
let result = registry.validate_chain(¶llel_chain).await;
assert!(result.is_ok(), "Parallel chain should be valid");
}
#[test]
async fn test_chain_with_complex_variables() {
let chain = ChainDefinition {
id: "complex-variables-chain".to_string(),
name: "Complex Variables Chain".to_string(),
description: None,
config: ChainConfig::default(),
links: vec![ChainLink {
request: ChainRequest {
id: "complex_request".to_string(),
method: "POST".to_string(),
url: "https://httpbin.org/post".to_string(),
headers: HashMap::from([
("Content-Type".to_string(), "application/json".to_string()),
("X-Custom".to_string(), "custom-value".to_string()),
]),
body: Some(RequestBody::Json(serde_json::json!({
"nested": {
"value": "{{faker.uuid}}",
"list": [1, 2, "{{faker.name}}"],
"timestamp": "{{now}}"
},
"int_value": "{{randInt 10 100}}",
"float_value": "{{rand.float}}"
}))),
depends_on: vec![],
timeout_secs: None,
expected_status: None,
scripting: None,
},
extract: HashMap::from([
("request_id".to_string(), "json.nested.value".to_string()),
("server_time".to_string(), "headers.Date".to_string()),
]),
store_as: Some("complex_response".to_string()),
}],
variables: HashMap::from([
("api_version".to_string(), serde_json::Value::String("v1".to_string())),
(
"base_url".to_string(),
serde_json::Value::String("https://httpbin.org".to_string()),
),
]),
tags: vec![],
};
assert_eq!(chain.links.len(), 1);
assert_eq!(chain.variables.len(), 2);
let request_body = &chain.links[0].request.body.as_ref().unwrap();
let RequestBody::Json(json_value) = request_body else {
panic!("Expected JSON body");
};
assert!(json_value.is_object());
let obj = json_value.as_object().unwrap();
assert!(obj.contains_key("nested"));
assert!(obj.contains_key("int_value"));
assert!(obj.contains_key("float_value"));
let nested = obj.get("nested").unwrap().as_object().unwrap();
assert!(nested.contains_key("value"));
assert!(nested.contains_key("list"));
assert!(nested.contains_key("timestamp"));
}
#[test]
async fn test_chain_with_multiple_extraction_patterns() {
let chain = ChainDefinition {
id: "extraction-test-chain".to_string(),
name: "Extraction Test Chain".to_string(),
description: None,
config: ChainConfig::default(),
links: vec![ChainLink {
request: ChainRequest {
id: "extraction_request".to_string(),
method: "GET".to_string(),
url: "https://httpbin.org/json".to_string(),
headers: HashMap::new(),
body: None,
depends_on: vec![],
timeout_secs: None,
expected_status: None,
scripting: None,
},
extract: HashMap::from([
("slideshow_title".to_string(), "slideshow.title".to_string()),
("first_slide_title".to_string(), "slideshow.slides.[0].title".to_string()),
("total_slides".to_string(), "slideshow.slides.*".to_string()),
]),
store_as: Some("extraction_response".to_string()),
}],
variables: HashMap::new(),
tags: vec!["extraction".to_string()],
};
assert_eq!(chain.links[0].extract.len(), 3);
assert!(chain.links[0].extract.contains_key("slideshow_title"));
assert!(chain.links[0].extract.contains_key("first_slide_title"));
assert!(chain.links[0].extract.contains_key("total_slides"));
}
#[test]
async fn test_chain_engine_creation() {
let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
let config = ChainConfig {
enabled: true,
max_chain_length: 10,
global_timeout_secs: 30,
enable_parallel_execution: true,
};
let _engine = Arc::new(ChainExecutionEngine::new(registry, config));
}