use std::path::PathBuf;
use std::pin::Pin;
use std::time::Duration;
use futures::{Stream, StreamExt};
use notify::{RecursiveMode, Watcher};
use objectiveai_sdk::cli::command::binary::BinaryExecutor;
use objectiveai_sdk::cli::command::{
AgentArguments, CommandExecutor, CommandRequest, CommandResponse,
};
use tokio::sync::mpsc;
use tokio::time::Instant;
pub const HANG_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("{0}")]
Inner(#[from] objectiveai_sdk::cli::command::binary::Error),
#[error(
"cli child went silent on CONFIG_BASE_DIR ({config_base_dir}) for {elapsed:?} — \
hang-preventing watchdog killed the child"
)]
HangTimeout {
elapsed: Duration,
config_base_dir: PathBuf,
},
#[error("hang-prevention fs watcher setup failed for {config_base_dir}: {source}")]
WatcherSetup {
config_base_dir: PathBuf,
#[source]
source: notify::Error,
},
}
pub struct HangPreventingBinaryCommandExecutor {
inner: BinaryExecutor,
config_base_dir: PathBuf,
}
impl HangPreventingBinaryCommandExecutor {
pub fn new(inner: BinaryExecutor, config_base_dir: PathBuf) -> Self {
Self {
inner: inner.kill_on_drop(true),
config_base_dir,
}
}
pub fn env(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.inner = self.inner.env(key, value);
self
}
}
impl CommandExecutor for HangPreventingBinaryCommandExecutor {
type Error = Error;
type Stream<T>
= Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
where
T: Send + 'static;
async fn execute<R, T>(
&self,
request: R,
agent_arguments: Option<&AgentArguments>,
) -> Result<Self::Stream<T>, Self::Error>
where
R: CommandRequest + Send,
T: CommandResponse + serde::de::DeserializeOwned + Send + 'static,
{
let inner_stream: Pin<
Box<
dyn Stream<
Item = Result<
T,
objectiveai_sdk::cli::command::binary::Error,
>,
> + Send,
>,
> = self.inner.execute::<R, T>(request, agent_arguments).await?;
let config_base_dir = self.config_base_dir.clone();
let (out_tx, out_rx) = mpsc::channel::<Result<T, Error>>(16);
let (notify_tx, notify_rx) =
mpsc::unbounded_channel::<notify::Result<notify::Event>>();
let mut watcher = notify::recommended_watcher(
move |res: notify::Result<notify::Event>| {
let _ = notify_tx.send(res);
},
)
.map_err(|e| Error::WatcherSetup {
config_base_dir: config_base_dir.clone(),
source: e,
})?;
watcher
.watch(&config_base_dir, RecursiveMode::Recursive)
.map_err(|e| Error::WatcherSetup {
config_base_dir: config_base_dir.clone(),
source: e,
})?;
tokio::spawn(watchdog_task(
inner_stream,
out_tx,
notify_rx,
config_base_dir,
watcher,
));
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(out_rx)))
}
async fn execute_one<R, T>(
&self,
request: R,
agent_arguments: Option<&AgentArguments>,
) -> Result<T, Self::Error>
where
R: CommandRequest + Send,
T: CommandResponse + serde::de::DeserializeOwned + Send + 'static,
{
let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
match stream.next().await {
Some(item) => item,
None => Err(Error::Inner(
objectiveai_sdk::cli::command::binary::Error::Empty,
)),
}
}
}
async fn watchdog_task<T>(
mut inner_stream: Pin<
Box<
dyn Stream<
Item = Result<T, objectiveai_sdk::cli::command::binary::Error>,
> + Send,
>,
>,
out_tx: mpsc::Sender<Result<T, Error>>,
mut notify_rx: mpsc::UnboundedReceiver<notify::Result<notify::Event>>,
config_base_dir: PathBuf,
_watcher: notify::RecommendedWatcher,
) where
T: Send + 'static,
{
let started = Instant::now();
let mut sleeper = Box::pin(tokio::time::sleep(HANG_TIMEOUT));
loop {
tokio::select! {
_ = &mut sleeper => {
let _ = out_tx
.send(Err(Error::HangTimeout {
elapsed: started.elapsed(),
config_base_dir,
}))
.await;
return;
}
next = inner_stream.next() => {
match next {
Some(item) => {
let mapped: Result<T, Error> = item.map_err(Error::Inner);
if out_tx.send(mapped).await.is_err() {
return;
}
sleeper
.as_mut()
.reset(Instant::now() + HANG_TIMEOUT);
}
None => return,
}
}
Some(_event) = notify_rx.recv() => {
sleeper
.as_mut()
.reset(Instant::now() + HANG_TIMEOUT);
}
}
}
}