use std::collections::HashMap;
use std::fmt;
use std::mem;
use std::sync::atomic::AtomicUsize;
use crossbeam_channel as channel;
use orchard::note_encryption::{CompactAction, OrchardDomain};
use sapling_crypto::note_encryption::{CompactOutputDescription, SaplingDomain};
use zcash_client_backend::proto::compact_formats::CompactBlock;
use zcash_note_encryption::{BatchDomain, COMPACT_NOTE_SIZE, Domain, ShieldedOutput, batch};
use zcash_primitives::{
block::BlockHash, transaction::TxId, transaction::components::sapling::zip212_enforcement,
};
use zcash_protocol::{ShieldedProtocol, consensus};
use memuse::DynamicUsage;
use crate::keys::KeyId;
use crate::keys::ScanningKeyOps as _;
use crate::keys::ScanningKeys;
use crate::wallet::OutputId;
type TaggedSaplingBatch = Batch<
SaplingDomain,
sapling_crypto::note_encryption::CompactOutputDescription,
CompactDecryptor,
>;
type TaggedSaplingBatchRunner<Tasks> = BatchRunner<
SaplingDomain,
sapling_crypto::note_encryption::CompactOutputDescription,
CompactDecryptor,
Tasks,
>;
type TaggedOrchardBatch =
Batch<OrchardDomain, orchard::note_encryption::CompactAction, CompactDecryptor>;
type TaggedOrchardBatchRunner<Tasks> =
BatchRunner<OrchardDomain, orchard::note_encryption::CompactAction, CompactDecryptor, Tasks>;
pub(crate) trait SaplingTasks: Tasks<TaggedSaplingBatch> {}
impl<T: Tasks<TaggedSaplingBatch>> SaplingTasks for T {}
pub(crate) trait OrchardTasks: Tasks<TaggedOrchardBatch> {}
impl<T: Tasks<TaggedOrchardBatch>> OrchardTasks for T {}
pub(crate) struct BatchRunners<TS: SaplingTasks, TO: OrchardTasks> {
pub(crate) sapling: TaggedSaplingBatchRunner<TS>,
pub(crate) orchard: TaggedOrchardBatchRunner<TO>,
}
impl<TS, TO> BatchRunners<TS, TO>
where
TS: SaplingTasks,
TO: OrchardTasks,
{
pub(crate) fn for_keys(batch_size_threshold: usize, scanning_keys: &ScanningKeys) -> Self {
BatchRunners {
sapling: BatchRunner::new(
batch_size_threshold,
scanning_keys
.sapling
.iter()
.map(|(id, key)| (*id, key.prepare())),
),
orchard: BatchRunner::new(
batch_size_threshold,
scanning_keys
.orchard
.iter()
.map(|(id, key)| (*id, key.prepare())),
),
}
}
pub(crate) fn flush(&mut self) {
self.sapling.flush();
self.orchard.flush();
}
#[tracing::instrument(skip_all, fields(height = block.height))]
pub(crate) fn add_block<P>(
&mut self,
params: &P,
block: CompactBlock,
) -> Result<(), zcash_client_backend::scanning::ScanError>
where
P: consensus::Parameters + Send + 'static,
{
let block_hash = block.hash();
let block_height = block.height();
let zip212_enforcement = zip212_enforcement(params, block_height);
for tx in block.vtx.into_iter() {
let txid = tx.txid();
self.sapling.add_outputs(
block_hash,
txid,
|_| SaplingDomain::new(zip212_enforcement),
&tx.outputs
.iter()
.enumerate()
.map(|(i, output)| {
CompactOutputDescription::try_from(output).map_err(|_| {
zcash_client_backend::scanning::ScanError::EncodingInvalid {
at_height: block_height,
txid,
pool_type: ShieldedProtocol::Sapling,
index: i,
}
})
})
.collect::<Result<Vec<_>, _>>()?,
);
self.orchard.add_outputs(
block_hash,
txid,
OrchardDomain::for_compact_action,
&tx.actions
.iter()
.enumerate()
.map(|(i, action)| {
CompactAction::try_from(action).map_err(|_| {
zcash_client_backend::scanning::ScanError::EncodingInvalid {
at_height: block_height,
txid,
pool_type: ShieldedProtocol::Orchard,
index: i,
}
})
})
.collect::<Result<Vec<_>, _>>()?,
);
}
Ok(())
}
}
pub(crate) struct DecryptedOutput<D: Domain, M> {
pub(crate) ivk_tag: KeyId,
pub(crate) recipient: D::Recipient,
pub(crate) note: D::Note,
pub(crate) memo: M,
}
impl<D: Domain, M> fmt::Debug for DecryptedOutput<D, M>
where
D::IncomingViewingKey: fmt::Debug,
D::Recipient: fmt::Debug,
D::Note: fmt::Debug,
M: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DecryptedOutput")
.field("ivk_tag", &self.ivk_tag)
.field("recipient", &self.recipient)
.field("note", &self.note)
.field("memo", &self.memo)
.finish()
}
}
pub(crate) trait Decryptor<D: BatchDomain, Output> {
type Memo;
fn batch_decrypt(
tags: &[KeyId],
ivks: &[D::IncomingViewingKey],
outputs: &[(D, Output)],
) -> Vec<Option<DecryptedOutput<D, Self::Memo>>>;
}
pub(crate) struct CompactDecryptor;
impl<D: BatchDomain, Output: ShieldedOutput<D, COMPACT_NOTE_SIZE>> Decryptor<D, Output>
for CompactDecryptor
{
type Memo = ();
fn batch_decrypt(
tags: &[KeyId],
ivks: &[D::IncomingViewingKey],
outputs: &[(D, Output)],
) -> Vec<Option<DecryptedOutput<D, Self::Memo>>> {
batch::try_compact_note_decryption(ivks, outputs)
.into_iter()
.map(|res| {
res.map(|((note, recipient), ivk_idx)| DecryptedOutput {
ivk_tag: tags[ivk_idx],
recipient,
note,
memo: (),
})
})
.collect()
}
}
struct OutputIndex<V> {
output_index: usize,
value: V,
}
type OutputItem<D, M> = OutputIndex<DecryptedOutput<D, M>>;
struct OutputReplier<D: Domain, M>(OutputIndex<channel::Sender<OutputItem<D, M>>>);
impl<D: Domain, M> DynamicUsage for OutputReplier<D, M> {
#[inline(always)]
fn dynamic_usage(&self) -> usize {
0
}
#[inline(always)]
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
(0, Some(0))
}
}
struct BatchReceiver<D: Domain, M>(channel::Receiver<OutputItem<D, M>>);
impl<D: Domain, M> DynamicUsage for BatchReceiver<D, M> {
fn dynamic_usage(&self) -> usize {
let num_items = self.0.len();
const ITEMS_PER_BLOCK: usize = 31;
let num_blocks = num_items.div_ceil(ITEMS_PER_BLOCK);
const PTR_SIZE: usize = std::mem::size_of::<usize>();
let item_size = std::mem::size_of::<OutputItem<D, M>>();
const ATOMIC_USIZE_SIZE: usize = std::mem::size_of::<AtomicUsize>();
let block_size = PTR_SIZE + ITEMS_PER_BLOCK * (item_size + ATOMIC_USIZE_SIZE);
num_blocks * block_size
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let usage = self.dynamic_usage();
(usage, Some(usage))
}
}
pub(crate) trait Tasks<Item> {
type Task: Task;
fn new() -> Self;
fn add_task(&self, item: Item) -> Self::Task;
fn run_task(&self, item: Item) {
let task = self.add_task(item);
rayon::spawn_fifo(|| task.run());
}
}
pub(crate) trait Task: Send + 'static {
fn run(self);
}
impl<Item: Task> Tasks<Item> for () {
type Task = Item;
fn new() -> Self {}
fn add_task(&self, item: Item) -> Self::Task {
item
}
}
pub(crate) struct Batch<D: BatchDomain, Output, Dec: Decryptor<D, Output>> {
tags: Vec<KeyId>,
ivks: Vec<D::IncomingViewingKey>,
outputs: Vec<(D, Output)>,
repliers: Vec<OutputReplier<D, Dec::Memo>>,
}
impl<D, Output, Dec> DynamicUsage for Batch<D, Output, Dec>
where
D: BatchDomain + DynamicUsage,
D::IncomingViewingKey: DynamicUsage,
Output: DynamicUsage,
Dec: Decryptor<D, Output>,
{
fn dynamic_usage(&self) -> usize {
self.tags.dynamic_usage()
+ self.ivks.dynamic_usage()
+ self.outputs.dynamic_usage()
+ self.repliers.dynamic_usage()
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let (tags_lower, tags_upper) = self.tags.dynamic_usage_bounds();
let (ivks_lower, ivks_upper) = self.ivks.dynamic_usage_bounds();
let (outputs_lower, outputs_upper) = self.outputs.dynamic_usage_bounds();
let (repliers_lower, repliers_upper) = self.repliers.dynamic_usage_bounds();
(
tags_lower + ivks_lower + outputs_lower + repliers_lower,
tags_upper
.zip(ivks_upper)
.zip(outputs_upper)
.zip(repliers_upper)
.map(|(((a, b), c), d)| a + b + c + d),
)
}
}
impl<D, Output, Dec> Batch<D, Output, Dec>
where
D: BatchDomain,
Dec: Decryptor<D, Output>,
{
fn new(tags: Vec<KeyId>, ivks: Vec<D::IncomingViewingKey>) -> Self {
assert_eq!(tags.len(), ivks.len());
Self {
tags,
ivks,
outputs: vec![],
repliers: vec![],
}
}
fn is_empty(&self) -> bool {
self.outputs.is_empty()
}
}
impl<D, Output, Dec> Task for Batch<D, Output, Dec>
where
D: BatchDomain + Send + 'static,
D::IncomingViewingKey: Send,
D::Memo: Send,
D::Note: Send,
D::Recipient: Send,
Output: Send + 'static,
Dec: Decryptor<D, Output> + 'static,
Dec::Memo: Send,
{
fn run(self) {
let Self {
tags,
ivks,
outputs,
repliers,
} = self;
assert_eq!(outputs.len(), repliers.len());
let decryption_results = Dec::batch_decrypt(&tags, &ivks, &outputs);
for (decryption_result, OutputReplier(replier)) in
decryption_results.into_iter().zip(repliers.into_iter())
{
if let Some(value) = decryption_result {
let result = OutputIndex {
output_index: replier.output_index,
value,
};
if replier.value.send(result).is_err() {
tracing::debug!("BatchRunner was dropped before batch finished");
break;
}
}
}
}
}
impl<D, Output, Dec> Batch<D, Output, Dec>
where
D: BatchDomain,
Output: Clone,
Dec: Decryptor<D, Output>,
{
fn add_outputs(
&mut self,
domain: impl Fn(&Output) -> D,
outputs: &[Output],
replier: channel::Sender<OutputItem<D, Dec::Memo>>,
) {
self.outputs.extend(
outputs
.iter()
.cloned()
.map(|output| (domain(&output), output)),
);
self.repliers.extend((0..outputs.len()).map(|output_index| {
OutputReplier(OutputIndex {
output_index,
value: replier.clone(),
})
}));
}
}
#[derive(PartialEq, Eq, Hash)]
struct ResultKey(BlockHash, TxId);
impl DynamicUsage for ResultKey {
#[inline(always)]
fn dynamic_usage(&self) -> usize {
0
}
#[inline(always)]
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
(0, Some(0))
}
}
pub(crate) struct BatchRunner<D, Output, Dec, T>
where
D: BatchDomain,
Dec: Decryptor<D, Output>,
T: Tasks<Batch<D, Output, Dec>>,
{
batch_size_threshold: usize,
acc: Batch<D, Output, Dec>,
running_tasks: T,
pending_results: HashMap<ResultKey, BatchReceiver<D, Dec::Memo>>,
}
impl<D, Output, Dec, T> DynamicUsage for BatchRunner<D, Output, Dec, T>
where
D: BatchDomain + DynamicUsage,
D::IncomingViewingKey: DynamicUsage,
Output: DynamicUsage,
Dec: Decryptor<D, Output>,
T: Tasks<Batch<D, Output, Dec>> + DynamicUsage,
{
fn dynamic_usage(&self) -> usize {
self.acc.dynamic_usage()
+ self.running_tasks.dynamic_usage()
+ self.pending_results.dynamic_usage()
}
fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
let running_usage = self.running_tasks.dynamic_usage();
let bounds = (
self.acc.dynamic_usage_bounds(),
self.pending_results.dynamic_usage_bounds(),
);
(
bounds.0.0 + running_usage + bounds.1.0,
bounds
.0
.1
.zip(bounds.1.1)
.map(|(a, b)| a + running_usage + b),
)
}
}
impl<D, Output, Dec, T> BatchRunner<D, Output, Dec, T>
where
D: BatchDomain,
Dec: Decryptor<D, Output>,
T: Tasks<Batch<D, Output, Dec>>,
{
pub(crate) fn new(
batch_size_threshold: usize,
ivks: impl Iterator<Item = (KeyId, D::IncomingViewingKey)>,
) -> Self {
let (tags, ivks) = ivks.unzip();
Self {
batch_size_threshold,
acc: Batch::new(tags, ivks),
running_tasks: T::new(),
pending_results: HashMap::default(),
}
}
}
impl<D, Output, Dec, T> BatchRunner<D, Output, Dec, T>
where
D: BatchDomain + Send + 'static,
D::IncomingViewingKey: Clone + Send,
D::Memo: Send,
D::Note: Send,
D::Recipient: Send,
Output: Clone + Send + 'static,
Dec: Decryptor<D, Output>,
T: Tasks<Batch<D, Output, Dec>>,
{
pub(crate) fn add_outputs(
&mut self,
block_tag: BlockHash,
txid: TxId,
domain: impl Fn(&Output) -> D,
outputs: &[Output],
) {
let (tx, rx) = channel::unbounded();
self.acc.add_outputs(domain, outputs, tx);
self.pending_results
.insert(ResultKey(block_tag, txid), BatchReceiver(rx));
if self.acc.outputs.len() >= self.batch_size_threshold {
self.flush();
}
}
pub(crate) fn flush(&mut self) {
if !self.acc.is_empty() {
let mut batch = Batch::new(self.acc.tags.clone(), self.acc.ivks.clone());
mem::swap(&mut batch, &mut self.acc);
self.running_tasks.run_task(batch);
}
}
pub(crate) fn collect_results(
&mut self,
block_tag: BlockHash,
txid: TxId,
) -> HashMap<OutputId, DecryptedOutput<D, Dec::Memo>> {
self.pending_results
.remove(&ResultKey(block_tag, txid))
.map(|BatchReceiver(rx)| {
rx.into_iter()
.map(
|OutputIndex {
output_index,
value,
}| {
(OutputId::new(txid, output_index as u16), value)
},
)
.collect()
})
.unwrap_or_default()
}
}