use anyhow::{Error, Result};
use async_trait::async_trait;
use channel_plugin::{
channel_client::ChannelClient, control_client::ControlClient, message::{ChannelMessage, ChannelState}, plugin_actor::PluginHandle, plugin_helpers::PluginError,
};
use dashmap::DashMap;
use futures::stream::{self, StreamExt};
use std::{
fmt,
path::PathBuf,
sync::{Arc, Mutex},
};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use crate::{
channel::{
plugin::{PluginEventHandler, PluginWatcher},
wrapper::PluginWrapper,
},
config::ConfigManager,
flow::session::SessionStore,
logger::LogConfig,
secret::SecretsManager,
watcher::DirectoryWatcher,
};
#[async_trait]
pub trait IncomingHandler: Send + Sync {
async fn handle_incoming(&self, msg: ChannelMessage, session_store: SessionStore);
}
#[derive(Clone)]
pub struct ChannelManager {
config: ConfigManager,
secrets: SecretsManager,
greentic_id: String,
store: SessionStore,
channels: Arc<DashMap<String, ManagedChannel>>,
log_config: LogConfig,
incoming_subscribers: Arc<Mutex<Vec<Arc<dyn IncomingHandler>>>>,
}
impl ChannelManager {
pub async fn new(
config: ConfigManager,
secrets: SecretsManager,
greentic_id: String,
store: SessionStore,
log_config: LogConfig,
) -> Result<Arc<Self>> {
let me = Arc::new(Self {
config,
secrets,
greentic_id,
store,
channels: Arc::new(DashMap::new()),
log_config,
incoming_subscribers: Arc::new(Mutex::new(vec![])),
});
Ok(me)
}
pub async fn diagnostics(&self) -> DashMap<String, ChannelState> {
let results = stream::iter(self.channels.iter())
.then(|kv| {
let key = kv.key().clone();
let wrapper = kv.value().wrapper.clone();
async move {
let state = wrapper.state().await;
(key, state)
}
})
.collect::<Vec<_>>()
.await;
results.into_iter().collect()
}
pub fn session_store(&self) -> SessionStore {
self.store.clone()
}
pub fn subscribe_incoming(&self, h: Arc<dyn IncomingHandler>) {
self.incoming_subscribers.lock().unwrap().push(h);
}
pub async fn register_channel(
&self,
name: String,
wrapper: ManagedChannel,
) -> Result<(), PluginError> {
self.channels.insert(name, wrapper);
Ok(())
}
pub async fn unload_channel(&self, name: &str) -> Result<(), PluginError> {
if let Some((_, mut wrapper)) = self.channels.remove(name) {
wrapper.wrapper.stop().await?;
}
Ok(())
}
pub async fn start_channel(&self, name: &str) -> Result<(), PluginError> {
if let Some(mut entry) = self.channels.get_mut(name) {
if entry.value_mut().wrapper.state().await == ChannelState::RUNNING {
info!(
"Ignoring start_channel {} because it is already staretd.",
name
);
return Ok(()); }
let config = self.config.0.as_ref().as_vec().await;
let secrets = self.secrets.0.as_ref().as_vec().await;
entry.value_mut().wrapper.start(config, secrets).await?;
}
Ok(())
}
pub fn channel(&self, name: &str) -> Option<PluginWrapper> {
if let Some(entry) = self.channels.get(name) {
Some(entry.value().wrapper.clone())
} else {
None
}
}
pub async fn stop_all(&self) -> Result<(), PluginError> {
for channel in self.channels.iter() {
let _ = self.stop_channel(&channel.wrapper().name()).await;
}
Ok(())
}
pub async fn stop_channel(&self, name: &str) -> Result<(), PluginError> {
if let Some(mut entry) = self.channels.get_mut(name) {
entry.value_mut().wrapper.stop().await?;
}
Ok(())
}
pub async fn send_to_channel(
&self,
name: &str,
msg: ChannelMessage,
) -> Result<(), PluginError> {
if let Some(mut wrapper) = self.channels.get_mut(name) {
wrapper.wrapper.send_message(msg).await
} else {
Err(PluginError::Other(format!("channel `{}` not loaded", name)))
}
}
pub fn list_channels(&self) -> Vec<String> {
self.channels.iter().map(|kv| kv.key().clone()).collect()
}
pub fn channels(&self) -> Arc<DashMap<String, ManagedChannel>> {
self.channels.clone()
}
pub async fn start_all(
self: Arc<Self>,
plugins_dir: PathBuf,
remote_channels: Vec<String>,
) -> Result<DirectoryWatcher, Error> {
for channel in remote_channels
{
let session_store = self.session_store();
let log_config = self.log_config.clone();
let greentic_id = self.greentic_id.clone();
let client = ChannelClient::new_pubsub(channel.clone(), greentic_id.clone()).await;
let control = ControlClient::new_pubsub(channel.clone(),greentic_id).await;
if client.is_err() {
let err = client.unwrap_err();
error!("Cannot create remote channel `{}` because {}", channel, err.to_string());
return Err(err);
} else if control.is_err() {
let err = control.unwrap_err();
error!("Cannot create remote channel `{}` because {}", channel, err.to_string());
return Err(err);
} else {
let plugin = PluginHandle::new(client.unwrap(), control.unwrap()).await;
let plugin_wrapper = PluginWrapper::new(plugin, session_store, log_config).await;
let wrapper = ManagedChannel::new(plugin_wrapper.clone(), None, None);
let _ = self.register_channel(channel, wrapper).await;
}
}
let watcher = Arc::new(PluginWatcher::new(plugins_dir.clone()).await);
watcher
.subscribe(self.clone() as Arc<dyn PluginEventHandler>, false)
.await;
match watcher.watch().await {
Ok(handle) => Ok(handle),
Err(err) => {
let error = format!(
"Could not watch the channel plugins at {}",
plugins_dir.to_string_lossy()
);
error!(error);
Err(err)
}
}
}
pub fn shutdown_all(&self, graceful: bool, timeout_ms: u64) {
for kv in self.channels.iter() {
let mut w = kv.value().wrapper.clone();
if graceful {
let _ = w.drain();
} else {
let _ = w.stop();
}
}
if graceful {
for kv in self.channels.iter() {
let mut w = kv.value().wrapper.clone();
let _ = w.wait_until_drained(timeout_ms);
}
}
}
}
#[async_trait]
impl PluginEventHandler for ChannelManager {
async fn plugin_added_or_reloaded(
&self,
name: &str,
plugin: PluginHandle,
) -> Result<(), Error> {
info!("Channel plugin added/reloaded: {}", name);
if let Some(mut old_plugin) = self.channels.get_mut(name) {
let mut wrapper = old_plugin.wrapper().clone();
if wrapper.stop().await.is_err() {
info!("Could not stop the existing plugin {} ", name);
}
old_plugin.cancel.as_ref().map(|tok| tok.cancel());
old_plugin.poller.as_ref().map(|poller| poller.abort());
if let Err(e) = old_plugin.wrapper.stop().await {
info!("Could not stop existing plugin `{}`: {:?}", name, e);
}
drop(old_plugin);
self.channels.remove(name);
info!("— replaced existing channel `{}`", name);
}
let wrapper = PluginWrapper::new(plugin, self.store.clone(), self.log_config.clone()).await;
let mut wrapper_cloned = wrapper.clone();
let plugin_name = name.to_string();
let config = self.config.0.as_vec().await;
let secrets = self.secrets.0.as_vec().await;
match wrapper_cloned.start(config, secrets).await {
Ok(()) => tracing::info!("Plugin `{}` started", plugin_name),
Err(e) => tracing::error!("Failed to start `{}`: {:?}", plugin_name, e),
}
let caps = wrapper.capabilities().await;
if caps.supports_receiving {
let channel_name = name.to_string();
let subs = self.incoming_subscribers.clone();
let cancel_token = CancellationToken::new();
let poller_cancel = cancel_token.clone();
let poller_wrapper = wrapper.clone();
let store = self.store.clone();
let poller = tokio::spawn(async move {
loop {
let store = store.clone();
if poller_cancel.is_cancelled() {
break;
}
let mut w = poller_wrapper.clone();
let poll_result = w.receive_message().await;
match poll_result {
Ok(mut msg) => {
msg.channel = channel_name.clone();
let handlers = {
let guard = subs.lock().unwrap();
guard.clone()
};
for h in handlers {
let m = msg.clone();
let store = store.clone();
tokio::spawn(async move {
let _ = h.handle_incoming(m, store.clone()).await;
});
}
}
Err(err) => {
tracing::warn!(%channel_name, ?err, "plugin.receive_message() returned error");
}
}
}
});
self.channels.insert(
name.to_string(),
ManagedChannel {
wrapper,
cancel: Some(cancel_token),
poller: Some(poller),
},
);
} else {
self.channels.insert(
name.to_string(),
ManagedChannel {
wrapper,
cancel: None,
poller: None,
},
);
}
Ok(())
}
async fn plugin_removed(&self, name: &str) -> Result<(), Error> {
if let Some(mut old_plugin) = self.channels.get_mut(name) {
let mut wrapper = old_plugin.wrapper().clone();
let drain_result = wrapper.drain().await;
if drain_result.is_err() {
info!("Could not start draining the existing plugin {} ", name);
}
let is_drained_result = wrapper.wait_until_drained(3000).await;
if is_drained_result.is_err() {
info!("Could not drain the existing plugin {} ", name);
}
old_plugin.cancel().as_ref().map(|tok| tok.cancel());
old_plugin.poller().as_ref().map(|poller| poller.abort());
if let Err(e) = old_plugin.wrapper.stop().await {
info!("Could not stop existing plugin `{}`: {:?}", name, e);
}
drop(old_plugin);
self.channels.remove(name);
info!("— replaced existing channel `{}`", name);
} else {
warn!("Tried to remove unknown channel plugin: {}", name);
}
Ok(())
}
}
#[derive(Debug,)]
pub struct ManagedChannel {
wrapper: PluginWrapper,
cancel: Option<CancellationToken>,
poller: Option<JoinHandle<()>>,
}
impl ManagedChannel {
pub fn new(
wrapper: PluginWrapper,
cancel: Option<CancellationToken>,
poller: Option<JoinHandle<()>>,
) -> Self {
Self {
wrapper,
cancel,
poller,
}
}
fn cancel(&self) -> &Option<CancellationToken> {
&self.cancel
}
fn poller(&mut self) -> &Option<JoinHandle<()>> {
&self.poller
}
pub fn wrapper(&self) -> &PluginWrapper {
&self.wrapper
}
}
impl fmt::Debug for ChannelManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let channel_names: Vec<String> = self.channels.iter().map(|kv| kv.key().clone()).collect();
let subscriber_count = {
match self.incoming_subscribers.try_lock() {
Ok(vec) => vec.len(),
Err(_) => usize::MAX, }
};
f.debug_struct("ChannelManager")
.field("config", &self.config)
.field("secrets", &self.secrets)
.field("channel_names", &channel_names)
.field("host_logger", &"<HostLogger - no Debug>")
.field("incoming_subscriber_count", &subscriber_count)
.finish()
}
}
#[cfg(test)]
pub mod tests {
use channel_plugin::plugin_test_util::spawn_mock_handle;
use crate::{
config::MapConfigManager, flow::session::InMemorySessionStore, secret::TestSecretsManager,
};
use super::*;
impl ChannelManager {
pub fn dummy() -> Arc<Self> {
Arc::new(ChannelManager {
greentic_id: "123".to_string(),
config: ConfigManager(MapConfigManager::new()),
secrets: SecretsManager(TestSecretsManager::new()),
store: InMemorySessionStore::new(10),
channels: Arc::new(DashMap::new()),
log_config: LogConfig::default(),
incoming_subscribers: Arc::new(Mutex::new(Vec::new())),
})
}
}
#[tokio::test]
async fn test_register_and_unload() {
let secrets = SecretsManager(TestSecretsManager::new());
let config = ConfigManager(MapConfigManager::new());
let store = InMemorySessionStore::new(10);
let mgr = ChannelManager::new(config, secrets, "123".to_string(), store.clone(), LogConfig::default())
.await
.unwrap();
let (_mock, plugin_handle) = spawn_mock_handle().await;
let wrapper = PluginWrapper::new(plugin_handle, store, LogConfig::default()).await;
mgr.register_channel(
"foo".into(),
ManagedChannel {
wrapper,
cancel: None,
poller: None,
},
)
.await
.unwrap();
assert_eq!(mgr.list_channels(), vec!["foo".to_string()]);
mgr.unload_channel("foo").await.unwrap();
assert!(mgr.list_channels().is_empty());
}
}