objectiveai_sdk/cli/command/command_executor/
plugin.rs1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use dashmap::DashMap;
6use futures::{Stream, StreamExt};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{Mutex, mpsc};
9
10use crate::cli::command::{
11 AgentArguments, CommandExecutor, CommandRequest,
12 CommandResponse as CommandResponseTrait,
13};
14use crate::cli::plugins::{Command, CommandType, Output};
15
16pub struct PluginExecutor {
25 stdout: Arc<Mutex<tokio::io::Stdout>>,
26 counter: AtomicU64,
27 pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
28 listener_alive: Arc<AtomicBool>,
34}
35
36impl Default for PluginExecutor {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl PluginExecutor {
43 pub fn new() -> Self {
45 let pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>> =
46 Arc::new(DashMap::new());
47 let listener_alive = Arc::new(AtomicBool::new(true));
48 Self::spawn_listener(
49 tokio::io::stdin(),
50 pending.clone(),
51 listener_alive.clone(),
52 );
53 Self {
54 stdout: Arc::new(Mutex::new(tokio::io::stdout())),
55 counter: AtomicU64::new(0),
56 pending,
57 listener_alive,
58 }
59 }
60
61 fn spawn_listener(
62 stdin: tokio::io::Stdin,
63 pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
64 listener_alive: Arc<AtomicBool>,
65 ) {
66 tokio::spawn(async move {
67 let mut lines = BufReader::new(stdin).lines();
68 while let Ok(Some(line)) = lines.next_line().await {
69 let env = match serde_json::from_str::<CommandResponse>(&line) {
70 Ok(e) => e,
71 Err(_) => continue,
72 };
73 match env {
74 CommandResponse::Value { id, value } => {
75 if let Some(sender) = pending.get(&id) {
76 if sender.send(value).is_err() {
77 drop(sender);
78 pending.remove(&id);
79 }
80 }
81 }
82 CommandResponse::Done { id, .. } => {
83 pending.remove(&id);
84 }
85 }
86 }
87 listener_alive.store(false, Ordering::Release);
94 pending.clear();
95 });
96 }
97}
98
99#[derive(serde::Deserialize, Debug, Clone)]
109#[serde(untagged)]
110enum CommandResponse {
111 Done {
114 id: String,
115 #[allow(dead_code)]
116 done: bool,
117 },
118 Value {
119 id: String,
120 value: serde_json::Value,
121 },
122}
123
124#[derive(Debug, thiserror::Error)]
125pub enum Error {
126 #[error("plugin executor stdin closed")]
129 Closed,
130 #[error("plugin executor io: {0}")]
131 Io(std::io::Error),
132 #[error("plugin executor decode line: {0}")]
133 Json(serde_json::Error),
134 #[error("{0}")]
135 Cli(crate::cli::Error),
136 #[error("plugin executor stream produced no items")]
138 Empty,
139}
140
141#[derive(serde::Deserialize)]
145#[serde(untagged)]
146enum Line<T> {
147 Err(crate::cli::Error),
148 Ok(T),
149}
150
151impl<T> From<Line<T>> for Result<T, Error> {
152 fn from(line: Line<T>) -> Self {
153 match line {
154 Line::Err(e) => Err(Error::Cli(e)),
155 Line::Ok(t) => Ok(t),
156 }
157 }
158}
159
160impl CommandExecutor for PluginExecutor {
161 type Error = Error;
162 type Stream<T>
163 = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
164 where
165 T: Send + 'static;
166
167 async fn execute<R, T>(
168 &self,
169 request: R,
170 _agent_arguments: Option<&AgentArguments>,
174 ) -> Result<Self::Stream<T>, Error>
175 where
176 R: CommandRequest + Send,
177 T: CommandResponseTrait + serde::de::DeserializeOwned + Send + 'static,
178 {
179 let id = self.counter.fetch_add(1, Ordering::Relaxed).to_string();
180 let (tx, rx) = mpsc::unbounded_channel::<serde_json::Value>();
181 self.pending.insert(id.clone(), tx);
182
183 if !self.listener_alive.load(Ordering::Acquire) {
194 self.pending.remove(&id);
195 return Err(Error::Closed);
196 }
197
198 let argv = request.into_command();
199 let envelope = Output::Command(Command {
200 r#type: CommandType::Command,
201 id: id.clone(),
202 command: argv.join(" "),
203 });
204 let line = serde_json::to_string(&envelope).expect("Output serializes");
205
206 {
207 let mut stdout = self.stdout.lock().await;
208 if let Err(e) = stdout.write_all(line.as_bytes()).await {
209 self.pending.remove(&id);
210 return Err(Error::Io(e));
211 }
212 if let Err(e) = stdout.write_all(b"\n").await {
213 self.pending.remove(&id);
214 return Err(Error::Io(e));
215 }
216 if let Err(e) = stdout.flush().await {
217 self.pending.remove(&id);
218 return Err(Error::Io(e));
219 }
220 }
221
222 let pending = self.pending.clone();
223 let stream = futures::stream::unfold(
224 (rx, id, pending),
225 |(mut rx, id, pending)| async move {
226 match rx.recv().await {
227 Some(value) => {
228 let item = match serde_json::from_value::<Line<T>>(value) {
229 Ok(line) => line.into(),
230 Err(e) => Err(Error::Json(e)),
231 };
232 Some((item, (rx, id, pending)))
233 }
234 None => {
235 pending.remove(&id);
236 None
237 }
238 }
239 },
240 );
241
242 Ok(Box::pin(stream))
243 }
244
245 async fn execute_one<R, T>(
246 &self,
247 request: R,
248 agent_arguments: Option<&AgentArguments>,
249 ) -> Result<T, Error>
250 where
251 R: CommandRequest + Send,
252 T: CommandResponseTrait + serde::de::DeserializeOwned + Send + 'static,
253 {
254 let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
255 stream.next().await.ok_or(Error::Empty)?
256 }
257}