use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Poll, Waker},
time::Duration,
};
use futures::{
channel::mpsc::{self},
Sink, SinkExt, Stream, StreamExt,
};
use monitor::UpstreamConnectionSnapshot;
use serde::Serialize;
use super::{
message::{
ChannelType, DownstreamControlChannelAction, ExchangeAction, Message, MessageSink, MessageStream, ObjectContractAction, UninitializedChannelAction,
UpstreamControlChannelAction,
},
state_stream::{receive_state, send_state},
termination::{CoTerminatingSet, ConnectionTerminationReason},
};
use crate::{
connection::message::TagContractAction, exchange::ExchangeShared, object::core::Object, tag::core::Tag, utils::Generator, Exchange, ObjectDescriptor,
};
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug, Serialize)]
pub(crate) struct UpstreamConnectionId(pub u64);
pub struct UpstreamConnectionControl {
id: UpstreamConnectionId,
termination: CoTerminatingSet,
local_address: String,
local_name: String,
remote_name: Arc<Mutex<Option<String>>>,
exchange_shared: Arc<ExchangeShared>,
_outgoing_channel_sender:
mpsc::UnboundedSender<(Pin<Box<dyn Sink<Message, Error = mpsc::SendError> + Send + Sync>>, Pin<Box<dyn Stream<Item = Message> + Send + Sync>>)>,
}
impl UpstreamConnectionControl {
pub fn new(
exchange: &Exchange,
local_name: String,
local_address: String,
load_limit: Vec<(Duration, u32)>,
) -> (UpstreamConnectionControl, impl Iterator<Item = (MessageSink, MessageStream)>, impl Stream<Item = (MessageSink, MessageStream)>) {
let id = exchange.shared.next_upstream_id();
let termination = CoTerminatingSet::new();
let (outgoing_channel_sender, outgoing_channel_receiver) = futures::channel::mpsc::unbounded::<(MessageSink, MessageStream)>();
let object_interface = UpstreamObjectInterface::new(id);
outgoing_channel_sender
.unbounded_send({
let (input_sender, input_receiver) = mpsc::unbounded();
let (output_sender, output_receiver) = mpsc::unbounded::<Message>();
(
Box::pin(input_sender),
Box::pin(output_receiver.with_generator(termination.clone().abort_on_termination(outgoing_control_channel(
Box::pin(input_receiver),
Box::pin(output_sender),
object_interface.clone(),
local_name.clone(),
load_limit,
)))),
)
})
.unwrap();
let remote_name = Arc::new(Mutex::new(None));
if !exchange.shared.set_upstream_connection(object_interface) {
panic!("You can't have more than one UpstreamCOnnectionControl alive per Exchange")
};
let incoming_channels = std::iter::from_fn::<(MessageSink, MessageStream), _>({
let termination = termination.clone();
let exchange_shared = exchange.shared.clone();
let remote_name = remote_name.clone();
move || {
let (input_sender, input_receiver) = mpsc::unbounded();
let (output_sender, output_receiver) = mpsc::unbounded::<Message>();
Some((
Box::pin(input_sender),
Box::pin(output_receiver.with_generator(termination.clone().abort_on_termination(Box::pin(incoming_uninitialized_channel(
Box::pin(input_receiver),
Box::pin(output_sender.sink_map_err(|error| error.into())),
exchange_shared.clone(),
id,
remote_name.clone(),
))))),
))
}
});
(
UpstreamConnectionControl {
id,
termination: termination.clone(),
local_name,
remote_name,
local_address,
exchange_shared: exchange.shared.clone(),
_outgoing_channel_sender: outgoing_channel_sender,
},
termination.clone().terminate_on_drop(ConnectionTerminationReason::Shutdown("ChannelsIterator dropped".to_string()), incoming_channels),
termination.terminate_on_drop(ConnectionTerminationReason::Shutdown("ChannelsStream dropped".to_string()), outgoing_channel_receiver),
)
}
pub fn termination(&self) -> &CoTerminatingSet {
&self.termination
}
pub fn id(&self) -> UpstreamConnectionId {
self.id
}
pub fn snapshot(&self) -> UpstreamConnectionSnapshot {
UpstreamConnectionSnapshot {
upstream_connection_id: self.id.0,
local_address: self.local_address.clone(),
local_name: self.local_name.clone(),
remote_name: self.remote_name.lock().unwrap().clone(),
}
}
}
impl Drop for UpstreamConnectionControl {
fn drop(&mut self) {
self.termination.terminate(ConnectionTerminationReason::Shutdown("UpstreamConnectionControl dropped".to_string()));
self.exchange_shared.remove_upstream_connection(self.id);
}
}
pub(crate) struct UpstreamObjectInterface {
pub id: UpstreamConnectionId,
mutable: Mutex<UpstreamObjectInterfaceShared>,
}
struct UpstreamObjectInterfaceShared {
waker: Option<Waker>,
object_contracts_state_new: HashMap<Arc<Object>, ContractsState>,
tag_contracts_state_new: HashMap<Arc<Tag>, ContractsState>,
}
#[derive(Clone, Debug)]
struct ContractsState {
observe_contract_exists: bool,
expose_contract_exists: bool,
expose_capacity: u32,
}
impl UpstreamObjectInterface {
fn new(id: UpstreamConnectionId) -> Arc<UpstreamObjectInterface> {
Arc::new(UpstreamObjectInterface {
id,
mutable: Mutex::new(UpstreamObjectInterfaceShared {
waker: None,
object_contracts_state_new: HashMap::new(),
tag_contracts_state_new: HashMap::new(),
}),
})
}
}
impl UpstreamObjectInterface {
pub fn object_contracts_changed(&self, object: Arc<Object>, observe_contract_exists: bool, expose_contract_exists: bool, expose_capacity: u32) {
let mut locked_mutable = self.mutable.lock().unwrap();
locked_mutable.object_contracts_state_new.insert(object, ContractsState { observe_contract_exists, expose_contract_exists, expose_capacity });
if let Some(waker) = locked_mutable.waker.take() {
waker.wake();
}
}
pub fn tag_contracts_changed(&self, tag: Arc<Tag>, observe_contract_exists: bool, expose_contract_exists: bool, expose_capacity: u32) {
let mut locked_mutable = self.mutable.lock().unwrap();
locked_mutable.tag_contracts_state_new.insert(tag, ContractsState { observe_contract_exists, expose_contract_exists, expose_capacity });
if let Some(waker) = locked_mutable.waker.take() {
waker.wake();
}
}
fn changed(self: &Arc<Self>) -> impl Future<Output = (HashMap<Arc<Object>, ContractsState>, HashMap<Arc<Tag>, ContractsState>)> {
let shared = self.clone();
futures::future::poll_fn(move |cx| {
let mut locked_mutable = shared.mutable.lock().unwrap();
if !locked_mutable.object_contracts_state_new.is_empty() || !locked_mutable.tag_contracts_state_new.is_empty() {
locked_mutable.waker = None;
Poll::Ready((std::mem::take(&mut locked_mutable.object_contracts_state_new), std::mem::take(&mut locked_mutable.tag_contracts_state_new)))
} else {
locked_mutable.waker = Some(cx.waker().clone());
Poll::Pending
}
})
}
}
async fn outgoing_control_channel(
_input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
mut output: Pin<Box<dyn Sink<Message, Error = mpsc::SendError> + Send + Sync + 'static>>,
shared: Arc<UpstreamObjectInterface>,
name: String,
load_limit: Vec<(Duration, u32)>,
) -> Result<(), ConnectionTerminationReason> {
output.send(Message::UninitializedChannelActions(vec![UninitializedChannelAction::Initialize(ChannelType::UpstreamControl)])).await?;
output
.send(Message::UpstreamControlChannelActions(vec![
UpstreamControlChannelAction::ExchangeAction(ExchangeAction::SetName(name)),
UpstreamControlChannelAction::ExchangeAction(ExchangeAction::SetLoadLimit(
load_limit.into_iter().map(|(duration, limit)| (duration.as_secs_f32(), limit)).collect(),
)),
]))
.await?;
#[allow(clippy::mutable_key_type)]
let mut object_contracts_last_sent: HashMap<Arc<Object>, ContractsState> = HashMap::new();
#[allow(clippy::mutable_key_type)]
let mut tag_contracts_last_sent: HashMap<Arc<Tag>, ContractsState> = HashMap::new();
loop {
let (object_contracts_state_new, tag_contracts_state_new) = shared.changed().await;
let mut actions_to_send = Vec::new();
{
for (object, object_contract_state) in object_contracts_state_new.into_iter() {
let last_sent_state = object_contracts_last_sent.get(&object).cloned().unwrap_or(ContractsState {
observe_contract_exists: false,
expose_contract_exists: false,
expose_capacity: 0,
});
if object_contract_state.observe_contract_exists != last_sent_state.observe_contract_exists {
if object_contract_state.observe_contract_exists {
actions_to_send.push(UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Observe(
object.descriptor.tags_to_strings(),
object.descriptor.json_to_string(),
)));
} else {
actions_to_send.push(UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Unobserve(
object.descriptor.tags_to_strings(),
object.descriptor.json_to_string(),
)));
}
}
if object_contract_state.expose_contract_exists != last_sent_state.expose_contract_exists {
if object_contract_state.expose_contract_exists {
actions_to_send.push(UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Expose(
object.descriptor.tags_to_strings(),
object.descriptor.json_to_string(),
object_contract_state.expose_capacity,
)));
} else {
actions_to_send.push(UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Unexpose(
object.descriptor.tags_to_strings(),
object.descriptor.json_to_string(),
)));
}
} else if object_contract_state.expose_capacity != last_sent_state.expose_capacity {
actions_to_send.push(UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::SetExposeCapacity(
object.descriptor.tags_to_strings(),
object.descriptor.json_to_string(),
object_contract_state.expose_capacity,
)));
}
if object_contract_state.expose_contract_exists || object_contract_state.observe_contract_exists {
object_contracts_last_sent.insert(object, object_contract_state);
} else {
object_contracts_last_sent.remove(&object);
}
}
for (tag, tag_contract_state) in tag_contracts_state_new.into_iter() {
let last_sent_state = tag_contracts_last_sent.get(&tag).cloned().unwrap_or(ContractsState {
observe_contract_exists: false,
expose_contract_exists: false,
expose_capacity: 0,
});
if tag_contract_state.observe_contract_exists != last_sent_state.observe_contract_exists {
if tag_contract_state.observe_contract_exists {
actions_to_send.push(UpstreamControlChannelAction::TagContractAction(TagContractAction::Observe(tag.descriptor.json_to_string())));
} else {
actions_to_send.push(UpstreamControlChannelAction::TagContractAction(TagContractAction::Unobserve(tag.descriptor.json_to_string())));
}
}
if tag_contract_state.expose_contract_exists != last_sent_state.expose_contract_exists {
if tag_contract_state.expose_contract_exists {
actions_to_send.push(UpstreamControlChannelAction::TagContractAction(TagContractAction::Expose(
tag.descriptor.json_to_string(),
tag_contract_state.expose_capacity,
)));
} else {
actions_to_send.push(UpstreamControlChannelAction::TagContractAction(TagContractAction::Unexpose(tag.descriptor.json_to_string())));
}
} else if tag_contract_state.expose_capacity != last_sent_state.expose_capacity {
actions_to_send.push(UpstreamControlChannelAction::TagContractAction(TagContractAction::SetExposeCapacity(
tag.descriptor.json_to_string(),
tag_contract_state.expose_capacity,
)));
}
if tag_contract_state.expose_contract_exists || tag_contract_state.observe_contract_exists {
tag_contracts_last_sent.insert(tag, tag_contract_state);
} else {
tag_contracts_last_sent.remove(&tag);
}
}
};
if !actions_to_send.is_empty() {
output.send(Message::UpstreamControlChannelActions(actions_to_send)).await?;
}
}
}
async fn incoming_uninitialized_channel(
mut input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
exchange_shared: Arc<ExchangeShared>,
upstream_connection_id: UpstreamConnectionId,
remote_name: Arc<Mutex<Option<String>>>,
) -> Result<(), ConnectionTerminationReason> {
match input.next().await {
Some(Message::UninitializedChannelActions(mut uninitialized_channel_actions)) => {
if uninitialized_channel_actions.len() != 1 {
return Err(ConnectionTerminationReason::SeriousError(
"First message received in unititialized channel had wrong number of actions".to_string(),
));
};
match uninitialized_channel_actions.pop().unwrap() {
UninitializedChannelAction::Initialize(ChannelType::DownstreamControl) => {
incoming_downstream_control_channel(input, output, remote_name.clone()).await
}
UninitializedChannelAction::Initialize(ChannelType::ObjectSinkToUpstream(tag_descriptor_strings, object_descriptor_string)) => {
incoming_object_state_sink_channel(
input,
output,
tag_descriptor_strings,
upstream_connection_id,
object_descriptor_string,
exchange_shared.clone(),
)
.await
}
UninitializedChannelAction::Initialize(ChannelType::ObjectStreamFromUpstream(tag_descriptor_strings, object_descriptor_string)) => {
incoming_object_state_stream_channel(
input,
output,
tag_descriptor_strings,
upstream_connection_id,
object_descriptor_string,
exchange_shared.clone(),
)
.await
}
_ => Err(ConnectionTerminationReason::SeriousError("First message received in unititialized channel had wrong action type".to_string())),
}
}
None => Ok(()),
_ => Err(ConnectionTerminationReason::SeriousError("First message received in unititialized channel was not UninitializedChannelActions".to_string())),
}
}
async fn incoming_downstream_control_channel(
mut input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
remote_name: Arc<Mutex<Option<String>>>,
) -> Result<(), ConnectionTerminationReason> {
drop(output);
while let Some(message) = input.next().await {
match message {
Message::DownstreamControlChannelActions(downstream_control_channel_actions) => {
for action in downstream_control_channel_actions {
match action {
DownstreamControlChannelAction::ExchangeAction(ExchangeAction::SetName(name)) => *remote_name.lock().unwrap() = Some(name),
DownstreamControlChannelAction::ExchangeAction(ExchangeAction::SetLoadLimit(_)) => {
return Err(ConnectionTerminationReason::SeriousError(
"incoming_downstream_control_channel received SetLoadLimit. and that's, like, unexpected.".to_string(),
))
}
}
}
}
_ => return Err(ConnectionTerminationReason::SeriousError("incoming_downstream_control_channel received message of invalid type".to_string())),
}
}
Err(ConnectionTerminationReason::Shutdown("downstream channel closed".to_string()))
}
async fn incoming_object_state_sink_channel(
input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
tag_descriptor_strings: Vec<String>,
upstream_connection_id: UpstreamConnectionId,
object_descriptor_string: String,
exchange_shared: Arc<ExchangeShared>,
) -> Result<(), ConnectionTerminationReason> {
let descriptor = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string)?;
let object_core = exchange_shared.object_acquire(&descriptor);
let Some(object_state_stream) = object_core.create_upstream_object_state_stream(upstream_connection_id) else { return Ok(()) };
send_state(input, output, object_state_stream).await
}
async fn incoming_object_state_stream_channel(
input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
tag_descriptor_strings: Vec<String>,
upstream_connection_id: UpstreamConnectionId,
object_descriptor_string: String,
exchange_shared: Arc<ExchangeShared>,
) -> Result<(), ConnectionTerminationReason> {
let Ok(descriptor) = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string) else {
return Err(ConnectionTerminationReason::SeriousError("Received object descriptor is not a parseable JSON".to_string()));
};
let object_core = exchange_shared.object_acquire(&descriptor);
let Some(object_state_sink) = object_core.create_upstream_object_state_sink(upstream_connection_id) else { return Ok(()) };
receive_state(input, output, object_state_sink).await
}
pub mod monitor {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct UpstreamConnectionSnapshot {
pub upstream_connection_id: u64,
pub local_address: String,
pub local_name: String,
pub remote_name: Option<String>,
}
impl std::fmt::Debug for UpstreamConnectionSnapshot {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_fmt(format_args!("│ Id: {:?}\n", self.upstream_connection_id))?;
fmt.write_fmt(format_args!("│ Local address: {}\n", self.local_address))?;
if let Some(ref remote_name) = self.remote_name {
fmt.write_fmt(format_args!("│ Remote name: {}\n", remote_name))?;
}
Ok(())
}
}
}