use crate::{
get_quorum_value, Error, GetRecordCfg, GetRecordError, Result, SwarmDriver, CLOSE_GROUP_SIZE,
};
use libp2p::{
kad::{self, PeerRecord, ProgressStep, QueryId, QueryResult, QueryStats, Record},
PeerId,
};
use sn_protocol::PrettyPrintRecordKey;
use std::collections::{hash_map::Entry, HashMap, HashSet};
use tokio::sync::oneshot;
use xor_name::XorName;
type GetRecordResultMap = HashMap<XorName, (Record, HashSet<PeerId>)>;
pub(crate) type PendingGetRecord = HashMap<
QueryId,
(
oneshot::Sender<std::result::Result<Record, GetRecordError>>,
GetRecordResultMap,
GetRecordCfg,
),
>;
impl SwarmDriver {
pub(crate) fn accumulate_get_record_found(
&mut self,
query_id: QueryId,
peer_record: PeerRecord,
stats: QueryStats,
step: ProgressStep,
) -> Result<()> {
let peer_id = if let Some(peer_id) = peer_record.peer {
peer_id
} else {
self.self_peer_id
};
if let Entry::Occupied(mut entry) = self.pending_get_record.entry(query_id) {
let (_sender, result_map, cfg) = entry.get_mut();
let pretty_key = PrettyPrintRecordKey::from(&peer_record.record.key).into_owned();
if !cfg.expected_holders.is_empty() {
if cfg.expected_holders.remove(&peer_id) {
debug!("For record {pretty_key:?} task {query_id:?}, received a copy from an expected holder {peer_id:?}");
} else {
debug!("For record {pretty_key:?} task {query_id:?}, received a copy from an unexpected holder {peer_id:?}");
}
}
let record_content_hash = XorName::from_content(&peer_record.record.value);
let responded_peers =
if let Entry::Occupied(mut entry) = result_map.entry(record_content_hash) {
let (_, peer_list) = entry.get_mut();
let _ = peer_list.insert(peer_id);
peer_list.len()
} else {
let mut peer_list = HashSet::new();
let _ = peer_list.insert(peer_id);
result_map.insert(record_content_hash, (peer_record.record.clone(), peer_list));
1
};
let expected_answers = get_quorum_value(&cfg.get_quorum);
trace!("Expecting {expected_answers:?} answers for record {pretty_key:?} task {query_id:?}, received {responded_peers} so far");
if responded_peers >= expected_answers {
if !cfg.expected_holders.is_empty() {
debug!("For record {pretty_key:?} task {query_id:?}, fetch completed with non-responded expected holders {:?}", cfg.expected_holders);
}
let cfg = cfg.clone();
let (sender, result_map, _) = entry.remove();
if result_map.len() == 1 {
Self::send_record_after_checking_target(sender, peer_record.record, &cfg)?;
} else {
debug!("For record {pretty_key:?} task {query_id:?}, fetch completed with split record");
sender
.send(Err(GetRecordError::SplitRecord { result_map }))
.map_err(|_| Error::InternalMsgChannelDropped)?;
}
if let Some(mut query) = self.swarm.behaviour_mut().kademlia.query_mut(&query_id) {
query.finish();
}
} else if usize::from(step.count) >= CLOSE_GROUP_SIZE {
debug!("For record {pretty_key:?} task {query_id:?}, got {:?} with {} versions so far.",
step.count, result_map.len());
}
} else {
return Err(Error::ReceivedKademliaEventDropped(
kad::Event::OutboundQueryProgressed {
id: query_id,
result: QueryResult::GetRecord(Ok(kad::GetRecordOk::FoundRecord(peer_record))),
stats,
step,
},
));
}
Ok(())
}
pub(crate) fn handle_get_record_finished(
&mut self,
query_id: QueryId,
step: ProgressStep,
) -> Result<()> {
if let Some((sender, result_map, cfg)) = self.pending_get_record.remove(&query_id) {
let num_of_versions = result_map.len();
let (result, log_string) = if let Some((record, from_peers)) =
result_map.values().next()
{
let result = if num_of_versions == 1 {
Err(GetRecordError::NotEnoughCopies {
record: record.clone(),
expected: get_quorum_value(&cfg.get_quorum),
got: from_peers.len(),
})
} else {
Err(GetRecordError::SplitRecord {
result_map: result_map.clone(),
})
};
(
result,
format!("Getting record {:?} completed with only {:?} copies received, and {num_of_versions} versions.",
PrettyPrintRecordKey::from(&record.key), usize::from(step.count) - 1)
)
} else {
(
Err(GetRecordError::RecordNotFound),
format!("Getting record task {query_id:?} completed with step count {:?}, but no copy found.", step.count),
)
};
if cfg.expected_holders.is_empty() {
debug!("{log_string}");
} else {
debug!(
"{log_string}, and {:?} expected holders not responded",
cfg.expected_holders
);
}
sender
.send(result)
.map_err(|_| Error::InternalMsgChannelDropped)?;
} else {
trace!("Can't locate query task {query_id:?} during GetRecord finished. We might have already returned the result to the sender.");
}
Ok(())
}
pub(crate) fn handle_get_record_error(
&mut self,
query_id: QueryId,
get_record_err: kad::GetRecordError,
stats: QueryStats,
step: ProgressStep,
) -> Result<()> {
match &get_record_err {
kad::GetRecordError::NotFound { .. } | kad::GetRecordError::QuorumFailed { .. } => {
let (sender, _, cfg) =
self.pending_get_record.remove(&query_id).ok_or_else(|| {
trace!("Can't locate query task {query_id:?}, it has likely been completed already.");
Error::ReceivedKademliaEventDropped( kad::Event::OutboundQueryProgressed {
id: query_id,
result: QueryResult::GetRecord(Err(get_record_err.clone())),
stats,
step,
})
})?;
if cfg.expected_holders.is_empty() {
info!("Get record task {query_id:?} failed with error {get_record_err:?}");
} else {
debug!("Get record task {query_id:?} failed with {:?} expected holders not responded, error {get_record_err:?}", cfg.expected_holders);
}
sender
.send(Err(GetRecordError::RecordNotFound))
.map_err(|_| Error::InternalMsgChannelDropped)?;
}
kad::GetRecordError::Timeout { key } => {
let pretty_key = PrettyPrintRecordKey::from(key);
let (sender, result_map, cfg) =
self.pending_get_record.remove(&query_id).ok_or_else(|| {
trace!(
"Can't locate query task {query_id:?} for {pretty_key:?}, it has likely been completed already."
);
Error::ReceivedKademliaEventDropped( kad::Event::OutboundQueryProgressed {
id: query_id,
result: QueryResult::GetRecord(Err(get_record_err.clone())),
stats,
step,
})
})?;
let required_response_count = get_quorum_value(&cfg.get_quorum);
if result_map.len() > 1 {
warn!(
"Get record task {query_id:?} for {pretty_key:?} timed out with split result map"
);
sender
.send(Err(GetRecordError::QueryTimeout))
.map_err(|_| Error::InternalMsgChannelDropped)?;
return Ok(());
}
if let Some((record, peers)) = result_map.values().next() {
if peers.len() >= required_response_count {
Self::send_record_after_checking_target(sender, record.clone(), &cfg)?;
return Ok(());
}
}
warn!("Get record task {query_id:?} for {pretty_key:?} returned insufficient responses. {:?} did not return record", cfg.expected_holders);
sender
.send(Err(GetRecordError::QueryTimeout))
.map_err(|_| Error::InternalMsgChannelDropped)?;
}
}
Ok(())
}
fn send_record_after_checking_target(
sender: oneshot::Sender<std::result::Result<Record, GetRecordError>>,
record: Record,
cfg: &GetRecordCfg,
) -> Result<()> {
if cfg.target_record.is_none() || cfg.does_target_match(&record) {
sender
.send(Ok(record))
.map_err(|_| Error::InternalMsgChannelDropped)
} else {
sender
.send(Err(GetRecordError::RecordDoesNotMatch(record)))
.map_err(|_| Error::InternalMsgChannelDropped)
}
}
}