use std::path::PathBuf;
use std::pin::Pin;
use futures::{Stream, StreamExt};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use crate::cli::command::{AgentArguments, CommandExecutor, CommandRequest, CommandResponse};
pub struct BinaryExecutor {
config_base_dir: Option<PathBuf>,
explicit_path: Option<PathBuf>,
extra_env: Vec<(String, String)>,
}
impl BinaryExecutor {
pub fn new(config_base_dir: Option<impl Into<PathBuf>>) -> Self {
Self {
config_base_dir: config_base_dir.map(Into::into),
explicit_path: None,
extra_env: Vec::new(),
}
}
pub fn from_path(binary: impl Into<PathBuf>) -> Self {
Self {
config_base_dir: None,
explicit_path: Some(binary.into()),
extra_env: Vec::new(),
}
}
pub fn env(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.extra_env.push((key.into(), value.into()));
self
}
fn binary_path(&self) -> Result<PathBuf, Error> {
if let Some(p) = &self.explicit_path {
return Ok(p.clone());
}
let base = match &self.config_base_dir {
Some(d) => d.clone(),
None => dirs::home_dir()
.ok_or(Error::NoHomeDir)?
.join(".objectiveai"),
};
let name = if cfg!(windows) { "objectiveai.exe" } else { "objectiveai" };
Ok(base.join(name))
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("no home directory and no config_base_dir set")]
NoHomeDir,
#[error("failed to spawn cli binary: {0}")]
Spawn(std::io::Error),
#[error("cli binary child has no stdout handle")]
NoStdout,
#[error("read cli binary stdout: {0}")]
Io(std::io::Error),
#[error("decode cli binary stdout line: {0}")]
Json(serde_json::Error),
#[error("{0}")]
Cli(crate::cli::Error),
#[error("cli binary stream produced no items")]
Empty,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Line<T> {
Err(crate::cli::Error),
Ok(T),
}
impl<T> From<Line<T>> for Result<T, Error> {
fn from(line: Line<T>) -> Self {
match line {
Line::Err(e) => Err(Error::Cli(e)),
Line::Ok(t) => Ok(t),
}
}
}
impl CommandExecutor for BinaryExecutor {
type Error = Error;
type Stream<T>
= Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>
where
T: Send + 'static;
async fn execute<R, T>(
&self,
request: R,
agent_arguments: Option<&AgentArguments>,
) -> Result<Self::Stream<T>, Error>
where
R: CommandRequest + Send,
T: CommandResponse + serde::de::DeserializeOwned + Send + 'static,
{
let argv = request.into_command();
let binary = self.binary_path()?;
let mut command = Command::new(&binary);
command
.args(&argv)
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit());
for (k, v) in &self.extra_env {
command.env(k, v);
}
if let Some(args) = agent_arguments {
args.apply_to_command(&mut command);
}
let mut child = command.spawn().map_err(Error::Spawn)?;
let stdout = child.stdout.take().ok_or(Error::NoStdout)?;
let lines = BufReader::new(stdout).lines();
let stream = futures::stream::unfold(
(child, lines),
|(child, mut lines)| async move {
match lines.next_line().await {
Ok(Some(line)) => {
let item = match serde_json::from_str::<Line<T>>(&line) {
Ok(line) => line.into(),
Err(e) => Err(Error::Json(e)),
};
Some((item, (child, lines)))
}
Ok(None) => None,
Err(e) => Some((Err(Error::Io(e)), (child, lines))),
}
},
);
Ok(Box::pin(stream))
}
async fn execute_one<R, T>(
&self,
request: R,
agent_arguments: Option<&AgentArguments>,
) -> Result<T, Error>
where
R: CommandRequest + Send,
T: CommandResponse + serde::de::DeserializeOwned + Send + 'static,
{
let mut stream = self.execute::<R, T>(request, agent_arguments).await?;
stream.next().await.ok_or(Error::Empty)?
}
}