use std::collections::BTreeMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use assert_matches::assert_matches;
use miden_protocol::MIN_PROOF_SECURITY_LEVEL;
use miden_protocol::account::auth::AuthScheme;
use miden_protocol::asset::{Asset, FungibleAsset};
use miden_protocol::batch::{ProposedBatch, ProvenBatch};
use miden_protocol::note::NoteType;
use miden_protocol::testing::account_id::{ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET, ACCOUNT_ID_SENDER};
use miden_protocol::transaction::{ExecutedTransaction, ProvenTransaction};
use miden_protocol::utils::serde::{Deserializable, Serializable};
use miden_testing::{Auth, MockChainBuilder};
use miden_tx::{LocalTransactionProver, TransactionVerifier};
use miden_tx_batch_prover::LocalBatchProver;
use serial_test::serial;
use crate::generated::api_client::ApiClient;
use crate::generated::{Proof, ProofRequest, ProofType};
use crate::server::Server;
use crate::server::proof_kind::ProofKind;
#[derive(Clone)]
struct Client {
inner: ApiClient<tonic::transport::Channel>,
}
impl Client {
async fn connect(port: u16) -> Self {
let inner = ApiClient::connect(format!("http://127.0.0.1:{port}"))
.await
.expect("client should connect");
Self { inner }
}
async fn submit_request(&mut self, request: ProofRequest) -> Result<Proof, tonic::Status> {
self.inner.prove(request).await.map(tonic::Response::into_inner)
}
}
impl ProofRequest {
fn from_tx(tx: &ExecutedTransaction) -> Self {
let tx_inputs = tx.tx_inputs().clone();
Self {
proof_type: ProofType::Transaction as i32,
payload: tx_inputs.to_bytes(),
}
}
fn from_batch(batch: &ProposedBatch) -> Self {
Self {
proof_type: ProofType::Batch as i32,
payload: batch.to_bytes(),
}
}
async fn mock_tx() -> ExecutedTransaction {
let mut mock_chain_builder = MockChainBuilder::new();
let account = mock_chain_builder
.add_existing_wallet(Auth::BasicAuth {
auth_scheme: AuthScheme::Falcon512Poseidon2,
})
.unwrap();
let fungible_asset_1: Asset =
FungibleAsset::new(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET.try_into().unwrap(), 100)
.unwrap()
.into();
let note_1 = mock_chain_builder
.add_p2id_note(
ACCOUNT_ID_SENDER.try_into().unwrap(),
account.id(),
&[fungible_asset_1],
NoteType::Private,
)
.unwrap();
let mock_chain = mock_chain_builder.build().unwrap();
let tx_context = mock_chain
.build_tx_context(account.id(), &[note_1.id()], &[])
.unwrap()
.disable_debug_mode()
.build()
.unwrap();
Box::pin(tx_context.execute()).await.unwrap()
}
async fn mock_batch() -> ProposedBatch {
let mut mock_chain_builder = MockChainBuilder::new();
let account = mock_chain_builder
.add_existing_wallet(Auth::BasicAuth {
auth_scheme: AuthScheme::Falcon512Poseidon2,
})
.unwrap();
let fungible_asset_1: Asset =
FungibleAsset::new(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET.try_into().unwrap(), 100)
.unwrap()
.into();
let note_1 = mock_chain_builder
.add_p2id_note(
ACCOUNT_ID_SENDER.try_into().unwrap(),
account.id(),
&[fungible_asset_1],
NoteType::Private,
)
.unwrap();
let mock_chain = mock_chain_builder.build().unwrap();
let tx = mock_chain
.build_tx_context(account.id(), &[note_1.id()], &[])
.unwrap()
.disable_debug_mode()
.build()
.unwrap();
let tx = Box::pin(tx.execute()).await.unwrap();
let tx = LocalTransactionProver::default().prove(tx.tx_inputs().clone()).await.unwrap();
ProposedBatch::new(
vec![Arc::new(tx)],
mock_chain.latest_block_header(),
mock_chain.latest_partial_blockchain(),
BTreeMap::new(),
)
.unwrap()
}
}
impl Server {
fn with_arbitrary_port(kind: ProofKind) -> Self {
Self {
port: 0,
kind,
timeout: Duration::from_secs(60),
capacity: NonZeroUsize::new(10).unwrap(),
}
}
fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = NonZeroUsize::new(capacity).unwrap();
self
}
fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
}
#[serial]
#[tokio::test(flavor = "multi_thread")]
async fn legacy_behaviour_with_capacity_1() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Transaction)
.with_capacity(1)
.spawn()
.await
.expect("server should spawn");
let request = ProofRequest::from_tx(&ProofRequest::mock_tx().await);
let mut client_a = Client::connect(port).await;
let mut client_b = client_a.clone();
let a = client_a.submit_request(request.clone());
let b = client_b.submit_request(request);
let (first, second) = tokio::join!(a, b);
assert!(first.is_ok() || second.is_ok());
assert!(first.is_err() || second.is_err());
let err = first.err().or(second.err()).unwrap();
assert_eq!(err.code(), tonic::Code::ResourceExhausted);
server.abort();
}
#[ignore = "Proving 3 requests concurrently causes temporary CI resource starvation which results in _sporadic_ timeouts"]
#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn capacity_is_respected() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Transaction)
.with_capacity(2)
.spawn()
.await
.expect("server should spawn");
let request = ProofRequest::from_tx(&ProofRequest::mock_tx().await);
let mut client_a = Client::connect(port).await;
let mut client_b = client_a.clone();
let mut client_c = client_a.clone();
let a = client_a.submit_request(request.clone());
let b = client_b.submit_request(request.clone());
let c = client_c.submit_request(request);
let (first, second, third) = tokio::join!(a, b, c);
let mut expected = [true, true, false];
let mut result = [first.is_ok(), second.is_ok(), third.is_ok()];
expected.sort_unstable();
result.sort_unstable();
assert_eq!(expected, result);
assert_matches!(first.err().or(second.err()).or(third.err()), Some(err) => {
assert_eq!(err.code(), tonic::Code::ResourceExhausted);
});
server.abort();
}
#[tokio::test(flavor = "multi_thread")]
async fn timeout_is_respected() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Transaction)
.with_timeout(Duration::from_nanos(10))
.spawn()
.await
.expect("server should spawn");
let request = ProofRequest::from_tx(&ProofRequest::mock_tx().await);
let mut client_a = Client::connect(port).await;
let mut client_b = Client::connect(port).await;
let a = client_a.submit_request(request.clone());
let b = client_b.submit_request(request);
let (a, b) = tokio::join!(a, b);
let err = a.err().or(b.err()).unwrap();
assert_eq!(err.code(), tonic::Code::Cancelled);
assert!(err.message().contains("Timeout expired"));
server.abort();
}
#[tokio::test(flavor = "multi_thread")]
async fn invalid_proof_kind_is_rejected() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Transaction)
.spawn()
.await
.expect("server should spawn");
let mut request = ProofRequest::from_tx(&ProofRequest::mock_tx().await);
request.proof_type = i32::MAX;
let mut client = Client::connect(port).await;
let response = client.submit_request(request).await;
let err = response.unwrap_err();
assert_eq!(err.code(), tonic::Code::InvalidArgument);
assert!(err.message().contains("unknown proof_type value"));
server.abort();
}
#[tokio::test(flavor = "multi_thread")]
async fn unsupported_proof_kind_is_rejected() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Batch)
.spawn()
.await
.expect("server should spawn");
let request = ProofRequest::from_tx(&ProofRequest::mock_tx().await);
let mut client = Client::connect(port).await;
let response = client.submit_request(request).await;
let err = response.unwrap_err();
assert_eq!(err.code(), tonic::Code::InvalidArgument);
assert!(err.message().contains("unsupported proof type"));
server.abort();
}
#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn transaction_proof_is_correct() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Transaction)
.spawn()
.await
.expect("server should spawn");
let tx = ProofRequest::mock_tx().await;
let request = ProofRequest::from_tx(&tx);
let mut client = Client::connect(port).await;
let response = client.submit_request(request).await.unwrap();
let response = ProvenTransaction::read_from_bytes(&response.payload).unwrap();
assert_eq!(response.id(), tx.id());
TransactionVerifier::new(MIN_PROOF_SECURITY_LEVEL).verify(&response).unwrap();
server.abort();
}
#[tokio::test(flavor = "multi_thread")]
#[serial]
async fn batch_proof_is_correct() {
let (server, port) = Server::with_arbitrary_port(ProofKind::Batch)
.spawn()
.await
.expect("server should spawn");
let batch = ProofRequest::mock_batch().await;
let request = ProofRequest::from_batch(&batch);
let mut client = Client::connect(port).await;
let response = client.submit_request(request).await.unwrap();
let response = ProvenBatch::read_from_bytes(&response.payload).unwrap();
let expected = tokio::task::block_in_place(|| {
LocalBatchProver::new(MIN_PROOF_SECURITY_LEVEL).prove(batch).unwrap()
});
assert_eq!(response, expected);
server.abort();
}