use crate::{
log,
network_service::{self, BitswapEvent, PeerId, SendBitswapMessageError},
platform::PlatformRef,
util,
};
use alloc::{
borrow::ToOwned,
boxed::Box,
collections::{BTreeSet, VecDeque},
format,
string::{String, ToString as _},
sync::Arc,
vec::Vec,
};
use core::{iter, pin::Pin, str::FromStr, time::Duration};
use futures_channel::oneshot;
use futures_lite::FutureExt as _;
use futures_util::{StreamExt as _, future, stream::FuturesUnordered};
use itertools::Itertools;
use rand::RngCore;
use rand_chacha::rand_core::SeedableRng as _;
use smoldot::{
json_rpc::parse,
libp2p::cid::{self, Cid, CidPrefix},
network::codec::{Block, BlockPresence, BlockPresenceType, WantType, build_bitswap_message},
};
const PARALLEL_REQUESTS: usize = 50;
pub struct Config<TPlat: PlatformRef> {
pub log_name: String,
pub platform: TPlat,
pub network_service: Arc<network_service::NetworkServiceChain<TPlat>>,
}
pub struct BitswapService {
messages_tx: async_channel::Sender<ToBackground>,
}
impl BitswapService {
pub fn new<TPlat: PlatformRef>(
Config {
log_name,
platform,
network_service,
}: Config<TPlat>,
) -> Self {
let (messages_tx, messages_rx) = async_channel::bounded(32);
let log_target = format!("bitswap-service-{}", log_name);
let task = Box::pin(background_task(BackgroundTask {
log_target: log_target.clone(),
messages_rx: Box::pin(messages_rx),
network_service,
from_network_service: None,
pending_have_broadcast: None,
pending_block_requests: FuturesUnordered::new(),
platform: platform.clone(),
next_request_id_inner: 0,
randomness: rand_chacha::ChaCha20Rng::from_seed({
let mut seed = [0; 32];
platform.fill_random_bytes(&mut seed);
seed
}),
requests: hashbrown::HashMap::with_capacity_and_hasher(
PARALLEL_REQUESTS,
fnv::FnvBuildHasher::default(),
),
requests_by_timeout: BTreeSet::new(),
requests_by_cid: hashbrown::HashMap::with_capacity_and_hasher(
PARALLEL_REQUESTS,
util::SipHasherBuild::new({
let mut seed = [0; 16];
platform.fill_random_bytes(&mut seed);
seed
}),
),
}));
platform.spawn_task(log_target.clone().into(), {
let platform = platform.clone();
async move {
task.await;
log!(&platform, Debug, &log_target, "shutdown");
}
});
BitswapService { messages_tx }
}
pub async fn bitswap_get(&self, cid: String) -> Result<Vec<u8>, BitswapGetError> {
let cid = Cid::from_str(&cid).map_err(BitswapGetError::InvalidCid)?;
let (result_tx, result_rx) = oneshot::channel();
self.messages_tx
.send(ToBackground::BitswapBlock { cid, result_tx })
.await
.unwrap();
result_rx.await.unwrap()
}
}
#[derive(Debug, derive_more::Display, derive_more::Error, Clone)]
pub enum BitswapGetError {
#[display("Invalid CID: {_0}")]
InvalidCid(cid::ParseError),
#[display("No Bitswap peers connected, can't issue \"have\" request.")]
NoPeers,
#[display("\"Block\" request to selected peer failed after successful \"have\" request.")]
BlockRequestFailed,
#[display("Network sending queue is full.")]
QueueFull,
#[display("No connected peers have the CID requested.")]
NotFound,
#[display("Request timeout.")]
Timeout,
}
enum BitswapJsonRpcError {
Fail = -32810,
FailRetry = -32811,
FailRetryBackoff = -32812,
}
impl BitswapGetError {
pub fn to_json_rpc_error(&self, request_id_json: &str) -> String {
let message = self.to_string();
let (variant, category) = match self {
BitswapGetError::InvalidCid(_) => ("InvalidCid", None),
BitswapGetError::NotFound => ("NotFound", Some(BitswapJsonRpcError::Fail)),
BitswapGetError::BlockRequestFailed => {
("BlockRequestFailed", Some(BitswapJsonRpcError::FailRetry))
}
BitswapGetError::Timeout => ("Timeout", Some(BitswapJsonRpcError::FailRetry)),
BitswapGetError::QueueFull => {
("QueueFull", Some(BitswapJsonRpcError::FailRetryBackoff))
}
BitswapGetError::NoPeers => ("NoPeers", Some(BitswapJsonRpcError::FailRetryBackoff)),
};
let data = format!("{{\"variant\":\"{variant}\"}}");
let error_response = match category {
None => parse::ErrorResponse::InvalidParams(Some(&message)),
Some(cat) => parse::ErrorResponse::ApplicationDefined(cat as i64, &message),
};
parse::build_error_response(request_id_json, error_response, Some(&data))
}
}
impl From<SendBitswapMessageError> for BitswapGetError {
fn from(error: SendBitswapMessageError) -> BitswapGetError {
match error {
SendBitswapMessageError::NoConnection => BitswapGetError::NoPeers,
SendBitswapMessageError::QueueFull => BitswapGetError::QueueFull,
}
}
}
enum ToBackground {
BitswapBlock {
cid: Cid,
result_tx: oneshot::Sender<Result<Vec<u8>, BitswapGetError>>,
},
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct RequestId(u64);
impl RequestId {
const _MIN: RequestId = RequestId(u64::MIN);
const MAX: RequestId = RequestId(u64::MAX);
}
#[derive(Debug)]
enum RequestStage {
Have(hashbrown::HashSet<PeerId, util::SipHasherBuild>),
Block,
}
#[derive(Debug)]
struct Request<TPlat: PlatformRef> {
result_tx: oneshot::Sender<Result<Vec<u8>, BitswapGetError>>,
timeout: TPlat::Instant,
stage: RequestStage,
cid: Cid,
}
type HaveBroadcastResult = (
Result<Vec<PeerId>, SendBitswapMessageError>,
Cid,
oneshot::Sender<Result<Vec<u8>, BitswapGetError>>,
);
struct BackgroundTask<TPlat: PlatformRef> {
log_target: String,
messages_rx: Pin<Box<async_channel::Receiver<ToBackground>>>,
network_service: Arc<network_service::NetworkServiceChain<TPlat>>,
from_network_service: Option<Pin<Box<async_channel::Receiver<network_service::BitswapEvent>>>>,
pending_have_broadcast:
Option<Pin<Box<dyn Future<Output = HaveBroadcastResult> + Send + Sync>>>,
pending_block_requests: FuturesUnordered<
Pin<Box<dyn Future<Output = (Result<(), SendBitswapMessageError>, Cid)> + Send + Sync>>,
>,
platform: TPlat,
next_request_id_inner: u64,
randomness: rand_chacha::ChaCha20Rng,
requests: hashbrown::HashMap<RequestId, Request<TPlat>, fnv::FnvBuildHasher>,
requests_by_timeout: BTreeSet<(TPlat::Instant, RequestId)>,
requests_by_cid: hashbrown::HashMap<Cid, VecDeque<RequestId>, util::SipHasherBuild>,
}
impl<TPlat: PlatformRef> BackgroundTask<TPlat> {
fn allocate_request_id(&mut self) -> RequestId {
let request_id = RequestId(self.next_request_id_inner);
self.next_request_id_inner += 1;
request_id
}
}
fn bitswap_have_message(cid: &Cid) -> Vec<u8> {
build_bitswap_message(iter::once(cid), WantType::Have, true, false)
}
fn bitswap_block_message(cid: &Cid) -> Vec<u8> {
build_bitswap_message(iter::once(cid), WantType::Block, false, false)
}
async fn background_task<TPlat: PlatformRef>(mut task: BackgroundTask<TPlat>) {
loop {
futures_lite::future::yield_now().await;
enum WakeUpReason {
MustSubscribeNetworkEvents,
NetworkEvent(network_service::BitswapEvent),
Message(ToBackground),
HaveBroadcastResult(HaveBroadcastResult),
BlockRequestResult((Result<(), SendBitswapMessageError>, Cid)),
RequestTimeout,
ForegroundClosed,
}
let wake_up_reason = {
let backpressure_messages = task.pending_have_broadcast.is_some();
async {
if let Some(from_network_service) = task.from_network_service.as_mut() {
match from_network_service.next().await {
Some(ev) => WakeUpReason::NetworkEvent(ev),
None => {
task.from_network_service = None;
WakeUpReason::MustSubscribeNetworkEvents
}
}
} else {
WakeUpReason::MustSubscribeNetworkEvents
}
}
.or(async {
if !backpressure_messages {
task.messages_rx
.next()
.await
.map_or(WakeUpReason::ForegroundClosed, WakeUpReason::Message)
} else {
future::pending().await
}
})
.or(async {
if let Some(pending_have_broadcast) = &mut task.pending_have_broadcast {
let result = pending_have_broadcast.await;
task.pending_have_broadcast = None;
WakeUpReason::HaveBroadcastResult(result)
} else {
future::pending().await
}
})
.or(async {
if !task.pending_block_requests.is_empty() {
let result = task
.pending_block_requests
.next()
.await
.expect("non-empty; qed");
WakeUpReason::BlockRequestResult(result)
} else {
future::pending().await
}
})
.or(async {
if let Some((first_timeout, _request_id)) = task.requests_by_timeout.first() {
let now = task.platform.now();
if now < *first_timeout {
task.platform.sleep(first_timeout.clone() - now).await;
}
WakeUpReason::RequestTimeout
} else {
future::pending().await
}
})
.await
};
match wake_up_reason {
WakeUpReason::MustSubscribeNetworkEvents => {
debug_assert!(task.from_network_service.is_none());
task.from_network_service = Some(Box::pin(
task.network_service.subscribe_bitswap().await,
));
}
WakeUpReason::Message(ToBackground::BitswapBlock { cid, result_tx }) => {
debug_assert!(task.pending_have_broadcast.is_none());
let message = bitswap_have_message(&cid);
let network_service = task.network_service.clone();
task.pending_have_broadcast = Some(Box::pin(async move {
let result = network_service.broadcast_bitswap_message(message).await;
(result, cid, result_tx)
}));
}
WakeUpReason::HaveBroadcastResult((result, cid, result_tx)) => {
let broadcast_to = match result {
Ok(peers) => peers,
Err(err) => {
let _ = result_tx.send(Err(err.into()));
continue;
}
};
let request_id = task.allocate_request_id();
let timeout = task.platform.now() + Duration::from_secs(10);
let have_peers = {
let mut have_peers = hashbrown::HashSet::with_capacity_and_hasher(
broadcast_to.len(),
util::SipHasherBuild::new({
let mut seed = [0; 16];
task.randomness.fill_bytes(&mut seed);
seed
}),
);
have_peers.extend(broadcast_to.into_iter());
have_peers
};
task.requests.insert(
request_id,
Request {
result_tx,
timeout: timeout.clone(),
stage: RequestStage::Have(have_peers),
cid: cid.clone(),
},
);
task.requests_by_timeout.insert((timeout, request_id));
task.requests_by_cid
.entry(cid)
.or_default()
.push_back(request_id);
}
WakeUpReason::NetworkEvent(BitswapEvent::BitswapMessage { peer_id, message }) => {
let message = message.decode();
for BlockPresence { cid, presence_type } in message.block_presences {
let cid = match Cid::from_bytes(cid.to_owned()) {
Ok(cid) => cid,
Err(error) => {
log!(
&task.platform,
Debug,
&task.log_target,
"error decoding CID",
peer_id,
error,
);
continue;
}
};
let hashbrown::hash_map::Entry::Occupied(mut entry) =
task.requests_by_cid.entry(cid.clone())
else {
log!(
&task.platform,
Trace,
&task.log_target,
"stale/unsolicited have response",
peer_id
);
continue;
};
let mut needs_block_request = false;
let request_ids = entry.get_mut();
for i in (0..request_ids.len()).rev() {
let request_id = request_ids[i];
let request = task.requests.get_mut(&request_id).unwrap();
match (&mut request.stage, presence_type) {
(RequestStage::Have(peers), BlockPresenceType::Have) => {
if peers.contains(&peer_id) {
request.stage = RequestStage::Block;
needs_block_request = true;
}
}
(RequestStage::Have(peers), BlockPresenceType::DontHave) => {
let _ = peers.remove(&peer_id);
if peers.is_empty() {
request_ids.remove(i);
let request = task.requests.remove(&request_id).unwrap();
let _was_in = task
.requests_by_timeout
.remove(&(request.timeout, request_id));
debug_assert!(_was_in);
let _ = request.result_tx.send(Err(BitswapGetError::NotFound));
}
}
(RequestStage::Block, _) => {}
}
}
if entry.get().is_empty() {
entry.remove();
}
if needs_block_request {
let message = bitswap_block_message(&cid);
let network_service = task.network_service.clone();
let peer_id = peer_id.clone();
task.pending_block_requests.push(Box::pin(async move {
let result =
network_service.send_bitswap_message(peer_id, message).await;
(result, cid)
}));
}
}
for Block { prefix, data } in message.payload {
let prefix = match CidPrefix::from_bytes(prefix.to_owned()) {
Ok(prefix) => prefix,
Err(error) => {
log!(
&task.platform,
Debug,
&task.log_target,
"error decoding CID prefix",
peer_id,
error,
);
continue;
}
};
let cid = prefix.with_digest_of(data);
if let Some(request_ids) = task.requests_by_cid.remove(&cid) {
for request_id in request_ids {
let request = task.requests.remove(&request_id).unwrap();
let _was_in = task
.requests_by_timeout
.remove(&(request.timeout, request_id));
debug_assert!(_was_in);
let _ = request.result_tx.send(Ok(data.to_owned()));
}
}
}
}
WakeUpReason::BlockRequestResult((result, cid)) => {
if let Err(err) = result {
if let Some(request_ids) = task.requests_by_cid.remove(&cid) {
let err = match err {
SendBitswapMessageError::QueueFull => BitswapGetError::QueueFull,
SendBitswapMessageError::NoConnection => {
BitswapGetError::BlockRequestFailed
}
};
for request_id in request_ids {
let request = task.requests.remove(&request_id).unwrap();
let _was_in = task
.requests_by_timeout
.remove(&(request.timeout, request_id));
debug_assert!(_was_in);
let _ = request.result_tx.send(Err(err.clone()));
}
}
}
}
WakeUpReason::RequestTimeout => {
let now = task.platform.now();
let requests = task
.requests_by_timeout
.range(..=(now, RequestId::MAX))
.cloned()
.collect::<Vec<_>>();
for (timeout, request_id) in requests {
task.requests_by_timeout.remove(&(timeout, request_id));
let request = task.requests.remove(&request_id).unwrap();
match task.requests_by_cid.entry(request.cid) {
hashbrown::hash_map::Entry::Occupied(mut entry) => {
let (index, _) = entry
.get()
.iter()
.find_position(|id| **id == request_id)
.unwrap();
entry.get_mut().remove(index);
if entry.get().is_empty() {
entry.remove();
}
}
hashbrown::hash_map::Entry::Vacant(_) => unreachable!(),
}
let _ = request.result_tx.send(Err(BitswapGetError::Timeout));
}
}
WakeUpReason::ForegroundClosed => {
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_error_code(json: &str) -> i64 {
let parsed: serde_json::Value = serde_json::from_str(json).unwrap();
parsed["error"]["code"].as_i64().unwrap()
}
fn extract_variant(json: &str) -> String {
let parsed: serde_json::Value = serde_json::from_str(json).unwrap();
parsed["error"]["data"]["variant"]
.as_str()
.unwrap()
.to_owned()
}
#[test]
fn error_invalid_cid_maps_to_invalid_params() {
let err = BitswapGetError::InvalidCid(Cid::from_str("not-a-cid").unwrap_err());
let json = err.to_json_rpc_error("\"1\"");
assert_eq!(extract_error_code(&json), -32602); assert_eq!(extract_variant(&json), "InvalidCid");
}
#[test]
fn error_not_found_maps_to_fail() {
let json = BitswapGetError::NotFound.to_json_rpc_error("\"1\"");
assert_eq!(extract_error_code(&json), -32810); assert_eq!(extract_variant(&json), "NotFound");
}
#[test]
fn error_block_request_failed_maps_to_fail_retry() {
let json = BitswapGetError::BlockRequestFailed.to_json_rpc_error("\"1\"");
assert_eq!(extract_error_code(&json), -32811); assert_eq!(extract_variant(&json), "BlockRequestFailed");
}
#[test]
fn error_timeout_maps_to_fail_retry() {
let json = BitswapGetError::Timeout.to_json_rpc_error("\"1\"");
assert_eq!(extract_error_code(&json), -32811); assert_eq!(extract_variant(&json), "Timeout");
}
#[test]
fn error_queue_full_maps_to_fail_retry_backoff() {
let json = BitswapGetError::QueueFull.to_json_rpc_error("\"1\"");
assert_eq!(extract_error_code(&json), -32812); assert_eq!(extract_variant(&json), "QueueFull");
}
#[test]
fn error_no_peers_maps_to_fail_retry_backoff() {
let json = BitswapGetError::NoPeers.to_json_rpc_error("\"1\"");
assert_eq!(extract_error_code(&json), -32812); assert_eq!(extract_variant(&json), "NoPeers");
}
#[test]
fn error_response_is_valid_jsonrpc() {
let json = BitswapGetError::NotFound.to_json_rpc_error("42");
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["jsonrpc"], "2.0");
assert_eq!(parsed["id"], 42);
assert!(parsed["error"]["message"].is_string());
}
#[test]
fn from_send_error_no_connection() {
let err: BitswapGetError = SendBitswapMessageError::NoConnection.into();
assert!(matches!(err, BitswapGetError::NoPeers));
}
#[test]
fn from_send_error_queue_full() {
let err: BitswapGetError = SendBitswapMessageError::QueueFull.into();
assert!(matches!(err, BitswapGetError::QueueFull));
}
}