use std::io::{Error, ErrorKind, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::cell::RefCell;
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::{Duration, Instant};
use futures::{
future::{FutureExt, LocalBoxFuture},
};
use https::{HeaderMap, StatusCode};
use pi_atom::Atom;
use pi_async_rt::rt::{serial::AsyncRuntimeBuilder, startup_global_time_loop};
use pi_gray::GrayVersion;
use pi_handler::{Args, Handler, SGenType};
use pi_hash::XHashMap;
use tcp::{
connect::TcpSocket,
server::{PortsAdapterFactory, SocketListener},
SocketConfig,
};
use pi_http::{
gateway::GatewayContext,
middleware::MiddlewareChain,
port::HttpPort,
route::HttpRoute,
server::HttpListenerFactory,
response::ResponseHandler,
sse::{
write_sse_accept_headers, write_sse_reject_headers, SseAcceptDecision, SseConfig,
SseEvent, SseHub, SseMiddleware, SSE_PARAM_CONNECTION_ID, SSE_PARAM_NONCE,
},
virtual_host::{VirtualHost, VirtualHostPool, VirtualHostTab},
};
#[derive(Clone, Copy)]
enum SseNetworkScenario {
SingleEvent,
Reject,
SameThreadOrder,
CrossThreadControlledOrder,
Heartbeat,
}
#[derive(Clone, Copy)]
enum SsePortNetworkScenario {
Accept,
Reject,
}
struct SsePortDecisionHandler {
scenario: SsePortNetworkScenario,
}
impl Handler for SsePortDecisionHandler {
type A = SocketAddr;
type B = String;
type C = Arc<HeaderMap>;
type D = Arc<RefCell<XHashMap<String, SGenType>>>;
type E = ResponseHandler;
type F = ();
type G = ();
type H = ();
type HandleResult = ();
fn handle(
&self,
_env: Arc<dyn GrayVersion>,
_topic: Atom,
args: Args<Self::A, Self::B, Self::C, Self::D, Self::E, Self::F, Self::G, Self::H>,
) -> LocalBoxFuture<'static, Self::HandleResult> {
let scenario = self.scenario;
async move {
if let Args::FiveArgs(_addr, _method, _headers, params, response) = args {
let (id, nonce) = {
let params = params.borrow();
let id = match params.get(SSE_PARAM_CONNECTION_ID) {
Some(SGenType::Str(value)) => value.clone(),
_ => panic!("port handler must receive SSE connection id"),
};
let nonce = match params.get(SSE_PARAM_NONCE) {
Some(SGenType::Str(value)) => value.clone(),
_ => panic!("port handler must receive SSE nonce"),
};
(id, nonce)
};
match scenario {
SsePortNetworkScenario::Accept => {
write_sse_accept_headers(&response, id.as_str(), &nonce, "port-session-a")
.expect("port accept handshake headers must be valid");
}
SsePortNetworkScenario::Reject => {
write_sse_reject_headers(
&response,
id.as_str(),
&nonce,
StatusCode::UNAUTHORIZED,
"port rejected sse",
)
.expect("port reject handshake headers must be valid");
}
}
response
.finish()
.await
.expect("port decision response must finish");
}
}
.boxed_local()
}
}
fn send_ordered_event(
hub: &SseHub<String, TcpSocket>,
id: pi_http::sse::SseConnectionId,
event_id: &str,
event_name: &str,
data: &str,
) {
let event = SseEvent::builder()
.id(event_id)
.event(event_name)
.data(data)
.build()
.expect("SSE ordered test event must be valid");
let report = hub.try_send_to_id(id, event);
assert_eq!(report.total, 1);
assert_eq!(report.sent, 1);
}
fn reserve_local_port() -> u16 {
TcpListener::bind("127.0.0.1:0")
.expect("test must reserve a local TCP port")
.local_addr()
.expect("test listener must expose local addr")
.port()
}
fn start_sse_server(
scenario: SseNetworkScenario,
) -> (
SocketListener<TcpSocket, PortsAdapterFactory<TcpSocket>>,
SocketAddr,
) {
let port = reserve_local_port();
let addr: SocketAddr = format!("127.0.0.1:{}", port)
.parse()
.expect("test addr must parse");
let hub = SseHub::<String, TcpSocket>::builder().build();
let accept_scenario = scenario;
let hub_for_open = hub.clone();
let heartbeat_interval = if matches!(scenario, SseNetworkScenario::Heartbeat) {
10
} else {
0
};
let middleware_builder = SseMiddleware::with_acceptor(hub.clone(), move |accept| {
assert_eq!(accept.request.url().path(), "/sse");
if matches!(accept_scenario, SseNetworkScenario::Reject) {
Ok(SseAcceptDecision::reject(
StatusCode::FORBIDDEN,
"sse rejected by acceptor",
))
} else {
Ok(SseAcceptDecision::accept("client-a".to_string()))
}
})
.config(
SseConfig::builder()
.channel_size(8)
.heartbeat_interval_ms(heartbeat_interval)
.send_initial_comment(false)
.build()
.expect("SSE test config must be valid"),
)
.on_open(move |open| {
assert_eq!(open.key, "client-a");
let hub_for_thread = hub_for_open.clone();
let id = open.id;
thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
match scenario {
SseNetworkScenario::SingleEvent => {
send_ordered_event(
&hub_for_thread,
id,
"evt-1",
"notice",
"hello from pi_http",
);
hub_for_thread
.try_close(id)
.expect("SSE test connection must close explicitly");
}
SseNetworkScenario::SameThreadOrder => {
send_ordered_event(&hub_for_thread, id, "evt-1", "notice", "first");
send_ordered_event(&hub_for_thread, id, "evt-2", "notice", "second");
hub_for_thread
.try_close(id)
.expect("SSE ordered test connection must close explicitly");
}
SseNetworkScenario::CrossThreadControlledOrder => {
let (first_done_sender, first_done_receiver) = mpsc::channel();
let first_hub = hub_for_thread.clone();
let second_hub = hub_for_thread.clone();
let first = thread::spawn(move || {
send_ordered_event(&first_hub, id, "evt-1", "notice", "first");
first_done_sender
.send(())
.expect("first sender must notify second sender");
});
let second = thread::spawn(move || {
first_done_receiver
.recv()
.expect("second sender must wait for first enqueue");
send_ordered_event(&second_hub, id, "evt-2", "notice", "second");
second_hub.try_close(id).expect(
"SSE cross-thread ordered test connection must close explicitly",
);
});
first.join().expect("first SSE sender thread must finish");
second.join().expect("second SSE sender thread must finish");
}
SseNetworkScenario::Heartbeat => {
thread::sleep(Duration::from_millis(80));
hub_for_thread
.try_close(id)
.expect("SSE heartbeat test connection must close explicitly");
}
SseNetworkScenario::Reject => {}
}
});
Ok(())
});
let middleware_builder = if matches!(scenario, SseNetworkScenario::Heartbeat) {
middleware_builder.heartbeat_runtime(
pi_async_rt::rt::AsyncRuntimeBuilder::default_multi_thread(
Some("sse-real-heartbeat"),
None,
Some(1),
Some(1),
),
)
} else {
middleware_builder
};
let middleware = middleware_builder
.build()
.expect("SSE real-network middleware must build");
let mut route = HttpRoute::<TcpSocket, GatewayContext, SseMiddleware<String, TcpSocket>>::new();
route.at("/sse").get(middleware);
let host = VirtualHost::with(route);
let mut hosts = VirtualHostTab::<TcpSocket, SseMiddleware<String, TcpSocket>>::new();
hosts
.add_default(host)
.expect("test virtual host must register");
let mut factory = PortsAdapterFactory::<TcpSocket>::new();
factory.bind(
port,
HttpListenerFactory::<TcpSocket, _>::with_hosts(hosts, 5000).new_service(),
);
let rt = AsyncRuntimeBuilder::default_local_thread(None, None);
let mut config = SocketConfig::new("127.0.0.1", &[port]);
config.set_option(16 * 1024, 16 * 1024, 16 * 1024, 16);
let listener = SocketListener::try_bind(
vec![rt],
factory,
config,
64,
1024 * 1024,
128,
8,
16 * 1024,
16 * 1024,
Some(10),
)
.expect("test SSE server must bind");
(listener, addr)
}
fn start_sse_port_server(
scenario: SsePortNetworkScenario,
) -> (
SocketListener<TcpSocket, PortsAdapterFactory<TcpSocket>>,
SocketAddr,
) {
let port = reserve_local_port();
let addr: SocketAddr = format!("127.0.0.1:{}", port)
.parse()
.expect("test addr must parse");
let hub = SseHub::<String, TcpSocket>::builder().build();
let hub_for_open = hub.clone();
let middleware = SseMiddleware::with_acceptor(hub.clone(), |_accept| {
unreachable!("port/params handshake must not call direct acceptor")
})
.config(
SseConfig::builder()
.channel_size(8)
.heartbeat_interval_ms(0)
.send_initial_comment(false)
.build()
.expect("SSE port test config must be valid"),
)
.port_handshake_string_key()
.on_open(move |open| {
assert_eq!(open.key, "port-session-a");
let hub_for_thread = hub_for_open.clone();
let id = open.id;
thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
send_ordered_event(
&hub_for_thread,
id,
"evt-port-1",
"session",
"port session ready",
);
hub_for_thread
.try_close(id)
.expect("SSE port test connection must close explicitly");
});
Ok(())
})
.build()
.expect("SSE port middleware must build");
let port_handler = HttpPort::with_handler(
None,
Arc::new(SsePortDecisionHandler {
scenario,
}),
);
let mut chain = MiddlewareChain::<TcpSocket, GatewayContext>::new();
chain.push_back(Arc::new(middleware));
chain.push_back(Arc::new(port_handler));
chain.finish();
let chain = Arc::new(chain);
let mut route =
HttpRoute::<TcpSocket, GatewayContext, Arc<MiddlewareChain<TcpSocket, GatewayContext>>>::new();
route.at("/sse").get(chain);
let host = VirtualHost::with(route);
let mut hosts = VirtualHostTab::<TcpSocket, Arc<MiddlewareChain<TcpSocket, GatewayContext>>>::new();
hosts
.add_default(host)
.expect("test virtual host must register");
let mut factory = PortsAdapterFactory::<TcpSocket>::new();
factory.bind(
port,
HttpListenerFactory::<TcpSocket, _>::with_hosts(hosts, 5000).new_service(),
);
let rt = AsyncRuntimeBuilder::default_local_thread(None, None);
let mut config = SocketConfig::new("127.0.0.1", &[port]);
config.set_option(16 * 1024, 16 * 1024, 16 * 1024, 16);
let listener = SocketListener::try_bind(
vec![rt],
factory,
config,
64,
1024 * 1024,
128,
8,
16 * 1024,
16 * 1024,
Some(10),
)
.expect("test SSE port server must bind");
(listener, addr)
}
fn assert_response_text_order(response: &[u8], first: &str, second: &str) {
let text = String::from_utf8_lossy(response).to_ascii_lowercase();
let first_index = text
.find(first)
.unwrap_or_else(|| panic!("response must contain first marker `{}`: {}", first, text));
let second_index = text
.find(second)
.unwrap_or_else(|| panic!("response must contain second marker `{}`: {}", second, text));
assert!(
first_index < second_index,
"`{}` must appear before `{}` in response: {}",
first,
second,
text
);
}
fn read_sse_response(addr: SocketAddr) -> Vec<u8> {
let mut stream = TcpStream::connect(addr).expect("test client must connect to SSE server");
stream
.set_read_timeout(Some(Duration::from_millis(500)))
.expect("test client must set read timeout");
stream
.set_write_timeout(Some(Duration::from_secs(2)))
.expect("test client must set write timeout");
let req = format!(
"GET /sse HTTP/1.1\r\nHost: {}\r\nAccept: text/event-stream\r\nConnection: close\r\n\r\n",
addr
);
stream
.write_all(req.as_bytes())
.expect("test client must write HTTP request");
let started = Instant::now();
let mut response = Vec::new();
let mut buf = [0u8; 1024];
while started.elapsed() < Duration::from_secs(5) {
match stream.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
response.extend_from_slice(&buf[..n]);
if response
.windows(b"0\r\n\r\n".len())
.any(|w| w == b"0\r\n\r\n")
{
return response;
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
continue;
}
Err(e) => panic!("test client read failed: {:?}", e),
}
}
response
}
#[test]
fn sse_real_network_get_stream_receives_chunked_event() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_server(SseNetworkScenario::SingleEvent);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE real-network test listener",
)));
let text = String::from_utf8_lossy(&response).to_ascii_lowercase();
assert!(
text.contains("http/1.1 200"),
"response must contain HTTP 200 status, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("content-type:text/event-stream; charset=utf-8"),
"response must contain SSE content type, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("transfer-encoding:chunked"),
"response must contain chunked transfer encoding, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("id: evt-1"),
"response must contain SSE id field, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("event: notice"),
"response must contain SSE event field, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("data: hello from pi_http"),
"response must contain SSE data field, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
response
.windows(b"0\r\n\r\n".len())
.any(|w| w == b"0\r\n\r\n"),
"response must contain chunked finish frame, got: {}",
String::from_utf8_lossy(&response)
);
}
#[test]
fn sse_real_network_acceptor_can_reject_open_request() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_server(SseNetworkScenario::Reject);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE reject test listener",
)));
let text = String::from_utf8_lossy(&response).to_ascii_lowercase();
assert!(
text.contains("http/1.1 403"),
"response must contain HTTP 403 status, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("sse rejected by acceptor"),
"response must contain acceptor rejection message, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
!text.contains("content-type:text/event-stream"),
"reject response must not be SSE stream, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
!response
.windows(b"0\r\n\r\n".len())
.any(|w| w == b"0\r\n\r\n"),
"reject response must not contain chunked finish frame, got: {}",
String::from_utf8_lossy(&response)
);
}
#[test]
fn sse_real_network_port_handshake_accepts_and_binds_key() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_port_server(SsePortNetworkScenario::Accept);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE port accept test listener",
)));
let text = String::from_utf8_lossy(&response).to_ascii_lowercase();
assert!(
text.contains("http/1.1 200"),
"port accept response must contain HTTP 200 status, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("content-type:text/event-stream; charset=utf-8"),
"port accept response must be SSE stream, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("data: port session ready"),
"port accept response must contain initialized SSE event, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
!text.contains("x-pi-http-sse-"),
"internal SSE handshake headers must not leak, got: {}",
String::from_utf8_lossy(&response)
);
}
#[test]
fn sse_real_network_port_handshake_rejects_open_request() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_port_server(SsePortNetworkScenario::Reject);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE port reject test listener",
)));
let text = String::from_utf8_lossy(&response).to_ascii_lowercase();
assert!(
text.contains("http/1.1 401"),
"port reject response must contain HTTP 401 status, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("port rejected sse"),
"port reject response must contain reject reason, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
!text.contains("content-type:text/event-stream"),
"port reject response must not be SSE stream, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
!text.contains("x-pi-http-sse-"),
"internal SSE handshake headers must not leak, got: {}",
String::from_utf8_lossy(&response)
);
}
#[test]
fn sse_real_network_heartbeat_runtime_emits_comment() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_server(SseNetworkScenario::Heartbeat);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE heartbeat test listener",
)));
let text = String::from_utf8_lossy(&response).to_ascii_lowercase();
assert!(
text.contains("http/1.1 200"),
"heartbeat response must contain HTTP 200 status, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("content-type:text/event-stream; charset=utf-8"),
"heartbeat response must be SSE stream, got: {}",
String::from_utf8_lossy(&response)
);
assert!(
text.contains("\r\n:\n\n\r\n") || text.contains(":\n\n"),
"heartbeat response must contain SSE comment frame, got: {}",
String::from_utf8_lossy(&response)
);
}
#[test]
fn sse_real_network_same_thread_order_matches_call_order() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_server(SseNetworkScenario::SameThreadOrder);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE same-thread order test listener",
)));
assert_response_text_order(&response, "id: evt-1", "id: evt-2");
assert_response_text_order(&response, "data: first", "data: second");
}
#[test]
fn sse_real_network_cross_thread_controlled_enqueue_order_matches_output_order() {
let _ = env_logger::builder().is_test(true).try_init();
let _timer = startup_global_time_loop(10);
let (listener, addr) = start_sse_server(SseNetworkScenario::CrossThreadControlledOrder);
thread::sleep(Duration::from_millis(100));
let response = read_sse_response(addr);
listener.close(Err(Error::new(
ErrorKind::Interrupted,
"close SSE cross-thread order test listener",
)));
assert_response_text_order(&response, "id: evt-1", "id: evt-2");
assert_response_text_order(&response, "data: first", "data: second");
}