Skip to main content

objectiveai_sdk/cli/command/command_executor/
plugin.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use dashmap::DashMap;
6use futures::{Stream, StreamExt};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{Mutex, mpsc};
9
10use crate::cli::command::{
11    AgentArguments, CommandExecutor, CommandRequest,
12    CommandResponse as CommandResponseTrait,
13};
14use crate::cli::plugins::{Command, CommandType, Output};
15
16/// Demultiplex many in-flight `CommandRequest` calls over a plugin's
17/// stdin/stdout. Each `execute` mints a fresh id, emits a
18/// `Output::Command(Command { id, command })` line on the plugin's stdout,
19/// and returns a stream that yields whatever the overlord writes back
20/// to the plugin's stdin under the same id.
21///
22/// Only one instance per process — the constructor consumes the global
23/// `tokio::io::stdin()` / `stdout()` handles. The struct is [`Clone`]
24/// so callers that need a second handle can share without an outer
25/// `Arc`: every field is already behind `Arc`, including `counter`, so
26/// clones share the id sequence and pending map.
27#[derive(Clone)]
28pub struct PluginExecutor {
29    stdout: Arc<Mutex<tokio::io::Stdout>>,
30    counter: Arc<AtomicU64>,
31    pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
32    /// `true` while the listener task is still reading stdin. Flipped
33    /// to `false` immediately before the listener drops its pending
34    /// senders, so `execute()` can re-check after registering its own
35    /// sender and bail with `Error::Closed` instead of installing a
36    /// channel nothing will ever drain.
37    listener_alive: Arc<AtomicBool>,
38}
39
40impl Default for PluginExecutor {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl PluginExecutor {
47    /// Capture the plugin's stdin/stdout and spawn the demuxer task.
48    pub fn new() -> Self {
49        let pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>> =
50            Arc::new(DashMap::new());
51        let listener_alive = Arc::new(AtomicBool::new(true));
52        Self::spawn_listener(
53            tokio::io::stdin(),
54            pending.clone(),
55            listener_alive.clone(),
56        );
57        Self {
58            stdout: Arc::new(Mutex::new(tokio::io::stdout())),
59            counter: Arc::new(AtomicU64::new(0)),
60            pending,
61            listener_alive,
62        }
63    }
64
65    fn spawn_listener(
66        stdin: tokio::io::Stdin,
67        pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
68        listener_alive: Arc<AtomicBool>,
69    ) {
70        tokio::spawn(async move {
71            let mut lines = BufReader::new(stdin).lines();
72            while let Ok(Some(line)) = lines.next_line().await {
73                let env = match serde_json::from_str::<CommandResponse>(&line) {
74                    Ok(e) => e,
75                    Err(_) => continue,
76                };
77                match env {
78                    CommandResponse::Value { id, value } => {
79                        if let Some(sender) = pending.get(&id) {
80                            if sender.send(value).is_err() {
81                                drop(sender);
82                                pending.remove(&id);
83                            }
84                        }
85                    }
86                    CommandResponse::Done { id, .. } => {
87                        pending.remove(&id);
88                    }
89                }
90            }
91            // stdin EOF or read error. Flip the liveness flag BEFORE
92            // dropping any senders — `execute()` does an insert-then-
93            // re-check, and this ordering guarantees a concurrent
94            // `execute()` either sees the flag and removes its own
95            // entry, or completes its insert before `clear()` runs and
96            // gets drained by it.
97            listener_alive.store(false, Ordering::Release);
98            pending.clear();
99        });
100    }
101}
102
103/// One line the overlord writes to a plugin's stdin in response to a
104/// previously-emitted `Output::Command`.
105///
106/// Wire shape:
107/// - Value: `{"id":"42","value":<JSON>}`
108/// - Done:  `{"id":"42","done":true}`
109///
110/// `Done` signals end-of-stream for that id from the receiver's
111/// perspective — the request's stream ends right after.
112#[derive(serde::Deserialize, Debug, Clone)]
113#[serde(untagged)]
114enum CommandResponse {
115    /// Listed first so the untagged decoder tries it before `Value` —
116    /// the `done` discriminator field is what tells the variants apart.
117    Done {
118        id: String,
119        #[allow(dead_code)]
120        done: bool,
121    },
122    Value {
123        id: String,
124        value: serde_json::Value,
125    },
126}
127
128#[derive(Debug, thiserror::Error)]
129pub enum Error {
130    /// Stdin closed (clean EOF or read error). The listener task has
131    /// exited and no new requests can be served.
132    #[error("plugin executor stdin closed")]
133    Closed,
134    #[error("plugin executor io: {0}")]
135    Io(std::io::Error),
136    #[error("plugin executor decode line: {0}")]
137    Json(serde_json::Error),
138    #[error("{0}")]
139    Cli(crate::cli::Error),
140    /// `execute_one` was called but the stream produced no items.
141    #[error("plugin executor stream produced no items")]
142    Empty,
143}
144
145/// Per-value untagged decode. `Err` first so `cli::Error`'s `type:"error"`
146/// constant short-circuits non-error wire shapes; `Ok(T)` is the
147/// fallthrough. Mirrors the helper in `binary.rs`.
148#[derive(serde::Deserialize)]
149#[serde(untagged)]
150enum Line<T> {
151    Err(crate::cli::Error),
152    Ok(T),
153}
154
155impl<T> From<Line<T>> for Result<T, Error> {
156    fn from(line: Line<T>) -> Self {
157        match line {
158            Line::Err(e) => Err(Error::Cli(e)),
159            Line::Ok(t) => Ok(t),
160        }
161    }
162}
163
164impl CommandExecutor for PluginExecutor {
165    type Error = Error;
166    type Stream<T>
167        = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
168    where
169        T: Send + 'static;
170
171    async fn execute<R, T>(
172        &self,
173        request: R,
174        // Plugin runs in-process — no subprocess to stamp env on. The
175        // bag is accepted for trait-signature symmetry with the
176        // binary executor and intentionally ignored.
177        _agent_arguments: Option<&AgentArguments>,
178    ) -> Result<Self::Stream<T>, Error>
179    where
180        R: CommandRequest + Send,
181        T: CommandResponseTrait + serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
182    {
183        let id = self.counter.fetch_add(1, Ordering::Relaxed).to_string();
184        let (tx, rx) = mpsc::unbounded_channel::<serde_json::Value>();
185        self.pending.insert(id.clone(), tx);
186
187        // Re-check liveness AFTER insert. The listener stores `false`
188        // before it calls `pending.clear()`, so any of these happens:
189        //   - We see `true`: listener is still running; if it dies
190        //     later, its `clear()` will drop our sender and the stream
191        //     will end naturally.
192        //   - We see `false` and our entry got cleared: remove() is a
193        //     no-op, sender is already dropped.
194        //   - We see `false` and our entry survived (we inserted after
195        //     `clear()` ran): remove() drops the sender ourselves.
196        // In every `false` path we bail with `Closed`.
197        if !self.listener_alive.load(Ordering::Acquire) {
198            self.pending.remove(&id);
199            return Err(Error::Closed);
200        }
201
202        let argv = request.into_command();
203        let envelope = Output::Command(Command {
204            r#type: CommandType::Command,
205            id: id.clone(),
206            // Carry argv structured — joining into a single string
207            // would lose argument boundaries for any value containing
208            // whitespace (e.g. `--simple "a b c"`), which the host
209            // could not recover.
210            command: argv,
211        });
212        let line = serde_json::to_string(&envelope).expect("Output serializes");
213
214        {
215            let mut stdout = self.stdout.lock().await;
216            if let Err(e) = stdout.write_all(line.as_bytes()).await {
217                self.pending.remove(&id);
218                return Err(Error::Io(e));
219            }
220            if let Err(e) = stdout.write_all(b"\n").await {
221                self.pending.remove(&id);
222                return Err(Error::Io(e));
223            }
224            if let Err(e) = stdout.flush().await {
225                self.pending.remove(&id);
226                return Err(Error::Io(e));
227            }
228        }
229
230        let pending = self.pending.clone();
231        let stream = futures::stream::unfold(
232            (rx, id, pending),
233            |(mut rx, id, pending)| async move {
234                match rx.recv().await {
235                    Some(value) => {
236                        let item = match serde_json::from_value::<Line<T>>(value) {
237                            Ok(line) => line.into(),
238                            Err(e) => Err(Error::Json(e)),
239                        };
240                        Some((item, (rx, id, pending)))
241                    }
242                    None => {
243                        pending.remove(&id);
244                        None
245                    }
246                }
247            },
248        );
249
250        Ok(Box::pin(stream))
251    }
252
253    async fn execute_one<R, T>(
254        &self,
255        request: R,
256        agent_arguments: Option<&AgentArguments>,
257    ) -> Result<T, Error>
258    where
259        R: CommandRequest + Send,
260        T: CommandResponseTrait + serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
261    {
262        let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
263        stream.next().await.ok_or(Error::Empty)?
264    }
265}