use snow::{Builder, TransportState};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
#[derive(Debug, thiserror::Error)]
pub enum ScallopError {
#[error("failed to init builder")]
InitFailed(#[source] snow::Error),
#[error("transport error")]
TransportError(#[from] tokio::io::Error),
#[error("noise error")]
NoiseError(#[from] snow::Error),
#[error("protocol error")]
ProtocolError(String),
#[error("auth error")]
AuthError(String),
}
#[derive(Debug, PartialEq)]
enum ReadMode {
Length,
Body,
Read,
}
pub type Key = [u8; 32];
#[derive(PartialEq)]
pub enum ContainsResponse<State: PartialEq> {
Approved(State),
NotFound,
Rejected,
}
pub trait ScallopAuthStore {
type State: PartialEq;
fn contains(&mut self, _key: &Key) -> ContainsResponse<Self::State> {
ContainsResponse::NotFound
}
fn verify(&mut self, attestation: &[u8], key: Key) -> Option<Self::State>;
}
impl<T: ScallopAuthStore> ScallopAuthStore for &mut T {
type State = T::State;
fn contains(&mut self, key: &Key) -> ContainsResponse<Self::State> {
(**self).contains(key)
}
fn verify(&mut self, attestation: &[u8], key: Key) -> Option<Self::State> {
(**self).verify(attestation, key)
}
}
impl ScallopAuthStore for () {
type State = ();
fn contains(&mut self, _key: &Key) -> ContainsResponse<Self::State> {
unimplemented!()
}
fn verify(&mut self, _attestation: &[u8], _key: Key) -> Option<Self::State> {
unimplemented!()
}
}
pub trait ScallopAuther: Send {
type Error: std::fmt::Debug;
fn new_auth(
&mut self,
) -> impl std::future::Future<Output = Result<Box<[u8]>, Self::Error>> + Send;
}
impl<T: ScallopAuther> ScallopAuther for &mut T {
type Error = T::Error;
async fn new_auth(&mut self) -> Result<Box<[u8]>, T::Error> {
(**self).new_auth().await
}
}
impl ScallopAuther for () {
type Error = ();
async fn new_auth(&mut self) -> Result<Box<[u8]>, ()> {
unimplemented!();
}
}
#[derive(Debug)]
pub struct ScallopStream<Stream: AsyncWrite + AsyncRead + Unpin, State = ()> {
noise: TransportState,
stream: Stream,
rbuf: Box<[u8]>,
pending: usize,
mode: ReadMode,
read_end: usize,
read_start: usize,
wbuf: Box<[u8]>,
write_start: usize,
write_end: usize,
pub state: Option<State>,
}
trait Noiser {
fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error>;
fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error>;
}
impl Noiser for snow::HandshakeState {
fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
self.read_message(payload, message)
}
fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
self.write_message(payload, message)
}
}
impl Noiser for snow::TransportState {
fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
snow::TransportState::read_message(self, payload, message)
}
fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
snow::TransportState::write_message(self, payload, message)
}
}
async fn noise_read(
noise: &mut impl Noiser,
stream: &mut (impl AsyncRead + Unpin),
src: &mut [u8],
dst: &mut [u8],
) -> Result<usize, ScallopError> {
let len = stream.read_u16().await? as usize;
if len > src.len() {
return Err(ScallopError::ProtocolError("message too big".into()));
}
stream.read_exact(&mut src[0..len]).await?;
let len = noise.read_message(&src[0..len], dst)?;
Ok(len)
}
async fn noise_write(
noise: &mut impl Noiser,
stream: &mut (impl AsyncWrite + Unpin),
src: &[u8],
dst: &mut [u8],
dst_offset: usize,
) -> Result<(), ScallopError> {
let len = noise
.write_message(src, &mut dst[dst_offset + 2..])
.map_err(std::io::Error::other)?;
dst[dst_offset..dst_offset + 2].copy_from_slice(&(len as u16).to_be_bytes());
stream.write_all(&dst[0..dst_offset + len + 2]).await?;
stream.flush().await?;
Ok(())
}
#[allow(non_snake_case)]
pub async fn new_client_async_Noise_IX_25519_ChaChaPoly_BLAKE2b<
Base: AsyncWrite + AsyncRead + Unpin,
AS: ScallopAuthStore,
>(
mut stream: Base,
secret: &[u8; 32],
mut auth_store: Option<AS>,
auther: Option<impl ScallopAuther>,
) -> Result<ScallopStream<Base, AS::State>, ScallopError> {
let mut buf = vec![0u8; 65000].into_boxed_slice();
let mut noise_buf = vec![0u8; 65000].into_boxed_slice();
let prologue = b"NoiseSocketInit1\x00\x00";
let mut noise = Builder::new(
"Noise_IX_25519_ChaChaPoly_BLAKE2b"
.parse()
.map_err(ScallopError::InitFailed)?,
)
.local_private_key(secret)
.prologue(prologue)
.build_initiator()
.map_err(ScallopError::InitFailed)?;
noise_write(&mut noise, &mut stream, &[], &mut buf, 2).await?;
let len = stream.read_u16().await?;
if len != 0 {
return Err(ScallopError::ProtocolError(
"non zero second negotiation length".into(),
));
}
let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
if len != 3 || noise_buf[0] != 0 || noise_buf[1] != 1 {
return Err(ScallopError::ProtocolError(
"invalid second payload length".into(),
));
}
if noise_buf[2] > 1 {
return Err(ScallopError::ProtocolError(
"invalid auth request in second payload".into(),
));
}
let should_send_auth = noise_buf[2] == 1;
if should_send_auth && auther.is_none() {
return Err(ScallopError::ProtocolError(
"auth requested but no auther available".into(),
));
}
let remote_static: [u8; 32] = noise.get_remote_static().unwrap().try_into().unwrap();
let contains = auth_store.as_mut().map(|x| x.contains(&remote_static));
if contains == Some(ContainsResponse::Rejected) {
return Err(ScallopError::ProtocolError(
"remote static key rejected".into(),
));
}
let mut state = None::<AS::State>;
let should_ask_auth = contains == Some(ContainsResponse::NotFound);
if let Some(ContainsResponse::Approved(_state)) = contains {
state = Some(_state);
}
let mut noise = noise.into_transport_mode()?;
async fn send_CLIENTFIN(
noise: &mut impl Noiser,
stream: &mut (impl AsyncWrite + Unpin),
buf: &mut [u8],
noise_buf: &mut [u8],
payload: &[u8],
should_ask_auth: bool,
) -> Result<(), ScallopError> {
noise_buf[0] = if !should_ask_auth { 0 } else { 1 };
noise_buf[1..3].copy_from_slice(&(payload.len() as u16).to_be_bytes());
noise_buf[3..3 + payload.len()].copy_from_slice(payload);
noise_write(noise, stream, &noise_buf[0..payload.len() + 3], buf, 0).await?;
Ok(())
}
if should_send_auth {
let payload = auther
.unwrap()
.new_auth()
.await
.map_err(|e| ScallopError::AuthError(format!("{e:?}")))?;
if payload.len() > 60000 {
return Err(ScallopError::ProtocolError("auth payload too big".into()));
}
send_CLIENTFIN(
&mut noise,
&mut stream,
&mut buf,
&mut noise_buf,
&payload,
should_ask_auth,
)
.await?;
} else {
send_CLIENTFIN(
&mut noise,
&mut stream,
&mut buf,
&mut noise_buf,
&[],
should_ask_auth,
)
.await?;
}
if should_ask_auth {
let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
if len < 2 {
return Err(ScallopError::ProtocolError(
"invalid SERVERFIN length".into(),
));
}
if u16::from_be_bytes([noise_buf[0], noise_buf[1]]) as usize != len - 2 {
return Err(ScallopError::ProtocolError(
"invalid SERVERFIN payload length".into(),
));
}
let Some(_state) = auth_store
.as_mut()
.unwrap()
.verify(&noise_buf[2..len], remote_static)
else {
return Err(ScallopError::ProtocolError("invalid attestation".into()));
};
state = Some(_state)
}
Ok(ScallopStream {
noise,
stream,
rbuf: vec![0u8; 2].into_boxed_slice(),
pending: 2,
mode: ReadMode::Length,
read_start: 0,
read_end: 0,
wbuf: vec![].into_boxed_slice(),
write_start: 0,
write_end: 0,
state,
})
}
#[allow(non_snake_case)]
pub async fn new_server_async_Noise_IX_25519_ChaChaPoly_BLAKE2b<
Base: AsyncWrite + AsyncRead + Unpin,
AS: ScallopAuthStore,
>(
mut stream: Base,
secret: &[u8; 32],
mut auth_store: Option<AS>,
auther: Option<impl ScallopAuther>,
) -> Result<ScallopStream<Base, AS::State>, ScallopError> {
let mut buf = vec![0u8; 65000].into_boxed_slice();
let mut noise_buf = vec![0u8; 65000].into_boxed_slice();
let prologue = b"NoiseSocketInit1\x00\x00";
let mut noise = Builder::new(
"Noise_IX_25519_ChaChaPoly_BLAKE2b"
.parse()
.map_err(ScallopError::InitFailed)?,
)
.local_private_key(secret)
.prologue(prologue)
.build_responder()
.map_err(ScallopError::InitFailed)?;
let len = stream.read_u16().await?;
if len != 0 {
return Err(ScallopError::ProtocolError(
"non zero first negotiation length".into(),
));
}
let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
if len != 0 {
return Err(ScallopError::ProtocolError(
"non zero first handshake payload".into(),
));
}
buf[0..2].copy_from_slice(&0u16.to_be_bytes());
let remote_static: [u8; 32] = noise
.get_remote_static()
.expect("handshake should have static key by now")
.try_into()
.expect("expected 32 byte key");
let contains = auth_store.as_mut().map(|x| x.contains(&remote_static));
if contains == Some(ContainsResponse::Rejected) {
return Err(ScallopError::ProtocolError(
"remote static key rejected".into(),
));
}
let mut state = None::<AS::State>;
let should_ask_auth = contains == Some(ContainsResponse::NotFound);
if let Some(ContainsResponse::Approved(_state)) = contains {
state = Some(_state);
}
let payload = &[0u8, 1u8, if !should_ask_auth { 0u8 } else { 1u8 }];
noise_write(&mut noise, &mut stream, payload, &mut buf, 2).await?;
let mut noise = noise.into_transport_mode()?;
let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
if len < 3 {
return Err(ScallopError::ProtocolError(
"invalid CLIENTFIN length".into(),
));
}
if u16::from_be_bytes([noise_buf[1], noise_buf[2]]) as usize != len - 3 {
return Err(ScallopError::ProtocolError(
"invalid CLIENTFIN payload length".into(),
));
}
if should_ask_auth {
let Some(_state) = auth_store
.as_mut()
.unwrap()
.verify(&noise_buf[3..len], remote_static)
else {
return Err(ScallopError::ProtocolError("invalid attestation".into()));
};
state = Some(_state)
}
if noise_buf[0] > 1 {
return Err(ScallopError::ProtocolError(
"invalid auth request in third payload".into(),
));
}
let should_send_auth = noise_buf[0] == 1;
if should_send_auth && auther.is_none() {
return Err(ScallopError::ProtocolError(
"auth requested but no auther available".into(),
));
}
if should_send_auth {
let payload = auther
.unwrap()
.new_auth()
.await
.map_err(|e| ScallopError::AuthError(format!("{e:?}")))?;
if payload.len() > 60000 {
return Err(ScallopError::ProtocolError("auth payload too big".into()));
}
noise_buf[0..2].copy_from_slice(&(payload.len() as u16).to_be_bytes());
noise_buf[2..2 + payload.len()].copy_from_slice(&payload);
noise_write(
&mut noise,
&mut stream,
&noise_buf[0..payload.len() + 2],
&mut buf,
0,
)
.await?;
}
Ok(ScallopStream {
noise,
stream,
rbuf: vec![0u8; 2].into_boxed_slice(),
pending: 2,
mode: ReadMode::Length,
read_start: 0,
read_end: 0,
wbuf: vec![].into_boxed_slice(),
write_start: 0,
write_end: 0,
state,
})
}
impl<Base: AsyncWrite + AsyncRead + Unpin, State> ScallopStream<Base, State> {
pub fn get_remote_static(&self) -> Option<[u8; 32]> {
self.noise
.get_remote_static()
.map(|x| x.try_into().expect("expected 32 byte key"))
}
}
impl<Base: AsyncWrite + AsyncRead + Unpin, State: Unpin> AsyncRead for ScallopStream<Base, State> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let stream = self.get_mut();
loop {
while stream.pending != 0 {
let base = std::pin::pin!(&mut stream.stream);
let len = stream.rbuf.len();
let mut buf = ReadBuf::new(&mut stream.rbuf[(len - stream.pending)..]);
std::task::ready!(base.poll_read(cx, &mut buf))?;
if buf.filled().is_empty() {
return std::task::Poll::Ready(Ok(()));
}
stream.pending -= buf.filled().len();
}
if stream.mode == ReadMode::Length {
let record_length = u16::from_be_bytes(stream.rbuf[0..2].try_into().unwrap());
stream.pending = record_length.into();
stream.mode = ReadMode::Body;
stream.rbuf = vec![0u8; stream.pending].into_boxed_slice();
} else if stream.mode == ReadMode::Body {
let len = stream
.noise
.read_message(&stream.rbuf.clone(), &mut stream.rbuf)
.map_err(std::io::Error::other)?;
stream.read_start = 0;
stream.read_end = len;
stream.mode = ReadMode::Read;
} else {
if buf.remaining() < stream.read_end - stream.read_start {
let read_start = stream.read_start;
stream.read_start += buf.remaining();
let read_end = read_start + buf.remaining();
buf.put_slice(&stream.rbuf[read_start..read_end]);
} else {
buf.put_slice(&stream.rbuf[stream.read_start..stream.read_end]);
stream.rbuf = vec![0u8; 2].into_boxed_slice();
stream.pending = 2;
stream.mode = ReadMode::Length;
}
return std::task::Poll::Ready(Ok(()));
}
}
}
}
impl<Base: AsyncWrite + AsyncRead + Unpin, State: Unpin> AsyncWrite for ScallopStream<Base, State> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
std::task::ready!(self.as_mut().poll_flush(cx))?;
let mut stream = self.as_mut();
let len = std::cmp::min(buf.len(), 64000) as u16;
let mut new_buf = vec![0u8; len as usize + 1000].into_boxed_slice();
let noise_len = stream
.noise
.write_message(&buf[0..len as usize], &mut new_buf[2..])
.map_err(std::io::Error::other)?;
new_buf[0..2].copy_from_slice(&(noise_len as u16).to_be_bytes());
stream.wbuf = new_buf;
stream.write_start = 0;
stream.write_end = noise_len + 2;
std::task::Poll::Ready(Ok(len as usize))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let stream = self.get_mut();
while stream.write_start != stream.write_end {
let base = std::pin::pin!(&mut stream.stream);
let size = std::task::ready!(
base.poll_write(cx, &stream.wbuf[stream.write_start..stream.write_end])
)?;
stream.write_start += size;
}
let base = std::pin::pin!(&mut stream.stream);
base.poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::ready!(self.as_mut().poll_flush(cx))?;
let stream = self.get_mut();
let base = std::pin::pin!(&mut stream.stream);
base.poll_shutdown(cx)
}
}