use crate::provider::prelude::*;
use futures::StreamExt;
use serde_derive::{Deserialize, Serialize};
use std::process::Stdio;
use tokio::{io::AsyncWriteExt, process::Command};
use tokio_util::{
bytes::{Buf, BytesMut},
codec::Decoder,
};
#[derive(Debug, ThisError, Display)]
#[non_exhaustive]
pub enum Error {
NoStdin,
NoStdout,
InvalidUtf8(#[from] std::string::FromUtf8Error),
BrokenStdin(#[from] BrokenStdinError),
BrokenStdout(#[from] BrokenStdoutError),
Decode(#[from] DecoderError),
Child(String),
}
#[derive(Debug, ThisError, Display)]
pub enum DecoderError {
Deserialize(#[from] serde_json::Error),
StdIo(#[from] std::io::Error),
}
impl From<Error> for ProviderError {
fn from(value: Error) -> Self {
Self::ApplicationError(anyhow::Error::new(value))
}
}
#[derive(Debug, ThisError, Display)]
pub struct BrokenStdinError(#[from] std::io::Error);
#[derive(Debug, ThisError, Display)]
pub struct BrokenStdoutError(#[from] std::io::Error);
pub const MESSAGE_TERMINATOR: [u8; 2] = [0x07, b'\n'];
#[derive(Default, Debug, PartialEq, Clone, Deserialize, Serialize)]
pub struct Custom {
pub name: String,
pub command: String,
pub args: Vec<String>,
}
impl fmt::Display for Custom {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "custom ({})", self.name)
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "kebab-case")]
enum CnfToCustom {
CommandResponse {
stdout: String,
stderr: String,
exit_code: i32,
},
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum CustomToCnf {
Execute(CommandLine),
Results(Vec<Candidate>),
Error(String),
}
#[derive(Debug, Default)]
struct MessageDecoder {}
impl MessageDecoder {
fn new() -> Self {
Self::default()
}
}
impl Decoder for MessageDecoder {
type Item = CustomToCnf;
type Error = DecoderError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let Some(frame_end_pos) = src
.windows(MESSAGE_TERMINATOR.len())
.position(|window| window == MESSAGE_TERMINATOR)
else {
return Ok(None);
};
let message = &src[0..frame_end_pos].to_vec();
src.advance(frame_end_pos + 1);
let msg =
serde_json::from_slice::<CustomToCnf>(message).map_err(DecoderError::Deserialize)?;
Ok(Some(msg))
}
}
#[async_trait]
impl IsProvider for Custom {
async fn search_internal(
&self,
command: &str,
target_env: Arc<Environment>,
) -> ProviderResult<Vec<Candidate>> {
let mut result: Vec<Candidate> = vec![];
let mut child = Command::new(&self.command)
.args(&self.args)
.arg(command)
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
ProviderError::Requirements(self.command.to_string())
}
_ => ProviderError::ApplicationError(anyhow::Error::new(e)),
})?;
let mut stdin = child.stdin.take().ok_or(Error::NoStdin)?;
let mut stdout = tokio_util::codec::FramedRead::new(
child.stdout.take().ok_or(Error::NoStdout)?,
MessageDecoder::new(),
);
while let Some(message) = stdout.next().await {
match message.map_err(Error::Decode)? {
CustomToCnf::Execute(commandline) => {
let output = target_env.output_of(commandline).await;
let message = match output {
Ok(stdout) => CnfToCustom::CommandResponse {
stdout,
stderr: "".to_string(),
exit_code: 0,
},
Err(ExecutionError::NonZero { output, .. }) => {
CnfToCustom::CommandResponse {
stdout: String::from_utf8(output.stdout)
.map_err(Error::InvalidUtf8)?,
stderr: String::from_utf8(output.stderr)
.map_err(Error::InvalidUtf8)?,
exit_code: output.status.code().unwrap_or(256),
}
}
_ => return output.map(|_| vec![]).map_err(ProviderError::from),
};
let mut response = serde_json::to_vec(&message)
.with_context(|| format!("failed to send response to provider {}", self))
.map_err(ProviderError::ApplicationError)?;
response.push(MESSAGE_TERMINATOR[0]);
response.push(MESSAGE_TERMINATOR[1]);
stdin
.write_all(&response)
.await
.map_err(BrokenStdinError)
.map_err(Error::BrokenStdin)?;
}
CustomToCnf::Results(results) => {
result = results;
break;
}
CustomToCnf::Error(error) => {
return Err(Error::Child(error).into());
}
}
}
stdin
.shutdown()
.await
.with_context(|| format!("failed to close stdin of custom provider '{}'", self.name))?;
if let Err(e) = child
.wait()
.await
.with_context(|| format!("child process of '{}' terminated unexpectedly", self))
{
error!("{:#?}", e);
};
Ok(result)
}
}