use std::sync::Arc;
use std::time::Duration;
use http::{Request, Uri};
use openwire_core::websocket::{
Role, SharedWebSocketEngine, WebSocketChannel, WebSocketEngineConfig, WebSocketError,
WebSocketHandshake,
};
use openwire_core::{BoxConnection, CallContext, RequestBody, WireError};
use crate::auth::AuthAttemptState;
use crate::client::WebSocketCall;
use crate::connection::Address;
use crate::proxy::{ProxyChoice, ProxySelector, SelectedProxy};
use crate::transport::{connect_route_plan, ProxyConnectDeps};
use crate::websocket::handshake::{
validate_handshake_response, ValidatedHandshake, WebSocketRequestMarker,
};
use crate::websocket::native::NativeEngine;
use crate::websocket::public::{WebSocket, WebSocketReceiver, WebSocketSender};
use crate::websocket::writer::{
spawn_session, websocket_error_as_wire_error, HeartbeatConfig, SessionConfig,
};
const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_CLOSE_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
const DEFAULT_SEND_QUEUE_SIZE: usize = 32;
pub(crate) async fn execute(call: WebSocketCall<'_>) -> Result<WebSocket, WebSocketError> {
let WebSocketCall {
client,
mut request,
handshake_timeout,
close_timeout,
max_frame_size,
max_message_size,
send_queue_size,
ping_interval,
pong_timeout,
subprotocols,
deliver_control_frames,
engine,
} = call;
let handshake_timeout = handshake_timeout.unwrap_or(DEFAULT_HANDSHAKE_TIMEOUT);
let close_timeout = close_timeout.unwrap_or(DEFAULT_CLOSE_TIMEOUT);
let max_frame_size = max_frame_size.unwrap_or(DEFAULT_MAX_FRAME_SIZE);
let max_message_size = max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
let send_queue_size = send_queue_size.unwrap_or(DEFAULT_SEND_QUEUE_SIZE);
let heartbeat = validate_runtime_config(send_queue_size, ping_interval, pong_timeout)
.map_err(WebSocketError::Io)?;
let engine: SharedWebSocketEngine = engine.unwrap_or_else(|| Arc::new(NativeEngine::new()));
request
.extensions_mut()
.insert(WebSocketRequestMarker::new(subprotocols.clone()));
crate::bridge::normalize_request(&mut request).map_err(WebSocketError::Io)?;
let expected_accept = request
.extensions()
.get::<WebSocketRequestMarker>()
.map(|marker| marker.expected_accept.clone())
.ok_or_else(|| {
WebSocketError::Io(WireError::internal(
"WebSocketRequestMarker missing after bridge normalization",
std::io::Error::other("missing marker"),
))
})?;
let ctx = CallContext::from_factory(
client.event_listener_factory(),
&request,
Some(handshake_timeout),
);
ctx.listener().call_start(&ctx, &request);
let connect = async {
let address = build_address(client, request.uri())?;
let route_plan = client
.ws_connector()
.route_plan(ctx.clone(), &address)
.await?;
let deps: ProxyConnectDeps = client
.ws_connector()
.proxy_connect_deps(client.ws_connector().connect_timeout);
let io = connect_route_plan(
ctx.clone(),
request.uri().clone(),
route_plan,
AuthAttemptState {
total_attempt: 1,
retry_count: 0,
redirect_count: 0,
auth_count: 0,
},
deps,
)
.await?;
Ok::<BoxConnection, WireError>(io)
};
let io = match tokio::time::timeout(handshake_timeout, connect).await {
Ok(Ok(io)) => io,
Ok(Err(error)) => return Err(fail_before_open(&ctx, WebSocketError::Io(error))),
Err(_) => {
return Err(fail_before_open(
&ctx,
WebSocketError::Timeout(openwire_core::websocket::TimeoutKind::Handshake),
))
}
};
let handshake_future = run_handshake(
io,
request,
expected_accept,
subprotocols,
engine,
max_frame_size,
max_message_size,
);
let (response, channel, validated) =
match tokio::time::timeout(handshake_timeout, handshake_future).await {
Ok(Ok(triple)) => triple,
Ok(Err(error)) => return Err(fail_before_open(&ctx, error)),
Err(_) => {
return Err(fail_before_open(
&ctx,
WebSocketError::Timeout(openwire_core::websocket::TimeoutKind::Handshake),
))
}
};
let handshake = WebSocketHandshake::new(
response.status(),
response.headers().clone(),
validated.subprotocol,
validated.extensions,
);
ctx.listener().websocket_open(&ctx, &handshake);
let session = spawn_session(
channel,
SessionConfig {
queue_size: send_queue_size,
deliver_control_frames,
close_timeout,
heartbeat,
ctx: Some(ctx.clone()),
listener: Some(ctx.listener().clone()),
},
);
let sender = WebSocketSender::new(session.sender_tx);
let receiver = WebSocketReceiver {
rx: session.receiver_rx,
};
Ok(WebSocket {
sender,
receiver,
handshake,
})
}
fn validate_runtime_config(
send_queue_size: usize,
ping_interval: Option<Duration>,
pong_timeout: Option<Duration>,
) -> Result<Option<HeartbeatConfig>, WireError> {
if send_queue_size == 0 {
return Err(invalid_runtime_config(
"WebSocket send_queue_size must be greater than 0",
));
}
let Some(interval) = ping_interval else {
return Ok(None);
};
if interval.is_zero() {
return Err(invalid_runtime_config(
"WebSocket ping_interval must be greater than 0",
));
}
let pong_timeout = match pong_timeout {
Some(timeout) if timeout.is_zero() => {
return Err(invalid_runtime_config(
"WebSocket pong_timeout must be greater than 0 when ping_interval is enabled",
));
}
Some(timeout) => timeout,
None => interval.checked_mul(2).ok_or_else(|| {
invalid_runtime_config(
"WebSocket pong_timeout default would overflow for the configured ping_interval",
)
})?,
};
Ok(Some(HeartbeatConfig {
interval,
pong_timeout,
}))
}
fn invalid_runtime_config(message: &'static str) -> WireError {
WireError::invalid_request(message)
}
fn fail_before_open(ctx: &CallContext, error: WebSocketError) -> WebSocketError {
ctx.listener().websocket_failed(ctx, &error);
let wire_error = websocket_error_as_wire_error(&error);
ctx.listener().call_failed(ctx, &wire_error);
error
}
async fn run_handshake(
io: BoxConnection,
request: Request<RequestBody>,
expected_accept: String,
offered_subprotocols: Vec<String>,
engine: SharedWebSocketEngine,
max_frame_size: usize,
max_message_size: usize,
) -> Result<(http::Response<()>, WebSocketChannel, ValidatedHandshake), WebSocketError> {
let (response, upgraded) = crate::transport::protocol::bind_websocket_handshake(io, request)
.await
.map_err(WebSocketError::Io)?;
let validated = validate_handshake_response(&response, &expected_accept, &offered_subprotocols)
.map_err(|reason| WebSocketError::handshake(reason, Some(response.status())))?;
let upgraded = upgraded.ok_or_else(|| {
WebSocketError::handshake(
openwire_core::websocket::HandshakeFailure::Other(
"server returned 101 without an upgradable connection".into(),
),
Some(response.status()),
)
})?;
let cfg = WebSocketEngineConfig {
role: Role::Client,
subprotocol: validated.subprotocol.clone(),
extensions: validated.extensions.clone(),
max_frame_size,
max_message_size,
};
let upgraded_io = crate::transport::protocol::upgraded_into_box_connection(upgraded);
let channel = engine.upgrade(upgraded_io, cfg).await?;
Ok((response, channel, validated))
}
fn build_address(client: &crate::Client, uri: &Uri) -> Result<Address, WireError> {
let selection = client.ws_proxy_selector().select(uri)?;
let selected_proxy = selection.iter().find_map(|choice| match choice {
ProxyChoice::Direct => None,
ProxyChoice::Proxy(proxy) => Some(SelectedProxy::from_proxy(proxy)),
});
Address::from_uri(uri, selected_proxy.as_ref())
}