use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
pub struct PeerEntry {
pub channel: Arc<dyn PeerLink>,
pub pool: PeerPool,
}
pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
peer_id: String,
#[allow(dead_code)]
pubkey: String,
transport: Arc<R>,
conn_factory: Arc<F>,
peers: RwLock<HashMap<String, PeerEntry>>,
pending_offers: RwLock<HashMap<String, ()>>,
pools: PoolSettings,
peer_roots: RwLock<HashMap<String, Vec<String>>>,
classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
debug: bool,
}
pub type SignalingManager<R, F> = MeshRouter<R, F>;
pub type PeerRouter<R, F> = MeshRouter<R, F>;
impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
pub fn new(
peer_id: String,
pubkey: String,
transport: Arc<R>,
conn_factory: Arc<F>,
pools: PoolSettings,
debug: bool,
) -> Self {
Self {
peer_id,
pubkey,
transport,
conn_factory,
peers: RwLock::new(HashMap::new()),
pending_offers: RwLock::new(HashMap::new()),
pools,
peer_roots: RwLock::new(HashMap::new()),
classifier_tx: None,
debug,
}
}
pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
self.classifier_tx = Some(tx);
}
pub fn peer_id(&self) -> &str {
&self.peer_id
}
pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
let msg = SignalingMessage::Hello {
peer_id: self.peer_id.clone(),
roots,
};
self.transport.publish(msg).await
}
async fn count_pools(&self) -> (usize, usize) {
let peers = self.peers.read().await;
let mut follows = 0;
let mut other = 0;
for entry in peers.values() {
match entry.pool {
PeerPool::Follows => follows += 1,
PeerPool::Other => other += 1,
}
}
(follows, other)
}
async fn classify_peer(&self, pubkey: &str) -> PeerPool {
if let Some(ref tx) = self.classifier_tx {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let request = ClassifyRequest {
pubkey: pubkey.to_string(),
response: response_tx,
};
if tx.send(request).await.is_ok() {
if let Ok(pool) = response_rx.await {
return pool;
}
}
}
PeerPool::Other
}
fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
match pool {
PeerPool::Follows => self.pools.follows.can_accept(follows),
PeerPool::Other => self.pools.other.can_accept(other),
}
}
fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
match pool {
PeerPool::Follows => self.pools.follows.needs_peers(follows),
PeerPool::Other => self.pools.other.needs_peers(other),
}
}
pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
match &msg {
SignalingMessage::Hello { peer_id, roots } => self.handle_hello(peer_id, roots).await,
SignalingMessage::Offer {
peer_id,
target_peer_id,
sdp,
} => {
if target_peer_id == &self.peer_id {
self.handle_offer(peer_id, sdp).await
} else {
Ok(()) }
}
SignalingMessage::Answer {
peer_id,
target_peer_id,
sdp,
} => {
if target_peer_id == &self.peer_id {
self.handle_answer(peer_id, sdp).await
} else {
Ok(()) }
}
SignalingMessage::Candidate { .. } | SignalingMessage::Candidates { .. } => {
Ok(())
}
}
}
async fn handle_hello(
&self,
from_peer_id: &str,
roots: &[String],
) -> Result<(), TransportError> {
if from_peer_id == self.peer_id {
return Ok(());
}
let peer_pubkey = from_peer_id.split(':').next().unwrap_or("");
let pool = self.classify_peer(peer_pubkey).await;
let (follows_count, other_count) = self.count_pools().await;
if !self.can_accept_peer(pool, follows_count, other_count) {
if self.debug {
println!(
"[Signaling] Ignoring hello from {} - {:?} pool full",
from_peer_id, pool
);
}
return Ok(());
}
self.peer_roots
.write()
.await
.insert(from_peer_id.to_string(), roots.to_vec());
if self.pool_needs_peers(pool, follows_count, other_count) {
if self.peers.read().await.contains_key(from_peer_id) {
return Ok(());
}
if self.pending_offers.read().await.contains_key(from_peer_id) {
return Ok(());
}
if self.debug {
println!(
"[Signaling] Sending offer to {} (pool: {:?})",
from_peer_id, pool
);
}
self.pending_offers
.write()
.await
.insert(from_peer_id.to_string(), ());
let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
self.peers
.write()
.await
.insert(from_peer_id.to_string(), PeerEntry { channel, pool });
let offer_msg = SignalingMessage::Offer {
peer_id: self.peer_id.clone(),
target_peer_id: from_peer_id.to_string(),
sdp,
};
self.transport.publish(offer_msg).await?;
}
Ok(())
}
async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
let peer_pubkey = from_peer_id.split(':').next().unwrap_or("");
let pool = self.classify_peer(peer_pubkey).await;
let (follows_count, other_count) = self.count_pools().await;
if !self.can_accept_peer(pool, follows_count, other_count) {
if self.debug {
println!(
"[Signaling] Ignoring offer from {} - {:?} pool full",
from_peer_id, pool
);
}
return Ok(());
}
let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
if have_pending {
let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
if we_are_polite {
self.pending_offers.write().await.remove(from_peer_id);
self.peers.write().await.remove(from_peer_id);
if self.debug {
println!(
"[Signaling] Collision with {} - we're polite, accepting their offer",
from_peer_id
);
}
} else {
if self.debug {
println!(
"[Signaling] Collision with {} - we're impolite, ignoring their offer",
from_peer_id
);
}
return Ok(());
}
}
if self.peers.read().await.contains_key(from_peer_id) {
return Ok(());
}
if self.debug {
println!("[Signaling] Accepting offer from {}", from_peer_id);
}
let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
self.peers
.write()
.await
.insert(from_peer_id.to_string(), PeerEntry { channel, pool });
let answer_msg = SignalingMessage::Answer {
peer_id: self.peer_id.clone(),
target_peer_id: from_peer_id.to_string(),
sdp: answer_sdp,
};
self.transport.publish(answer_msg).await?;
Ok(())
}
async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
if self.debug {
println!("[Signaling] Received answer from {}", from_peer_id);
}
let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
Ok(())
}
pub async fn peer_count(&self) -> usize {
self.peers.read().await.len()
}
pub async fn peer_ids(&self) -> Vec<String> {
self.peers.read().await.keys().cloned().collect()
}
pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
self.peers
.read()
.await
.get(peer_id)
.map(|e| e.channel.clone())
}
pub async fn needs_peers(&self) -> bool {
let (follows, other) = self.count_pools().await;
self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
}
pub async fn can_accept(&self) -> bool {
let (follows, other) = self.count_pools().await;
self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
}
}