use std::{panic, time::{Instant}};
use async_trait::async_trait;
use anyhow::Result;
use dashmap::DashMap;
use tracing::{dispatcher, error, level_filters::LevelFilter, Dispatch};
use tracing_appender::rolling::daily;
use tracing_subscriber::{fmt, Registry};
use crate::{jsonrpc::{Id, Message, Request, Response}, plugin_actor::Method};
use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt;
use tracing_subscriber::Layer;
use serde_json::{json, Value};
use tokio::{io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, sync::{mpsc::{self, UnboundedSender}}, time::sleep};
use crate::message::*;
pub trait HasStore {
fn config_store(&self) -> &DashMap<String, String>;
fn secret_store(&self) -> &DashMap<String, String>;
}
pub const VERSION: &str = "0.1.0";
#[async_trait]
pub trait PluginHandler: HasStore + Send + Sync + Clone + 'static {
async fn init(&mut self, params: InitParams) -> InitResult;
async fn start(&mut self, params: InitParams) -> InitResult {
let res = self.init_from_params(¶ms).await;
if !res.success {
return res;
}
let result = self.init(params).await;
result
}
async fn drain(&mut self) -> DrainResult;
async fn wait_until_drained(&self, params: WaitUntilDrainedParams) -> WaitUntilDrainedResult {
let deadline = Instant::now() + std::time::Duration::from_millis(params.timeout_ms);
loop {
let state = self.state().await.state;
if state == ChannelState::STOPPED {
return WaitUntilDrainedResult{stopped: true, error: false};
}
if Instant::now() >= deadline {
return WaitUntilDrainedResult{stopped: false, error: true};
}
sleep(std::time::Duration::from_millis(100)).await;
}
}
async fn stop(&mut self) -> StopResult;
async fn send_message(&mut self, params: MessageOutParams) -> MessageOutResult;
async fn receive_message(&mut self) -> MessageInResult;
async fn health(&self) -> HealthResult {
HealthResult { healthy: true, reason: None }
}
async fn version(&self) -> VersionResult{
VersionResult{version: VERSION.to_string()}
}
async fn state(&self) -> StateResult;
async fn set_config(&mut self, p: SetConfigParams) -> SetConfigResult {
for (k, v) in p.config {
self.config_store().insert(k, v);
}
let required: std::collections::HashSet<_> =
self.list_config_keys().required_keys.into_iter().map(|(k, _)| k).collect();
let missing: Vec<_> = required
.into_iter()
.filter(|k| !self.config_store().contains_key(k))
.collect();
if missing.is_empty() {
SetConfigResult { success: true, error: None }
} else {
let msg = format!("missing required config keys: {}", missing.join(", "));
SetConfigResult { success: false, error: Some(msg) }
}
}
async fn set_secrets(&mut self, p: SetSecretsParams) -> SetSecretsResult {
for (k, v) in p.secrets {
self.secret_store().insert(k, v);
}
let required: std::collections::HashSet<_> =
self.list_secret_keys().required_keys.into_iter().map(|(k, _)| k).collect();
let missing: Vec<_> = required
.into_iter()
.filter(|k| !self.secret_store().contains_key(k))
.collect();
if missing.is_empty() {
SetSecretsResult { success: true, error: None }
} else {
let msg = format!("missing required secret keys: {}", missing.join(", "));
SetSecretsResult { success: false, error: Some(msg) }
}
}
fn get_config(&self, key: &str) -> Option<String> {
self.config_store() .get(key) .map(|guard| guard.value().clone())
}
fn get_secret(&self, key: &str) -> Option<String> {
self.secret_store() .get(key) .map(|guard| guard.value().clone())
}
fn name(&self) -> NameResult;
fn list_config_keys(&self) -> ListKeysResult;
fn list_secret_keys(&self) -> ListKeysResult;
fn capabilities(&self) -> CapabilitiesResult;
async fn init_from_params(&mut self, params: &InitParams) -> InitResult{
static LOG_INIT: std::sync::Once = std::sync::Once::new();
LOG_INIT.call_once(|| {
let result = panic::catch_unwind(|| {
let level = match params.log_level {
LogLevel::Trace => LevelFilter::TRACE,
LogLevel::Debug => LevelFilter::DEBUG,
LogLevel::Info => LevelFilter::INFO,
LogLevel::Warn => LevelFilter::WARN,
LogLevel::Error => LevelFilter::ERROR,
LogLevel::Critical => LevelFilter::ERROR,
};
let subscriber_dispatch: Dispatch = if let Some(dir) = ¶ms.log_dir {
std::fs::create_dir_all(dir).ok(); let file_app = daily(dir, "plugin.log");
Dispatch::new(
Registry::default()
.with(
fmt::layer()
.with_ansi(false)
.with_target(false)
.with_writer(file_app)
.with_filter(level),
),
)
} else {
panic!("❌ Logging requires a `log_dir`. None was provided.");
};
dispatcher::set_global_default(subscriber_dispatch)
.expect("failed to install tracing subscriber");
if cfg!(debug_assertions) && std::io::IsTerminal::is_terminal(&std::io::stdout()) {
tracing::warn!(
"‼️ A tracing layer is writing to STDOUT – \
this WILL break the JSON-RPC protocol."
);
}
});
if result.is_err() {
eprintln!("❌ Logging setup failed");
}
});
let res = self
.set_config(SetConfigParams { config: params.config.clone() })
.await;
if !res.success {
return InitResult{ success: false, error: res.error }
}
let res = self.set_secrets(SetSecretsParams{secrets:params.secrets.clone()}).await;
if !res.success {
return InitResult{ success: false, error: res.error }
}
InitResult{ success: true, error: None }
}
}
pub async fn run<P: PluginHandler>(mut plugin: P) -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
tokio::spawn(async move {
let mut w = BufWriter::new(io::stdout());
while let Some(line) = rx.recv().await {
if let Err(e) = w.write_all(line.as_bytes()).await {
error!("stdout write error: {e}");
break; }
if w.flush().await.is_err() {
error!("stdout flush error");
break;
}
}
});
let mut plugin_clone = plugin.clone();
tokio::spawn(async move {
loop {
let result = plugin_clone.receive_message().await;
match serde_json::to_value(&result) {
Ok(v) => {
let notif = Request::notification("messageIn", Some(v));
let mut w = BufWriter::new(io::stdout());
let msg = format!("{}\n", serde_json::to_string(¬if).unwrap());
if let Err(e) = w.write_all(msg.as_bytes()).await {
error!("stdout write error: {e}");
}
if w.flush().await.is_err() {
error!("stdout flush error");
}
}
Err(e) => {
error!("serde_json error serialising MessageInResult: {e}");
break;
}
}
}
});
let mut reader = BufReader::new(io::stdin());
let mut line = String::new();
while reader.read_line(&mut line).await? != 0 {
trim_newlines(&mut line);
if line.is_empty() { continue; }
match serde_json::from_str::<Message>(&line) {
Ok(Message::Request(req)) => {
handle_request(&mut plugin, req, &tx.clone()).await
}
Ok(_) => { }
Err(e) => {
let err = Response::fail(Id::Null, -32700, "Parse error", Some(json!(e.to_string())));
let _ = tx.send(format!("{}\n", serde_json::to_string(&err).unwrap()));
}
}
line.clear();
}
Ok(())
}
fn trim_newlines(s: &mut String) {
while matches!(s.chars().last(), Some('\n' | '\r')) { s.pop(); }
}
async fn handle_request<P>(
plugin: &mut P,
req: Request,
tx: &UnboundedSender<String>,
)
where
P: PluginHandler,
{
fn enqueue(tx: &UnboundedSender<String>, resp: Response) {
let _ = tx.send(format!("{}\n", serde_json::to_string(&resp).unwrap()));
}
match req.method.parse::<Method>() {
Ok(Method::Init) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<InitParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.init(p).await)));
}
}
}
}
Ok(Method::Start) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<InitParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.start(p).await)));
}
}
}
}
Ok(Method::Drain) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.drain().await)));
}
}
Ok(Method::Stop) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.stop().await)));
}
}
Ok(Method::MessageOut) => {
match serde_json::from_value::<MessageOutParams>(req.params.unwrap_or(Value::Null)) {
Ok(p) => {
let result = plugin.send_message(p).await;
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(result)));
}
}
Err(e) => {
if let Some(id) = req.id {
enqueue(tx,
Response::fail(id, -32602, "Invalid params", Some(json!(e.to_string())))
);
}
}
}
}
Ok(Method::Name) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.name())));
}
}
Ok(Method::Health) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.health().await)));
}
}
Ok(Method::State) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.state().await)));
}
}
Ok(Method::Capabilities) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.capabilities())));
}
}
Ok(Method::ListConfigKeys) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.list_config_keys())));
}
}
Ok(Method::ListSecretKeys) => {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.list_secret_keys())));
}
}
Ok(Method::WaitUntilDrained) => {
if let Some(v) = req.params {
if let Some(id) = req.id {
if let Ok(p) = serde_json::from_value::<WaitUntilDrainedParams>(v) {
enqueue(tx, Response::success(id, json!(plugin.wait_until_drained(p).await)));
}
}
}
}
Ok(Method::SetConfig) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<SetConfigParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.set_config(p).await)));
}
}
}
}
Ok(Method::SetSecrets) => {
if let Some(v) = req.params {
if let Ok(p) = serde_json::from_value::<SetSecretsParams>(v) {
if let Some(id) = req.id {
enqueue(tx, Response::success(id, json!(plugin.set_secrets(p).await)));
}
}
}
}
method => {
error!("Failed to implement method {:?} in handle_request",method);
if let Some(id) = req.id {
enqueue(tx, Response::fail(id, -32601, "Method not found", None));
}
}
}
}