objectiveai_sdk/cli/command/command_executor/
plugin.rs1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use dashmap::DashMap;
6use futures::{Stream, StreamExt};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{Mutex, mpsc};
9
10use crate::cli::command::{
11 AgentArguments, CommandExecutor, CommandRequest,
12 CommandResponse as CommandResponseTrait,
13};
14use crate::cli::plugins::{Command, CommandType, Output};
15
16#[derive(Clone)]
28pub struct PluginExecutor {
29 stdout: Arc<Mutex<tokio::io::Stdout>>,
30 counter: Arc<AtomicU64>,
31 pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
32 listener_alive: Arc<AtomicBool>,
38}
39
40impl Default for PluginExecutor {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl PluginExecutor {
47 pub fn new() -> Self {
49 let pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>> =
50 Arc::new(DashMap::new());
51 let listener_alive = Arc::new(AtomicBool::new(true));
52 Self::spawn_listener(
53 tokio::io::stdin(),
54 pending.clone(),
55 listener_alive.clone(),
56 );
57 Self {
58 stdout: Arc::new(Mutex::new(tokio::io::stdout())),
59 counter: Arc::new(AtomicU64::new(0)),
60 pending,
61 listener_alive,
62 }
63 }
64
65 fn spawn_listener(
66 stdin: tokio::io::Stdin,
67 pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
68 listener_alive: Arc<AtomicBool>,
69 ) {
70 tokio::spawn(async move {
71 let mut lines = BufReader::new(stdin).lines();
72 while let Ok(Some(line)) = lines.next_line().await {
73 let env = match serde_json::from_str::<CommandResponse>(&line) {
74 Ok(e) => e,
75 Err(_) => continue,
76 };
77 match env {
78 CommandResponse::Value { id, value } => {
79 if let Some(sender) = pending.get(&id) {
80 if sender.send(value).is_err() {
81 drop(sender);
82 pending.remove(&id);
83 }
84 }
85 }
86 CommandResponse::Done { id, .. } => {
87 pending.remove(&id);
88 }
89 }
90 }
91 listener_alive.store(false, Ordering::Release);
98 pending.clear();
99 });
100 }
101}
102
103#[derive(serde::Deserialize, Debug, Clone)]
113#[serde(untagged)]
114enum CommandResponse {
115 Done {
118 id: String,
119 #[allow(dead_code)]
120 done: bool,
121 },
122 Value {
123 id: String,
124 value: serde_json::Value,
125 },
126}
127
128#[derive(Debug, thiserror::Error)]
129pub enum Error {
130 #[error("plugin executor stdin closed")]
133 Closed,
134 #[error("plugin executor io: {0}")]
135 Io(std::io::Error),
136 #[error("plugin executor decode line: {0}")]
137 Json(serde_json::Error),
138 #[error("{0}")]
139 Cli(crate::cli::Error),
140 #[error("plugin executor stream produced no items")]
142 Empty,
143}
144
145#[derive(serde::Deserialize)]
149#[serde(untagged)]
150enum Line<T> {
151 Err(crate::cli::Error),
152 Ok(T),
153}
154
155impl<T> From<Line<T>> for Result<T, Error> {
156 fn from(line: Line<T>) -> Self {
157 match line {
158 Line::Err(e) => Err(Error::Cli(e)),
159 Line::Ok(t) => Ok(t),
160 }
161 }
162}
163
164impl CommandExecutor for PluginExecutor {
165 type Error = Error;
166 type Stream<T>
167 = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
168 where
169 T: Send + 'static;
170
171 async fn execute<R, T>(
172 &self,
173 request: R,
174 _agent_arguments: Option<&AgentArguments>,
178 ) -> Result<Self::Stream<T>, Error>
179 where
180 R: CommandRequest + Send + serde::Serialize,
181 T: CommandResponseTrait + serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
182 {
183 let id = self.counter.fetch_add(1, Ordering::Relaxed).to_string();
184 let (tx, rx) = mpsc::unbounded_channel::<serde_json::Value>();
185 self.pending.insert(id.clone(), tx);
186
187 if !self.listener_alive.load(Ordering::Acquire) {
198 self.pending.remove(&id);
199 return Err(Error::Closed);
200 }
201
202 let argv = vec![
206 "--request".to_string(),
207 serde_json::to_string(&request).map_err(Error::Json)?,
208 ];
209 let envelope = Output::Command(Command {
210 r#type: CommandType::Command,
211 id: id.clone(),
212 command: argv,
217 });
218 let line = serde_json::to_string(&envelope).expect("Output serializes");
219
220 {
221 let mut stdout = self.stdout.lock().await;
222 if let Err(e) = stdout.write_all(line.as_bytes()).await {
223 self.pending.remove(&id);
224 return Err(Error::Io(e));
225 }
226 if let Err(e) = stdout.write_all(b"\n").await {
227 self.pending.remove(&id);
228 return Err(Error::Io(e));
229 }
230 if let Err(e) = stdout.flush().await {
231 self.pending.remove(&id);
232 return Err(Error::Io(e));
233 }
234 }
235
236 let pending = self.pending.clone();
237 let stream = futures::stream::unfold(
238 (rx, id, pending),
239 |(mut rx, id, pending)| async move {
240 match rx.recv().await {
241 Some(value) => {
242 let item = match serde_json::from_value::<Line<T>>(value) {
243 Ok(line) => line.into(),
244 Err(e) => Err(Error::Json(e)),
245 };
246 Some((item, (rx, id, pending)))
247 }
248 None => {
249 pending.remove(&id);
250 None
251 }
252 }
253 },
254 );
255
256 Ok(Box::pin(stream))
257 }
258
259 async fn execute_one<R, T>(
260 &self,
261 request: R,
262 agent_arguments: Option<&AgentArguments>,
263 ) -> Result<T, Error>
264 where
265 R: CommandRequest + Send + serde::Serialize,
266 T: CommandResponseTrait + serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
267 {
268 let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
269 stream.next().await.ok_or(Error::Empty)?
270 }
271}