use crossbeam_channel as channel;
use std::collections::HashMap;
use std::fmt;
use std::mem;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use memuse::DynamicUsage;
use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, COMPACT_NOTE_SIZE};
use zcash_primitives::{block::BlockHash, transaction::TxId};
pub(crate) struct DecryptedNote<A, D: Domain> {
pub(crate) ivk_tag: A,
pub(crate) recipient: D::Recipient,
pub(crate) note: D::Note,
}
impl<A, D: Domain> fmt::Debug for DecryptedNote<A, D>
where
A: fmt::Debug,
D::IncomingViewingKey: fmt::Debug,
D::Recipient: fmt::Debug,
D::Note: fmt::Debug,
D::Memo: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DecryptedNote")
.field("ivk_tag", &self.ivk_tag)
.field("recipient", &self.recipient)
.field("note", &self.note)
.finish()
}
}
struct OutputIndex<V> {
output_index: usize,
value: V,
}
type OutputItem<A, D> = OutputIndex<DecryptedNote<A, D>>;
struct OutputReplier<A, D: Domain>(OutputIndex<channel::Sender<OutputItem<A, D>>>);
impl<A, D: Domain> DynamicUsage for OutputReplier<A, D> {
#[inline(always)]
fn dynamic_usage(&self) -> usize {
0
}
#[inline(always)]
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
(0, Some(0))
}
}
struct BatchReceiver<A, D: Domain>(channel::Receiver<OutputItem<A, D>>);
impl<A, D: Domain> DynamicUsage for BatchReceiver<A, D> {
fn dynamic_usage(&self) -> usize {
let num_items = self.0.len();
const ITEMS_PER_BLOCK: usize = 31;
let num_blocks = (num_items + ITEMS_PER_BLOCK - 1) / ITEMS_PER_BLOCK;
const PTR_SIZE: usize = std::mem::size_of::<usize>();
let item_size = std::mem::size_of::<OutputItem<A, D>>();
const ATOMIC_USIZE_SIZE: usize = std::mem::size_of::<AtomicUsize>();
let block_size = PTR_SIZE + ITEMS_PER_BLOCK * (item_size + ATOMIC_USIZE_SIZE);
num_blocks * block_size
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let usage = self.dynamic_usage();
(usage, Some(usage))
}
}
pub(crate) trait Tasks<Item> {
type Task: Task;
fn new() -> Self;
fn add_task(&self, item: Item) -> Self::Task;
fn run_task(&self, item: Item) {
let task = self.add_task(item);
rayon::spawn_fifo(|| task.run());
}
}
pub(crate) trait Task: Send + 'static {
fn run(self);
}
impl<Item: Task> Tasks<Item> for () {
type Task = Item;
fn new() -> Self {}
fn add_task(&self, item: Item) -> Self::Task {
item
}
}
pub(crate) struct WithUsage {
running_usage: Arc<AtomicUsize>,
}
impl DynamicUsage for WithUsage {
fn dynamic_usage(&self) -> usize {
self.running_usage.load(Ordering::Relaxed)
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let usage = self.dynamic_usage();
(usage, Some(usage))
}
}
impl<Item: Task + DynamicUsage> Tasks<Item> for WithUsage {
type Task = WithUsageTask<Item>;
fn new() -> Self {
Self {
running_usage: Arc::new(AtomicUsize::new(0)),
}
}
fn add_task(&self, item: Item) -> Self::Task {
let mut task = WithUsageTask {
item,
own_usage: 0,
running_usage: self.running_usage.clone(),
};
task.own_usage =
mem::size_of::<Arc<()>>() + mem::size_of_val(&task) + task.item.dynamic_usage();
self.running_usage
.fetch_add(task.own_usage, Ordering::SeqCst);
task
}
}
pub(crate) struct WithUsageTask<Item> {
item: Item,
own_usage: usize,
running_usage: Arc<AtomicUsize>,
}
impl<Item: Task> Task for WithUsageTask<Item> {
fn run(self) {
self.item.run();
self.running_usage
.fetch_sub(self.own_usage, Ordering::SeqCst);
}
}
pub(crate) struct Batch<A, D: BatchDomain, Output: ShieldedOutput<D, COMPACT_NOTE_SIZE>> {
tags: Vec<A>,
ivks: Vec<D::IncomingViewingKey>,
outputs: Vec<(D, Output)>,
repliers: Vec<OutputReplier<A, D>>,
}
impl<A, D, Output> DynamicUsage for Batch<A, D, Output>
where
A: DynamicUsage,
D: BatchDomain + DynamicUsage,
D::IncomingViewingKey: DynamicUsage,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE> + DynamicUsage,
{
fn dynamic_usage(&self) -> usize {
self.tags.dynamic_usage()
+ self.ivks.dynamic_usage()
+ self.outputs.dynamic_usage()
+ self.repliers.dynamic_usage()
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let (tags_lower, tags_upper) = self.tags.dynamic_usage_bounds();
let (ivks_lower, ivks_upper) = self.ivks.dynamic_usage_bounds();
let (outputs_lower, outputs_upper) = self.outputs.dynamic_usage_bounds();
let (repliers_lower, repliers_upper) = self.repliers.dynamic_usage_bounds();
(
tags_lower + ivks_lower + outputs_lower + repliers_lower,
tags_upper
.zip(ivks_upper)
.zip(outputs_upper)
.zip(repliers_upper)
.map(|(((a, b), c), d)| a + b + c + d),
)
}
}
impl<A, D, Output> Batch<A, D, Output>
where
A: Clone,
D: BatchDomain,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE>,
{
fn new(tags: Vec<A>, ivks: Vec<D::IncomingViewingKey>) -> Self {
assert_eq!(tags.len(), ivks.len());
Self {
tags,
ivks,
outputs: vec![],
repliers: vec![],
}
}
fn is_empty(&self) -> bool {
self.outputs.is_empty()
}
}
impl<A, D, Output> Task for Batch<A, D, Output>
where
A: Clone + Send + 'static,
D: BatchDomain + Send + 'static,
D::IncomingViewingKey: Send,
D::Memo: Send,
D::Note: Send,
D::Recipient: Send,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE> + Send + 'static,
{
fn run(self) {
let Self {
tags,
ivks,
outputs,
repliers,
} = self;
assert_eq!(outputs.len(), repliers.len());
let decryption_results = batch::try_compact_note_decryption(&ivks, &outputs);
for (decryption_result, OutputReplier(replier)) in
decryption_results.into_iter().zip(repliers.into_iter())
{
if let Some(((note, recipient), ivk_idx)) = decryption_result {
let result = OutputIndex {
output_index: replier.output_index,
value: DecryptedNote {
ivk_tag: tags[ivk_idx].clone(),
recipient,
note,
},
};
if replier.value.send(result).is_err() {
tracing::debug!("BatchRunner was dropped before batch finished");
break;
}
}
}
}
}
impl<A, D: BatchDomain, Output: ShieldedOutput<D, COMPACT_NOTE_SIZE> + Clone> Batch<A, D, Output> {
fn add_outputs(
&mut self,
domain: impl Fn() -> D,
outputs: &[Output],
replier: channel::Sender<OutputItem<A, D>>,
) {
self.outputs
.extend(outputs.iter().cloned().map(|output| (domain(), output)));
self.repliers.extend((0..outputs.len()).map(|output_index| {
OutputReplier(OutputIndex {
output_index,
value: replier.clone(),
})
}));
}
}
#[derive(PartialEq, Eq, Hash)]
struct ResultKey(BlockHash, TxId);
impl DynamicUsage for ResultKey {
#[inline(always)]
fn dynamic_usage(&self) -> usize {
0
}
#[inline(always)]
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
(0, Some(0))
}
}
pub(crate) struct BatchRunner<A, D, Output, T>
where
D: BatchDomain,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE>,
T: Tasks<Batch<A, D, Output>>,
{
batch_size_threshold: usize,
acc: Batch<A, D, Output>,
running_tasks: T,
pending_results: HashMap<ResultKey, BatchReceiver<A, D>>,
}
impl<A, D, Output, T> DynamicUsage for BatchRunner<A, D, Output, T>
where
A: DynamicUsage,
D: BatchDomain + DynamicUsage,
D::IncomingViewingKey: DynamicUsage,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE> + DynamicUsage,
T: Tasks<Batch<A, D, Output>> + DynamicUsage,
{
fn dynamic_usage(&self) -> usize {
self.acc.dynamic_usage()
+ self.running_tasks.dynamic_usage()
+ self.pending_results.dynamic_usage()
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let running_usage = self.running_tasks.dynamic_usage();
let bounds = (
self.acc.dynamic_usage_bounds(),
self.pending_results.dynamic_usage_bounds(),
);
(
bounds.0 .0 + running_usage + bounds.1 .0,
bounds
.0
.1
.zip(bounds.1 .1)
.map(|(a, b)| a + running_usage + b),
)
}
}
impl<A, D, Output, T> BatchRunner<A, D, Output, T>
where
A: Clone,
D: BatchDomain,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE>,
T: Tasks<Batch<A, D, Output>>,
{
pub(crate) fn new(
batch_size_threshold: usize,
ivks: impl Iterator<Item = (A, D::IncomingViewingKey)>,
) -> Self {
let (tags, ivks) = ivks.unzip();
Self {
batch_size_threshold,
acc: Batch::new(tags, ivks),
running_tasks: T::new(),
pending_results: HashMap::default(),
}
}
}
impl<A, D, Output, T> BatchRunner<A, D, Output, T>
where
A: Clone + Send + 'static,
D: BatchDomain + Send + 'static,
D::IncomingViewingKey: Clone + Send,
D::Memo: Send,
D::Note: Send,
D::Recipient: Send,
Output: ShieldedOutput<D, COMPACT_NOTE_SIZE> + Clone + Send + 'static,
T: Tasks<Batch<A, D, Output>>,
{
pub(crate) fn add_outputs(
&mut self,
block_tag: BlockHash,
txid: TxId,
domain: impl Fn() -> D,
outputs: &[Output],
) {
let (tx, rx) = channel::unbounded();
self.acc.add_outputs(domain, outputs, tx);
self.pending_results
.insert(ResultKey(block_tag, txid), BatchReceiver(rx));
if self.acc.outputs.len() >= self.batch_size_threshold {
self.flush();
}
}
pub(crate) fn flush(&mut self) {
if !self.acc.is_empty() {
let mut batch = Batch::new(self.acc.tags.clone(), self.acc.ivks.clone());
mem::swap(&mut batch, &mut self.acc);
self.running_tasks.run_task(batch);
}
}
pub(crate) fn collect_results(
&mut self,
block_tag: BlockHash,
txid: TxId,
) -> HashMap<(TxId, usize), DecryptedNote<A, D>> {
self.pending_results
.remove(&ResultKey(block_tag, txid))
.map(|BatchReceiver(rx)| {
rx.into_iter()
.map(
|OutputIndex {
output_index,
value,
}| { ((txid, output_index), value) },
)
.collect()
})
.unwrap_or_default()
}
}