use codec::Decode;
use tc_network::config::{IncomingRequest, OutgoingResponse, ProtocolId, RequestResponseConfig};
use tc_client_api::Backend;
use tp_runtime::traits::NumberFor;
use futures::channel::{mpsc, oneshot};
use futures::stream::StreamExt;
use log::debug;
use tp_runtime::traits::Block as BlockT;
use std::time::Duration;
use std::sync::Arc;
use tc_service::{SpawnTaskHandle, config::{Configuration, Role}};
use tc_finality_grandpa::WarpSyncFragmentCache;
pub fn request_response_config_for_chain<TBlock: BlockT, TBackend: Backend<TBlock> + 'static>(
config: &Configuration,
spawn_handle: SpawnTaskHandle,
backend: Arc<TBackend>,
) -> RequestResponseConfig
where NumberFor<TBlock>: tc_finality_grandpa::BlockNumberOps,
{
let protocol_id = config.protocol_id();
if matches!(config.role, Role::Light) {
generate_request_response_config(protocol_id.clone())
} else {
let (handler, request_response_config) = GrandpaWarpSyncRequestHandler::new(
protocol_id.clone(),
backend.clone(),
);
spawn_handle.spawn("grandpa_warp_sync_request_handler", handler.run());
request_response_config
}
}
const LOG_TARGET: &str = "finality-grandpa-warp-sync-request-handler";
pub fn generate_request_response_config(protocol_id: ProtocolId) -> RequestResponseConfig {
RequestResponseConfig {
name: generate_protocol_name(protocol_id).into(),
max_request_size: 32,
max_response_size: 16 * 1024 * 1024,
request_timeout: Duration::from_secs(10),
inbound_queue: None,
}
}
fn generate_protocol_name(protocol_id: ProtocolId) -> String {
let mut s = String::new();
s.push_str("/");
s.push_str(protocol_id.as_ref());
s.push_str("/sync/warp");
s
}
#[derive(codec::Decode)]
struct Request<B: BlockT> {
begin: B::Hash
}
const WARP_SYNC_FRAGMENTS_LIMIT: usize = 100;
const WARP_SYNC_CACHE_SIZE: usize = 20;
pub struct GrandpaWarpSyncRequestHandler<TBackend, TBlock: BlockT> {
backend: Arc<TBackend>,
cache: Arc<parking_lot::RwLock<WarpSyncFragmentCache<TBlock::Header>>>,
request_receiver: mpsc::Receiver<IncomingRequest>,
_phantom: std::marker::PhantomData<TBlock>
}
impl<TBlock: BlockT, TBackend: Backend<TBlock>> GrandpaWarpSyncRequestHandler<TBackend, TBlock> {
pub fn new(protocol_id: ProtocolId, backend: Arc<TBackend>) -> (Self, RequestResponseConfig) {
let (tx, request_receiver) = mpsc::channel(20);
let mut request_response_config = generate_request_response_config(protocol_id);
request_response_config.inbound_queue = Some(tx);
let cache = Arc::new(parking_lot::RwLock::new(WarpSyncFragmentCache::new(WARP_SYNC_CACHE_SIZE)));
(Self { backend, request_receiver, cache, _phantom: std::marker::PhantomData }, request_response_config)
}
fn handle_request(
&self,
payload: Vec<u8>,
pending_response: oneshot::Sender<OutgoingResponse>
) -> Result<(), HandleRequestError>
where NumberFor<TBlock>: tc_finality_grandpa::BlockNumberOps,
{
let request = Request::<TBlock>::decode(&mut &payload[..])?;
let mut cache = self.cache.write();
let response = tc_finality_grandpa::prove_warp_sync(
self.backend.blockchain(), request.begin, Some(WARP_SYNC_FRAGMENTS_LIMIT), Some(&mut cache)
)?;
pending_response.send(OutgoingResponse {
result: Ok(response),
reputation_changes: Vec::new(),
}).map_err(|_| HandleRequestError::SendResponse)
}
pub async fn run(mut self)
where NumberFor<TBlock>: tc_finality_grandpa::BlockNumberOps,
{
while let Some(request) = self.request_receiver.next().await {
let IncomingRequest { peer, payload, pending_response } = request;
match self.handle_request(payload, pending_response) {
Ok(()) => debug!(target: LOG_TARGET, "Handled grandpa warp sync request from {}.", peer),
Err(e) => debug!(
target: LOG_TARGET,
"Failed to handle grandpa warp sync request from {}: {}",
peer, e,
),
}
}
}
}
#[derive(derive_more::Display, derive_more::From)]
enum HandleRequestError {
#[display(fmt = "Failed to decode request: {}.", _0)]
DecodeProto(prost::DecodeError),
#[display(fmt = "Failed to encode response: {}.", _0)]
EncodeProto(prost::EncodeError),
#[display(fmt = "Failed to decode block hash: {}.", _0)]
DecodeScale(codec::Error),
Client(tp_blockchain::Error),
#[display(fmt = "Failed to send response.")]
SendResponse,
}