use crate::client::ClientManager;
use crate::config::local_commit;
use crate::discovery::{PrimaryCache, normalize_mdns_name};
use local_channel::mpsc::{Receiver, Sender, channel};
use mousehop_ipc::{ClientHandle, DEFAULT_PORT};
use mousehop_proto::{
MAX_CLIPBOARD_SIZE, MAX_EVENT_SIZE, PROTOCOL_MAGIC, ProtoEvent, decode_clipboard_event,
encode_clipboard_event,
};
use std::{
cell::{Cell, RefCell},
collections::{HashMap, HashSet},
hash::{DefaultHasher, Hash, Hasher},
io,
net::{IpAddr, SocketAddr},
rc::Rc,
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use tokio::{
net::UdpSocket,
sync::Mutex,
task::{JoinSet, spawn_local},
};
use webrtc_dtls::{
config::{Config, ExtendedMasterSecretType},
conn::DTLSConn,
crypto::Certificate,
};
use webrtc_util::Conn;
#[derive(Debug, Error)]
pub(crate) enum MousehopConnectionError {
#[error(transparent)]
Bind(#[from] io::Error),
#[error(transparent)]
Dtls(#[from] webrtc_dtls::Error),
#[error(transparent)]
Webrtc(#[from] webrtc_util::Error),
#[error("not connected")]
NotConnected,
#[error("emulation is disabled on the target device")]
TargetEmulationDisabled,
#[error("Connection timed out")]
Timeout,
}
const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
const INITIAL_RETRY_BACKOFF: Duration = Duration::from_secs(1);
const MAX_RETRY_BACKOFF: Duration = Duration::from_secs(30);
struct RetryState {
next_attempt_at: Instant,
backoff: Duration,
signature: u64,
}
fn signature_of(ips: &HashSet<IpAddr>, primary: Option<IpAddr>) -> u64 {
let mut sorted: Vec<IpAddr> = ips.iter().copied().collect();
sorted.sort();
let mut hasher = DefaultHasher::new();
sorted.hash(&mut hasher);
primary.hash(&mut hasher);
hasher.finish()
}
fn record_retry_failure(
retry_state: &Rc<RefCell<HashMap<ClientHandle, RetryState>>>,
handle: ClientHandle,
ips: &HashSet<IpAddr>,
primary: Option<IpAddr>,
) {
let sig = signature_of(ips, primary);
let mut map = retry_state.borrow_mut();
let entry = map.entry(handle).or_insert(RetryState {
next_attempt_at: Instant::now(),
backoff: INITIAL_RETRY_BACKOFF,
signature: sig,
});
entry.signature = sig;
let next = entry.backoff;
entry.next_attempt_at = Instant::now() + next;
entry.backoff = (next * 2).min(MAX_RETRY_BACKOFF);
}
async fn connect(
addr: SocketAddr,
cert: Certificate,
) -> Result<(Arc<dyn Conn + Sync + Send>, SocketAddr), (SocketAddr, MousehopConnectionError)> {
log::info!("connecting to {addr} ...");
let conn = Arc::new(
UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| (addr, e.into()))?,
);
conn.connect(addr).await.map_err(|e| (addr, e.into()))?;
let config = Config {
certificates: vec![cert],
server_name: "ignored".to_owned(),
insecure_skip_verify: true,
extended_master_secret: ExtendedMasterSecretType::Require,
..Default::default()
};
let timeout = tokio::time::sleep(DEFAULT_CONNECTION_TIMEOUT);
tokio::select! {
_ = timeout => Err((addr, MousehopConnectionError::Timeout)),
result = DTLSConn::new(conn, config, true, None) => match result {
Ok(dtls_conn) => Ok((Arc::new(dtls_conn), addr)),
Err(e) => Err((addr, e.into())),
}
}
}
const PREFERRED_ADDR_HEAD_START: Duration = Duration::from_millis(200);
async fn connect_any(
addrs: &[SocketAddr],
preferred: Option<SocketAddr>,
cert: Certificate,
) -> Result<(Arc<dyn Conn + Send + Sync>, SocketAddr), MousehopConnectionError> {
let mut joinset = JoinSet::new();
if let Some(p) = preferred {
joinset.spawn_local(connect(p, cert.clone()));
let head_start = tokio::time::sleep(PREFERRED_ADDR_HEAD_START);
tokio::pin!(head_start);
loop {
tokio::select! {
_ = &mut head_start => break,
Some(r) = joinset.join_next() => match r.expect("join error") {
Ok(conn) => return Ok(conn),
Err((a, e)) => log::warn!("failed to connect to {a}: `{e}`"),
},
}
}
}
for &addr in addrs {
if Some(addr) == preferred {
continue;
}
joinset.spawn_local(connect(addr, cert.clone()));
}
loop {
match joinset.join_next().await {
None => return Err(MousehopConnectionError::NotConnected),
Some(r) => match r.expect("join error") {
Ok(conn) => return Ok(conn),
Err((a, e)) => {
log::warn!("failed to connect to {a}: `{e}`")
}
},
};
}
}
pub(crate) struct MousehopConnection {
cert: Certificate,
client_manager: ClientManager,
conns: Rc<Mutex<HashMap<SocketAddr, Arc<dyn Conn + Send + Sync>>>>,
connecting: Rc<Mutex<HashSet<ClientHandle>>>,
recv_rx: Receiver<(ClientHandle, ProtoEvent)>,
recv_tx: Sender<(ClientHandle, ProtoEvent)>,
ping_response: Rc<RefCell<HashSet<SocketAddr>>>,
primary_hints: PrimaryCache,
retry_state: Rc<RefCell<HashMap<ClientHandle, RetryState>>>,
}
impl MousehopConnection {
pub(crate) fn new(
cert: Certificate,
client_manager: ClientManager,
primary_hints: PrimaryCache,
) -> Self {
let (recv_tx, recv_rx) = channel();
Self {
cert,
client_manager,
conns: Default::default(),
connecting: Default::default(),
recv_rx,
recv_tx,
ping_response: Default::default(),
primary_hints,
retry_state: Default::default(),
}
}
pub(crate) async fn recv(&mut self) -> (ClientHandle, ProtoEvent) {
self.recv_rx.recv().await.expect("channel closed")
}
pub(crate) fn sender_clone(&self) -> Self {
let (_, dead_rx) = channel();
Self {
cert: self.cert.clone(),
client_manager: self.client_manager.clone(),
conns: self.conns.clone(),
connecting: self.connecting.clone(),
recv_rx: dead_rx,
recv_tx: self.recv_tx.clone(),
ping_response: self.ping_response.clone(),
primary_hints: self.primary_hints.clone(),
retry_state: self.retry_state.clone(),
}
}
pub(crate) async fn send(
&self,
event: ProtoEvent,
handle: ClientHandle,
) -> Result<(), MousehopConnectionError> {
let event_display = format!("{event}");
let bytes_owned: Option<Vec<u8>> = match &event {
ProtoEvent::Clipboard { .. } => match encode_clipboard_event(&event) {
Ok(v) => Some(v),
Err(e) => {
log::warn!("dropping oversize clipboard event for client {handle}: {e}");
return Ok(());
}
},
_ => None,
};
let bytes_fixed: ([u8; MAX_EVENT_SIZE], usize) = if bytes_owned.is_some() {
([0u8; MAX_EVENT_SIZE], 0)
} else {
event.into()
};
let buf: &[u8] = if let Some(v) = bytes_owned.as_deref() {
v
} else {
&bytes_fixed.0[..bytes_fixed.1]
};
if let Some(addr) = self.client_manager.active_addr(handle) {
let conn = {
let conns = self.conns.lock().await;
conns.get(&addr).cloned()
};
if let Some(conn) = conn {
if !self.client_manager.alive(handle) {
return Err(MousehopConnectionError::TargetEmulationDisabled);
}
match conn.send(buf).await {
Ok(_) => {}
Err(e) => {
log::warn!("client {handle} failed to send: {e}");
disconnect(&self.client_manager, handle, addr, &self.conns).await;
}
}
log::trace!("{event_display} >->->->->- {addr}");
return Ok(());
}
}
let mut connecting = self.connecting.lock().await;
if !connecting.contains(&handle) && self.should_attempt(handle) {
connecting.insert(handle);
spawn_local(connect_to_handle(
self.client_manager.clone(),
self.cert.clone(),
handle,
self.conns.clone(),
self.connecting.clone(),
self.recv_tx.clone(),
self.ping_response.clone(),
self.primary_hints.clone(),
self.retry_state.clone(),
));
}
Err(MousehopConnectionError::NotConnected)
}
fn should_attempt(&self, handle: ClientHandle) -> bool {
let ips = self.client_manager.get_ips(handle).unwrap_or_default();
let primary = self.client_manager.get_hostname(handle).and_then(|h| {
let key = normalize_mdns_name(&h);
self.primary_hints.borrow().get(&key).copied()
});
let sig = signature_of(&ips, primary);
let mut state = self.retry_state.borrow_mut();
match state.get_mut(&handle) {
None => true,
Some(s) if s.signature != sig => {
s.signature = sig;
s.next_attempt_at = Instant::now();
s.backoff = INITIAL_RETRY_BACKOFF;
true
}
Some(s) => Instant::now() >= s.next_attempt_at,
}
}
}
#[allow(clippy::too_many_arguments)]
async fn connect_to_handle(
client_manager: ClientManager,
cert: Certificate,
handle: ClientHandle,
conns: Rc<Mutex<HashMap<SocketAddr, Arc<dyn Conn + Send + Sync>>>>,
connecting: Rc<Mutex<HashSet<ClientHandle>>>,
tx: Sender<(ClientHandle, ProtoEvent)>,
ping_response: Rc<RefCell<HashSet<SocketAddr>>>,
primary_hints: PrimaryCache,
retry_state: Rc<RefCell<HashMap<ClientHandle, RetryState>>>,
) -> Result<(), MousehopConnectionError> {
log::info!("client {handle} connecting ...");
if let Some(ips_set) = client_manager.get_ips(handle) {
let port = client_manager.get_port(handle).unwrap_or(DEFAULT_PORT);
let addrs = ips_set
.iter()
.copied()
.map(|a| SocketAddr::new(a, port))
.collect::<Vec<_>>();
let primary_ip = client_manager.get_hostname(handle).and_then(|h| {
let key = normalize_mdns_name(&h);
primary_hints.borrow().get(&key).copied()
});
let preferred = primary_ip.map(|ip| SocketAddr::new(ip, port));
log::info!("client ({handle}) connecting ... (ips: {addrs:?}, preferred: {preferred:?})");
if addrs.is_empty() && preferred.is_none() {
record_retry_failure(&retry_state, handle, &ips_set, primary_ip);
connecting.lock().await.remove(&handle);
return Err(MousehopConnectionError::NotConnected);
}
let res = connect_any(&addrs, preferred, cert).await;
let (conn, addr) = match res {
Ok(c) => c,
Err(e) => {
record_retry_failure(&retry_state, handle, &ips_set, primary_ip);
connecting.lock().await.remove(&handle);
return Err(e);
}
};
log::info!("client ({handle}) connected @ {addr}");
client_manager.set_active_addr(handle, Some(addr));
conns.lock().await.insert(addr, conn.clone());
connecting.lock().await.remove(&handle);
retry_state.borrow_mut().remove(&handle);
let hello_ok = Rc::new(Cell::new(false));
spawn_local(hello_handshake(addr, conn.clone(), hello_ok.clone()));
spawn_local(ping_pong(addr, conn.clone(), ping_response.clone()));
spawn_local(receive_loop(
client_manager,
handle,
addr,
conn,
conns,
tx,
ping_response.clone(),
hello_ok,
));
return Ok(());
}
connecting.lock().await.remove(&handle);
Err(MousehopConnectionError::NotConnected)
}
const HELLO_MAX_ATTEMPTS: u32 = 8;
const HELLO_RETRY_INTERVAL: Duration = Duration::from_millis(750);
async fn hello_handshake(
addr: SocketAddr,
conn: Arc<dyn Conn + Send + Sync>,
hello_ok: Rc<Cell<bool>>,
) {
let (buf, len): ([u8; MAX_EVENT_SIZE], usize) = ProtoEvent::hello(local_commit()).into();
for _ in 0..HELLO_MAX_ATTEMPTS {
if hello_ok.get() {
return;
}
if let Err(e) = conn.send(&buf[..len]).await {
log::debug!("hello send to {addr} failed: {e}");
}
tokio::time::sleep(HELLO_RETRY_INTERVAL).await;
}
if !hello_ok.get() {
log::warn!(
"refusing {addr}: peer did not complete the mousehop handshake \
(no valid Hello) — closing connection"
);
let _ = conn.close().await;
}
}
async fn ping_pong(
addr: SocketAddr,
conn: Arc<dyn Conn + Send + Sync>,
ping_response: Rc<RefCell<HashSet<SocketAddr>>>,
) {
loop {
let (buf, len) = ProtoEvent::Ping.into();
for _ in 0..4 {
if let Err(e) = conn.send(&buf[..len]).await {
log::warn!("{addr}: send error `{e}`, closing connection");
let _ = conn.close().await;
break;
}
log::trace!("PING >->->->->- {addr}");
tokio::time::sleep(Duration::from_millis(500)).await;
}
if !ping_response.borrow_mut().remove(&addr) {
log::warn!("{addr} did not respond, closing connection");
let _ = conn.close().await;
return;
}
}
}
#[allow(clippy::too_many_arguments)]
async fn receive_loop(
client_manager: ClientManager,
handle: ClientHandle,
addr: SocketAddr,
conn: Arc<dyn Conn + Send + Sync>,
conns: Rc<Mutex<HashMap<SocketAddr, Arc<dyn Conn + Send + Sync>>>>,
tx: Sender<(ClientHandle, ProtoEvent)>,
ping_response: Rc<RefCell<HashSet<SocketAddr>>>,
hello_ok: Rc<Cell<bool>>,
) {
let mut buf = [0u8; MAX_CLIPBOARD_SIZE];
while let Ok(n) = conn.recv(&mut buf).await {
if n == 0 {
continue;
}
let datagram = &buf[..n];
let event = match decode_proto_datagram(datagram) {
Some(event) => event,
None => {
log::debug!("ignoring undecodable {n}-byte event from {addr}");
continue;
}
};
log::trace!("{addr} <==<==<== {event}");
match event {
ProtoEvent::Pong(b) => {
client_manager.set_active_addr(handle, Some(addr));
client_manager.set_alive(handle, b);
ping_response.borrow_mut().insert(addr);
}
ProtoEvent::Hello { magic, commit } => {
if magic != PROTOCOL_MAGIC {
log::warn!(
"refusing {addr}: peer presented a foreign protocol \
handshake (not mousehop) — closing connection"
);
let _ = conn.close().await;
break;
}
hello_ok.set(true);
client_manager.set_peer_commit(handle, Some(commit));
tx.send((handle, ProtoEvent::hello(commit)))
.expect("channel closed");
}
event => tx.send((handle, event)).expect("channel closed"),
}
}
log::debug!("{addr}: receive loop ended");
disconnect(&client_manager, handle, addr, &conns).await;
}
fn decode_proto_datagram(bytes: &[u8]) -> Option<ProtoEvent> {
use mousehop_proto::EventType;
let tag = *bytes.first()?;
if tag == EventType::Clipboard as u8 {
return decode_clipboard_event(bytes).ok();
}
let mut fixed = [0u8; MAX_EVENT_SIZE];
let copy_len = bytes.len().min(MAX_EVENT_SIZE);
fixed[..copy_len].copy_from_slice(&bytes[..copy_len]);
fixed.try_into().ok()
}
async fn disconnect(
client_manager: &ClientManager,
handle: ClientHandle,
addr: SocketAddr,
conns: &Mutex<HashMap<SocketAddr, Arc<dyn Conn + Send + Sync>>>,
) {
log::warn!("client ({handle}) @ {addr} connection closed");
conns.lock().await.remove(&addr);
client_manager.set_active_addr(handle, None);
client_manager.set_peer_commit(handle, None);
let active: Vec<SocketAddr> = conns.lock().await.keys().copied().collect();
log::info!("active connections: {active:?}");
}