use digest::Digest;
use futures::SinkExt;
use futures::io::{AsyncRead, AsyncWrite};
use std::sync::Arc;
use std::time::SystemTime;
use tracing::{debug, instrument, trace};
use safelog::{MaybeSensitive, Sensitive};
use tor_cell::chancell::msg;
use tor_linkspec::{ChannelMethod, OwnedChanTarget};
use tor_rtcompat::{CoarseTimeProvider, SleepProvider, StreamOps};
use crate::ClockSkew;
use crate::Result;
use crate::channel::handshake::{
ChannelBaseHandshake, ChannelInitiatorHandshake, UnverifiedChannel, UnverifiedInitiatorChannel,
VerifiedChannel, unauthenticated_clock_skew,
};
use crate::channel::{Channel, ChannelFrame, ChannelType, Reactor, UniqId, new_frame};
use crate::memquota::ChannelAccount;
use crate::peer::{PeerAddr, PeerInfo};
pub struct ClientInitiatorHandshake<
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
> {
sleep_prov: S,
memquota: ChannelAccount,
framed_tls: ChannelFrame<T>,
target_method: Option<ChannelMethod>,
unique_id: UniqId,
}
impl<T, S> ChannelBaseHandshake<T> for ClientInitiatorHandshake<T, S>
where
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
{
fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
&mut self.framed_tls
}
fn unique_id(&self) -> &UniqId {
&self.unique_id
}
}
impl<T, S> ChannelInitiatorHandshake<T> for ClientInitiatorHandshake<T, S>
where
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
{
}
impl<
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
> ClientInitiatorHandshake<T, S>
{
pub(crate) fn new(
tls: T,
target_method: Option<ChannelMethod>,
sleep_prov: S,
memquota: ChannelAccount,
) -> Self {
Self {
framed_tls: new_frame(tls, ChannelType::ClientInitiator),
target_method,
unique_id: UniqId::new(),
sleep_prov,
memquota,
}
}
#[instrument(skip_all, level = "trace")]
pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedClientChannel<T, S>>
where
F: FnOnce() -> SystemTime,
{
match &self.target_method {
Some(method) => debug!(
stream_id = %self.unique_id,
"starting Tor handshake with {:?}",
method
),
None => debug!(stream_id = %self.unique_id, "starting Tor handshake"),
}
let (versions_flushed_at, versions_flushed_wallclock) =
self.send_versions_cell(now_fn).await?;
let link_protocol = self.recv_versions_cell().await?;
self.set_link_protocol(link_protocol)?;
let (_, certs_cell, (netinfo_cell, netinfo_rcvd_at), _) = self
.recv_cells_from_responder( false)
.await?;
let clock_skew = unauthenticated_clock_skew(
&netinfo_cell,
netinfo_rcvd_at,
versions_flushed_at,
versions_flushed_wallclock,
);
trace!(stream_id = %self.unique_id, "received handshake, ready to verify.");
Ok(UnverifiedClientChannel {
inner: UnverifiedInitiatorChannel {
inner: UnverifiedChannel {
link_protocol,
framed_tls: self.framed_tls,
clock_skew,
target_method: self.target_method.take(),
unique_id: self.unique_id,
sleep_prov: self.sleep_prov.clone(),
memquota: self.memquota.clone(),
},
certs_cell,
},
netinfo_cell,
})
}
}
pub struct UnverifiedClientChannel<
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
> {
inner: UnverifiedInitiatorChannel<T, S>,
netinfo_cell: msg::Netinfo,
}
impl<
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
> UnverifiedClientChannel<T, S>
{
#[instrument(skip_all, level = "trace")]
pub fn verify(
self,
peer: &OwnedChanTarget,
peer_cert: &[u8],
now: Option<std::time::SystemTime>,
) -> Result<VerifiedClientChannel<T, S>> {
let peer_cert_digest = tor_llcrypto::d::Sha256::digest(peer_cert).into();
let inner = self.inner.verify(peer, peer_cert_digest, now)?;
Ok(VerifiedClientChannel {
inner,
netinfo_cell: self.netinfo_cell,
})
}
pub fn clock_skew(&self) -> ClockSkew {
self.inner.inner.clock_skew
}
#[cfg(test)]
pub(crate) fn link_protocol(&self) -> u16 {
self.inner.inner.link_protocol
}
}
pub struct VerifiedClientChannel<
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
> {
inner: VerifiedChannel<T, S>,
netinfo_cell: msg::Netinfo,
}
impl<
T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
S: CoarseTimeProvider + SleepProvider,
> VerifiedClientChannel<T, S>
{
#[instrument(skip_all, level = "trace")]
pub async fn finish(
mut self,
peer_addr: Sensitive<PeerAddr>,
) -> Result<(Arc<Channel>, Reactor<S>)> {
let netinfo = msg::Netinfo::from_client(peer_addr.netinfo_addr());
trace!(stream_id = %self.inner.unique_id, "Sending netinfo cell.");
self.inner.framed_tls.send(netinfo.into()).await?;
let peer_info = MaybeSensitive::sensitive(PeerInfo::new(
peer_addr.into_inner(),
self.inner.relay_ids().clone(),
));
self.inner.finish(&self.netinfo_cell, &[], peer_info).await
}
}