use std::future::Future;
use std::io;
use std::net::TcpStream;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use crate::async_client::{maybe_ping, next_packet_id, write_blocking, PendingRequest, Shared, SharedInner};
use crate::codec::properties::Properties;
use crate::codec::types::*;
use crate::error::{Error, Result};
use crate::transport::Transport;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
const REPLY_TOPIC_PREFIX: &str = "egress/reply/";
#[derive(Serialize)]
pub struct RequestEnvelope<'a, T: Serialize> {
#[serde(rename = "type")]
pub msg_type: &'a str,
pub action: &'a str,
pub params: &'a T,
#[serde(rename = "correlationId")]
pub correlation_id: &'a str,
#[serde(rename = "replyTo")]
pub reply_to: &'a str,
}
#[derive(Deserialize)]
pub struct ReplyEnvelope {
#[serde(rename = "correlationId")]
pub correlation_id: String,
pub result: serde_json::Value,
}
enum RequestState {
Init {
topic: String,
payload_json: Vec<u8>,
},
Waiting,
Done,
}
pub struct RequestFuture<T: Transport = TcpStream> {
inner: Shared<T>,
correlation_id: String,
reply_topic: String,
state: RequestState,
deadline: Instant,
}
impl<T: Transport> RequestFuture<T> {
pub(crate) fn new<Req: Serialize>(
inner: Shared<T>,
topic: String,
payload: &Req,
) -> Self {
let correlation_id = generate_correlation_id();
let reply_topic = format!("{REPLY_TOPIC_PREFIX}{correlation_id}");
let amqp_mode = inner.borrow().amqp_reply_format;
let reply_to = if amqp_mode {
reply_topic.replace('/', ".")
} else {
reply_topic.clone()
};
let envelope = RequestEnvelope {
msg_type: "request",
action: "chat",
params: payload,
correlation_id: &correlation_id,
reply_to: &reply_to,
};
let payload_json = serde_json::to_vec(&envelope)
.expect("serialization should not fail");
RequestFuture {
inner,
correlation_id,
reply_topic,
state: RequestState::Init {
topic,
payload_json,
},
deadline: Instant::now() + DEFAULT_TIMEOUT,
}
}
}
impl<T: Transport> Future for RequestFuture<T> {
type Output = Result<serde_json::Value>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
{
let inner = this.inner.borrow();
if let Some(ref _e) = inner.error {
return Poll::Ready(Err(Error::ConnectionClosed));
}
}
if Instant::now() >= this.deadline {
this.inner.borrow_mut().pending.remove(&this.correlation_id);
this.state = RequestState::Done;
return Poll::Ready(Err(Error::Timeout));
}
loop {
match std::mem::replace(&mut this.state, RequestState::Done) {
RequestState::Init { topic, payload_json } => {
let mut inner = this.inner.borrow_mut();
inner.pending.insert(
this.correlation_id.clone(),
PendingRequest {
waker: Some(cx.waker().clone()),
result: None,
},
);
let sub_id = next_packet_id(&mut inner);
let sub = SubscribePacket {
packet_id: sub_id,
filters: vec![(this.reply_topic.clone(), QoS::AtMostOnce)],
properties: Properties::new(),
};
let sub_bytes = sub.encode()?;
write_blocking(&mut inner.stream, &sub_bytes)?;
let pub_pkt = PublishPacket {
topic,
packet_id: None,
payload: payload_json,
qos: QoS::AtMostOnce,
retain: false,
dup: false,
properties: Properties::new(),
};
let pub_bytes = pub_pkt.encode()?;
write_blocking(&mut inner.stream, &pub_bytes)?;
inner.last_activity = Instant::now();
drop(inner);
this.state = RequestState::Waiting;
}
RequestState::Waiting => {
pump_socket(&this.inner)?;
let mut inner = this.inner.borrow_mut();
if let Some(pending) = inner.pending.get_mut(&this.correlation_id) {
if let Some(result) = pending.result.take() {
inner.pending.remove(&this.correlation_id);
let unsub_id = next_packet_id(&mut inner);
let unsub = UnsubscribePacket {
packet_id: unsub_id,
filters: vec![this.reply_topic.clone()],
properties: Properties::new(),
};
if let Ok(bytes) = unsub.encode() {
let _ = write_blocking(&mut inner.stream, &bytes);
}
drop(inner);
this.state = RequestState::Done;
return match result {
Ok(payload_bytes) => {
match serde_json::from_slice::<ReplyEnvelope>(&payload_bytes) {
Ok(env) => Poll::Ready(Ok(env.result)),
Err(e) => Poll::Ready(Err(Error::Deserialize(e.to_string()))),
}
}
Err(e) => Poll::Ready(Err(e)),
};
}
pending.waker = Some(cx.waker().clone());
}
drop(inner);
this.state = RequestState::Waiting;
cx.waker().wake_by_ref();
return Poll::Pending;
}
RequestState::Done => {
panic!("RequestFuture polled after completion");
}
}
}
}
}
fn pump_socket<T: Transport>(shared: &Shared<T>) -> Result<()> {
let mut inner = shared.borrow_mut();
if inner.error.is_some() {
return Ok(());
}
let mut tmp = [0u8; 8192];
loop {
match inner.stream.read(&mut tmp) {
Ok(0) => {
inner.error = Some(Error::ConnectionClosed);
for (_, pending) in inner.pending.iter_mut() {
pending.result = Some(Err(Error::ConnectionClosed));
if let Some(w) = pending.waker.take() {
w.wake();
}
}
return Ok(());
}
Ok(n) => {
inner.frame_reader.push(&tmp[..n]);
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
Err(e) => {
inner.error = Some(Error::Io(io::Error::new(e.kind(), e.to_string())));
return Err(Error::Io(e));
}
}
}
loop {
match inner.frame_reader.try_decode() {
Ok(Some(packet)) => dispatch_packet(&mut inner, packet)?,
Ok(None) => break,
Err(e) => {
inner.error = Some(Error::MalformedPacket("frame decode error"));
return Err(e);
}
}
}
maybe_ping(&mut inner)?;
Ok(())
}
fn dispatch_packet<T: Transport>(
inner: &mut SharedInner<T>,
packet: Packet,
) -> Result<()> {
match packet {
Packet::Publish(pub_pkt) => {
if pub_pkt.qos == QoS::AtLeastOnce {
if let Some(id) = pub_pkt.packet_id {
let ack = PubAckPacket { packet_id: id, reason_code: 0x00 };
let bytes = ack.encode()?;
write_blocking(&mut inner.stream, &bytes)?;
}
}
if let Ok(envelope) = serde_json::from_slice::<ReplyEnvelope>(&pub_pkt.payload) {
if let Some(pending) = inner.pending.get_mut(&envelope.correlation_id) {
pending.result = Some(Ok(pub_pkt.payload));
if let Some(w) = pending.waker.take() {
w.wake();
}
}
}
}
Packet::PingResp => {
inner.last_activity = Instant::now();
}
Packet::Disconnect(disc) => {
inner.error = Some(Error::ConnectionRefused(disc.reason_code));
for (_, pending) in inner.pending.iter_mut() {
pending.result = Some(Err(Error::ConnectionClosed));
if let Some(w) = pending.waker.take() {
w.wake();
}
}
}
_ => {}
}
Ok(())
}
fn generate_correlation_id() -> String {
uuid::Uuid::new_v4().to_string()
}