use std::collections::HashMap;
use std::{panic, time::Instant};
use crate::message::*;
use crate::{
jsonrpc::{Id, Message, Request, Response},
plugin_actor::Method,
};
use anyhow::Result;
use async_trait::async_trait;
use dashmap::DashMap;
use serde_json::{Value, json};
use tokio::{
io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
sync::mpsc::{self, UnboundedSender},
time::sleep,
};
use tracing::{Dispatch, dispatcher, error, level_filters::LevelFilter};
use tracing_appender::rolling::daily;
use tracing_subscriber::Layer;
use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt;
use tracing_subscriber::{Registry, fmt};
pub trait HasStore {
fn config_store(&self) -> &DashMap<String, String>;
fn secret_store(&self) -> &DashMap<String, String>;
}
pub const VERSION: &str = "0.1.0";
#[async_trait]
pub trait PluginHandler: HasStore + Send + Sync + Clone + 'static {
async fn init(&mut self, params: InitParams) -> InitResult;
async fn start(&mut self, params: InitParams) -> InitResult {
let res = self.init_from_params(¶ms).await;
if !res.success {
return res;
}
let result = self.init(params).await;
result
}
async fn drain(&mut self) -> DrainResult;
async fn wait_until_drained(&self, params: WaitUntilDrainedParams) -> WaitUntilDrainedResult {
let deadline = Instant::now() + std::time::Duration::from_millis(params.timeout_ms);
loop {
let state = self.state().await.state;
if state == ChannelState::STOPPED {
return WaitUntilDrainedResult {
stopped: true,
error: false,
};
}
if Instant::now() >= deadline {
return WaitUntilDrainedResult {
stopped: false,
error: true,
};
}
sleep(std::time::Duration::from_millis(100)).await;
}
}
async fn stop(&mut self) -> StopResult;
async fn send_message(&mut self, params: MessageOutParams) -> MessageOutResult;
async fn receive_message(&mut self) -> MessageInResult;
async fn health(&self) -> HealthResult {
HealthResult {
healthy: true,
reason: None,
}
}
async fn version(&self) -> VersionResult {
VersionResult {
version: VERSION.to_string(),
}
}
async fn state(&self) -> StateResult;
async fn set_config(&mut self, p: SetConfigParams) -> SetConfigResult {
for (k, v) in p.config {
self.config_store().insert(k, v);
}
let required: std::collections::HashSet<_> = self
.list_config_keys()
.required_keys
.into_iter()
.map(|(k, _)| k)
.collect();
let missing: Vec<_> = required
.into_iter()
.filter(|k| !self.config_store().contains_key(k))
.collect();
if missing.is_empty() {
SetConfigResult {
success: true,
error: None,
}
} else {
let msg = format!("missing required config keys: {}", missing.join(", "));
SetConfigResult {
success: false,
error: Some(msg),
}
}
}
async fn set_secrets(&mut self, p: SetSecretsParams) -> SetSecretsResult {
for (k, v) in p.secrets {
self.secret_store().insert(k, v);
}
let required: std::collections::HashSet<_> = self
.list_secret_keys()
.required_keys
.into_iter()
.map(|(k, _)| k)
.collect();
let missing: Vec<_> = required
.into_iter()
.filter(|k| !self.secret_store().contains_key(k))
.collect();
if missing.is_empty() {
SetSecretsResult {
success: true,
error: None,
}
} else {
let msg = format!("missing required secret keys: {}", missing.join(", "));
SetSecretsResult {
success: false,
error: Some(msg),
}
}
}
fn get_config(&self, key: &str) -> Option<String> {
self.config_store() .get(key) .map(|guard| guard.value().clone())
}
fn get_secret(&self, key: &str) -> Option<String> {
self.secret_store() .get(key) .map(|guard| guard.value().clone())
}
fn name(&self) -> NameResult;
fn list_config_keys(&self) -> ListKeysResult;
fn list_secret_keys(&self) -> ListKeysResult;
fn capabilities(&self) -> CapabilitiesResult;
async fn init_from_params(&mut self, params: &InitParams) -> InitResult {
static LOG_INIT: std::sync::Once = std::sync::Once::new();
LOG_INIT.call_once(|| {
let result = panic::catch_unwind(|| {
let level = match params.log_level {
LogLevel::Trace => LevelFilter::TRACE,
LogLevel::Debug => LevelFilter::DEBUG,
LogLevel::Info => LevelFilter::INFO,
LogLevel::Warn => LevelFilter::WARN,
LogLevel::Error => LevelFilter::ERROR,
LogLevel::Critical => LevelFilter::ERROR,
};
let subscriber_dispatch: Dispatch = if let Some(dir) = ¶ms.log_dir {
std::fs::create_dir_all(dir).ok(); let file_app = daily(dir, "plugin.log");
Dispatch::new(
Registry::default().with(
fmt::layer()
.with_ansi(false)
.with_target(false)
.with_writer(file_app)
.with_filter(level),
),
)
} else {
panic!("❌ Logging requires a `log_dir`. None was provided.");
};
dispatcher::set_global_default(subscriber_dispatch)
.expect("failed to install tracing subscriber");
if cfg!(debug_assertions) && std::io::IsTerminal::is_terminal(&std::io::stdout()) {
tracing::warn!(
"‼️ A tracing layer is writing to STDOUT – \
this WILL break the JSON-RPC protocol."
);
}
});
if result.is_err() {
eprintln!("❌ Logging setup failed");
}
});
let res = self
.set_config(SetConfigParams {
config: params.config.clone(),
})
.await;
if !res.success {
return InitResult {
success: false,
error: res.error,
};
}
let res = self
.set_secrets(SetSecretsParams {
secrets: params.secrets.clone(),
})
.await;
if !res.success {
return InitResult {
success: false,
error: res.error,
};
}
InitResult {
success: true,
error: None,
}
}
}
pub async fn run<P: PluginHandler>(mut plugin: P) -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
tokio::spawn(async move {
let mut w = BufWriter::new(io::stdout());
while let Some(line) = rx.recv().await {
if let Err(e) = w.write_all(line.as_bytes()).await {
error!("stdout write error: {e}");
break; }
if w.flush().await.is_err() {
error!("stdout flush error");
break;
}
}
});
let mut plugin_clone = plugin.clone();
tokio::spawn(async move {
loop {
if plugin_clone.state().await.state == ChannelState::RUNNING {
let result = plugin_clone.receive_message().await;
match serde_json::to_value(&result) {
Ok(v) => {
let notif = Request::notification("messageIn", Some(v));
let mut w = BufWriter::new(io::stdout());
let msg = format!("{}\n", serde_json::to_string(¬if).unwrap());
if let Err(e) = w.write_all(msg.as_bytes()).await {
error!("stdout write error: {e}");
}
if w.flush().await.is_err() {
error!("stdout flush error");
}
}
Err(e) => {
error!("serde_json error serialising MessageInResult: {e}");
break;
}
}
} else {
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
}
}
});
let mut reader = BufReader::new(io::stdin());
let mut line = String::new();
while reader.read_line(&mut line).await? != 0 {
trim_newlines(&mut line);
if line.is_empty() {
continue;
}
match serde_json::from_str::<Message>(&line) {
Ok(Message::Request(req)) => handle_request(&mut plugin, req, &tx.clone()).await,
Ok(_) => { }
Err(e) => {
let err =
Response::fail(Id::Null, -32700, "Parse error", Some(json!(e.to_string())));
let _ = tx.send(format!("{}\n", serde_json::to_string(&err).unwrap()));
}
}
line.clear();
}
Ok(())
}
fn trim_newlines(s: &mut String) {
while matches!(s.chars().last(), Some('\n' | '\r')) {
s.pop();
}
}
async fn handle_request<P>(plugin: &mut P, req: Request, tx: &UnboundedSender<String>)
where
P: PluginHandler,
{
fn enqueue(tx: &UnboundedSender<String>, resp: Response) {
let _ = tx.send(format!("{}\n", serde_json::to_string(&resp).unwrap()));
}
match req.method.parse::<Method>() {
Ok(Method::Init) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<InitParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.init(p).await)));
}
}
}
}
Ok(Method::Start) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<InitParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.start(p).await)));
}
}
}
}
Ok(Method::Drain) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.drain().await)));
}
}
Ok(Method::Stop) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.stop().await)));
}
}
Ok(Method::MessageOut) => {
match serde_json::from_value::<MessageOutParams>(req.params.unwrap_or(Value::Null)) {
Ok(p) => {
let result = plugin.send_message(p).await;
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(result)));
}
}
Err(e) => {
if let Some(id) = req.id {
enqueue(
tx,
Response::fail(
id,
-32602,
"Invalid params",
Some(json!(e.to_string())),
),
);
}
}
}
}
Ok(Method::Name) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.name())));
}
}
Ok(Method::Health) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.health().await)));
}
}
Ok(Method::State) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.state().await)));
}
}
Ok(Method::Capabilities) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.capabilities())));
}
}
Ok(Method::ListConfigKeys) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.list_config_keys())));
}
}
Ok(Method::ListSecretKeys) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.list_secret_keys())));
}
}
Ok(Method::WaitUntilDrained) => {
if let Some(v) = req.params {
if let Some(id) = req.id {
if let Ok(p) = serde_json::from_value::<WaitUntilDrainedParams>(v) {
enqueue(
tx,
Response::success(id, json!(plugin.wait_until_drained(p).await)),
);
}
}
}
}
Ok(Method::SetConfig) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<SetConfigParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.set_config(p).await)));
}
}
}
}
Ok(Method::SetSecrets) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<SetSecretsParams>(v) {
if let Some(id) = req.id {
enqueue(
tx,
Response::success(id, json!(plugin.set_secrets(p).await)),
);
}
}
}
}
method => {
error!("Failed to implement method {:?} in handle_request", method);
if let Some(id) = req.id {
enqueue(tx, Response::fail(id, -32601, "Method not found", None));
}
}
}
}
pub fn fill_dynamic_fields(template: &str, values: &HashMap<&str, &str>) -> String {
let mut result = template.to_string();
for (key, value) in values {
let placeholder = format!("{{{}}}", key);
result = result.replace(&placeholder, value);
}
result
}