use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use dashmap::DashMap;
use futures::{Stream, StreamExt};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{Mutex, mpsc};
use crate::cli::command::{
AgentArguments, CommandExecutor, CommandRequest,
CommandResponse as CommandResponseTrait,
};
use crate::cli::plugins::{Command, CommandType, Output};
pub struct PluginExecutor {
stdout: Arc<Mutex<tokio::io::Stdout>>,
counter: AtomicU64,
pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
listener_alive: Arc<AtomicBool>,
}
impl Default for PluginExecutor {
fn default() -> Self {
Self::new()
}
}
impl PluginExecutor {
pub fn new() -> Self {
let pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>> =
Arc::new(DashMap::new());
let listener_alive = Arc::new(AtomicBool::new(true));
Self::spawn_listener(
tokio::io::stdin(),
pending.clone(),
listener_alive.clone(),
);
Self {
stdout: Arc::new(Mutex::new(tokio::io::stdout())),
counter: AtomicU64::new(0),
pending,
listener_alive,
}
}
fn spawn_listener(
stdin: tokio::io::Stdin,
pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
listener_alive: Arc<AtomicBool>,
) {
tokio::spawn(async move {
let mut lines = BufReader::new(stdin).lines();
while let Ok(Some(line)) = lines.next_line().await {
let env = match serde_json::from_str::<CommandResponse>(&line) {
Ok(e) => e,
Err(_) => continue,
};
match env {
CommandResponse::Value { id, value } => {
if let Some(sender) = pending.get(&id) {
if sender.send(value).is_err() {
drop(sender);
pending.remove(&id);
}
}
}
CommandResponse::Done { id, .. } => {
pending.remove(&id);
}
}
}
listener_alive.store(false, Ordering::Release);
pending.clear();
});
}
}
#[derive(serde::Deserialize, Debug, Clone)]
#[serde(untagged)]
enum CommandResponse {
Done {
id: String,
#[allow(dead_code)]
done: bool,
},
Value {
id: String,
value: serde_json::Value,
},
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("plugin executor stdin closed")]
Closed,
#[error("plugin executor io: {0}")]
Io(std::io::Error),
#[error("plugin executor decode line: {0}")]
Json(serde_json::Error),
#[error("{0}")]
Cli(crate::cli::Error),
#[error("plugin executor stream produced no items")]
Empty,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Line<T> {
Err(crate::cli::Error),
Ok(T),
}
impl<T> From<Line<T>> for Result<T, Error> {
fn from(line: Line<T>) -> Self {
match line {
Line::Err(e) => Err(Error::Cli(e)),
Line::Ok(t) => Ok(t),
}
}
}
impl CommandExecutor for PluginExecutor {
type Error = Error;
type Stream<T>
= Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
where
T: Send + 'static;
async fn execute<R, T>(
&self,
request: R,
_agent_arguments: Option<&AgentArguments>,
) -> Result<Self::Stream<T>, Error>
where
R: CommandRequest + Send,
T: CommandResponseTrait + serde::de::DeserializeOwned + Send + 'static,
{
let id = self.counter.fetch_add(1, Ordering::Relaxed).to_string();
let (tx, rx) = mpsc::unbounded_channel::<serde_json::Value>();
self.pending.insert(id.clone(), tx);
if !self.listener_alive.load(Ordering::Acquire) {
self.pending.remove(&id);
return Err(Error::Closed);
}
let argv = request.into_command();
let envelope = Output::Command(Command {
r#type: CommandType::Command,
id: id.clone(),
command: argv.join(" "),
});
let line = serde_json::to_string(&envelope).expect("Output serializes");
{
let mut stdout = self.stdout.lock().await;
if let Err(e) = stdout.write_all(line.as_bytes()).await {
self.pending.remove(&id);
return Err(Error::Io(e));
}
if let Err(e) = stdout.write_all(b"\n").await {
self.pending.remove(&id);
return Err(Error::Io(e));
}
if let Err(e) = stdout.flush().await {
self.pending.remove(&id);
return Err(Error::Io(e));
}
}
let pending = self.pending.clone();
let stream = futures::stream::unfold(
(rx, id, pending),
|(mut rx, id, pending)| async move {
match rx.recv().await {
Some(value) => {
let item = match serde_json::from_value::<Line<T>>(value) {
Ok(line) => line.into(),
Err(e) => Err(Error::Json(e)),
};
Some((item, (rx, id, pending)))
}
None => {
pending.remove(&id);
None
}
}
},
);
Ok(Box::pin(stream))
}
async fn execute_one<R, T>(
&self,
request: R,
agent_arguments: Option<&AgentArguments>,
) -> Result<T, Error>
where
R: CommandRequest + Send,
T: CommandResponseTrait + serde::de::DeserializeOwned + Send + 'static,
{
let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
stream.next().await.ok_or(Error::Empty)?
}
}