use crate::socket::NoiseSocket;
use crate::transport::{Transport, TransportEvent};
use log::{debug, info, warn};
use prost::Message;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use wacore::handshake::{
HandshakeError as CoreHandshakeError, HandshakeState, build_handshake_header,
};
use wacore::runtime::{Runtime, timeout as rt_timeout};
use wacore_binary::consts::{NOISE_START_PATTERN, WA_CONN_HEADER};
const NOISE_HANDSHAKE_RESPONSE_TIMEOUT: Duration = Duration::from_secs(20);
#[derive(Debug, Error)]
pub enum HandshakeError {
#[error("Transport error: {0}")]
Transport(#[from] anyhow::Error),
#[error("Core handshake error: {0}")]
Core(#[from] CoreHandshakeError),
#[error("Timed out waiting for handshake response")]
Timeout,
#[error("Unexpected event during handshake: {0}")]
UnexpectedEvent(String),
}
type Result<T> = std::result::Result<T, HandshakeError>;
pub async fn do_handshake(
runtime: Arc<dyn Runtime>,
device: &crate::store::Device,
transport: Arc<dyn Transport>,
transport_events: &mut async_channel::Receiver<TransportEvent>,
) -> Result<Arc<NoiseSocket>> {
let client_payload = device.core.get_client_payload().encode_to_vec();
let mut handshake_state = HandshakeState::new(
device.core.noise_key.clone(),
client_payload,
NOISE_START_PATTERN,
&WA_CONN_HEADER,
)?;
let mut frame_decoder = wacore::framing::FrameDecoder::new();
debug!("--> Sending ClientHello");
let client_hello_bytes = handshake_state.build_client_hello()?;
let (header, used_edge_routing) =
build_handshake_header(device.core.edge_routing_info.as_deref());
if used_edge_routing {
debug!("Sending edge routing pre-intro for optimized reconnection");
} else if device.core.edge_routing_info.is_some() {
warn!("Edge routing info provided but not used (possibly too large)");
}
let framed = wacore::framing::encode_frame(&client_hello_bytes, Some(&header))
.map_err(HandshakeError::Transport)?;
transport.send(framed).await?;
let resp_frame = loop {
match rt_timeout(
&*runtime,
NOISE_HANDSHAKE_RESPONSE_TIMEOUT,
transport_events.recv(),
)
.await
{
Ok(Ok(TransportEvent::DataReceived(data))) => {
frame_decoder.feed(&data);
if let Some(frame) = frame_decoder.decode_frame() {
break frame;
}
continue;
}
Ok(Ok(TransportEvent::Connected)) => {
continue;
}
Ok(Ok(TransportEvent::Disconnected)) => {
return Err(HandshakeError::UnexpectedEvent(
"Disconnected during handshake".to_string(),
));
}
Ok(Err(_)) => return Err(HandshakeError::Timeout), Err(_) => return Err(HandshakeError::Timeout),
}
};
debug!("<-- Received handshake response, building ClientFinish");
let client_finish_bytes =
handshake_state.read_server_hello_and_build_client_finish(&resp_frame)?;
debug!("--> Sending ClientFinish");
let framed = wacore::framing::encode_frame(&client_finish_bytes, None)
.map_err(HandshakeError::Transport)?;
transport.send(framed).await?;
let (write_key, read_key) = handshake_state.finish()?;
info!("Handshake complete, switching to encrypted communication");
Ok(Arc::new(NoiseSocket::new(
runtime, transport, write_key, read_key,
)))
}