use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use tokio_tungstenite::client_async;
use tokio_tungstenite::tungstenite::handshake::client::{Request as WsRequest, generate_key};
use crate::config::DaemonServer;
use crate::error::{Error, ProtocolError};
use crate::protocol::{IdAllocator, Request, Response};
use crate::transport::dispatcher::Dispatcher;
use crate::transport::socket::WsTransport;
use crate::transport::tls;
pub(crate) struct ConnectedDispatcher {
pub(crate) dispatcher: Dispatcher,
pub(crate) version: String,
pub(crate) initial_job: String,
pub(crate) ids: IdAllocator,
pub(crate) in_flight: Arc<AtomicU32>,
}
pub(crate) async fn connect(server: &DaemonServer) -> crate::Result<ConnectedDispatcher> {
let tls_stream = tls::connect(server).await?;
let url = format!("wss://{}:{}/db2", server.host, server.port);
let ws_request = WsRequest::builder()
.uri(&url)
.header("Host", &server.host)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key())
.body(())
.map_err(|e| Error::Internal(format!("malformed ws request: {e}")))?;
let (ws_stream, _http_response) = client_async(ws_request, tls_stream)
.await
.map_err(|e| Error::Internal(format!("websocket upgrade failed: {e}")))?;
let transport = WsTransport::new(ws_stream);
let in_flight = Arc::new(AtomicU32::new(0));
let dispatcher = Dispatcher::spawn(Box::pin(transport), Arc::clone(&in_flight));
let handle = dispatcher.handle();
let ids = IdAllocator::new();
let connect_id = ids.next();
let request = Request::Connect {
id: connect_id.clone(),
user: server.user.clone(),
password: server.password.expose().to_string(),
};
let response = handle.send(request).await?;
let (version, initial_job) = match response {
Response::Connected { version, job, .. } => (version, job),
Response::Error(e) => {
return Err(Error::Auth(
e.error.unwrap_or_else(|| "connect rejected".into()),
));
}
other => {
return Err(Error::from(ProtocolError::CorrelationMismatch {
expected: connect_id,
got: format!("{other:?}"),
}));
}
};
Ok(ConnectedDispatcher {
dispatcher,
version,
initial_job,
ids,
in_flight,
})
}