Skip to main content

objectiveai_sdk/cli/command/command_executor/
plugin.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use dashmap::DashMap;
6use futures::{Stream, StreamExt};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{Mutex, mpsc};
9
10use crate::cli::command::{
11    AgentArguments, CommandExecutor, CommandRequest,
12    CommandResponse as CommandResponseTrait,
13};
14use crate::cli::plugins::{Command, CommandType, Output};
15
16/// Demultiplex many in-flight `CommandRequest` calls over a plugin's
17/// stdin/stdout. Each `execute` mints a fresh id, emits a
18/// `Output::Command(Command { id, command })` line on the plugin's stdout,
19/// and returns a stream that yields whatever the overlord writes back
20/// to the plugin's stdin under the same id.
21///
22/// Only one instance per process — the constructor consumes the global
23/// `tokio::io::stdin()` / `stdout()` handles.
24pub struct PluginExecutor {
25    stdout: Arc<Mutex<tokio::io::Stdout>>,
26    counter: AtomicU64,
27    pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
28    /// `true` while the listener task is still reading stdin. Flipped
29    /// to `false` immediately before the listener drops its pending
30    /// senders, so `execute()` can re-check after registering its own
31    /// sender and bail with `Error::Closed` instead of installing a
32    /// channel nothing will ever drain.
33    listener_alive: Arc<AtomicBool>,
34}
35
36impl Default for PluginExecutor {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl PluginExecutor {
43    /// Capture the plugin's stdin/stdout and spawn the demuxer task.
44    pub fn new() -> Self {
45        let pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>> =
46            Arc::new(DashMap::new());
47        let listener_alive = Arc::new(AtomicBool::new(true));
48        Self::spawn_listener(
49            tokio::io::stdin(),
50            pending.clone(),
51            listener_alive.clone(),
52        );
53        Self {
54            stdout: Arc::new(Mutex::new(tokio::io::stdout())),
55            counter: AtomicU64::new(0),
56            pending,
57            listener_alive,
58        }
59    }
60
61    fn spawn_listener(
62        stdin: tokio::io::Stdin,
63        pending: Arc<DashMap<String, mpsc::UnboundedSender<serde_json::Value>>>,
64        listener_alive: Arc<AtomicBool>,
65    ) {
66        tokio::spawn(async move {
67            let mut lines = BufReader::new(stdin).lines();
68            while let Ok(Some(line)) = lines.next_line().await {
69                let env = match serde_json::from_str::<CommandResponse>(&line) {
70                    Ok(e) => e,
71                    Err(_) => continue,
72                };
73                match env {
74                    CommandResponse::Value { id, value } => {
75                        if let Some(sender) = pending.get(&id) {
76                            if sender.send(value).is_err() {
77                                drop(sender);
78                                pending.remove(&id);
79                            }
80                        }
81                    }
82                    CommandResponse::Done { id, .. } => {
83                        pending.remove(&id);
84                    }
85                }
86            }
87            // stdin EOF or read error. Flip the liveness flag BEFORE
88            // dropping any senders — `execute()` does an insert-then-
89            // re-check, and this ordering guarantees a concurrent
90            // `execute()` either sees the flag and removes its own
91            // entry, or completes its insert before `clear()` runs and
92            // gets drained by it.
93            listener_alive.store(false, Ordering::Release);
94            pending.clear();
95        });
96    }
97}
98
99/// One line the overlord writes to a plugin's stdin in response to a
100/// previously-emitted `Output::Command`.
101///
102/// Wire shape:
103/// - Value: `{"id":"42","value":<JSON>}`
104/// - Done:  `{"id":"42","done":true}`
105///
106/// `Done` signals end-of-stream for that id from the receiver's
107/// perspective — the request's stream ends right after.
108#[derive(serde::Deserialize, Debug, Clone)]
109#[serde(untagged)]
110enum CommandResponse {
111    /// Listed first so the untagged decoder tries it before `Value` —
112    /// the `done` discriminator field is what tells the variants apart.
113    Done {
114        id: String,
115        #[allow(dead_code)]
116        done: bool,
117    },
118    Value {
119        id: String,
120        value: serde_json::Value,
121    },
122}
123
124#[derive(Debug, thiserror::Error)]
125pub enum Error {
126    /// Stdin closed (clean EOF or read error). The listener task has
127    /// exited and no new requests can be served.
128    #[error("plugin executor stdin closed")]
129    Closed,
130    #[error("plugin executor io: {0}")]
131    Io(std::io::Error),
132    #[error("plugin executor decode line: {0}")]
133    Json(serde_json::Error),
134    #[error("{0}")]
135    Cli(crate::cli::Error),
136    /// `execute_one` was called but the stream produced no items.
137    #[error("plugin executor stream produced no items")]
138    Empty,
139}
140
141/// Per-value untagged decode. `Err` first so `cli::Error`'s `type:"error"`
142/// constant short-circuits non-error wire shapes; `Ok(T)` is the
143/// fallthrough. Mirrors the helper in `binary.rs`.
144#[derive(serde::Deserialize)]
145#[serde(untagged)]
146enum Line<T> {
147    Err(crate::cli::Error),
148    Ok(T),
149}
150
151impl<T> From<Line<T>> for Result<T, Error> {
152    fn from(line: Line<T>) -> Self {
153        match line {
154            Line::Err(e) => Err(Error::Cli(e)),
155            Line::Ok(t) => Ok(t),
156        }
157    }
158}
159
160impl CommandExecutor for PluginExecutor {
161    type Error = Error;
162    type Stream<T>
163        = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
164    where
165        T: Send + 'static;
166
167    async fn execute<R, T>(
168        &self,
169        request: R,
170        // Plugin runs in-process — no subprocess to stamp env on. The
171        // bag is accepted for trait-signature symmetry with the
172        // binary executor and intentionally ignored.
173        _agent_arguments: Option<&AgentArguments>,
174    ) -> Result<Self::Stream<T>, Error>
175    where
176        R: CommandRequest + Send,
177        T: CommandResponseTrait + serde::de::DeserializeOwned + Send + 'static,
178    {
179        let id = self.counter.fetch_add(1, Ordering::Relaxed).to_string();
180        let (tx, rx) = mpsc::unbounded_channel::<serde_json::Value>();
181        self.pending.insert(id.clone(), tx);
182
183        // Re-check liveness AFTER insert. The listener stores `false`
184        // before it calls `pending.clear()`, so any of these happens:
185        //   - We see `true`: listener is still running; if it dies
186        //     later, its `clear()` will drop our sender and the stream
187        //     will end naturally.
188        //   - We see `false` and our entry got cleared: remove() is a
189        //     no-op, sender is already dropped.
190        //   - We see `false` and our entry survived (we inserted after
191        //     `clear()` ran): remove() drops the sender ourselves.
192        // In every `false` path we bail with `Closed`.
193        if !self.listener_alive.load(Ordering::Acquire) {
194            self.pending.remove(&id);
195            return Err(Error::Closed);
196        }
197
198        let argv = request.into_command();
199        let envelope = Output::Command(Command {
200            r#type: CommandType::Command,
201            id: id.clone(),
202            command: argv.join(" "),
203        });
204        let line = serde_json::to_string(&envelope).expect("Output serializes");
205
206        {
207            let mut stdout = self.stdout.lock().await;
208            if let Err(e) = stdout.write_all(line.as_bytes()).await {
209                self.pending.remove(&id);
210                return Err(Error::Io(e));
211            }
212            if let Err(e) = stdout.write_all(b"\n").await {
213                self.pending.remove(&id);
214                return Err(Error::Io(e));
215            }
216            if let Err(e) = stdout.flush().await {
217                self.pending.remove(&id);
218                return Err(Error::Io(e));
219            }
220        }
221
222        let pending = self.pending.clone();
223        let stream = futures::stream::unfold(
224            (rx, id, pending),
225            |(mut rx, id, pending)| async move {
226                match rx.recv().await {
227                    Some(value) => {
228                        let item = match serde_json::from_value::<Line<T>>(value) {
229                            Ok(line) => line.into(),
230                            Err(e) => Err(Error::Json(e)),
231                        };
232                        Some((item, (rx, id, pending)))
233                    }
234                    None => {
235                        pending.remove(&id);
236                        None
237                    }
238                }
239            },
240        );
241
242        Ok(Box::pin(stream))
243    }
244
245    async fn execute_one<R, T>(
246        &self,
247        request: R,
248        agent_arguments: Option<&AgentArguments>,
249    ) -> Result<T, Error>
250    where
251        R: CommandRequest + Send,
252        T: CommandResponseTrait + serde::de::DeserializeOwned + Send + 'static,
253    {
254        let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
255        stream.next().await.ok_or(Error::Empty)?
256    }
257}