use std::fmt;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub struct PrefixedStream<S> {
prefix: Vec<u8>,
prefix_pos: usize,
inner: S,
}
impl<S> fmt::Debug for PrefixedStream<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PrefixedStream")
.field("prefix_remaining", &(self.prefix.len() - self.prefix_pos))
.finish_non_exhaustive()
}
}
impl<S> PrefixedStream<S> {
pub fn new(prefix: Vec<u8>, inner: S) -> Self {
Self {
prefix,
prefix_pos: 0,
inner,
}
}
}
impl<S: AsyncRead + Unpin> AsyncRead for PrefixedStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.prefix_pos < this.prefix.len() {
let remaining = &this.prefix[this.prefix_pos..];
let to_copy = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..to_copy]);
this.prefix_pos += to_copy;
return Poll::Ready(Ok(()));
}
Pin::new(&mut this.inner).poll_read(cx, buf)
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for PrefixedStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
pub async fn identify_plaintext_connection<S>(
mut stream: S,
) -> io::Result<Option<(irontide_core::Id20, PrefixedStream<S>)>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
use tokio::io::AsyncReadExt;
let mut first = [0u8; 1];
stream.read_exact(&mut first).await?;
if first[0] != 0x13 {
return Ok(None);
}
let mut rest = [0u8; 47];
stream.read_exact(&mut rest).await?;
let info_hash_bytes: [u8; 20] = rest[27..47].try_into().unwrap();
let info_hash = irontide_core::Id20::from(info_hash_bytes);
let mut preamble = Vec::with_capacity(48);
preamble.push(first[0]);
preamble.extend_from_slice(&rest);
Ok(Some((info_hash, PrefixedStream::new(preamble, stream))))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn prefixed_stream_yields_prefix_then_inner() {
let inner = tokio_test_stream(b"world");
let mut ps = PrefixedStream::new(b"hello ".to_vec(), inner);
let mut buf = vec![0u8; 11];
let n = ps.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello ");
let n2 = ps.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n2], b"world");
}
#[tokio::test]
async fn prefixed_stream_empty_prefix_passthrough() {
let inner = tokio_test_stream(b"direct");
let mut ps = PrefixedStream::new(Vec::new(), inner);
let mut buf = vec![0u8; 10];
let n = ps.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"direct");
}
#[tokio::test]
async fn prefixed_stream_write_delegates() {
let (client, mut server) = tokio::io::duplex(64);
let mut ps = PrefixedStream::new(b"prefix".to_vec(), client);
ps.write_all(b"data").await.unwrap();
ps.flush().await.unwrap();
let mut buf = vec![0u8; 4];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"data");
}
#[tokio::test]
async fn identify_plaintext_bt_handshake() {
let mut preamble = Vec::with_capacity(68);
preamble.push(0x13); preamble.extend_from_slice(b"BitTorrent protocol"); preamble.extend_from_slice(&[0u8; 8]); let info_hash = [0xAB; 20];
preamble.extend_from_slice(&info_hash); let peer_id = [0xCD; 20];
preamble.extend_from_slice(&peer_id);
let (client, mut server) = tokio::io::duplex(256);
let write_handle = tokio::spawn(async move {
server.write_all(&preamble).await.unwrap();
server
});
let result = identify_from_duplex(client).await;
assert!(result.is_some());
let (hash, mut ps) = result.unwrap();
assert_eq!(hash.as_bytes(), &info_hash);
let mut replay = vec![0u8; 68];
ps.read_exact(&mut replay).await.unwrap();
assert_eq!(replay[0], 0x13);
assert_eq!(&replay[28..48], &info_hash);
assert_eq!(&replay[48..68], &peer_id);
let _ = write_handle.await;
}
async fn identify_from_duplex(
mut stream: impl AsyncRead + AsyncWrite + Unpin,
) -> Option<(
irontide_core::Id20,
PrefixedStream<impl AsyncRead + AsyncWrite + Unpin>,
)> {
use tokio::io::AsyncReadExt;
let mut first = [0u8; 1];
stream.read_exact(&mut first).await.ok()?;
if first[0] != 0x13 {
return None;
}
let mut rest = [0u8; 47];
stream.read_exact(&mut rest).await.ok()?;
let info_hash_bytes: [u8; 20] = rest[27..47].try_into().unwrap();
let info_hash = irontide_core::Id20::from(info_hash_bytes);
let mut preamble = Vec::with_capacity(48);
preamble.push(first[0]);
preamble.extend_from_slice(&rest);
Some((info_hash, PrefixedStream::new(preamble, stream)))
}
fn tokio_test_stream(data: &[u8]) -> impl AsyncRead + AsyncWrite + Unpin {
let (mut writer, reader) = tokio::io::duplex(data.len() + 64);
let data = data.to_vec();
tokio::spawn(async move {
writer.write_all(&data).await.unwrap();
});
reader
}
}