use crate::{
ZmqResult, sealed,
socket::{MultipartReceiver, MultipartSender, Socket, SocketOption, SocketType},
};
pub type StreamSocket = Socket<Stream>;
pub struct Stream {}
impl sealed::SenderFlag for Stream {}
impl sealed::ReceiverFlag for Stream {}
impl sealed::SocketType for Stream {
fn raw_socket_type() -> SocketType {
SocketType::Stream
}
}
unsafe impl Sync for Socket<Stream> {}
unsafe impl Send for Socket<Stream> {}
impl MultipartSender for Socket<Stream> {}
impl MultipartReceiver for Socket<Stream> {}
impl Socket<Stream> {
pub fn set_routing_id<V>(&self, value: V) -> ZmqResult<()>
where
V: AsRef<str>,
{
self.set_sockopt_string(SocketOption::RoutingId, value)
}
pub fn routing_id(&self) -> ZmqResult<String> {
self.get_sockopt_string(SocketOption::RoutingId)
}
pub fn set_connect_routing_id<V>(&self, value: V) -> ZmqResult<()>
where
V: AsRef<str>,
{
self.set_sockopt_string(SocketOption::ConnectRoutingId, value)
}
#[cfg(feature = "draft-api")]
pub fn set_stream_notify(&self, value: bool) -> ZmqResult<()> {
self.set_sockopt_bool(SocketOption::StreamNotify, value)
}
}
#[cfg(test)]
mod stream_tests {
use core::error::Error;
use std::{
io::{Read, Write},
net::TcpStream,
};
use super::StreamSocket;
use crate::prelude::{
Context, MultipartReceiver, MultipartSender, RecvFlags, SendFlags, ZmqResult,
};
#[test]
fn set_routing_sets_routing_id() -> ZmqResult<()> {
let context = Context::new()?;
let socket = StreamSocket::from_context(&context)?;
socket.set_routing_id("asdf")?;
assert_eq!(socket.routing_id()?, "asdf");
Ok(())
}
#[test]
fn set_connect_routing_sets_connect_routing_id() -> ZmqResult<()> {
let context = Context::new()?;
let socket = StreamSocket::from_context(&context)?;
socket.set_connect_routing_id("asdf")?;
Ok(())
}
#[cfg(feature = "draft-api")]
#[test]
fn set_stream_notify_sets_stream_notify() -> ZmqResult<()> {
let context = Context::new()?;
let socket = StreamSocket::from_context(&context)?;
socket.set_stream_notify(true)?;
Ok(())
}
#[test]
#[rustversion::attr(all(nightly, since(1.88)), allow(clippy::collapsible_if))]
fn stream_server() -> Result<(), Box<dyn Error>> {
let context = Context::new()?;
let socket = StreamSocket::from_context(&context)?;
socket.bind("tcp://127.0.0.1:*")?;
let tcp_endpoint = socket.last_endpoint()?;
std::thread::spawn(move || {
let _routing_id = socket.recv_multipart(RecvFlags::empty()).unwrap();
let mut multipart = socket.recv_multipart(RecvFlags::empty()).unwrap();
let msg = multipart.pop_back().unwrap();
assert_eq!(msg.to_string(), "Hello");
multipart.push_back("World".into());
socket
.send_multipart(multipart, SendFlags::empty())
.unwrap();
});
let mut tcp_stream = TcpStream::connect(tcp_endpoint.strip_prefix("tcp://").unwrap())?;
tcp_stream.write_all(b"Hello")?;
let mut buffer = [0; 256];
if let Ok(length) = tcp_stream.read(&mut buffer) {
if length != 0 {
let received_msg = &buffer[..length];
assert_eq!(received_msg, b"World");
}
}
Ok(())
}
#[cfg(feature = "futures")]
#[test]
#[rustversion::attr(all(nightly, since(1.88)), allow(clippy::collapsible_if))]
fn stream_server_async() -> Result<(), Box<dyn Error>> {
let context = Context::new()?;
let socket = StreamSocket::from_context(&context)?;
socket.bind("tcp://127.0.0.1:*")?;
let tcp_endpoint = socket.last_endpoint()?;
std::thread::spawn(move || {
futures::executor::block_on(async {
let _routing_id = socket.recv_multipart_async().await;
let mut multipart = socket.recv_multipart_async().await;
let msg = multipart.pop_back().unwrap();
assert_eq!(msg.to_string(), "Hello");
multipart.push_back("World".into());
socket
.send_multipart_async(multipart, SendFlags::empty())
.await;
})
});
let mut tcp_stream = TcpStream::connect(tcp_endpoint.strip_prefix("tcp://").unwrap())?;
tcp_stream.write_all(b"Hello")?;
let mut buffer = [0; 256];
if let Ok(length) = tcp_stream.read(&mut buffer) {
if length != 0 {
let received_msg = &buffer[..length];
assert_eq!(received_msg, b"World");
}
}
Ok(())
}
}
#[cfg(feature = "builder")]
pub(crate) mod builder {
use core::default::Default;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use super::StreamSocket;
use crate::{ZmqResult, context::Context, socket::SocketBuilder};
#[derive(Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Builder)]
#[builder(
pattern = "owned",
name = "StreamBuilder",
public,
build_fn(skip, error = "ZmqError"),
derive(PartialEq, Eq, Hash, Clone, serde::Serialize, serde::Deserialize)
)]
#[builder_struct_attr(doc = "Builder for [`StreamSocket`].\n\n")]
#[allow(dead_code)]
struct StreamConfig {
socket_builder: SocketBuilder,
#[builder(setter(into), default = "Default::default()")]
routing_id: String,
#[builder(setter(into), default = "Default::default()")]
connect_routing_id: String,
#[cfg(feature = "draft-api")]
#[builder(default = false)]
stream_notify: bool,
}
impl StreamBuilder {
pub fn apply(self, socket: &StreamSocket) -> ZmqResult<()> {
if let Some(socket_builder) = self.socket_builder {
socket_builder.apply(socket)?;
}
self.routing_id
.iter()
.try_for_each(|routing_id| socket.set_routing_id(routing_id))?;
self.connect_routing_id
.iter()
.try_for_each(|connect_routing_id| {
socket.set_connect_routing_id(connect_routing_id)
})?;
#[cfg(feature = "draft-api")]
self.stream_notify
.iter()
.try_for_each(|&stream_notify| socket.set_stream_notify(stream_notify))?;
Ok(())
}
pub fn build_from_context(self, context: &Context) -> ZmqResult<StreamSocket> {
let socket = StreamSocket::from_context(context)?;
self.apply(&socket)?;
Ok(socket)
}
}
#[cfg(test)]
mod stream_builder_tests {
use super::StreamBuilder;
use crate::prelude::{Context, SocketBuilder, ZmqResult};
#[test]
fn default_stream_builder() -> ZmqResult<()> {
let context = Context::new()?;
let socket = StreamBuilder::default().build_from_context(&context)?;
assert_eq!(socket.routing_id()?, "");
Ok(())
}
#[test]
fn stream_builder_with_custom_values() -> ZmqResult<()> {
let context = Context::new()?;
let builder = StreamBuilder::default()
.socket_builder(SocketBuilder::default())
.routing_id("asdf")
.connect_routing_id("qwertz");
#[cfg(feature = "draft-api")]
let builder = builder.stream_notify(true);
let socket = builder.build_from_context(&context)?;
assert_eq!(socket.routing_id()?, "asdf");
Ok(())
}
}
}