objectiveai_sdk/cli/command/command_executor/
plugin.rs1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use dashmap::DashMap;
6use futures::{Stream, StreamExt};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{Mutex, mpsc};
9
10use crate::cli::command::{
11 AgentArguments, CommandExecutor, CommandRequest,
12 CommandResponse as CommandResponseTrait,
13};
14use crate::cli::plugins::{Command, CommandType, Output};
15
16#[derive(Clone)]
28pub struct PluginExecutor {
29 stdout: Arc<Mutex<tokio::io::Stdout>>,
30 counter: Arc<AtomicU64>,
31 pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
32 listener_alive: Arc<AtomicBool>,
38}
39
40impl Default for PluginExecutor {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl PluginExecutor {
47 pub fn new() -> Self {
49 let pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>> =
50 Arc::new(DashMap::new());
51 let listener_alive = Arc::new(AtomicBool::new(true));
52 Self::spawn_listener(
53 tokio::io::stdin(),
54 pending.clone(),
55 listener_alive.clone(),
56 );
57 Self {
58 stdout: Arc::new(Mutex::new(tokio::io::stdout())),
59 counter: Arc::new(AtomicU64::new(0)),
60 pending,
61 listener_alive,
62 }
63 }
64
65 fn spawn_listener(
66 stdin: tokio::io::Stdin,
67 pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
68 listener_alive: Arc<AtomicBool>,
69 ) {
70 tokio::spawn(async move {
71 let mut lines = BufReader::new(stdin).lines();
72 while let Ok(Some(line)) = lines.next_line().await {
73 let env = match serde_json::from_str::<CommandResponse>(&line) {
74 Ok(e) => e,
75 Err(_) => continue,
76 };
77 match env {
78 CommandResponse::Value { id, value } => {
79 if let Some(sender) = pending.get(&id) {
80 if sender.send(value).is_err() {
81 drop(sender);
82 pending.remove(&id);
83 }
84 }
85 }
86 CommandResponse::Done { id, .. } => {
87 pending.remove(&id);
88 }
89 }
90 }
91 listener_alive.store(false, Ordering::Release);
98 pending.clear();
99 });
100 }
101}
102
103#[derive(serde::Deserialize, Debug, Clone)]
113#[serde(untagged)]
114enum CommandResponse {
115 Done {
118 id: String,
119 #[allow(dead_code)]
120 done: bool,
121 },
122 Value {
123 id: String,
124 value: serde_json::Value,
125 },
126}
127
128#[derive(Debug, thiserror::Error)]
129pub enum Error {
130 #[error("plugin executor stdin closed")]
133 Closed,
134 #[error("plugin executor io: {0}")]
135 Io(std::io::Error),
136 #[error("plugin executor decode line: {0}")]
137 Json(serde_json::Error),
138 #[error("{0}")]
139 Cli(crate::cli::Error),
140 #[error("plugin executor stream produced no items")]
142 Empty,
143}
144
145#[derive(serde::Deserialize)]
149#[serde(untagged)]
150enum Line<T> {
151 Err(crate::cli::Error),
152 Ok(T),
153}
154
155impl<T> From<Line<T>> for Result<T, Error> {
156 fn from(line: Line<T>) -> Self {
157 match line {
158 Line::Err(e) => Err(Error::Cli(e)),
159 Line::Ok(t) => Ok(t),
160 }
161 }
162}
163
164impl CommandExecutor for PluginExecutor {
165 type Error = Error;
166 type Stream<T>
167 = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
168 where
169 T: Send + 'static;
170
171 async fn execute<R, T>(
172 &self,
173 request: R,
174 _agent_arguments: Option<&AgentArguments>,
178 ) -> Result<Self::Stream<T>, Error>
179 where
180 R: CommandRequest + Send,
181 T: CommandResponseTrait + serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
182 {
183 let id = self.counter.fetch_add(1, Ordering::Relaxed).to_string();
184 let (tx, rx) = mpsc::unbounded_channel::<serde_json::Value>();
185 self.pending.insert(id.clone(), tx);
186
187 if !self.listener_alive.load(Ordering::Acquire) {
198 self.pending.remove(&id);
199 return Err(Error::Closed);
200 }
201
202 let argv = request.into_command();
203 let envelope = Output::Command(Command {
204 r#type: CommandType::Command,
205 id: id.clone(),
206 command: argv,
211 });
212 let line = serde_json::to_string(&envelope).expect("Output serializes");
213
214 {
215 let mut stdout = self.stdout.lock().await;
216 if let Err(e) = stdout.write_all(line.as_bytes()).await {
217 self.pending.remove(&id);
218 return Err(Error::Io(e));
219 }
220 if let Err(e) = stdout.write_all(b"\n").await {
221 self.pending.remove(&id);
222 return Err(Error::Io(e));
223 }
224 if let Err(e) = stdout.flush().await {
225 self.pending.remove(&id);
226 return Err(Error::Io(e));
227 }
228 }
229
230 let pending = self.pending.clone();
231 let stream = futures::stream::unfold(
232 (rx, id, pending),
233 |(mut rx, id, pending)| async move {
234 match rx.recv().await {
235 Some(value) => {
236 let item = match serde_json::from_value::<Line<T>>(value) {
237 Ok(line) => line.into(),
238 Err(e) => Err(Error::Json(e)),
239 };
240 Some((item, (rx, id, pending)))
241 }
242 None => {
243 pending.remove(&id);
244 None
245 }
246 }
247 },
248 );
249
250 Ok(Box::pin(stream))
251 }
252
253 async fn execute_one<R, T>(
254 &self,
255 request: R,
256 agent_arguments: Option<&AgentArguments>,
257 ) -> Result<T, Error>
258 where
259 R: CommandRequest + Send,
260 T: CommandResponseTrait + serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
261 {
262 let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
263 stream.next().await.ok_or(Error::Empty)?
264 }
265}