use std::pin::Pin;
use anyhow::Context;
use async_task::Task;
use bipe::{BipeReader, BipeWriter};
use bytes::Bytes;
use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit};
use futures_util::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use sillad::Pipe;
use tap::Tap;
use crate::{read_prepend_length, write_prepend_length};
#[derive(Serialize, Deserialize)]
pub struct ClientHello {
pub credentials: Bytes,
pub crypt_hello: ClientCryptHello,
}
#[derive(Serialize, Deserialize)]
pub enum ClientCryptHello {
SharedSecretChallenge([u8; 32]),
X25519(x25519_dalek::PublicKey),
}
#[derive(Serialize, Deserialize)]
pub struct ExitHello {
pub inner: ExitHelloInner,
pub signature: ed25519_dalek::Signature,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum ExitHelloInner {
Reject(String),
SharedSecretResponse(blake3::Hash),
X25519(x25519_dalek::PublicKey),
}
#[pin_project]
pub struct ClientExitCryptPipe {
#[pin]
read_incoming: BipeReader,
_read_task: Task<()>,
#[pin]
write_outgoing: BipeWriter,
_write_task: Task<()>,
addr: Option<String>,
}
impl AsyncRead for ClientExitCryptPipe {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
self.project().read_incoming.poll_read(cx, buf)
}
}
impl AsyncWrite for ClientExitCryptPipe {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
self.project().write_outgoing.poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
self.project().write_outgoing.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
self.project().write_outgoing.poll_close(cx)
}
}
impl ClientExitCryptPipe {
pub fn new(pipe: impl Pipe, read_key: [u8; 32], write_key: [u8; 32]) -> Self {
let addr = pipe.remote_addr().map(|s| s.to_string());
let (mut pipe_read, mut pipe_write) = pipe.split();
let (mut write_incoming, read_incoming) = bipe::bipe(32768);
let (write_outgoing, mut read_outgoing) = bipe::bipe(32768);
let _read_task = smolscale::spawn(async move {
let read_aead = ChaCha20Poly1305::new_from_slice(&read_key).unwrap();
let fallible = async {
for read_nonce in 0u64.. {
let msg = read_prepend_length(&mut pipe_read).await?;
let read_nonce = [0; 12]
.tap_mut(|nonce| nonce[..8].copy_from_slice(&read_nonce.to_le_bytes()));
let plaintext = read_aead
.decrypt(&read_nonce.into(), msg.as_slice())
.ok()
.context("cannot decrypt")?;
write_incoming.write_all(&plaintext).await?;
}
anyhow::Ok(())
};
if let Err(_err) = fallible.await {
}
});
let _write_task = smolscale::spawn(async move {
let fallible = async {
let write_aead = ChaCha20Poly1305::new_from_slice(&write_key).unwrap();
let mut buf = [0; 8192];
for write_nonce in 0u64.. {
let write_nonce = [0; 12]
.tap_mut(|nonce| nonce[..8].copy_from_slice(&write_nonce.to_le_bytes()));
let n = read_outgoing.read(&mut buf).await?;
let ciphertext = write_aead.encrypt(&write_nonce.into(), &buf[..n]).unwrap();
write_prepend_length(&ciphertext, &mut pipe_write).await?;
}
anyhow::Ok(())
};
if let Err(_err) = fallible.await {
}
});
Self {
read_incoming,
_read_task,
write_outgoing,
_write_task,
addr,
}
}
}
impl Pipe for ClientExitCryptPipe {
fn protocol(&self) -> &str {
"plain"
}
fn remote_addr(&self) -> Option<&str> {
self.addr.as_deref()
}
}