use crate::protocol::{
bytes_to_hash, create_request, create_response, create_fragment_response,
encode_request, encode_response, hash_to_key, is_fragmented, parse_message,
DataMessage as ProtoMessage, DataResponse, FRAGMENT_SIZE,
};
use crate::types::{
should_forward, ForwardRequest, ForwardTx, PeerId, PeerHTLConfig, PeerState,
SignalingMessage, DATA_CHANNEL_LABEL, MAX_HTL,
};
use bytes::Bytes;
use hashtree_core::{Hash, Store};
use lru::LruCache;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot, RwLock};
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::data_channel::RTCDataChannel;
use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit};
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::peer_connection::RTCPeerConnection;
#[derive(Debug, Error)]
pub enum PeerError {
#[error("WebRTC error: {0}")]
WebRTC(#[from] webrtc::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Channel closed")]
ChannelClosed,
#[error("Request timeout")]
Timeout,
#[error("Peer not ready")]
NotReady,
#[error("Data not found")]
NotFound,
}
const THEIR_REQUESTS_SIZE: usize = 200;
#[allow(dead_code)]
const FRAGMENT_STALL_TIMEOUT_MS: u64 = 5000;
#[allow(dead_code)]
const FRAGMENT_TOTAL_TIMEOUT_MS: u64 = 120000;
struct PendingRequest {
#[allow(dead_code)] hash: Hash,
response_tx: oneshot::Sender<Option<Vec<u8>>>,
}
#[derive(Debug, Clone)]
struct TheirRequest {
hash: Hash,
#[allow(dead_code)]
requested_at: std::time::Instant,
}
struct PendingReassembly {
#[allow(dead_code)] hash: Hash,
fragments: HashMap<u32, Vec<u8>>,
total_expected: u32,
received_bytes: usize,
#[allow(dead_code)]
first_fragment_at: std::time::Instant,
last_fragment_at: std::time::Instant,
}
pub type ForwardRequestCallback = Arc<
dyn Fn(Hash, String, u8) -> futures::future::BoxFuture<'static, Option<Vec<u8>>> + Send + Sync,
>;
async fn forward_via_channel(
forward_tx: &ForwardTx,
hash: Hash,
exclude_peer_id: String,
htl: u8,
) -> Option<Vec<u8>> {
let (response_tx, response_rx) = oneshot::channel();
let req = ForwardRequest {
hash,
exclude_peer_id,
htl,
response: response_tx,
};
if forward_tx.send(req).await.is_err() {
return None;
}
response_rx.await.ok().flatten()
}
pub struct Peer<S: Store> {
pub remote_id: PeerId,
state: Arc<RwLock<PeerState>>,
connection: Arc<RTCPeerConnection>,
data_channel: Arc<RwLock<Option<Arc<RTCDataChannel>>>>,
pending_candidates: Arc<RwLock<Vec<RTCIceCandidateInit>>>,
pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
their_requests: Arc<RwLock<LruCache<String, TheirRequest>>>,
pending_reassemblies: Arc<RwLock<HashMap<String, PendingReassembly>>>,
signaling_tx: mpsc::Sender<SignalingMessage>,
local_store: Arc<S>,
local_peer_id: String,
debug: bool,
htl_config: PeerHTLConfig,
forward_tx: Option<ForwardTx>,
on_forward_request: Option<ForwardRequestCallback>,
}
impl<S: Store + 'static> Peer<S> {
pub async fn new(
remote_id: PeerId,
local_peer_id: String,
signaling_tx: mpsc::Sender<SignalingMessage>,
local_store: Arc<S>,
debug: bool,
) -> Result<Self, PeerError> {
Self::with_forward_channel(remote_id, local_peer_id, signaling_tx, local_store, debug, None)
.await
}
pub async fn with_forward_channel(
remote_id: PeerId,
local_peer_id: String,
signaling_tx: mpsc::Sender<SignalingMessage>,
local_store: Arc<S>,
debug: bool,
forward_tx: Option<ForwardTx>,
) -> Result<Self, PeerError> {
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs()?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
let config = RTCConfiguration {
ice_servers: vec![RTCIceServer {
urls: vec![
"stun:stun.iris.to:3478".to_string(),
"stun:stun.l.google.com:19302".to_string(),
"stun:stun.cloudflare.com:3478".to_string(),
],
..Default::default()
}],
..Default::default()
};
let connection = Arc::new(api.new_peer_connection(config).await?);
let peer = Self {
remote_id,
state: Arc::new(RwLock::new(PeerState::New)),
connection,
data_channel: Arc::new(RwLock::new(None)),
pending_candidates: Arc::new(RwLock::new(Vec::new())),
pending_requests: Arc::new(RwLock::new(HashMap::new())),
their_requests: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(THEIR_REQUESTS_SIZE).unwrap(),
))),
pending_reassemblies: Arc::new(RwLock::new(HashMap::new())),
signaling_tx,
local_store,
local_peer_id,
debug,
htl_config: PeerHTLConfig::random(),
forward_tx,
on_forward_request: None,
};
peer.setup_handlers().await?;
Ok(peer)
}
async fn setup_handlers(&self) -> Result<(), PeerError> {
let state = self.state.clone();
let data_channel = self.data_channel.clone();
let pending_requests = self.pending_requests.clone();
let their_requests = self.their_requests.clone();
let pending_reassemblies = self.pending_reassemblies.clone();
let local_store = self.local_store.clone();
let debug = self.debug;
let htl_config = self.htl_config;
let forward_tx = self.forward_tx.clone();
let on_forward_request = self.on_forward_request.clone();
let peer_id_str = self.remote_id.to_peer_string();
let state_clone = state.clone();
self.connection
.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
let state = state_clone.clone();
Box::pin(async move {
if debug {
println!("[Peer] Connection state changed: {:?}", s);
}
let mut state = state.write().await;
match s {
RTCPeerConnectionState::Connected => {
*state = PeerState::Connected;
if debug {
println!("[Peer] Connection established");
}
}
RTCPeerConnectionState::Disconnected
| RTCPeerConnectionState::Failed
| RTCPeerConnectionState::Closed => {
*state = PeerState::Disconnected;
if debug {
println!("[Peer] Connection closed: {:?}", s);
}
}
_ => {}
}
})
}));
let data_channel_clone = data_channel.clone();
let pending_requests_clone = pending_requests.clone();
let their_requests_clone = their_requests.clone();
let pending_reassemblies_clone = pending_reassemblies.clone();
let local_store_clone = local_store.clone();
let state_clone = state.clone();
let forward_tx_clone = forward_tx.clone();
let on_forward_clone = on_forward_request.clone();
let peer_id_clone = peer_id_str.clone();
self.connection.on_data_channel(Box::new(move |dc| {
let data_channel = data_channel_clone.clone();
let pending_requests = pending_requests_clone.clone();
let their_requests = their_requests_clone.clone();
let pending_reassemblies = pending_reassemblies_clone.clone();
let local_store = local_store_clone.clone();
let state = state_clone.clone();
let forward_tx = forward_tx_clone.clone();
let on_forward = on_forward_clone.clone();
let peer_id = peer_id_clone.clone();
Box::pin(async move {
if dc.label() == DATA_CHANNEL_LABEL {
Self::setup_data_channel_handlers(
dc.clone(),
pending_requests,
their_requests,
pending_reassemblies,
local_store,
debug,
htl_config,
forward_tx,
on_forward,
peer_id,
)
.await;
*data_channel.write().await = Some(dc);
*state.write().await = PeerState::Ready;
if debug {
println!("[Peer] Data channel opened (incoming)");
}
}
})
}));
let signaling_tx = self.signaling_tx.clone();
let local_peer_id = self.local_peer_id.clone();
let remote_id = self.remote_id.to_peer_string();
self.connection
.on_ice_candidate(Box::new(move |candidate: Option<RTCIceCandidate>| {
let signaling_tx = signaling_tx.clone();
let local_peer_id = local_peer_id.clone();
let remote_id = remote_id.clone();
Box::pin(async move {
if let Some(candidate) = candidate {
let json = candidate.to_json().unwrap();
let msg = SignalingMessage::Candidate {
peer_id: local_peer_id,
target_peer_id: remote_id,
candidate: json.candidate,
sdp_m_line_index: json.sdp_mline_index,
sdp_mid: json.sdp_mid,
};
let _ = signaling_tx.send(msg).await;
}
})
}));
let debug_clone = debug;
self.connection
.on_ice_connection_state_change(Box::new(move |s| {
if debug_clone {
println!("[Peer] ICE connection state: {:?}", s);
}
Box::pin(async {})
}));
let debug_clone2 = debug;
self.connection
.on_ice_gathering_state_change(Box::new(move |s| {
if debug_clone2 {
println!("[Peer] ICE gathering state: {:?}", s);
}
Box::pin(async {})
}));
Ok(())
}
async fn setup_data_channel_handlers(
dc: Arc<RTCDataChannel>,
pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
their_requests: Arc<RwLock<LruCache<String, TheirRequest>>>,
pending_reassemblies: Arc<RwLock<HashMap<String, PendingReassembly>>>,
local_store: Arc<S>,
debug: bool,
htl_config: PeerHTLConfig,
forward_tx: Option<ForwardTx>,
on_forward_request: Option<ForwardRequestCallback>,
peer_id: String,
) {
let pending_requests_clone = pending_requests.clone();
let their_requests_clone = their_requests.clone();
let pending_reassemblies_clone = pending_reassemblies.clone();
let local_store_clone = local_store.clone();
let dc_clone = dc.clone();
let forward_tx_clone = forward_tx.clone();
let on_forward_clone = on_forward_request.clone();
let peer_id_clone = peer_id.clone();
dc.on_message(Box::new(move |msg: DataChannelMessage| {
let pending_requests = pending_requests_clone.clone();
let their_requests = their_requests_clone.clone();
let pending_reassemblies = pending_reassemblies_clone.clone();
let local_store = local_store_clone.clone();
let dc = dc_clone.clone();
let forward_tx = forward_tx_clone.clone();
let on_forward = on_forward_clone.clone();
let peer_id = peer_id_clone.clone();
Box::pin(async move {
let data = msg.data.to_vec();
if data.is_empty() {
return;
}
let parsed = match parse_message(&data) {
Some(m) => m,
None => {
if debug {
println!("[Peer] Failed to parse message");
}
return;
}
};
match parsed {
ProtoMessage::Request(req) => {
let htl = req.htl.unwrap_or(MAX_HTL);
let hash_key = hash_to_key(&req.h);
if debug {
println!(
"[Peer] Request: hash={}..., htl={}",
&hash_key[..16.min(hash_key.len())],
htl
);
}
let hash_bytes = match bytes_to_hash(&req.h) {
Some(h) => h,
None => return,
};
let local_result = local_store.get(&hash_bytes).await;
if let Ok(Some(payload)) = local_result {
Self::send_response(&dc, &hash_bytes, payload, debug).await;
return;
}
let can_forward = forward_tx.is_some() || on_forward.is_some();
if can_forward && should_forward(htl) {
{
let mut their_reqs = their_requests.write().await;
their_reqs.put(
hash_key.clone(),
TheirRequest {
hash: hash_bytes,
requested_at: std::time::Instant::now(),
},
);
}
let forward_htl = htl_config.decrement(htl);
if debug {
println!(
"[Peer] Forwarding request htl={}->{}, hash={}...",
htl,
forward_htl,
&hash_key[..16.min(hash_key.len())]
);
}
let forward_result = if let Some(ref tx) = forward_tx {
forward_via_channel(tx, hash_bytes, peer_id.clone(), forward_htl)
.await
} else if let Some(ref forward_cb) = on_forward {
forward_cb(hash_bytes, peer_id.clone(), forward_htl).await
} else {
None
};
if let Some(payload) = forward_result {
their_requests.write().await.pop(&hash_key);
Self::send_response(&dc, &hash_bytes, payload, debug).await;
if debug {
println!(
"[Peer] Forward success for hash={}...",
&hash_key[..16.min(hash_key.len())]
);
}
return;
}
}
{
let mut their_reqs = their_requests.write().await;
their_reqs.put(
hash_key,
TheirRequest {
hash: hash_bytes,
requested_at: std::time::Instant::now(),
},
);
}
}
ProtoMessage::Response(res) => {
let hash_key = hash_to_key(&res.h);
let final_data = if is_fragmented(&res) {
Self::handle_fragment_response(
&res,
&pending_reassemblies,
debug,
)
.await
} else {
Some(res.d)
};
let final_data = match final_data {
Some(d) => d,
None => return, };
if debug {
println!(
"[Peer] Response: hash={}..., size={}",
&hash_key[..16.min(hash_key.len())],
final_data.len()
);
}
let mut requests = pending_requests.write().await;
if let Some(request) = requests.remove(&hash_key) {
let computed_hash = hashtree_core::sha256(&final_data);
if computed_hash.to_vec() == res.h {
let _ = request.response_tx.send(Some(final_data));
} else {
if debug {
println!("[Peer] Hash mismatch for response");
}
let _ = request.response_tx.send(None);
}
}
}
}
})
}));
}
async fn send_response(dc: &Arc<RTCDataChannel>, hash: &Hash, data: Vec<u8>, debug: bool) {
if data.len() <= FRAGMENT_SIZE {
let res = create_response(hash, data);
let encoded = encode_response(&res);
let _ = dc.send(&Bytes::from(encoded)).await;
} else {
let total_fragments = ((data.len() + FRAGMENT_SIZE - 1) / FRAGMENT_SIZE) as u32;
for i in 0..total_fragments {
let start = (i as usize) * FRAGMENT_SIZE;
let end = std::cmp::min(start + FRAGMENT_SIZE, data.len());
let fragment = data[start..end].to_vec();
let res = create_fragment_response(hash, fragment, i, total_fragments);
let encoded = encode_response(&res);
let _ = dc.send(&Bytes::from(encoded)).await;
if debug && i == 0 {
println!(
"[Peer] Sending {} fragments for hash",
total_fragments
);
}
}
}
}
async fn handle_fragment_response(
res: &DataResponse,
pending_reassemblies: &Arc<RwLock<HashMap<String, PendingReassembly>>>,
debug: bool,
) -> Option<Vec<u8>> {
let hash_key = hash_to_key(&res.h);
let now = std::time::Instant::now();
let index = res.i.unwrap();
let total = res.n.unwrap();
let mut reassemblies = pending_reassemblies.write().await;
let pending = reassemblies.entry(hash_key.clone()).or_insert_with(|| {
let hash = bytes_to_hash(&res.h).unwrap_or([0u8; 32]);
PendingReassembly {
hash,
fragments: HashMap::new(),
total_expected: total,
received_bytes: 0,
first_fragment_at: now,
last_fragment_at: now,
}
});
if !pending.fragments.contains_key(&index) {
pending.received_bytes += res.d.len();
pending.fragments.insert(index, res.d.clone());
pending.last_fragment_at = now;
}
if pending.fragments.len() == pending.total_expected as usize {
let total = pending.total_expected;
let mut assembled = Vec::with_capacity(pending.received_bytes);
for i in 0..total {
if let Some(fragment) = pending.fragments.get(&i) {
assembled.extend_from_slice(fragment);
}
}
reassemblies.remove(&hash_key);
if debug {
println!(
"[Peer] Reassembled {} fragments, {} bytes",
total,
assembled.len()
);
}
Some(assembled)
} else {
None }
}
pub async fn connect(&self) -> Result<(), PeerError> {
*self.state.write().await = PeerState::Connecting;
let dc_init = RTCDataChannelInit {
ordered: Some(false),
..Default::default()
};
let dc = self
.connection
.create_data_channel(DATA_CHANNEL_LABEL, Some(dc_init))
.await?;
Self::setup_data_channel_handlers(
dc.clone(),
self.pending_requests.clone(),
self.their_requests.clone(),
self.pending_reassemblies.clone(),
self.local_store.clone(),
self.debug,
self.htl_config,
self.forward_tx.clone(),
self.on_forward_request.clone(),
self.remote_id.to_peer_string(),
)
.await;
let data_channel = self.data_channel.clone();
let state = self.state.clone();
let debug = self.debug;
dc.on_open(Box::new(move || {
let _data_channel = data_channel.clone();
let state = state.clone();
Box::pin(async move {
*state.write().await = PeerState::Ready;
if debug {
println!("[Peer] Data channel opened (outgoing)");
}
})
}));
*self.data_channel.write().await = Some(dc);
let offer = self.connection.create_offer(None).await?;
self.connection.set_local_description(offer.clone()).await?;
let msg = SignalingMessage::Offer {
peer_id: self.local_peer_id.clone(),
target_peer_id: self.remote_id.to_peer_string(),
sdp: offer.sdp,
};
self.signaling_tx
.send(msg)
.await
.map_err(|_| PeerError::ChannelClosed)?;
Ok(())
}
pub async fn handle_signaling(&self, msg: SignalingMessage) -> Result<(), PeerError> {
match msg {
SignalingMessage::Offer { sdp, .. } => {
if self.debug {
println!("[Peer] Received offer, setting remote description");
}
let offer = RTCSessionDescription::offer(sdp)?;
self.connection.set_remote_description(offer).await?;
let candidates = self.pending_candidates.write().await.drain(..).collect::<Vec<_>>();
if self.debug && !candidates.is_empty() {
println!("[Peer] Adding {} pending candidates after offer", candidates.len());
}
for candidate in candidates {
self.connection.add_ice_candidate(candidate).await?;
}
let answer = self.connection.create_answer(None).await?;
self.connection.set_local_description(answer.clone()).await?;
let msg = SignalingMessage::Answer {
peer_id: self.local_peer_id.clone(),
target_peer_id: self.remote_id.to_peer_string(),
sdp: answer.sdp,
};
self.signaling_tx
.send(msg)
.await
.map_err(|_| PeerError::ChannelClosed)?;
*self.state.write().await = PeerState::Connecting;
}
SignalingMessage::Answer { sdp, .. } => {
if self.debug {
println!("[Peer] Received answer, setting remote description");
}
let answer = RTCSessionDescription::answer(sdp)?;
self.connection.set_remote_description(answer).await?;
let candidates = self.pending_candidates.write().await.drain(..).collect::<Vec<_>>();
if self.debug && !candidates.is_empty() {
println!("[Peer] Adding {} pending candidates after answer", candidates.len());
}
for candidate in candidates {
self.connection.add_ice_candidate(candidate).await?;
}
}
SignalingMessage::Candidate {
candidate,
sdp_m_line_index,
sdp_mid,
..
} => {
let init = RTCIceCandidateInit {
candidate: candidate.clone(),
sdp_mid,
sdp_mline_index: sdp_m_line_index,
..Default::default()
};
if self.connection.remote_description().await.is_some() {
if self.debug {
println!("[Peer] Adding ICE candidate: {}...", &candidate[..candidate.len().min(50)]);
}
self.connection.add_ice_candidate(init).await?;
} else {
if self.debug {
println!("[Peer] Queueing ICE candidate (no remote description yet)");
}
self.pending_candidates.write().await.push(init);
}
}
SignalingMessage::Candidates { candidates, .. } => {
for c in candidates {
let init = RTCIceCandidateInit {
candidate: c.candidate,
sdp_mid: c.sdp_mid,
sdp_mline_index: c.sdp_m_line_index,
..Default::default()
};
if self.connection.remote_description().await.is_some() {
self.connection.add_ice_candidate(init).await?;
} else {
self.pending_candidates.write().await.push(init);
}
}
}
_ => {}
}
Ok(())
}
pub async fn request(&self, hash: &Hash) -> Result<Option<Vec<u8>>, PeerError> {
self.request_with_htl(hash, MAX_HTL).await
}
pub async fn request_with_htl(&self, hash: &Hash, htl: u8) -> Result<Option<Vec<u8>>, PeerError> {
let state = *self.state.read().await;
if state != PeerState::Ready {
return Err(PeerError::NotReady);
}
let dc = self.data_channel.read().await;
let dc = dc.as_ref().ok_or(PeerError::NotReady)?;
let hash_key = hash_to_key(hash);
{
let requests = self.pending_requests.read().await;
if requests.contains_key(&hash_key) {
drop(requests);
}
}
let (tx, rx) = oneshot::channel();
self.pending_requests.write().await.insert(
hash_key.clone(),
PendingRequest {
hash: *hash,
response_tx: tx,
},
);
let send_htl = self.htl_config.decrement(htl);
let req = create_request(hash, send_htl);
let encoded = encode_request(&req);
dc.send(&Bytes::from(encoded)).await?;
if self.debug {
println!(
"[Peer] Sent request: htl={}, hash={}...",
send_htl,
&hash_key[..16.min(hash_key.len())]
);
}
match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
Ok(Ok(data)) => Ok(data),
Ok(Err(_)) => Err(PeerError::ChannelClosed),
Err(_) => {
self.pending_requests.write().await.remove(&hash_key);
Err(PeerError::Timeout)
}
}
}
pub async fn send_response_for_hash(
&self,
hash: &Hash,
data: Option<&[u8]>,
) -> Result<(), PeerError> {
let dc = self.data_channel.read().await;
let dc = dc.as_ref().ok_or(PeerError::NotReady)?;
if let Some(payload) = data {
Self::send_response(dc, hash, payload.to_vec(), self.debug).await;
}
Ok(())
}
pub async fn state(&self) -> PeerState {
*self.state.read().await
}
pub async fn close(&self) -> Result<(), PeerError> {
self.connection.close().await?;
*self.state.write().await = PeerState::Disconnected;
Ok(())
}
pub fn set_on_forward_request<F>(&mut self, callback: F)
where
F: Fn(Hash, String, u8) -> futures::future::BoxFuture<'static, Option<Vec<u8>>> + Send + Sync + 'static,
{
self.on_forward_request = Some(Arc::new(callback));
}
pub fn htl_config(&self) -> PeerHTLConfig {
self.htl_config
}
pub async fn send_data(&self, hash_hex: &str, data: &[u8]) -> Result<bool, PeerError> {
let their_req = {
let mut requests = self.their_requests.write().await;
requests.pop(hash_hex)
};
let Some(their_req) = their_req else {
return Ok(false);
};
let dc = self.data_channel.read().await;
let dc = dc.as_ref().ok_or(PeerError::NotReady)?;
Self::send_response(dc, &their_req.hash, data.to_vec(), self.debug).await;
if self.debug {
println!("[Peer] Sent data for hash: {}...", &hash_hex[..16.min(hash_hex.len())]);
}
Ok(true)
}
pub async fn has_requested(&self, hash_hex: &str) -> bool {
self.their_requests.read().await.peek(hash_hex).is_some()
}
pub async fn their_request_count(&self) -> usize {
self.their_requests.read().await.len()
}
pub async fn our_request_count(&self) -> usize {
self.pending_requests.read().await.len()
}
}