use tower::BoxError;
use crate::integration::common::IntegrationTest;
use crate::integration::common::graph_os_enabled;
use crate::integration::subscriptions::CALLBACK_CONFIG;
use crate::integration::subscriptions::CallbackTestState;
use crate::integration::subscriptions::start_callback_server;
use crate::integration::subscriptions::start_callback_subgraph_server;
use crate::integration::subscriptions::start_callback_subgraph_server_with_payloads;
#[tokio::test(flavor = "multi_thread")]
async fn test_subscription_callback() -> Result<(), BoxError> {
if !graph_os_enabled() {
eprintln!("test skipped");
return Ok(());
}
let nb_events = 3;
let interval_ms = 100;
let (callback_addr, callback_state) = start_callback_server().await;
let callback_url = format!("http://{}/callback", callback_addr);
let subgraph_server =
start_callback_subgraph_server(nb_events, interval_ms, callback_url.clone()).await;
let mut router = IntegrationTest::builder()
.supergraph("tests/integration/subscriptions/fixtures/supergraph.graphql")
.config(CALLBACK_CONFIG)
.build()
.await;
let callback_receiver_port = callback_addr.port();
let _callback_listener_port = router.reserve_address("CALLBACK_LISTENER_PORT");
router.set_address("CALLBACK_RECEIVER_PORT", callback_receiver_port);
router.set_address_from_uri("SUBGRAPH_PORT", &subgraph_server.uri());
router.start().await;
router.assert_started().await;
let subscription_query = r#"subscription { userWasCreated(intervalMs: 100, nbEvents: 3) { name reviews { body } } }"#;
let mut headers = std::collections::HashMap::new();
headers.insert(
"Accept".to_string(),
"multipart/mixed;subscriptionSpec=1.0".to_string(),
);
let query = crate::integration::common::Query::builder()
.body(serde_json::json!({
"query": subscription_query
}))
.headers(headers)
.build();
let (_trace_id, response) = router.execute_query(query).await;
assert!(
response.status().is_success(),
"Subscription request failed: {}",
response.status()
);
tokio::time::sleep(tokio::time::Duration::from_millis(
(nb_events as u64 * interval_ms) + 1000,
))
.await;
let expected_user_events = vec![
serde_json::json!({
"name": "User 1",
"reviews": [{
"body": "Review 1 from user 1"
}]
}),
serde_json::json!({
"name": "User 2",
"reviews": [{
"body": "Review 2 from user 2"
}]
}),
serde_json::json!({
"name": "User 3",
"reviews": [{
"body": "Review 3 from user 3"
}]
}),
];
verify_callback_events(&callback_state, expected_user_events).await?;
router.assert_no_error_logs();
Ok(())
}
async fn verify_callback_events(
callback_state: &CallbackTestState,
expected_user_events: Vec<serde_json::Value>,
) -> Result<(), BoxError> {
use pretty_assertions::assert_eq;
let callbacks = callback_state.received_callbacks.lock().unwrap().clone();
let next_callbacks: Vec<_> = callbacks.iter().filter(|c| c.action == "next").collect();
let complete_callbacks: Vec<_> = callbacks
.iter()
.filter(|c| c.action == "complete")
.collect();
assert_eq!(
complete_callbacks.len(),
1,
"Expected 1 'complete' callback, got {}. All callbacks: {:?}",
complete_callbacks.len(),
callbacks
);
let mut actual_user_events = Vec::new();
for callback in &next_callbacks {
if let Some(payload) = &callback.payload {
if let Some(data) = payload.get("data") {
if let Some(user_created) = data.get("userWasCreated") {
actual_user_events.push(user_created.clone());
}
}
}
}
assert_eq!(
actual_user_events, expected_user_events,
"Callback user events do not match expected events"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn test_subscription_callback_error_scenarios() -> Result<(), BoxError> {
if !graph_os_enabled() {
eprintln!("test skipped");
return Ok(());
}
let (callback_addr, callback_state) = start_callback_server().await;
let client = reqwest::Client::new();
let callback_url = format!("http://{}/callback/test-id", callback_addr);
let invalid_payload = serde_json::json!({
"kind": "subscription",
"action": "next"
});
let response = client
.post(&callback_url)
.json(&invalid_payload)
.send()
.await?;
assert_eq!(response.status(), 422, "Invalid payload should return 422");
let mismatched_payload = serde_json::json!({
"kind": "subscription",
"action": "next",
"id": "different-id",
"verifier": "test-verifier"
});
let response = client
.post(&callback_url)
.json(&mismatched_payload)
.send()
.await?;
assert_eq!(response.status(), 400, "ID mismatch should return 400");
let valid_payload = serde_json::json!({
"kind": "subscription",
"action": "check",
"id": "test-id",
"verifier": "test-verifier"
});
let response = client
.post(&callback_url)
.json(&valid_payload)
.send()
.await?;
assert_eq!(
response.status(),
404,
"Unknown subscription should return 404"
);
{
let mut ids = callback_state.subscription_ids.lock().unwrap();
ids.push("test-id".to_string());
}
let response = client
.post(&callback_url)
.json(&valid_payload)
.send()
.await?;
assert_eq!(response.status(), 204, "Valid check should return 204");
let heartbeat_payload = serde_json::json!({
"kind": "subscription",
"action": "heartbeat",
"id": "test-id",
"ids": ["test-id", "invalid-id"],
"verifier": "test-verifier"
});
let response = client
.post(&callback_url)
.json(&heartbeat_payload)
.send()
.await?;
assert_eq!(
response.status(),
404,
"Heartbeat with invalid IDs should return 404"
);
let valid_heartbeat_payload = serde_json::json!({
"kind": "subscription",
"action": "heartbeat",
"id": "test-id",
"ids": ["test-id"],
"verifier": "test-verifier"
});
let response = client
.post(&callback_url)
.json(&valid_heartbeat_payload)
.send()
.await?;
assert_eq!(response.status(), 204, "Valid heartbeat should return 204");
let complete_payload = serde_json::json!({
"kind": "subscription",
"action": "complete",
"id": "test-id",
"verifier": "test-verifier"
});
let response = client
.post(&callback_url)
.json(&complete_payload)
.send()
.await?;
assert_eq!(response.status(), 202, "Valid completion should return 202");
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn test_subscription_callback_error_payload() -> Result<(), BoxError> {
if !graph_os_enabled() {
eprintln!("test skipped");
return Ok(());
}
let interval_ms = 100;
let custom_payloads = vec![
serde_json::json!({
"data": {
"userWasCreated": {
"name": "User 1",
"reviews": [{
"body": "Review 1 from user 1"
}]
}
}
}),
serde_json::json!({
"data": {
"userWasCreated": {
"name": "User 2"
}
},
"errors": []
}),
];
let (callback_addr, callback_state) = start_callback_server().await;
let callback_url = format!("http://{}/callback", callback_addr);
let subgraph_server = start_callback_subgraph_server_with_payloads(
custom_payloads.clone(),
interval_ms,
callback_url.clone(),
)
.await;
let mut router = IntegrationTest::builder()
.supergraph("tests/integration/subscriptions/fixtures/supergraph.graphql")
.config(CALLBACK_CONFIG)
.build()
.await;
let callback_receiver_port = callback_addr.port();
let _callback_listener_port = router.reserve_address("CALLBACK_LISTENER_PORT");
router.set_address("CALLBACK_RECEIVER_PORT", callback_receiver_port);
router.set_address_from_uri("SUBGRAPH_PORT", &subgraph_server.uri());
router.start().await;
router.assert_started().await;
let subscription_query = r#"subscription { userWasCreated(intervalMs: 100, nbEvents: 2) { name reviews { body } } }"#;
let mut headers = std::collections::HashMap::new();
headers.insert(
"Accept".to_string(),
"multipart/mixed;subscriptionSpec=1.0".to_string(),
);
let query = crate::integration::common::Query::builder()
.body(serde_json::json!({
"query": subscription_query
}))
.headers(headers)
.build();
let (_trace_id, response) = router.execute_query(query).await;
assert!(
response.status().is_success(),
"Subscription request failed: {}",
response.status()
);
tokio::time::sleep(tokio::time::Duration::from_millis(
(custom_payloads.len() as u64 * interval_ms) + 1000,
))
.await;
let expected_user_events = vec![
serde_json::json!({
"name": "User 1",
"reviews": [{
"body": "Review 1 from user 1"
}]
}),
serde_json::json!({
"name": "User 2"
}),
];
verify_callback_events(&callback_state, expected_user_events).await?;
router.assert_no_error_logs();
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn test_subscription_callback_pure_error_payload() -> Result<(), BoxError> {
if !graph_os_enabled() {
eprintln!("test skipped");
return Ok(());
}
let interval_ms = 100;
let custom_payloads = vec![
serde_json::json!({
"data": {
"userWasCreated": {
"name": "User 1",
"reviews": [{
"body": "Review 1 from user 1"
}]
}
}
}),
serde_json::json!({
"errors": []
}),
];
let (callback_addr, callback_state) = start_callback_server().await;
let callback_url = format!("http://{}/callback", callback_addr);
let subgraph_server = start_callback_subgraph_server_with_payloads(
custom_payloads.clone(),
interval_ms,
callback_url.clone(),
)
.await;
let mut router = IntegrationTest::builder()
.supergraph("tests/integration/subscriptions/fixtures/supergraph.graphql")
.config(CALLBACK_CONFIG)
.build()
.await;
let callback_receiver_port = callback_addr.port();
let _callback_listener_port = router.reserve_address("CALLBACK_LISTENER_PORT");
router.set_address("CALLBACK_RECEIVER_PORT", callback_receiver_port);
router.set_address_from_uri("SUBGRAPH_PORT", &subgraph_server.uri());
router.start().await;
router.assert_started().await;
let subscription_query = r#"subscription { userWasCreated(intervalMs: 100, nbEvents: 2) { name reviews { body } } }"#;
let mut headers = std::collections::HashMap::new();
headers.insert(
"Accept".to_string(),
"multipart/mixed;subscriptionSpec=1.0".to_string(),
);
let query = crate::integration::common::Query::builder()
.body(serde_json::json!({
"query": subscription_query
}))
.headers(headers)
.build();
let (_trace_id, response) = router.execute_query(query).await;
assert!(
response.status().is_success(),
"Subscription request failed: {}",
response.status()
);
tokio::time::sleep(tokio::time::Duration::from_millis(
(custom_payloads.len() as u64 * interval_ms) + 1000,
))
.await;
let expected_user_events = vec![
serde_json::json!({
"name": "User 1",
"reviews": [{
"body": "Review 1 from user 1"
}]
}),
];
verify_callback_events(&callback_state, expected_user_events).await?;
router.assert_no_error_logs();
Ok(())
}