use super::Dest; use super::error::{NetworkError, NetworkResult};
use super::lane::DataLane;
use super::peer_transport::WireBuilder;
use super::wire_handle::WireHandle;
use super::wire_pool::ConnType;
use crate::lifecycle::CredentialState;
use crate::outbound::PendingRequestsMap;
use crate::wire::webrtc::WebRtcCoordinator;
use crate::wire::websocket::WebSocketConnection;
use actr_protocol::prost::Message as ProstMessage;
use actr_protocol::{ActrError, ActrId, PayloadType, RpcEnvelope};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
pub struct DefaultWireBuilderConfig {
pub local_id_hex: String,
pub enable_webrtc: bool,
pub enable_websocket: bool,
pub discovered_ws_addresses: Arc<RwLock<HashMap<ActrId, String>>>,
pub credential_state: Option<CredentialState>,
pub pending_requests: Option<PendingRequestsMap>,
}
impl Default for DefaultWireBuilderConfig {
fn default() -> Self {
Self {
local_id_hex: String::new(),
enable_webrtc: true,
enable_websocket: true,
discovered_ws_addresses: Arc::new(RwLock::new(HashMap::new())),
credential_state: None,
pending_requests: None,
}
}
}
pub struct DefaultWireBuilder {
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
local_id_hex: String,
discovered_ws_addresses: Arc<RwLock<HashMap<ActrId, String>>>,
credential_state: Option<CredentialState>,
pending_requests: Option<PendingRequestsMap>,
config: DefaultWireBuilderConfig,
}
impl DefaultWireBuilder {
pub fn new(
webrtc_coordinator: Option<Arc<WebRtcCoordinator>>,
config: DefaultWireBuilderConfig,
) -> Self {
Self {
webrtc_coordinator,
local_id_hex: config.local_id_hex.clone(),
discovered_ws_addresses: config.discovered_ws_addresses.clone(),
credential_state: config.credential_state.clone(),
pending_requests: config.pending_requests.clone(),
config,
}
}
async fn resolve_websocket_url(&self, dest: &Dest) -> Option<String> {
if let Dest::Actor(actor_id) = dest {
let map = self.discovered_ws_addresses.read().await;
if let Some(url) = map.get(actor_id) {
tracing::debug!(
"đ [Factory] Using discovered WebSocket URL for {}: {}",
actor_id,
url
);
return Some(url.clone());
}
}
None
}
}
#[derive(Debug)]
struct ClientWebSocketHandle {
inner: WebSocketConnection,
pending_requests: PendingRequestsMap,
}
impl ClientWebSocketHandle {
fn new(inner: WebSocketConnection, pending_requests: PendingRequestsMap) -> Self {
Self {
inner,
pending_requests,
}
}
fn spawn_response_readers(&self) {
for pt in [PayloadType::RpcReliable, PayloadType::RpcSignal] {
let conn = self.inner.clone();
let pending = self.pending_requests.clone();
tokio::spawn(async move {
let lane = match conn.get_lane(pt).await {
Ok(l) => l,
Err(e) => {
tracing::error!("ClientWebSocketHandle: get_lane({pt:?}) failed: {e:?}");
return;
}
};
tracing::debug!("ClientWebSocketHandle: response reader started for {pt:?}");
loop {
let data = match lane.recv().await {
Ok(d) => d,
Err(e) => {
tracing::info!("ClientWebSocketHandle: lane {pt:?} closed: {e:?}");
break;
}
};
match RpcEnvelope::decode(&data[..]) {
Ok(envelope) => {
let request_id = &envelope.request_id;
let mut guard = pending.write().await;
if let Some((_target, tx)) = guard.remove(request_id.as_str()) {
drop(guard);
let result: actr_protocol::ActorResult<actr_framework::Bytes> =
match (envelope.payload, envelope.error) {
(Some(payload), None) => Ok(payload),
(None, Some(err)) => {
Err(crate::lifecycle::node::wire_code_to_actr_error(
err.code,
err.message,
))
}
_ => Err(ActrError::DecodeFailure(
"invalid RpcEnvelope: payload/error inconsistent"
.to_string(),
)),
};
let _ = tx.send(result);
} else {
drop(guard);
tracing::debug!(
request_id = %request_id,
"ClientWebSocketHandle: no pending request for incoming envelope, dropping"
);
}
}
Err(e) => {
tracing::error!(
"ClientWebSocketHandle: RpcEnvelope decode failed: {e:?}"
);
}
}
}
tracing::debug!("ClientWebSocketHandle: response reader exited for {pt:?}");
});
}
}
}
#[async_trait]
impl WireHandle for ClientWebSocketHandle {
fn connection_type(&self) -> ConnType {
ConnType::WebSocket
}
fn priority(&self) -> u8 {
self.inner.priority()
}
async fn connect(&self) -> NetworkResult<()> {
self.inner.connect().await?;
self.spawn_response_readers();
Ok(())
}
fn is_connected(&self) -> bool {
self.inner.is_connected()
}
async fn close(&self) -> NetworkResult<()> {
self.inner.close().await
}
async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<Arc<dyn DataLane>> {
self.inner.get_lane(payload_type).await
}
}
#[async_trait]
impl WireBuilder for DefaultWireBuilder {
#[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
async fn create_connections(&self, dest: &Dest) -> NetworkResult<Vec<Arc<dyn WireHandle>>> {
self.create_connections_with_cancel(dest, None).await
}
#[cfg_attr(feature = "opentelemetry", tracing::instrument(skip_all))]
async fn create_connections_with_cancel(
&self,
dest: &Dest,
cancel_token: Option<CancellationToken>,
) -> NetworkResult<Vec<Arc<dyn WireHandle>>> {
let mut connections: Vec<Arc<dyn WireHandle>> = Vec::new();
let check_cancelled = |token: &Option<CancellationToken>| -> NetworkResult<()> {
if let Some(t) = token {
if t.is_cancelled() {
return Err(NetworkError::ConnectionClosed(
"Connection creation cancelled".to_string(),
));
}
}
Ok(())
};
check_cancelled(&cancel_token)?;
if self.config.enable_websocket {
check_cancelled(&cancel_token)?;
if let Some(url) = self.resolve_websocket_url(dest).await {
tracing::debug!("đ [Factory] Create WebSocket Connect: {}", url);
let mut ws_conn =
WebSocketConnection::new(url).with_local_id(self.local_id_hex.clone());
if let Some(ref cred_state) = self.credential_state {
let credential = cred_state.credential().await;
let cred_bytes = credential.encode_to_vec();
use base64::Engine as _;
let cred_b64 = base64::engine::general_purpose::STANDARD.encode(&cred_bytes);
ws_conn = ws_conn.with_credential_b64(cred_b64);
}
if let Some(ref pending) = self.pending_requests {
connections.push(
Arc::new(ClientWebSocketHandle::new(ws_conn, pending.clone()))
as Arc<dyn WireHandle>,
);
} else {
connections.push(Arc::new(ws_conn) as Arc<dyn WireHandle>);
}
} else {
tracing::debug!(
"đ [Factory] No WebSocket URL available for {:?}, skipping WS connection",
dest
);
}
}
check_cancelled(&cancel_token)?;
if self.config.enable_webrtc {
if let Some(coordinator) = &self.webrtc_coordinator {
if dest.is_actor() {
tracing::debug!("đ [Factory] Creating WebRTC connection to: {:?}", dest);
check_cancelled(&cancel_token)?;
match coordinator
.create_connection(dest, cancel_token.clone())
.await
{
Ok(webrtc_conn) => {
if let Err(e) = check_cancelled(&cancel_token) {
if let Err(close_err) = webrtc_conn.close().await {
tracing::warn!(
"â ī¸ [Factory] Failed to close cancelled connection: {}",
close_err
);
}
return Err(e);
}
connections.push(Arc::new(webrtc_conn) as Arc<dyn WireHandle>);
}
Err(e) => {
tracing::warn!(
"â [Factory] WebRTC connection creation failed: {:?}: {}",
dest,
e
);
}
}
} else {
tracing::debug!(
"âšī¸ [Factory] WebRTC does not support this destination type, skipping"
);
}
} else {
tracing::warn!(
"â ī¸ [Factory] WebRTC is enabled but no WebRtcCoordinator was provided"
);
}
}
tracing::info!(
"⨠[Factory] Finished creating {} connections for {:?}",
connections.len(),
dest,
);
Ok(connections)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::ConnType;
use actr_protocol::ActrId;
#[tokio::test]
async fn test_no_ws_connection_without_discovery() {
let config = DefaultWireBuilderConfig {
enable_websocket: true,
enable_webrtc: false,
local_id_hex: "deadbeef".to_string(),
discovered_ws_addresses: Arc::new(RwLock::new(HashMap::new())),
credential_state: None,
pending_requests: None,
};
let factory = DefaultWireBuilder::new(None, config);
let dest = Dest::actor(ActrId::default());
let connections = factory.create_connections(&dest).await.unwrap();
assert!(connections.is_empty());
}
#[tokio::test]
async fn test_ws_connection_from_discovery() {
let map = Arc::new(RwLock::new(HashMap::new()));
let actor_id = ActrId::default();
map.write()
.await
.insert(actor_id.clone(), "ws://localhost:9001".to_string());
let config = DefaultWireBuilderConfig {
enable_websocket: true,
enable_webrtc: false,
local_id_hex: "deadbeef".to_string(),
discovered_ws_addresses: map,
credential_state: None,
pending_requests: None,
};
let factory = DefaultWireBuilder::new(None, config);
let dest = Dest::actor(actor_id);
let connections = factory.create_connections(&dest).await.unwrap();
assert_eq!(connections.len(), 1);
assert_eq!(connections[0].connection_type(), ConnType::WebSocket);
}
}