use crate::error::ZmqError;
use crate::message::Msg;
use crate::runtime::{Command, MailboxSender};
use crate::socket::core::SocketCore;
use crate::socket::patterns::LoadBalancer;
use crate::socket::ISocket;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use parking_lot::RwLock;
use tokio::time::timeout as tokio_timeout;
use crate::{delegate_to_core, Blob, MsgFlags};
#[derive(Debug)]
pub(crate) struct PushSocket {
core: Arc<SocketCore>,
load_balancer: LoadBalancer, pipe_read_to_endpoint_uri: RwLock<HashMap<usize, String>>,
}
impl PushSocket {
pub fn new(core: Arc<SocketCore>) -> Self {
Self {
core,
load_balancer: LoadBalancer::new(),
pipe_read_to_endpoint_uri: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl ISocket for PushSocket {
fn core(&self) -> &Arc<SocketCore> {
&self.core
}
fn mailbox(&self) -> MailboxSender {
self.core.command_sender()
}
async fn bind(&self, endpoint: &str) -> Result<(), ZmqError> {
delegate_to_core!(self, UserBind, endpoint: endpoint.to_string())
}
async fn connect(&self, endpoint: &str) -> Result<(), ZmqError> {
delegate_to_core!(self, UserConnect, endpoint: endpoint.to_string())
}
async fn disconnect(&self, endpoint: &str) -> Result<(), ZmqError> {
delegate_to_core!(self, UserDisconnect, endpoint: endpoint.to_string())
}
async fn unbind(&self, endpoint: &str) -> Result<(), ZmqError> {
delegate_to_core!(self, UserUnbind, endpoint: endpoint.to_string())
}
async fn close(&self) -> Result<(), ZmqError> {
delegate_to_core!(self, UserClose,)
}
async fn send(&self, msg: Msg) -> Result<(), ZmqError> {
if !self.core.is_running().await {
return Err(ZmqError::ResourceLimitReached);
}
let sndtimeo_opt: Option<Duration> = self.core.core_state.read().options.sndtimeo;
let endpoint_uri_to_send_to = loop {
if let Some(uri) = self.load_balancer.get_next_connection_uri() {
if self.core.core_state.read().endpoints.contains_key(&uri) {
break uri;
} else {
tracing::warn!(handle = self.core.handle, uri = %uri, "PUSH send: Stale URI found in LoadBalancer. Removing.");
self.load_balancer.remove_connection(&uri); }
} else {
if self.core.command_sender().is_closed() {
return Err(ZmqError::InvalidState("Socket terminated".into()));
}
match sndtimeo_opt {
Some(duration) if duration.is_zero() => return Err(ZmqError::ResourceLimitReached),
None => {
self.load_balancer.wait_for_connection().await;
}
Some(duration) => {
match tokio_timeout(duration, self.load_balancer.wait_for_connection()).await {
Ok(()) => { }
Err(_timeout_elapsed) => return Err(ZmqError::Timeout),
}
}
}
}
};
let connection_iface_arc = {
let core_s = self.core.core_state.read(); core_s
.endpoints
.get(&endpoint_uri_to_send_to)
.map(|ep_info| ep_info.connection_iface.clone())
.ok_or_else(|| {
tracing::warn!(
handle = self.core.handle,
uri = %endpoint_uri_to_send_to,
"PUSH send: EndpointInfo disappeared for URI from LoadBalancer *after* selection. Removing from LB."
);
self.load_balancer.remove_connection(&endpoint_uri_to_send_to);
ZmqError::HostUnreachable("Peer connection disappeared just before send".into())
})?
};
match connection_iface_arc.send_message(msg).await {
Ok(()) => Ok(()),
Err(ZmqError::ConnectionClosed) => {
tracing::warn!(
handle = self.core.handle,
uri = %endpoint_uri_to_send_to,
"PUSH send: Connection closed by peer or transport for URI. Removing from LB."
);
self.load_balancer.remove_connection(&endpoint_uri_to_send_to);
Err(ZmqError::HostUnreachable("Peer connection closed during send".into()))
}
Err(e @ ZmqError::ResourceLimitReached) | Err(e @ ZmqError::Timeout) => {
tracing::trace!(
handle = self.core.handle,
uri = %endpoint_uri_to_send_to,
error = %e,
"PUSH send: HWM reached or timeout on specific connection. Propagating error."
);
Err(e)
}
Err(e) => {
tracing::error!(
handle = self.core.handle,
uri = %endpoint_uri_to_send_to,
error = %e,
"PUSH send: Unexpected error on connection_iface.send_message. Removing from LB."
);
self.load_balancer.remove_connection(&endpoint_uri_to_send_to);
Err(e)
}
}
}
async fn recv(&self) -> Result<Msg, ZmqError> {
Err(ZmqError::UnsupportedFeature("PUSH sockets cannot receive messages"))
}
async fn send_multipart(&self, mut frames: Vec<Msg>) -> Result<(), ZmqError> {
if !self.core.is_running().await {
return Err(ZmqError::ResourceLimitReached);
}
if frames.is_empty() {
return Ok(()); }
let num_frames = frames.len();
for (i, frame) in frames.iter_mut().enumerate() {
if i < num_frames - 1 {
frame.set_flags(frame.flags() | MsgFlags::MORE);
} else {
frame.set_flags(frame.flags() & !MsgFlags::MORE);
}
}
let sndtimeo_opt: Option<Duration> = self.core.core_state.read().options.sndtimeo;
let endpoint_uri_to_send_to = loop {
if let Some(uri) = self.load_balancer.get_next_connection_uri() {
if self.core.core_state.read().endpoints.contains_key(&uri) {
break uri;
} else {
self.load_balancer.remove_connection(&uri);
}
} else {
if !self.core.is_running().await {
return Err(ZmqError::InvalidState(
"Socket terminated while waiting for peer".into(),
));
}
match sndtimeo_opt {
Some(duration) if duration.is_zero() => return Err(ZmqError::ResourceLimitReached),
None => {
self.load_balancer.wait_for_connection().await;
continue;
}
Some(duration) => match tokio_timeout(duration, self.load_balancer.wait_for_connection()).await {
Ok(()) => continue,
Err(_) => return Err(ZmqError::Timeout),
},
}
}
};
let connection_iface_arc = {
let core_s = self.core.core_state.read();
core_s
.endpoints
.get(&endpoint_uri_to_send_to)
.map(|ep_info| ep_info.connection_iface.clone())
.ok_or_else(|| {
self.load_balancer.remove_connection(&endpoint_uri_to_send_to);
ZmqError::HostUnreachable("PUSH: Peer for multipart send disappeared after selection".into())
})?
};
match connection_iface_arc.send_multipart(frames).await {
Ok(()) => Ok(()),
Err(ZmqError::ConnectionClosed) => {
self.load_balancer.remove_connection(&endpoint_uri_to_send_to);
Err(ZmqError::HostUnreachable(
"PUSH: Peer connection closed during multipart send".into(),
))
}
Err(e) => Err(e),
}
}
async fn recv_multipart(&self) -> Result<Vec<Msg>, ZmqError> {
Err(ZmqError::UnsupportedFeature("PUSH sockets cannot receive messages"))
}
async fn set_option(&self, option: i32, value: &[u8]) -> Result<(), ZmqError> {
delegate_to_core!(self, UserSetOpt, option: option, value: value.to_vec())
}
async fn get_option(&self, option: i32) -> Result<Vec<u8>, ZmqError> {
delegate_to_core!(self, UserGetOpt, option: option)
}
async fn set_pattern_option(&self, option: i32, _value: &[u8]) -> Result<(), ZmqError> {
Err(ZmqError::UnsupportedOption(option))
}
async fn get_pattern_option(&self, option: i32) -> Result<Vec<u8>, ZmqError> {
Err(ZmqError::UnsupportedOption(option))
}
async fn process_command(&self, _command: Command) -> Result<bool, ZmqError> {
Ok(false)
}
async fn handle_pipe_event(&self, pipe_id: usize, event: Command) -> Result<(), ZmqError> {
tracing::warn!(
handle = self.core.handle,
pipe_id = pipe_id,
"PUSH socket received unexpected pipe event: {:?}",
event.variant_name()
);
Ok(())
}
async fn pipe_attached(
&self,
pipe_read_id: usize,
_pipe_write_id: usize, _peer_identity: Option<&[u8]>,
) {
let endpoint_uri_option = self
.core
.core_state
.read()
.pipe_read_id_to_endpoint_uri
.get(&pipe_read_id)
.cloned();
if let Some(endpoint_uri) = endpoint_uri_option {
tracing::debug!(
handle = self.core.handle,
pipe_read_id,
uri = %endpoint_uri,
"PUSH attaching connection"
);
self
.pipe_read_to_endpoint_uri
.write()
.insert(pipe_read_id, endpoint_uri.clone());
self.load_balancer.add_connection(endpoint_uri);
} else {
tracing::warn!(
handle = self.core.handle,
pipe_read_id,
"PUSH pipe_attached: Could not find endpoint_uri for pipe_read_id. Cannot add to LoadBalancer."
);
}
}
async fn update_peer_identity(&self, pipe_read_id: usize, identity: Option<Blob>) {
tracing::trace!(
handle = self.core.handle,
socket_type = "PUSH",
pipe_read_id,
?identity,
"update_peer_identity called, but PUSH socket does not use peer identities. Ignoring."
);
}
async fn pipe_detached(&self, pipe_read_id: usize) {
tracing::debug!(
handle = self.core.handle,
pipe_read_id,
"PUSH detaching connection identified by pipe_read_id"
);
let maybe_endpoint_uri = self.pipe_read_to_endpoint_uri.write().remove(&pipe_read_id);
if let Some(endpoint_uri) = maybe_endpoint_uri {
self.load_balancer.remove_connection(&endpoint_uri);
tracing::trace!(
handle = self.core.handle,
pipe_read_id,
uri = %endpoint_uri,
"PUSH removed detached connection from load balancer"
);
} else {
tracing::warn!(
handle = self.core.handle,
pipe_read_id,
"PUSH detach: Endpoint URI not found for pipe_read_id. LoadBalancer may not be updated."
);
}
}
}