use super::Dest; use super::error::{NetworkError, NetworkResult};
use super::peer_transport::WireBuilder;
use super::wire_handle::WireHandle;
use crate::lifecycle::CredentialState;
use crate::wire::webrtc::WebRtcCoordinator;
use crate::wire::websocket::WebSocketConnection;
use actr_protocol::ActrId;
use actr_protocol::prost::Message as ProstMessage;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
pub struct DefaultWireBuilderConfig {
pub local_id_hex: String,
pub enable_webrtc: bool,
pub enable_websocket: bool,
pub discovered_ws_addresses: Arc<RwLock<HashMap<ActrId, String>>>,
pub credential_state: Option<CredentialState>,
}
impl Default for DefaultWireBuilderConfig {
fn default() -> Self {
Self {
local_id_hex: String::new(),
enable_webrtc: true,
enable_websocket: true,
discovered_ws_addresses: Arc::new(RwLock::new(HashMap::new())),
credential_state: None,
}
}
}
pub struct DefaultWireBuilder {
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
local_id_hex: String,
discovered_ws_addresses: Arc<RwLock<HashMap<ActrId, String>>>,
credential_state: Option<CredentialState>,
config: DefaultWireBuilderConfig,
}
impl DefaultWireBuilder {
pub fn new(
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
config: DefaultWireBuilderConfig,
) -> Self {
Self {
webrtc_coordinator,
local_id_hex: config.local_id_hex.clone(),
discovered_ws_addresses: config.discovered_ws_addresses.clone(),
credential_state: config.credential_state.clone(),
config,
}
}
async fn resolve_websocket_url(&self, dest: &Dest) -> Option<String> {
if let Dest::Actor(actor_id) = dest {
let map = self.discovered_ws_addresses.read().await;
if let Some(url) = map.get(actor_id) {
tracing::debug!(
"đ [Factory] Using discovered WebSocket URL for {}: {}",
actor_id,
url
);
return Some(url.clone());
}
}
None
}
}
#[async_trait]
impl WireBuilder for DefaultWireBuilder {
#[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
async fn create_connections(&self, dest: &Dest) -> NetworkResult<Vec<Arc<dyn WireHandle>>> {
self.create_connections_with_cancel(dest, None).await
}
#[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
async fn create_connections_with_cancel(
&self,
dest: &Dest,
cancel_token: Option<CancellationToken>,
) -> NetworkResult<Vec<Arc<dyn WireHandle>>> {
let mut connections: Vec<Arc<dyn WireHandle>> = Vec::new();
let check_cancelled = |token: &Option<CancellationToken>| -> NetworkResult<()> {
if let Some(t) = token {
if t.is_cancelled() {
return Err(NetworkError::ConnectionClosed(
"Connection creation cancelled".to_string(),
));
}
}
Ok(())
};
check_cancelled(&cancel_token)?;
if self.config.enable_websocket {
check_cancelled(&cancel_token)?;
if let Some(url) = self.resolve_websocket_url(dest).await {
tracing::debug!("đ [Factory] Create WebSocket Connect: {}", url);
let mut ws_conn =
WebSocketConnection::new(url).with_local_id(self.local_id_hex.clone());
if let Some(ref cred_state) = self.credential_state {
let credential = cred_state.credential().await;
let cred_bytes = credential.encode_to_vec();
use base64::Engine as _;
let cred_b64 = base64::engine::general_purpose::STANDARD.encode(&cred_bytes);
ws_conn = ws_conn.with_credential_b64(cred_b64);
}
connections.push(Arc::new(ws_conn) as Arc<dyn WireHandle>);
} else {
tracing::debug!(
"đ [Factory] No WebSocket URL available for {:?}, skipping WS connection",
dest
);
}
}
check_cancelled(&cancel_token)?;
if self.config.enable_webrtc {
if let Some(coordinator) = &self.webrtc_coordinator {
if dest.is_actor() {
tracing::debug!("đ [Factory] Creating WebRTC connection to: {:?}", dest);
check_cancelled(&cancel_token)?;
match coordinator
.create_connection(dest, cancel_token.clone())
.await
{
Ok(webrtc_conn) => {
if let Err(e) = check_cancelled(&cancel_token) {
if let Err(close_err) = webrtc_conn.close().await {
tracing::warn!(
"â ī¸ [Factory] Failed to close cancelled connection: {}",
close_err
);
}
return Err(e);
}
connections.push(Arc::new(webrtc_conn) as Arc<dyn WireHandle>);
}
Err(e) => {
tracing::warn!(
"â [Factory] WebRTC connection creation failed: {:?}: {}",
dest,
e
);
}
}
} else {
tracing::debug!(
"âšī¸ [Factory] WebRTC does not support this destination type, skipping"
);
}
} else {
tracing::warn!(
"â ī¸ [Factory] WebRTC is enabled but no WebRtcCoordinator was provided"
);
}
}
tracing::info!(
"⨠[Factory] Finished creating {} connections for {:?}",
connections.len(),
dest,
);
Ok(connections)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::ConnType;
use actr_protocol::ActrId;
#[tokio::test]
async fn test_no_ws_connection_without_discovery() {
let config = DefaultWireBuilderConfig {
enable_websocket: true,
enable_webrtc: false,
local_id_hex: "deadbeef".to_string(),
discovered_ws_addresses: Arc::new(RwLock::new(HashMap::new())),
credential_state: None,
};
let factory = DefaultWireBuilder::new(None, config);
let dest = Dest::actor(ActrId::default());
let connections = factory.create_connections(&dest).await.unwrap();
assert!(connections.is_empty());
}
#[tokio::test]
async fn test_ws_connection_from_discovery() {
let map = Arc::new(RwLock::new(HashMap::new()));
let actor_id = ActrId::default();
map.write()
.await
.insert(actor_id.clone(), "ws://localhost:9001".to_string());
let config = DefaultWireBuilderConfig {
enable_websocket: true,
enable_webrtc: false,
local_id_hex: "deadbeef".to_string(),
discovered_ws_addresses: map,
credential_state: None,
};
let factory = DefaultWireBuilder::new(None, config);
let dest = Dest::actor(actor_id);
let connections = factory.create_connections(&dest).await.unwrap();
assert_eq!(connections.len(), 1);
assert_eq!(connections[0].connection_type(), ConnType::WebSocket);
}
}