use std::{convert::TryFrom, sync::Arc};
use futures::{Stream, StreamExt};
use log::*;
use tari_comms_dht::Dht;
use tari_p2p::{
comms_connector::{PeerMessage, SubscriptionFactory},
domain_message::DomainMessage,
tari_message::TariMessageType,
};
use tari_service_framework::{
ServiceInitializationError,
ServiceInitializer,
ServiceInitializerContext,
async_trait,
reply_channel,
};
use tari_transaction_components::transaction_components::Transaction;
use tokio::sync::mpsc;
use crate::{
base_node::comms_interface::LocalNodeCommsInterface,
mempool::{
mempool::Mempool,
service::{
MempoolHandle,
inbound_handlers::MempoolInboundHandlers,
local_service::LocalMempoolService,
outbound_interface::OutboundMempoolServiceInterface,
service::{MempoolService, MempoolStreams},
},
},
proto,
};
const LOG_TARGET: &str = "c::bn::mempool_service::initializer";
const SUBSCRIPTION_LABEL: &str = "Mempool";
pub struct MempoolServiceInitializer {
mempool: Mempool,
inbound_message_subscription_factory: Arc<SubscriptionFactory>,
}
impl MempoolServiceInitializer {
pub fn new(mempool: Mempool, inbound_message_subscription_factory: Arc<SubscriptionFactory>) -> Self {
Self {
mempool,
inbound_message_subscription_factory,
}
}
fn inbound_transaction_stream(&self) -> impl Stream<Item = DomainMessage<Transaction>> + use<> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::NewTransaction, SUBSCRIPTION_LABEL)
.filter_map(extract_transaction)
}
}
async fn extract_transaction(msg: Arc<PeerMessage>) -> Option<DomainMessage<Transaction>> {
match msg.decode_message::<proto::types::Transaction>() {
Err(e) => {
warn!(
target: LOG_TARGET,
"Could not decode inbound transaction message. {e}"
);
None
},
Ok(tx) => {
let tx = match Transaction::try_from(tx) {
Err(e) => {
warn!(
target: LOG_TARGET,
"Inbound transaction message from {} was ill-formed. {}", msg.source_peer.public_key, e
);
return None;
},
Ok(b) => b,
};
Some(DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: tx,
})
},
}
}
#[async_trait]
impl ServiceInitializer for MempoolServiceInitializer {
async fn initialize(&mut self, context: ServiceInitializerContext) -> Result<(), ServiceInitializationError> {
let inbound_transaction_stream = self.inbound_transaction_stream();
let (request_sender, request_receiver) = reply_channel::unbounded();
let mempool_handle = MempoolHandle::new(request_sender);
context.register_handle(mempool_handle);
let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded_channel();
let (local_request_sender_service, local_request_stream) = reply_channel::unbounded();
let outbound_mp_interface = OutboundMempoolServiceInterface::new(outbound_tx_sender);
let local_mp_interface = LocalMempoolService::new(local_request_sender_service);
let inbound_handlers = MempoolInboundHandlers::new(self.mempool.clone(), outbound_mp_interface.clone());
context.register_handle(outbound_mp_interface);
context.register_handle(local_mp_interface);
context.spawn_until_shutdown(move |handles| {
let outbound_message_service = handles.expect_handle::<Dht>().outbound_requester();
let base_node = handles.expect_handle::<LocalNodeCommsInterface>();
let streams = MempoolStreams {
outbound_tx_stream,
inbound_transaction_stream,
local_request_stream,
block_event_stream: base_node.get_block_event_stream(),
request_receiver,
};
debug!(target: LOG_TARGET, "Mempool service started");
MempoolService::new(outbound_message_service, inbound_handlers).start(streams)
});
Ok(())
}
}