#![allow(clippy::unwrap_in_result)]
use std::{
mem,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use color_eyre::{eyre::eyre, Report};
use ed25519_zebra::{batch, Error, Signature, SigningKey, VerificationKeyBytes};
use futures::stream::{FuturesOrdered, StreamExt};
use futures::FutureExt;
use futures_core::Future;
use rand::thread_rng;
use tokio::sync::{oneshot::error::RecvError, watch};
use tower::{Service, ServiceExt};
use tower_batch_control::{Batch, BatchControl, RequestWeight};
use tower_fallback::Fallback;
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
type BatchVerifier = batch::Verifier;
type VerifyResult = Result<(), Error>;
type Sender = watch::Sender<Option<VerifyResult>>;
#[derive(Clone, Debug)]
struct Item(batch::Item);
impl RequestWeight for Item {}
impl<'msg, M: AsRef<[u8]> + ?Sized> From<(VerificationKeyBytes, Signature, &'msg M)> for Item {
fn from(tup: (VerificationKeyBytes, Signature, &'msg M)) -> Self {
Self(batch::Item::from(tup))
}
}
impl Item {
fn verify_single(self) -> VerifyResult {
self.0.verify_single()
}
}
struct Verifier {
batch: BatchVerifier,
tx: Sender,
}
impl Default for Verifier {
fn default() -> Self {
let batch = BatchVerifier::default();
let (tx, _) = watch::channel(None);
Self { batch, tx }
}
}
impl Verifier {
fn take(&mut self) -> (BatchVerifier, Sender) {
let batch = mem::take(&mut self.batch);
let (tx, _) = watch::channel(None);
let tx = mem::replace(&mut self.tx, tx);
(batch, tx)
}
fn verify(batch: BatchVerifier, tx: Sender) {
let result = batch.verify(thread_rng());
let _ = tx.send(Some(result));
}
fn flush_blocking(&mut self) {
let (batch, tx) = self.take();
tokio::task::block_in_place(|| rayon::spawn_fifo(|| Self::verify(batch, tx)));
}
async fn flush_spawning(batch: BatchVerifier, tx: Sender) {
let _ = tx.send(spawn_fifo(move || batch.verify(thread_rng())).await.ok());
}
}
impl Service<BatchControl<Item>> for Verifier {
type Response = ();
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<(), BoxError>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: BatchControl<Item>) -> Self::Future {
match req {
BatchControl::Item(Item(item)) => {
tracing::trace!("got ed25519 item");
self.batch.queue(item);
let mut rx = self.tx.subscribe();
Box::pin(async move {
match rx.changed().await {
Ok(()) => {
let result = rx.borrow()
.ok_or("threadpool unexpectedly dropped response channel sender. Is Zebra shutting down?")?;
if result.is_ok() {
tracing::trace!(?result, "validated ed25519 signature");
} else {
tracing::trace!(?result, "invalid ed25519 signature");
}
result.map_err(BoxError::from)
}
Err(_recv_error) => panic!("ed25519 verifier was dropped without flushing"),
}
})
}
BatchControl::Flush => {
tracing::trace!("got ed25519 flush command");
let (batch, tx) = self.take();
Box::pin(Self::flush_spawning(batch, tx).map(Ok))
}
}
}
}
impl Drop for Verifier {
fn drop(&mut self) {
self.flush_blocking();
}
}
async fn spawn_fifo<
E: 'static + std::error::Error + Sync + Send,
F: 'static + FnOnce() -> Result<(), E> + Send,
>(
f: F,
) -> Result<Result<(), E>, RecvError> {
let (rsp_tx, rsp_rx) = tokio::sync::oneshot::channel();
rayon::spawn_fifo(move || {
let _ = rsp_tx.send(f());
});
rsp_rx.await
}
async fn sign_and_verify<V>(
mut verifier: V,
n: usize,
bad_index: Option<usize>,
) -> Result<(), V::Error>
where
V: Service<Item, Response = ()>,
{
let mut results = FuturesOrdered::new();
for i in 0..n {
let span = tracing::trace_span!("sig", i);
let sk = SigningKey::new(thread_rng());
let vk_bytes = VerificationKeyBytes::from(&sk);
let msg = b"BatchVerifyTest";
let sig = if Some(i) == bad_index {
sk.sign(b"badmsg")
} else {
sk.sign(&msg[..])
};
verifier.ready().await?;
results.push_back(span.in_scope(|| verifier.call((vk_bytes, sig, msg).into())))
}
let mut numbered_results = results.enumerate();
while let Some((i, result)) = numbered_results.next().await {
if Some(i) == bad_index {
assert!(result.is_err());
} else {
result?;
}
}
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn batch_flushes_on_max_items_weight() -> Result<(), Report> {
use tokio::time::timeout;
let _init_guard = zebra_test::init();
let verifier = Batch::new(Verifier::default(), 10, 5, Duration::from_secs(1000));
timeout(Duration::from_secs(1), sign_and_verify(verifier, 100, None))
.await
.map_err(|e| eyre!(e))?
.map_err(|e| eyre!(e))?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn batch_flushes_on_max_latency() -> Result<(), Report> {
use tokio::time::timeout;
let _init_guard = zebra_test::init();
let verifier = Batch::new(Verifier::default(), 100, 10, Duration::from_millis(500));
timeout(Duration::from_secs(1), sign_and_verify(verifier, 10, None))
.await
.map_err(|e| eyre!(e))?
.map_err(|e| eyre!(e))?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn fallback_verification() -> Result<(), Report> {
let _init_guard = zebra_test::init();
let verifier = Fallback::new(
Batch::new(Verifier::default(), 10, 1, Duration::from_millis(100)),
tower::service_fn(|item: Item| async move { item.verify_single() }),
);
sign_and_verify(verifier, 100, Some(39))
.await
.map_err(|e| eyre!(e))?;
Ok(())
}