use std::{
collections::{HashSet, VecDeque},
hash::{Hash, Hasher},
pin::Pin,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use futures::{
channel::mpsc::{self, UnboundedSender},
future::pending,
stream::once,
Sink, SinkExt, Stream, StreamExt,
};
use monitor::DownstreamConnectionSnapshot;
use serde::Serialize;
use super::{
message::{
ChannelType, DownstreamControlChannelAction, ExchangeAction, Message, MessageSink, MessageStream, ObjectContractAction, TagContractAction,
UninitializedChannelAction, UpstreamControlChannelAction,
},
state_stream::{receive_state, send_state},
termination::{CoTerminatingSet, ConnectionTerminationReason},
};
use crate::{
exchange::ExchangeShared,
object::core::Object,
tag::core::Tag,
utils::{BoolUtils, Generator},
Exchange, ObjectDescriptor, ObjectStateSink, ObjectStateStream, TagDescriptor,
};
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug, Serialize)]
pub(crate) struct DownstreamConnectionId(pub(crate) u64);
pub struct DownstreamConnectionControl {
id: DownstreamConnectionId,
local_name: String,
remote_address: String,
termination: CoTerminatingSet,
object_interface: Arc<DownsteamObjectInterface>,
}
enum SinkOrStream {
Sink(ObjectStateSink),
Stream(ObjectStateStream),
}
pub(crate) struct DownsteamObjectInterface {
objects_to_notify_on_termination: Mutex<Option<HashSet<Arc<Object>>>>,
tags_to_notify_on_termination: Mutex<Option<HashSet<Arc<Tag>>>>,
outgoing_sink_and_stream_sender: UnboundedSender<SinkOrStream>,
remote_name: Mutex<Option<String>>,
load: Mutex<f64>,
load_limit: Mutex<VecDeque<(Instant, u32)>>,
termination: CoTerminatingSet,
}
impl DownsteamObjectInterface {
pub(crate) fn stream_created(&self, object_state_stream: ObjectStateStream) -> bool {
if let Err(error) = self.outgoing_sink_and_stream_sender.unbounded_send(SinkOrStream::Stream(object_state_stream)) {
self.termination.is_terminated().assert_true();
let SinkOrStream::Stream(object_state_stream) = error.into_inner() else { panic!() };
object_state_stream.drop_without_notification();
false
} else {
true
}
}
pub(crate) fn sink_created(&self, object_state_sink: ObjectStateSink) -> bool {
if let Err(error) = self.outgoing_sink_and_stream_sender.unbounded_send(SinkOrStream::Sink(object_state_sink)) {
self.termination.is_terminated().assert_true();
let SinkOrStream::Sink(object_state_sink) = error.into_inner() else { panic!() };
object_state_sink.drop_without_notification();
false
} else {
true
}
}
pub(crate) fn linked_object_register(&self, object_core: &Arc<Object>) -> bool {
let mut locked_objects = self.objects_to_notify_on_termination.lock().unwrap();
if let Some(map) = locked_objects.as_mut() {
map.insert(object_core.clone()).assert_true();
true
} else {
false
}
}
pub(crate) fn linked_object_unregister(&self, object_core: &Arc<Object>) {
let mut locked_objects = self.objects_to_notify_on_termination.lock().unwrap();
if let Some(map) = locked_objects.as_mut() {
map.remove(object_core).assert_true();
}
}
pub(crate) fn linked_tag_register(&self, tag_core: &Arc<Tag>) -> bool {
let mut locked_tags = self.tags_to_notify_on_termination.lock().unwrap();
if let Some(map) = locked_tags.as_mut() {
map.insert(tag_core.clone()).assert_true();
true
} else {
false
}
}
pub(crate) fn linked_tag_unregister(&self, tag_core: &Arc<Tag>) {
let mut locked_tags = self.tags_to_notify_on_termination.lock().unwrap();
if let Some(map) = locked_tags.as_mut() {
map.remove(tag_core).assert_true();
}
}
fn current_load_limit(&self) -> Option<f32> {
let mut limit = self.load_limit.lock().unwrap();
if limit.is_empty() {
return None;
};
let now = Instant::now();
while limit.len() > 1 && limit[1].0 < now {
limit.pop_front();
}
if limit.len() == 1 {
Some(limit.front().unwrap().1 as f32)
} else {
Some(limit[0].1 as f32 + (limit[1].1 as f32 - limit[0].1 as f32) * (now - limit[0].0).as_secs_f32() / (limit[1].0 - limit[0].0).as_secs_f32())
}
}
pub(crate) fn expected_proportional_load(&self, delta: f64, check_load_limit: bool) -> Option<f64> {
let current = { *self.load.lock().unwrap() };
if check_load_limit && self.current_load_limit().map(|limit| current + delta > limit as f64).unwrap_or(false) {
None
} else {
Some(current + delta)
}
}
pub(crate) fn load(&self, delta: f64) {
let mut load = self.load.lock().unwrap();
*load += delta;
}
pub(crate) fn is_terminated(&self) -> bool {
self.termination.is_terminated()
}
fn terminate(&self, id: DownstreamConnectionId) {
let locked_objects = self.objects_to_notify_on_termination.lock().unwrap().take();
if let Some(mut map) = locked_objects {
map.drain().for_each(|object_core| {
object_core.downstream_connection_dropped(id);
});
}
let locked_tags = self.tags_to_notify_on_termination.lock().unwrap().take();
if let Some(mut map) = locked_tags {
map.drain().for_each(|tag_core| {
tag_core.downstream_connection_dropped(id);
});
}
}
}
impl DownstreamConnectionControl {
pub fn new(
exchange: &Exchange,
local_name: String,
remote_address: String,
) -> (DownstreamConnectionControl, impl Iterator<Item = (MessageSink, MessageStream)>, impl Stream<Item = (MessageSink, MessageStream)>) {
let id = exchange.shared.next_downstream_id();
let termination = CoTerminatingSet::new();
let (outgoing_sink_and_stream_sender, outgoing_sink_and_stream_receiver) = futures::channel::mpsc::unbounded::<SinkOrStream>();
let object_interface = Arc::new(DownsteamObjectInterface {
objects_to_notify_on_termination: Mutex::new(Some(HashSet::new())),
tags_to_notify_on_termination: Mutex::new(Some(HashSet::new())),
outgoing_sink_and_stream_sender,
remote_name: Mutex::new(None),
load: Mutex::new(0.0),
load_limit: Mutex::new(VecDeque::new()),
termination: termination.clone(),
});
let incoming_channels = std::iter::from_fn::<(MessageSink, MessageStream), _>({
let exchange_shared = exchange.shared.clone();
let downstream_connection_id = id;
let object_interface = object_interface.clone();
let termination = termination.clone();
move || {
let (input_sender, input_receiver) = mpsc::unbounded();
let (output_sender, output_receiver) = mpsc::unbounded();
Some((
Box::pin(input_sender),
Box::pin(output_receiver.with_generator(termination.clone().abort_on_termination(incoming_uninitialized_channel(
Box::pin(input_receiver),
Box::pin(output_sender.sink_map_err(|error| error.into())),
downstream_connection_id,
exchange_shared.clone(),
object_interface.clone(),
)))),
))
}
});
let outgoing_channels = once(Box::pin({
let termination = termination.clone();
let local_name = local_name.clone();
async move {
let (input_sender, input_receiver) = mpsc::unbounded();
let (output_sender, output_receiver) = mpsc::unbounded::<Message>();
(
Box::pin(input_sender) as MessageSink,
Box::pin(output_receiver.with_generator(termination.clone().abort_on_termination(outgoing_control_channel(
Box::pin(input_receiver),
Box::pin(output_sender),
local_name.clone(),
)))) as MessageStream,
)
}
}))
.chain(outgoing_sink_and_stream_receiver.map({
let termination = termination.clone();
move |sink_or_stream| {
let (input_sender, input_receiver) = mpsc::unbounded();
let (output_sender, output_receiver) = mpsc::unbounded();
let message_stream: MessageStream = match sink_or_stream {
SinkOrStream::Sink(object_state_sink) => {
Box::pin(output_receiver.with_generator(termination.abort_on_termination(outgoing_object_state_sink_channel(
Box::pin(input_receiver),
Box::pin(output_sender.sink_map_err(|error| error.into())),
object_state_sink,
))))
}
SinkOrStream::Stream(object_state_stream) => {
Box::pin(output_receiver.with_generator(termination.abort_on_termination(outgoing_object_state_stream_channel(
Box::pin(input_receiver),
Box::pin(output_sender.sink_map_err(|error| error.into())),
object_state_stream,
))))
}
};
(Box::pin(input_sender) as MessageSink, message_stream)
}
}));
(
DownstreamConnectionControl { id, object_interface, termination: termination.clone(), local_name, remote_address },
termination.clone().terminate_on_drop(ConnectionTerminationReason::Shutdown("ChannelsIterator dropped".to_string()), incoming_channels),
termination.terminate_on_drop(ConnectionTerminationReason::Shutdown("ChannelsStream dropped".to_string()), outgoing_channels),
)
}
pub fn termination(&self) -> &CoTerminatingSet {
&self.termination
}
pub fn id(&self) -> DownstreamConnectionId {
self.id
}
pub fn snapshot(&self) -> DownstreamConnectionSnapshot {
DownstreamConnectionSnapshot {
downstream_connection_id: self.id.0,
remote_address: self.remote_address.clone(),
local_name: self.local_name.clone(),
remote_name: self.object_interface.remote_name.lock().unwrap().clone(),
load: *self.object_interface.load.lock().unwrap() as f32,
load_limit: self.object_interface.current_load_limit(),
}
}
pub fn remote_name(&self) -> Option<String> {
self.object_interface.remote_name.lock().unwrap().clone()
}
}
impl Drop for DownstreamConnectionControl {
fn drop(&mut self) {
self.termination.terminate(ConnectionTerminationReason::Shutdown("DownstreamConnectionControl dropped".to_string()));
self.object_interface.terminate(self.id);
}
}
async fn outgoing_control_channel(
_input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
mut output: Pin<Box<dyn Sink<Message, Error = mpsc::SendError> + Send + Sync + 'static>>,
local_name: String,
) -> Result<(), ConnectionTerminationReason> {
output.send(Message::UninitializedChannelActions(vec![UninitializedChannelAction::Initialize(ChannelType::DownstreamControl)])).await?;
output.send(Message::DownstreamControlChannelActions(vec![DownstreamControlChannelAction::ExchangeAction(ExchangeAction::SetName(local_name))])).await?;
pending::<()>().await;
unreachable!();
}
async fn incoming_uninitialized_channel(
mut input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
downstream_connection_id: DownstreamConnectionId,
exchange: Arc<ExchangeShared>,
downstream_object_interface: Arc<DownsteamObjectInterface>,
) -> Result<(), ConnectionTerminationReason> {
match input.next().await {
Some(Message::UninitializedChannelActions(uninitialized_channel_actions)) => {
if uninitialized_channel_actions.len() != 1 {
return Err(ConnectionTerminationReason::SeriousError(
"First message received in unititialized channel had wrong number of actions".to_string(),
));
};
match uninitialized_channel_actions[0] {
UninitializedChannelAction::Initialize(ChannelType::UpstreamControl) => {
incoming_upstream_control_channel(input, output, downstream_connection_id, exchange, downstream_object_interface).await
}
_ => Err(ConnectionTerminationReason::SeriousError("First message received in unititialized channel had wrong action type".to_string())),
}
}
None => Ok(()),
_ => Err(ConnectionTerminationReason::SeriousError("First message received in unititialized channel was not UninitializedChannelActions".to_string())),
}
}
async fn incoming_upstream_control_channel(
mut input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
downstream_connection_id: DownstreamConnectionId,
exchange: Arc<ExchangeShared>,
downstream_object_interface: Arc<DownsteamObjectInterface>,
) -> Result<(), ConnectionTerminationReason> {
drop(output);
while let Some(message) = input.next().await {
match message {
Message::UpstreamControlChannelActions(upstream_control_channel_actions) => {
for action in upstream_control_channel_actions {
match action {
UpstreamControlChannelAction::ExchangeAction(ExchangeAction::SetName(new_name)) => {
*downstream_object_interface.remote_name.lock().unwrap() = Some(new_name);
}
UpstreamControlChannelAction::ExchangeAction(ExchangeAction::SetLoadLimit(load_limit)) => {
let now = Instant::now();
let mut limits = if !load_limit.is_empty() { vec![(now, 0)] } else { vec![] };
for (duration, limit) in load_limit.into_iter() {
limits.push((now + Duration::from_secs_f32(duration), limit));
}
limits.sort();
*downstream_object_interface.load_limit.lock().unwrap() = limits.into();
}
UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Observe(tag_descriptor_strings, object_descriptor_string)) => {
let object_descriptor = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string)?;
let object_core = exchange.object_acquire(&object_descriptor);
object_core.link_downstream_object_observe_contract(downstream_connection_id, downstream_object_interface.clone())?;
}
UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Unobserve(
tag_descriptor_strings,
object_descriptor_string,
)) => {
let object_descriptor = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string)?;
let object_core = exchange.object_acquire(&object_descriptor);
object_core.unlink_downstream_object_observe_contract(downstream_connection_id)?;
}
UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Expose(
tag_descriptor_strings,
object_descriptor_string,
capacity,
)) => {
let object_descriptor = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string)?;
let object_core = exchange.object_acquire(&object_descriptor);
object_core.link_downstream_object_expose_contract(downstream_connection_id, downstream_object_interface.clone(), capacity)?;
}
UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::Unexpose(
tag_descriptor_strings,
object_descriptor_string,
)) => {
let object_descriptor = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string)?;
let object_core = exchange.object_acquire(&object_descriptor);
object_core.unlink_downstream_object_expose_contract(downstream_connection_id)?;
}
UpstreamControlChannelAction::ObjectContractAction(ObjectContractAction::SetExposeCapacity(
tag_descriptor_strings,
object_descriptor_string,
capacity,
)) => {
let object_descriptor = ObjectDescriptor::from_json_strings(tag_descriptor_strings, object_descriptor_string)?;
let object_core = exchange.object_acquire(&object_descriptor);
object_core.set_downstream_object_expose_capacity(downstream_connection_id, capacity)?;
}
UpstreamControlChannelAction::TagContractAction(TagContractAction::Observe(tag_descriptor_string)) => {
let tag_descriptor = TagDescriptor::from_json_string(tag_descriptor_string)?;
let tag_core = exchange.tag_acquire(&tag_descriptor);
tag_core.link_downstream_tag_observe_contract(downstream_connection_id, downstream_object_interface.clone())?;
}
UpstreamControlChannelAction::TagContractAction(TagContractAction::Unobserve(tag_descriptor_string)) => {
let tag_descriptor = TagDescriptor::from_json_string(tag_descriptor_string)?;
let tag_core = exchange.tag_acquire(&tag_descriptor);
tag_core.unlink_downstream_tag_observe_contract(downstream_connection_id)?;
}
UpstreamControlChannelAction::TagContractAction(TagContractAction::Expose(tag_descriptor_string, capacity)) => {
let tag_descriptor = TagDescriptor::from_json_string(tag_descriptor_string)?;
let tag_core = exchange.tag_acquire(&tag_descriptor);
tag_core.link_downstream_tag_expose_contract(downstream_connection_id, downstream_object_interface.clone(), capacity)?;
}
UpstreamControlChannelAction::TagContractAction(TagContractAction::Unexpose(tag_descriptor_string)) => {
let tag_descriptor = TagDescriptor::from_json_string(tag_descriptor_string)?;
let tag_core = exchange.tag_acquire(&tag_descriptor);
tag_core.unlink_downstream_tag_expose_contract(downstream_connection_id)?;
}
UpstreamControlChannelAction::TagContractAction(TagContractAction::SetExposeCapacity(tag_descriptor_string, capacity)) => {
let tag_descriptor = TagDescriptor::from_json_string(tag_descriptor_string)?;
let tag_core = exchange.tag_acquire(&tag_descriptor);
tag_core.set_downstream_tag_expose_capacity(downstream_connection_id, capacity)?;
}
}
}
}
_ => return Err(ConnectionTerminationReason::SeriousError("incoming_upstream_control_channel received message of invalid type".to_string())),
}
}
Err(ConnectionTerminationReason::Shutdown("upstream channel closed".to_string()))
}
async fn outgoing_object_state_sink_channel(
input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
mut output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
object_state_sink: ObjectStateSink,
) -> Result<(), ConnectionTerminationReason> {
output
.send(Message::UninitializedChannelActions(vec![UninitializedChannelAction::Initialize(ChannelType::ObjectSinkToUpstream(
object_state_sink.descriptor().tags_to_strings(),
object_state_sink.descriptor().json_to_string(),
))]))
.await?;
receive_state(input, output, object_state_sink).await
}
async fn outgoing_object_state_stream_channel(
input: Pin<Box<dyn Stream<Item = Message> + Send + Sync + 'static>>,
mut output: Pin<Box<dyn Sink<Message, Error = ConnectionTerminationReason> + Send + Sync + 'static>>,
object_state_stream: ObjectStateStream,
) -> Result<(), ConnectionTerminationReason> {
let descriptor = object_state_stream.descriptor();
output
.send(Message::UninitializedChannelActions(vec![UninitializedChannelAction::Initialize(ChannelType::ObjectStreamFromUpstream(
descriptor.tags_to_strings(),
descriptor.json_to_string(),
))]))
.await?;
send_state(input, output, object_state_stream).await
}
impl Hash for DownstreamConnectionControl {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl PartialEq for DownstreamConnectionControl {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for DownstreamConnectionControl {}
pub mod monitor {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct DownstreamConnectionSnapshot {
pub downstream_connection_id: u64,
pub remote_address: String,
pub local_name: String,
pub remote_name: Option<String>,
pub load: f32,
pub load_limit: Option<f32>,
}
impl std::fmt::Debug for DownstreamConnectionSnapshot {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_fmt(format_args!("│ Id: {:?}\n", self.downstream_connection_id))?;
fmt.write_fmt(format_args!("│ Remote address: {}\n", self.remote_address))?;
if let Some(ref remote_name) = self.remote_name {
fmt.write_fmt(format_args!("│ Remote name: {}\n", remote_name))?;
}
if let Some(load_limit) = self.load_limit {
fmt.write_fmt(format_args!("│ Load: {} limit: {})\n", self.load, load_limit))?;
} else {
fmt.write_fmt(format_args!("│ Load: {}\n", self.load))?;
}
Ok(())
}
}
}