use crate::onion;
use crate::onion::circuit::Circuit;
use crate::onion::crypto::{self, EphemeralPrivateKey, SessionKey};
use crate::onion::protocol::{
CircuitOpaque, CircuitOpaqueBytes, TryFromBytesExt, TunnelRequest, VerifyKey,
};
use crate::onion::socket::{OnionSocket, OnionSocketError, SocketResult};
use crate::{Peer, PeerProvider, Result};
use anyhow::{anyhow, Context};
use bytes::Bytes;
use log::{trace, warn};
use std::fmt;
use std::mem;
use std::sync::Arc;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
const MAX_PEER_FAILURES: usize = 10;
pub type TunnelId = u32;
#[derive(Error, Debug)]
pub(crate) enum TunnelError {
#[error("Tunnel operation could not be completed")]
Incomplete,
#[error("Tunnel operation caused the tunnel to break")]
Broken(Option<OnionSocketError>),
}
impl From<OnionSocketError> for TunnelError {
fn from(e: OnionSocketError) -> Self {
match e {
OnionSocketError::Peer => TunnelError::Incomplete,
e => TunnelError::Broken(Some(e)),
}
}
}
pub(crate) type TunnelResult<T> = std::result::Result<T, TunnelError>;
#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) enum Event {
Switchover,
Destroy,
KeepAlive,
}
pub(crate) struct Tunnel {
pub(crate) id: TunnelId,
out_circuit: Circuit,
session_keys: Vec<SessionKey>,
}
impl Tunnel {
pub(crate) async fn init(id: TunnelId, peer: &Peer) -> Result<Self> {
trace!("Creating tunnel {} to peer {}", id, &peer.addr);
let (private_key, key) = crypto::generate_ephemeral_keypair();
let circuit_id = Circuit::random_id();
let stream = TcpStream::connect(peer.addr)
.await
.context("Could not connect to peer")?;
let mut socket = OnionSocket::new(stream);
let peer_key = socket
.initiate_handshake(circuit_id, key)
.await
.context("Handshake failed while initializing new tunnel")?;
let secret = Tunnel::derive_secret(&peer, private_key, peer_key)
.context("SessionKey derivation failed")?;
Ok(Self {
id,
out_circuit: Circuit::new(circuit_id, socket),
session_keys: vec![secret],
})
}
fn derive_secret(
peer: &&Peer,
private_key: EphemeralPrivateKey,
peer_key: VerifyKey,
) -> Result<SessionKey> {
let peer_key = peer_key
.verify(&peer.hostkey)
.context("Could not verify peer public key")?;
let secret = SessionKey::from_key_exchange(private_key, &peer_key)?;
Ok(secret)
}
pub(crate) fn len(&self) -> usize {
self.session_keys.len()
}
pub(crate) async fn extend(&mut self, peer: &Peer) -> TunnelResult<()> {
trace!("Extending tunnel {} to peer {}", self.id, &peer.addr);
let (private_key, key) = crypto::generate_ephemeral_keypair();
let peer_key = self
.out_circuit
.socket
.initiate_tunnel_handshake(self.out_circuit.id, peer.addr, key, &self.session_keys)
.await?;
if let Ok(secret) = Tunnel::derive_secret(&peer, private_key, peer_key) {
self.session_keys.insert(0, secret);
Ok(())
} else {
self.truncate(0)
.await
.map_err(|_| TunnelError::Broken(None))?;
Err(TunnelError::Incomplete)
}
}
pub(crate) async fn truncate(&mut self, n: usize) -> TunnelResult<()> {
if n >= self.session_keys.len() {
return Err(TunnelError::Incomplete);
}
self.out_circuit
.socket
.truncate_tunnel(self.out_circuit.id, &self.session_keys[n..])
.await?;
for _ in 0..n {
self.session_keys.remove(0);
}
Ok(())
}
pub(crate) async fn begin(&mut self) -> TunnelResult<()> {
self.out_circuit
.socket
.begin(self.out_circuit.id, self.id, &self.session_keys)
.await?;
Ok(())
}
pub(crate) async fn end(&mut self) -> TunnelResult<()> {
self.out_circuit
.socket
.end(self.out_circuit.id, self.id, &self.session_keys)
.await?;
Ok(())
}
pub(crate) async fn keep_alive(&mut self) -> TunnelResult<()> {
self.out_circuit
.socket
.send_keep_alive(self.out_circuit.id, &self.session_keys)
.await?;
Ok(())
}
pub(crate) async fn _truncate_to_length(&mut self, n_hops: usize) -> TunnelResult<()> {
let mut num_fails = 0;
while self.session_keys.len() > n_hops + 1 {
match self.truncate(1).await {
Err(TunnelError::Broken(e)) => {
return Err(TunnelError::Broken(e));
}
Err(TunnelError::Incomplete) => {
num_fails += 1;
if num_fails >= MAX_PEER_FAILURES {
return Err(TunnelError::Incomplete);
}
}
Ok(_) => {}
}
}
Ok(())
}
async fn unbuild(&mut self) {
self.teardown().await;
}
async fn teardown(&mut self) {
self.out_circuit.teardown_with_timeout().await;
}
}
pub fn random_id() -> TunnelId {
let mut id_buf = [0u8; 4];
crypto::fill_random(&mut id_buf);
u32::from_le_bytes(id_buf)
}
#[derive(Clone, Debug)]
pub enum Target {
Peer(Peer),
Random,
}
#[derive(Clone)]
pub(crate) struct TunnelBuilder {
tunnel_id: TunnelId,
dest: Target,
n_hops: usize,
peer_provider: PeerProvider,
}
impl TunnelBuilder {
pub(crate) fn new(
tunnel_id: TunnelId,
dest: Target,
n_hops: usize,
peer_provider: PeerProvider,
) -> Self {
TunnelBuilder {
tunnel_id,
dest,
n_hops,
peer_provider,
}
}
pub(crate) async fn build(&mut self) -> Result<Tunnel> {
let mut tunnel = None;
for _ in 0..MAX_PEER_FAILURES {
tunnel = match (tunnel.take(), &self.dest) {
(None, Target::Peer(peer)) if self.n_hops == 0 => {
Tunnel::init(self.tunnel_id, peer)
.await
.map_err(|e| warn!("Error while building tunnel: {:?}", e))
.ok()
}
(None, _) => {
let peer = self
.peer_provider
.random_peer()
.await
.context(anyhow!("Failed to get random peer"))?;
Tunnel::init(self.tunnel_id, &peer)
.await
.map_err(|e| warn!("Error while building tunnel: {:?}", e))
.ok()
}
(Some(mut tunnel), Target::Peer(peer)) if tunnel.len() == self.n_hops => {
match tunnel.extend(peer).await {
Err(TunnelError::Broken(e)) => {
warn!("Error while building tunnel: {:?}", e);
tunnel.teardown().await;
None
}
Err(TunnelError::Incomplete) => Some(tunnel),
Ok(_) => Some(tunnel),
}
}
(Some(mut tunnel), _) if tunnel.len() <= self.n_hops => {
let peer = self
.peer_provider
.random_peer()
.await
.context(anyhow!("Failed to get random peer"))?;
match tunnel.extend(&peer).await {
Err(TunnelError::Broken(e)) => {
warn!("Error while building tunnel: {:?}", e);
tunnel.teardown().await;
None
}
Err(TunnelError::Incomplete) => Some(tunnel),
Ok(_) => Some(tunnel),
}
}
(Some(tunnel), _) => return Ok(tunnel),
}
}
Err(anyhow!("failed to build tunnel"))
}
}
pub(crate) struct TunnelHandler {
tunnel: Tunnel,
next_tunnel: Arc<Mutex<Option<Tunnel>>>,
state: State,
events: broadcast::Receiver<Event>,
builder: TunnelBuilder,
}
pub(crate) enum State {
Building {
ready: oneshot::Sender<Result<onion::Tunnel>>,
},
Ready {
data_tx: mpsc::Sender<Bytes>,
data_rx: mpsc::UnboundedReceiver<Bytes>,
},
Destroying,
Destroyed,
}
impl TunnelHandler {
pub(crate) fn new(
first_tunnel: Tunnel,
tunnel_builder: TunnelBuilder,
events: broadcast::Receiver<Event>,
ready: oneshot::Sender<Result<onion::Tunnel>>,
) -> Self {
TunnelHandler {
tunnel: first_tunnel,
next_tunnel: Arc::new(Mutex::new(None)),
state: State::Building { ready },
events,
builder: tunnel_builder,
}
}
pub(crate) async fn handle(&mut self) {
trace!(
"Starting TunnelHandler for tunnel {:?}",
self.builder.tunnel_id
);
if let Err(e) = self.try_handle().await {
warn!("Error in TunnelHandler: {}", e);
self.state = State::Destroyed;
self.tunnel.teardown().await;
}
}
async fn try_handle(&mut self) -> Result<()> {
loop {
match &mut self.state {
State::Building { .. } | State::Destroying => {
tokio::select! {
Ok(evt) = self.events.recv() => {
self.handle_event(evt).await?;
}
}
}
State::Ready { data_rx, .. } => {
tokio::select! {
data = data_rx.recv() => {
self.handle_data(data).await?;
}
msg = self.tunnel.out_circuit.accept_opaque() => {
self.handle_tunnel_message(msg).await?;
}
Ok(evt) = self.events.recv() => {
self.handle_event(evt).await?;
}
}
}
State::Destroyed => return Ok(()),
}
}
}
async fn handle_tunnel_message(
&mut self,
msg: SocketResult<CircuitOpaque<CircuitOpaqueBytes>>,
) -> Result<()> {
let mut msg = msg?;
msg.decrypt(self.tunnel.session_keys.iter().rev())?;
let tunnel_msg = TunnelRequest::read_with_digest_from(&mut msg.payload.bytes);
match tunnel_msg {
Ok(TunnelRequest::Data(tunnel_id, data)) if tunnel_id == self.tunnel.id => {
if let State::Ready { data_tx, .. } = &mut self.state {
let _ = data_tx.send(data).await; }
Ok(())
}
Ok(TunnelRequest::End(_tunnel_id)) => {
Err(anyhow!("Tunnel broke due to unexpected End"))
}
_ => {
Err(anyhow!(
"Tunnel broke due to invalid request or broken digest"
))
}
}
}
async fn handle_data(&mut self, data: Option<Bytes>) -> Result<()> {
debug_assert!(matches!(&self.state, State::Ready { .. }));
match data {
Some(data) => {
let circuit_id = self.tunnel.out_circuit.id;
let tunnel_id = self.tunnel.id;
self.tunnel
.out_circuit
.socket
.send_data(circuit_id, tunnel_id, data, &self.tunnel.session_keys)
.await?;
}
None => self.state = State::Destroying,
}
Ok(())
}
async fn handle_event(&mut self, evt: Event) -> Result<()> {
trace!(
"TunnelHandler: handling event {:?} (state = {:?})",
evt,
self.state
);
let mut state = State::Destroyed; std::mem::swap(&mut self.state, &mut state);
self.state = match (evt, state) {
(Event::Switchover, State::Building { ready }) => {
self.tunnel.begin().await?;
let (tunnel, data_tx, data_rx) = onion::Tunnel::new(self.tunnel.id, true);
let _ = ready.send(Ok(tunnel)); self.spawn_next_tunnel_task();
State::Ready { data_tx, data_rx }
}
(Event::Switchover, State::Ready { data_tx, data_rx }) => {
let mut new_tunnel = self
.next_tunnel
.lock()
.await
.take()
.ok_or_else(|| anyhow!("Switchover failed: no next tunnel"))?;
mem::swap(&mut self.tunnel, &mut new_tunnel);
let mut old_tunnel = new_tunnel;
self.tunnel.begin().await?;
old_tunnel.end().await?;
self.spawn_next_tunnel_task();
tokio::spawn(async move {
old_tunnel.unbuild().await;
});
State::Ready { data_tx, data_rx }
}
(Event::Switchover, State::Destroying) => {
self.tunnel.end().await?;
self.tunnel.unbuild().await;
if let Some(mut next_tunnel) = self.next_tunnel.lock().await.take() {
next_tunnel.unbuild().await;
}
State::Destroyed
}
(Event::Destroy, State::Ready { .. }) => State::Destroying,
(Event::KeepAlive, State::Destroyed) => State::Destroyed, (Event::KeepAlive, state) => {
self.tunnel.keep_alive().await?;
if let Some(next_tunnel) = self.next_tunnel.lock().await.as_mut() {
next_tunnel.keep_alive().await?;
}
state
} _ => return Err(anyhow!("Illegal TunnelHandler state")),
};
Ok(())
}
fn spawn_next_tunnel_task(&self) {
tokio::spawn({
let next_tunnel = self.next_tunnel.clone();
let mut builder = self.builder.clone();
async move {
match builder.build().await {
Ok(new_tunnel) => {
next_tunnel.lock().await.replace(new_tunnel);
}
Err(e) => warn!("Rebuilding of a tunnel failed: {}", e),
};
}
});
}
}
impl fmt::Debug for Tunnel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tunnel")
.field("id", &self.id)
.field("out_circuit", &self.out_circuit)
.field("len", &self.len())
.finish()
}
}
impl fmt::Debug for TunnelBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TunnelBuilder")
.field("tunnel_id", &self.tunnel_id)
.field("dest", &self.dest)
.field("n_hops", &self.n_hops)
.finish()
}
}
impl fmt::Debug for TunnelHandler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TunnelHandler")
.field("tunnel", &self.tunnel)
.field("state", &self.state)
.field("builder", &self.builder)
.finish()
}
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
State::Building { .. } => f.debug_struct("Building").finish(),
State::Ready { .. } => f.debug_struct("Ready").finish(),
State::Destroying => f.debug_struct("Destroying").finish(),
State::Destroyed => f.debug_struct("Destroyed").finish(),
}
}
}