use super::connection::Connection;
use super::webmedia::ConnectOptions;
use crate::crypto::aes::Aes128State;
use anyhow::{anyhow, Result};
use gloo::timers::callback::Interval;
use log::{debug, error, info, warn};
use protobuf::Message;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use videocall_diagnostics::{global_sender, metric, now_ms, DiagEvent};
use videocall_types::protos::media_packet::media_packet::MediaType;
use videocall_types::protos::media_packet::MediaPacket;
use videocall_types::protos::packet_wrapper::packet_wrapper::PacketType;
use videocall_types::protos::packet_wrapper::PacketWrapper;
use videocall_types::Callback;
use wasm_bindgen::JsValue;
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
Testing {
progress: f32,
servers_tested: usize,
total_servers: usize,
},
Connected {
server_url: String,
rtt: f64,
is_webtransport: bool,
},
Reconnecting {
server_url: String,
attempt: u32,
max_attempts: u32,
},
Failed {
error: String,
last_known_server: Option<String>,
},
}
#[derive(Debug, Clone)]
pub struct ServerRttMeasurement {
pub url: String,
pub is_webtransport: bool,
pub measurements: Vec<f64>,
pub average_rtt: Option<f64>,
pub connection_id: String,
pub active: bool,
pub connected: bool,
}
#[derive(Debug)]
pub enum ElectionState {
Testing {
start_time: f64,
duration_ms: u64,
probe_timer: Option<Interval>,
},
Elected {
connection_id: String,
elected_at: f64,
},
Failed {
reason: String,
failed_at: f64,
},
}
#[derive(Clone, Debug)]
pub struct ConnectionManagerOptions {
pub websocket_urls: Vec<String>,
pub webtransport_urls: Vec<String>,
pub userid: String,
pub on_inbound_media: Callback<PacketWrapper>,
pub on_state_changed: Callback<ConnectionState>,
pub peer_monitor: Callback<()>,
pub election_period_ms: u64,
}
#[derive(Debug)]
pub struct ConnectionManager {
connections: HashMap<String, Connection>,
active_connection_id: Rc<RefCell<Option<String>>>,
rtt_measurements: HashMap<String, ServerRttMeasurement>,
election_state: ElectionState,
rtt_reporter: Option<Interval>,
rtt_probe_timer: Option<Interval>,
election_timer: Option<Interval>,
rtt_responses: Rc<RefCell<Vec<(String, MediaPacket, f64)>>>, options: ConnectionManagerOptions,
aes: Rc<Aes128State>,
own_session_id: Rc<RefCell<Option<u64>>>,
pending_session_ids: Rc<RefCell<HashMap<String, u64>>>,
}
impl ConnectionManager {
pub fn new(options: ConnectionManagerOptions, aes: Rc<Aes128State>) -> Result<Self> {
let total_servers = options.websocket_urls.len() + options.webtransport_urls.len();
if total_servers == 0 {
return Err(anyhow!("No servers provided for connection testing"));
}
info!("ConnectionManager starting with {total_servers} servers");
let rtt_responses = Rc::new(RefCell::new(Vec::new()));
let mut manager = Self {
connections: HashMap::new(),
active_connection_id: Rc::new(RefCell::new(None)),
rtt_measurements: HashMap::new(),
election_state: ElectionState::Failed {
reason: "Not started".to_string(),
failed_at: js_sys::Date::now(),
},
rtt_reporter: None,
rtt_probe_timer: None,
election_timer: None,
rtt_responses,
options,
aes,
own_session_id: Rc::new(RefCell::new(None)),
pending_session_ids: Rc::new(RefCell::new(HashMap::new())),
};
manager.start_election()?;
Ok(manager)
}
fn start_election(&mut self) -> Result<()> {
let election_duration = self.options.election_period_ms;
let start_time = js_sys::Date::now();
info!("Starting connection election for {election_duration}ms");
self.create_all_connections()?;
self.election_state = ElectionState::Testing {
start_time,
duration_ms: election_duration,
probe_timer: None, };
self.start_diagnostics_reporting();
self.report_state();
Ok(())
}
fn create_all_connections(&mut self) -> Result<()> {
for (i, url) in self.options.websocket_urls.iter().enumerate() {
let conn_id = format!("ws_{i}");
let connect_options = ConnectOptions {
websocket_url: url.clone(),
webtransport_url: String::new(), on_inbound_media: self.create_inbound_media_callback(conn_id.clone()),
on_connected: self.create_connected_callback(conn_id.clone()),
on_connection_lost: self
.create_connection_lost_callback(conn_id.clone(), url.clone()),
peer_monitor: self.options.peer_monitor.clone(),
};
match Connection::connect(false, connect_options, self.aes.clone()) {
Ok(connection) => {
self.connections.insert(conn_id.clone(), connection);
self.rtt_measurements.insert(
conn_id.clone(),
ServerRttMeasurement {
url: url.clone(),
is_webtransport: false,
measurements: Vec::new(),
average_rtt: None,
connection_id: conn_id.clone(),
active: false,
connected: false,
},
);
debug!("Created WebSocket connection {conn_id}: {url}");
}
Err(e) => {
error!("Failed to create WebSocket connection to {url}: {e}");
}
}
}
for (i, url) in self.options.webtransport_urls.iter().enumerate() {
let conn_id = format!("wt_{i}");
let connect_options = ConnectOptions {
websocket_url: String::new(), webtransport_url: url.clone(),
on_inbound_media: self.create_inbound_media_callback(conn_id.clone()),
on_connected: self.create_connected_callback(conn_id.clone()),
on_connection_lost: self
.create_connection_lost_callback(conn_id.clone(), url.clone()),
peer_monitor: self.options.peer_monitor.clone(),
};
match Connection::connect(true, connect_options, self.aes.clone()) {
Ok(connection) => {
self.connections.insert(conn_id.clone(), connection);
self.rtt_measurements.insert(
conn_id.clone(),
ServerRttMeasurement {
url: url.clone(),
is_webtransport: true,
measurements: Vec::new(),
average_rtt: None,
connection_id: conn_id.clone(),
active: false,
connected: false,
},
);
debug!("Created WebTransport connection {conn_id}: {url}");
}
Err(e) => {
error!("Failed to create WebTransport connection to {url}: {e}");
}
}
}
info!("Created {} connections for testing", self.connections.len());
if self.connections.len() == 1 {
info!("Only one connection created, waiting for it to be established before election");
}
Ok(())
}
fn create_inbound_media_callback(&self, connection_id: String) -> Callback<PacketWrapper> {
let userid = self.options.userid.clone();
let aes = self.aes.clone();
let on_inbound_media = self.options.on_inbound_media.clone();
let rtt_responses = self.rtt_responses.clone();
let own_session_id = self.own_session_id.clone();
let pending_session_ids = self.pending_session_ids.clone();
let active_connection_id = self.active_connection_id.clone();
Callback::from(move |packet: PacketWrapper| {
if packet.packet_type == PacketType::SESSION_ASSIGNED.into() {
let sid = packet.session_id;
info!(
"SESSION_ASSIGNED received on connection {}: {}",
connection_id, sid
);
let is_elected = active_connection_id
.borrow()
.as_deref()
.map(|id| id == connection_id)
.unwrap_or(false);
if is_elected {
info!("Applying SESSION_ASSIGNED immediately (connection already elected)");
*own_session_id.borrow_mut() = Some(sid);
on_inbound_media.emit(packet);
} else {
pending_session_ids
.borrow_mut()
.insert(connection_id.clone(), sid);
}
return;
}
if packet.email == userid {
let reception_time = js_sys::Date::now();
if let Ok(decrypted_data) = aes.decrypt(&packet.data) {
if let Ok(media_packet) = MediaPacket::parse_from_bytes(&decrypted_data) {
if media_packet.media_type == MediaType::RTT.into() {
debug!(
"RTT response received on connection {} at {}, sent at {}",
connection_id, reception_time, media_packet.timestamp
);
if let Ok(mut responses) = rtt_responses.try_borrow_mut() {
responses.push((
connection_id.clone(),
media_packet,
reception_time,
));
} else {
warn!("Unable to add RTT response to queue - queue is borrowed");
}
return;
}
}
}
}
if let Some(own_id) = *own_session_id.borrow() {
if packet.session_id != 0 && packet.session_id == own_id {
debug!(
"Rejecting packet from same session_id: {}",
packet.session_id
);
return;
}
}
on_inbound_media.emit(packet);
})
}
fn create_connected_callback(&self, connection_id: String) -> Callback<()> {
Callback::from(move |_| {
debug!("Connection {connection_id} established");
})
}
fn create_connection_lost_callback(
&self,
connection_id: String,
server_url: String,
) -> Callback<JsValue> {
let on_state_changed = self.options.on_state_changed.clone();
let active_connection_id = self.active_connection_id.clone();
Callback::from(move |error| {
warn!("Connection {connection_id} lost: {error:?}");
if Some(connection_id.as_str()) == active_connection_id.borrow().as_deref() {
*active_connection_id.borrow_mut() = None;
let failure_state = ConnectionState::Failed {
error: format!("Active connection {connection_id} lost"),
last_known_server: Some(server_url.clone()),
};
info!("Active connection lost, clearing internal state and emitting Failed state to trigger UI reconnection");
on_state_changed.emit(failure_state);
} else {
info!(
"Non-active connection lost: {connection_id}, current active: {:?}",
active_connection_id.borrow()
);
}
})
}
fn send_rtt_probe(&mut self, connection_id: &str) -> Result<()> {
let connection = self
.connections
.get(connection_id)
.ok_or_else(|| anyhow!("Connection {connection_id} not found"))?;
if !connection.is_connected() {
return Ok(()); }
if let Some(measurement) = self.rtt_measurements.get_mut(connection_id) {
measurement.connected = true;
}
let timestamp = js_sys::Date::now();
let rtt_packet = self.create_rtt_packet(timestamp)?;
connection.send_packet(rtt_packet);
debug!("Sent RTT probe to {connection_id} at timestamp {timestamp}");
Ok(())
}
fn create_rtt_packet(&self, timestamp: f64) -> Result<PacketWrapper> {
let media_packet = MediaPacket {
media_type: MediaType::RTT.into(),
email: self.options.userid.clone(),
timestamp,
..Default::default()
};
let data = self.aes.encrypt(&media_packet.write_to_bytes()?)?;
Ok(PacketWrapper {
packet_type: PacketType::MEDIA.into(),
email: self.options.userid.clone(),
data,
..Default::default()
})
}
fn handle_rtt_response(
&mut self,
connection_id: &str,
media_packet: &MediaPacket,
reception_time: f64,
) {
let sent_timestamp = media_packet.timestamp;
let rtt = reception_time - sent_timestamp;
if let Some(measurement) = self.rtt_measurements.get_mut(connection_id) {
measurement.measurements.push(rtt);
if measurement.measurements.len() > 10 {
measurement.measurements.remove(0);
}
let avg_rtt = measurement.measurements.iter().sum::<f64>()
/ measurement.measurements.len() as f64;
measurement.average_rtt = Some(avg_rtt);
}
}
fn complete_election(&mut self) {
info!("Completing connection election");
if let ElectionState::Testing { probe_timer, .. } = &mut self.election_state {
if let Some(timer) = probe_timer.take() {
timer.cancel();
}
}
match self.find_best_connection() {
Ok((connection_id, measurement)) => {
info!(
"Elected connection {}: {} (avg RTT: {}ms)",
connection_id,
measurement.url,
measurement.average_rtt.unwrap_or(0.0)
);
self.active_connection_id
.borrow_mut()
.replace(connection_id.clone());
if let Some(mut_measurement) = self.rtt_measurements.get_mut(&connection_id) {
mut_measurement.active = true;
}
self.election_state = ElectionState::Elected {
connection_id: connection_id.clone(),
elected_at: js_sys::Date::now(),
};
if let Some(sid) = self
.pending_session_ids
.borrow()
.get(&connection_id)
.copied()
{
info!(
"Applying pending SESSION_ASSIGNED for elected connection {}: {}",
connection_id, sid
);
*self.own_session_id.borrow_mut() = Some(sid);
if let Some(connection) = self.connections.get(&connection_id) {
connection.set_session_id(sid);
}
let wrapper = PacketWrapper {
packet_type: PacketType::SESSION_ASSIGNED.into(),
session_id: sid,
..Default::default()
};
self.options.on_inbound_media.emit(wrapper);
}
self.pending_session_ids.borrow_mut().clear();
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.start_heartbeat(self.options.userid.clone());
info!("Started heartbeat on elected connection {}", connection_id);
}
self.close_unused_connections();
self.report_state();
}
Err(e) => {
error!("Election failed: {e}");
self.election_state = ElectionState::Failed {
reason: e.to_string(),
failed_at: js_sys::Date::now(),
};
self.report_state();
}
}
}
fn find_best_connection(&self) -> Result<(String, ServerRttMeasurement)> {
let mut best_wt: Option<(String, ServerRttMeasurement)> = None;
let mut best_wt_rtt = f64::INFINITY;
let mut best_ws: Option<(String, ServerRttMeasurement)> = None;
let mut best_ws_rtt = f64::INFINITY;
for (connection_id, measurement) in &self.rtt_measurements {
if let Some(conn) = self.connections.get(connection_id) {
if !conn.is_connected() {
continue;
}
}
if let Some(avg_rtt) = measurement.average_rtt {
if measurement.measurements.is_empty() {
continue;
}
if measurement.is_webtransport {
if avg_rtt < best_wt_rtt {
best_wt_rtt = avg_rtt;
best_wt = Some((connection_id.clone(), measurement.clone()));
}
} else if avg_rtt < best_ws_rtt {
best_ws_rtt = avg_rtt;
best_ws = Some((connection_id.clone(), measurement.clone()));
}
}
}
if let Some(best) = best_wt {
return Ok(best);
}
best_ws.ok_or_else(|| anyhow!("No valid connections with RTT measurements found"))
}
fn close_unused_connections(&mut self) {
let active_connection_borrow = self.active_connection_id.borrow();
let active_id = active_connection_borrow.as_deref();
let mut to_remove = Vec::new();
for connection_id in self.connections.keys() {
if Some(connection_id.as_str()) != active_id {
to_remove.push(connection_id.clone());
}
}
for connection_id in to_remove {
self.connections.remove(&connection_id);
info!("Closed unused connection: {connection_id}");
}
}
fn start_diagnostics_reporting(&mut self) {
debug!("Diagnostics reporting initialized - will be triggered externally");
}
fn process_queued_rtt_responses(&mut self) {
let responses_to_process: Vec<(String, MediaPacket, f64)> =
if let Ok(mut responses) = self.rtt_responses.try_borrow_mut() {
responses.drain(..).collect()
} else {
Vec::new()
};
for (connection_id, media_packet, reception_time) in responses_to_process {
self.handle_rtt_response(&connection_id, &media_packet, reception_time);
}
}
pub fn trigger_diagnostics_report(&mut self) {
debug!(
"ConnectionManager::trigger_diagnostics_report called - state: {:?}",
self.election_state
);
self.process_queued_rtt_responses();
self.report_diagnostics();
}
fn report_diagnostics(&self) {
debug!(
"ConnectionManager::report_diagnostics - Active: {:?}, Election State: {:?}",
self.active_connection_id.borrow(),
self.election_state
);
let mut metrics = Vec::new();
match &self.election_state {
ElectionState::Testing {
start_time,
duration_ms,
..
} => {
let elapsed = js_sys::Date::now() - start_time;
let progress = (elapsed / *duration_ms as f64).min(1.0) as f32;
metrics.push(metric!("election_state", "testing"));
metrics.push(metric!("election_progress", progress as f64));
metrics.push(metric!("servers_total", self.connections.len() as u64));
}
ElectionState::Elected {
connection_id,
elected_at,
} => {
metrics.push(metric!("election_state", "elected"));
metrics.push(metric!("active_connection_id", connection_id.as_str()));
metrics.push(metric!("elected_at", *elected_at));
if let Some(measurement) = self.rtt_measurements.get(connection_id) {
if let Some(avg_rtt) = measurement.average_rtt {
metrics.push(metric!("active_server_rtt", avg_rtt));
metrics.push(metric!("active_server_url", measurement.url.as_str()));
metrics.push(metric!(
"active_server_type",
if measurement.is_webtransport {
"webtransport"
} else {
"websocket"
}
));
}
}
}
ElectionState::Failed { reason, failed_at } => {
metrics.push(metric!("election_state", "failed"));
metrics.push(metric!("failure_reason", reason.as_str()));
metrics.push(metric!("failed_at", *failed_at));
}
}
debug!(
"ConnectionManager: Prepared {} metrics for main event: {:?}",
metrics.len(),
metrics
);
if !metrics.is_empty() {
let event = DiagEvent {
subsystem: "connection_manager",
stream_id: None,
ts_ms: now_ms(),
metrics,
};
debug!(
"ConnectionManager: Sending main connection manager diagnostics event: {event:?}"
);
match global_sender().try_broadcast(event) {
Ok(_) => {
debug!("ConnectionManager: Successfully sent main connection manager diagnostics event");
}
Err(e) => {
error!(
"ConnectionManager: Failed to send main connection manager diagnostics: {e}"
);
}
}
} else {
warn!("ConnectionManager: No metrics to send for main connection manager event - this might be why UI shows 'unknown'");
}
for (connection_id, measurement) in &self.rtt_measurements {
let connected = self
.connections
.get(connection_id)
.map(|c| c.is_connected())
.unwrap_or(false);
let status = if measurement.active {
"active"
} else if connected {
if measurement.average_rtt.is_some() {
"testing"
} else {
"connected"
}
} else {
"connecting"
};
let server_metrics = vec![
metric!("server_url", measurement.url.as_str()),
metric!(
"server_type",
if measurement.is_webtransport {
"webtransport"
} else {
"websocket"
}
),
metric!("server_status", status),
metric!("server_active", measurement.active as u64),
metric!("server_connected", connected as u64),
metric!("measurement_count", measurement.measurements.len() as u64),
];
let mut final_metrics = server_metrics;
if let Some(avg_rtt) = measurement.average_rtt {
final_metrics.push(metric!("server_rtt", avg_rtt));
}
let event = DiagEvent {
subsystem: "connection_manager",
stream_id: Some(measurement.connection_id.clone()),
ts_ms: now_ms(),
metrics: final_metrics,
};
match global_sender().try_broadcast(event) {
Ok(_) => {
debug!(
"ConnectionManager: Successfully sent server diagnostics for {}",
measurement.connection_id
);
}
Err(e) => {
error!(
"ConnectionManager: Failed to send server diagnostics for {}: {}",
measurement.connection_id, e
);
}
}
}
}
fn report_state(&self) {
let state = match &self.election_state {
ElectionState::Testing {
start_time,
duration_ms,
..
} => {
let elapsed = js_sys::Date::now() - start_time;
let progress = (elapsed / *duration_ms as f64).min(1.0) as f32;
ConnectionState::Testing {
progress,
servers_tested: self.connections.len(),
total_servers: self.options.websocket_urls.len()
+ self.options.webtransport_urls.len(),
}
}
ElectionState::Elected { connection_id, .. } => {
if let Some(measurement) = self.rtt_measurements.get(connection_id) {
ConnectionState::Connected {
server_url: measurement.url.clone(),
rtt: measurement.average_rtt.unwrap_or(0.0),
is_webtransport: measurement.is_webtransport,
}
} else {
ConnectionState::Failed {
error: "Elected connection not found in measurements".to_string(),
last_known_server: None,
}
}
}
ElectionState::Failed { reason, .. } => ConnectionState::Failed {
error: reason.clone(),
last_known_server: self
.active_connection_id
.borrow()
.as_deref()
.and_then(|id| self.rtt_measurements.get(id))
.map(|m| m.url.clone()),
},
};
self.options.on_state_changed.emit(state);
}
pub fn send_packet(&self, packet: PacketWrapper) -> Result<()> {
if let Some(active_id) = self.active_connection_id.borrow().as_deref() {
if let Some(connection) = self.connections.get(active_id) {
connection.send_packet(packet);
return Ok(());
}
}
Err(anyhow!("No active connection available"))
}
pub fn set_video_enabled(&self, enabled: bool) -> Result<()> {
if let Some(active_id) = self.active_connection_id.borrow().as_deref() {
if let Some(connection) = self.connections.get(active_id) {
connection.set_video_enabled(enabled);
return Ok(());
}
}
Err(anyhow!("No active connection available"))
}
pub fn set_audio_enabled(&self, enabled: bool) -> Result<()> {
if let Some(active_id) = self.active_connection_id.borrow().as_deref() {
if let Some(connection) = self.connections.get(active_id) {
connection.set_audio_enabled(enabled);
return Ok(());
}
}
Err(anyhow!("No active connection available"))
}
pub fn set_screen_enabled(&self, enabled: bool) -> Result<()> {
if let Some(active_id) = self.active_connection_id.borrow().as_deref() {
if let Some(connection) = self.connections.get(active_id) {
connection.set_screen_enabled(enabled);
return Ok(());
}
}
Err(anyhow!("No active connection available"))
}
pub fn set_own_session_id(&self, session_id: u64) {
*self.own_session_id.borrow_mut() = Some(session_id);
if let Some(active_id) = self.active_connection_id.borrow().as_deref() {
if let Some(connection) = self.connections.get(active_id) {
connection.set_session_id(session_id);
}
}
debug!("Set own_session_id to {session_id}");
}
pub fn is_connected(&self) -> bool {
self.active_connection_id.borrow().is_some()
&& matches!(self.election_state, ElectionState::Elected { .. })
}
pub fn disconnect(&mut self) -> anyhow::Result<()> {
self.connections.clear();
self.get_connection_state();
Ok(())
}
pub fn get_rtt_measurements(&self) -> &HashMap<String, ServerRttMeasurement> {
&self.rtt_measurements
}
pub fn send_rtt_probes(&mut self) -> Result<()> {
for connection_id in self.connections.keys().cloned().collect::<Vec<_>>() {
if let Err(e) = self.send_rtt_probe(&connection_id) {
debug!("Failed to send RTT probe to {connection_id}: {e}");
}
}
Ok(())
}
pub fn check_and_complete_election(&mut self) {
if let ElectionState::Testing {
start_time,
duration_ms,
..
} = &self.election_state
{
let elapsed = js_sys::Date::now() - start_time;
if elapsed >= *duration_ms as f64 {
self.complete_election();
}
}
}
pub fn get_connection_state(&self) -> ConnectionState {
match &self.election_state {
ElectionState::Testing {
start_time,
duration_ms,
..
} => {
let elapsed = js_sys::Date::now() - start_time;
let progress = (elapsed / *duration_ms as f64).min(1.0) as f32;
ConnectionState::Testing {
progress,
servers_tested: self.connections.len(),
total_servers: self.options.websocket_urls.len()
+ self.options.webtransport_urls.len(),
}
}
ElectionState::Elected { connection_id, .. } => {
if let Some(measurement) = self.rtt_measurements.get(connection_id) {
ConnectionState::Connected {
server_url: measurement.url.clone(),
rtt: measurement.average_rtt.unwrap_or(0.0),
is_webtransport: measurement.is_webtransport,
}
} else {
ConnectionState::Failed {
error: "Elected connection not found in measurements".to_string(),
last_known_server: None,
}
}
}
ElectionState::Failed { reason, .. } => ConnectionState::Failed {
error: reason.clone(),
last_known_server: self
.active_connection_id
.borrow()
.as_deref()
.and_then(|id| self.rtt_measurements.get(id))
.map(|m| m.url.clone()),
},
}
}
}
impl Drop for ConnectionManager {
fn drop(&mut self) {
if let Some(reporter) = self.rtt_reporter.take() {
reporter.cancel();
}
if let Some(probe_timer) = self.rtt_probe_timer.take() {
probe_timer.cancel();
}
if let Some(election_timer) = self.election_timer.take() {
election_timer.cancel();
}
if let ElectionState::Testing { probe_timer, .. } = &mut self.election_state {
if let Some(timer) = probe_timer.take() {
timer.cancel();
}
}
}
}