cnf-lib 0.6.0

Distribution-agnostic 'command not found'-handler
Documentation
// SPDX-License-Identifier: GPL-3.0-or-later
// SPDX-FileCopyrightText: (C) 2023 Andreas Hartmann <hartan@7x.de>
// This file is part of cnf-lib, available at <https://gitlab.com/hartang/rust/cnf>

//! # Custom provider
//!
//! Adds integration to execute arbitrary commands as extra command providers. The main purpose of
//! this mechanism is to allow user-written scripts in any language to be used to enhance `cnf`s
//! default experience.
//!
//! Communication between `cnf` and custom providers takes place by passing messages via
//! stdin/stdout. The messages are JSON-formatted and must be terminated with a `BEL` char (ASCII
//! 7, 0x07), followed by a `\n` (newline) char (see [`MESSAGE_TERMINATOR`]). The search term is
//! passed as the first and only CLI argument to the custom provider.
//!
//! Child processes can run without restriction, there are no timeouts or similar measures. Cleanly
//! exiting from the custom provider executable can happen in one of two ways:
//!
//! - By sending a `CustomToCnf::Results` message with valid results to display
//! - By sending a `CustomToCnf::Error` message with an error description
//!
//! Any other type of exiting (regular exit, exit by signal) is recognized and will cause the
//! provider to report an appropriate error message.
//!
//!
//! ## Configuring custom providers
//!
//! Custom providers are registered through the application config file. On Linux systems, it is
//! found under `$XDG_CONFIG_DIR/cnf` (usually `~/.config/cnf`). Registering a provider looks like
//! this:
//!
//! ```yml
//! custom_providers:
//!     # A pretty name, displayed in the application
//!   - name: "cnf_fd (Bash)"
//!     # Main command to execute
//!     command: "/home/foo/Downloads/cnf_fd.sh"
//!     # Any additional arguments you may need
//!     args: []
//! ```
//!
//! You can configure an arbitrary amount of providers, just copy the snippet above and add more
//! list entries!

use crate::provider::prelude::*;
use futures::StreamExt;
use serde_derive::{Deserialize, Serialize};
use std::process::Stdio;
use tokio::{io::AsyncWriteExt, process::Command};
use tokio_util::{
    bytes::{Buf, BytesMut},
    codec::Decoder,
};

/// Error variants for custom providers.
///
/// These are returned to the application as `ProviderError::ApplicationError` variant and can be
/// accessed with `anyhow`s `downcast_*` functions.
#[derive(Debug, ThisError, Display)]
#[non_exhaustive]
pub enum Error {
    /// failed to capture stdin of spawned child process
    NoStdin,

    /// failed to capture stdout of spawned child process
    NoStdout,

    /// failed to parse command output into string
    InvalidUtf8(#[from] std::string::FromUtf8Error),

    /// failed to write message to child process
    BrokenStdin(#[from] BrokenStdinError),

    /// failed to read message from child process
    BrokenStdout(#[from] BrokenStdoutError),

    /// failed to decode message from custom provider
    Decode(#[from] DecoderError),

    /// unexpected error from custom provider
    Child(String),
}

/// Errors from decoding provider messages.
#[derive(Debug, ThisError, Display)]
pub enum DecoderError {
    /// failed to deserialize message for decoding
    Deserialize(#[from] serde_json::Error),

    /// got unexpected I/O error while reading message for decoding
    StdIo(#[from] std::io::Error),
}

impl From<Error> for ProviderError {
    fn from(value: Error) -> Self {
        Self::ApplicationError(anyhow::Error::new(value))
    }
}

/// failed to read data from stdin
#[derive(Debug, ThisError, Display)]
pub struct BrokenStdinError(#[from] std::io::Error);

/// failed to write data to stdout
#[derive(Debug, ThisError, Display)]
pub struct BrokenStdoutError(#[from] std::io::Error);

/// Terminating byte sequence for messages passed between CNF and custom plugin.
///
/// We use this two-byte sequence for the following reasons:
///
/// 1. This sequence seems to be reasonably unlikely in regular shell output
/// 2. The terminating `\n` makes sure that messages can be received even by languages which have
///    no trivial way to read raw (unbuffered) stdin
/// 3. We can distinguish between a newline as part of the payload and the message termination
///    (under the assumption that the message payload doesn't hold the exact terminating sequence).
pub const MESSAGE_TERMINATOR: [u8; 2] = [0x07, b'\n'];

/// Provider for user-provided custom solutions.
#[derive(Default, Debug, PartialEq, Clone, Deserialize, Serialize)]
pub struct Custom {
    /// A human-readable name/short identifier
    pub name: String,
    /// The main command to execute
    pub command: String,
    /// Additional arguments to provide
    pub args: Vec<String>,
}

impl fmt::Display for Custom {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "custom ({})", self.name)
    }
}

/// Messages from `cnf` to custom provider
#[derive(Debug, Serialize)]
#[serde(rename_all = "kebab-case")]
enum CnfToCustom {
    CommandResponse {
        stdout: String,
        stderr: String,
        exit_code: i32,
    },
}

/// Messages from custom provider to `cnf`
#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum CustomToCnf {
    /// A command that `cnf` should execute.
    Execute(CommandLine),

    /// List of results offered by this provider.
    Results(Vec<Candidate>),

    /// Error from the provider.
    Error(String),
}

/// Message decoder to convert raw bytes into message frames.
#[derive(Debug, Default)]
struct MessageDecoder {}

impl MessageDecoder {
    /// Create a new instance.
    fn new() -> Self {
        Self::default()
    }
}

impl Decoder for MessageDecoder {
    type Item = CustomToCnf;
    type Error = DecoderError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        let Some(frame_end_pos) = src
            .windows(MESSAGE_TERMINATOR.len())
            .position(|window| window == MESSAGE_TERMINATOR)
        else {
            return Ok(None);
        };

        let message = &src[0..frame_end_pos].to_vec();
        src.advance(frame_end_pos + 1);

        let msg =
            serde_json::from_slice::<CustomToCnf>(message).map_err(DecoderError::Deserialize)?;
        Ok(Some(msg))
    }
}

#[async_trait]
impl IsProvider for Custom {
    async fn search_internal(
        &self,
        command: &str,
        target_env: Arc<Environment>,
    ) -> ProviderResult<Vec<Candidate>> {
        let mut result: Vec<Candidate> = vec![];

        let mut child = Command::new(&self.command)
            .args(&self.args)
            .arg(command)
            .kill_on_drop(true)
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .stderr(Stdio::null())
            .spawn()
            .map_err(|e| match e.kind() {
                std::io::ErrorKind::NotFound => {
                    ProviderError::Requirements(self.command.to_string())
                }
                _ => ProviderError::ApplicationError(anyhow::Error::new(e)),
            })?;

        let mut stdin = child.stdin.take().ok_or(Error::NoStdin)?;
        let mut stdout = tokio_util::codec::FramedRead::new(
            child.stdout.take().ok_or(Error::NoStdout)?,
            MessageDecoder::new(),
        );

        while let Some(message) = stdout.next().await {
            match message.map_err(Error::Decode)? {
                CustomToCnf::Execute(commandline) => {
                    let output = target_env.output_of(commandline).await;

                    let message = match output {
                        Ok(stdout) => CnfToCustom::CommandResponse {
                            stdout,
                            stderr: "".to_string(),
                            exit_code: 0,
                        },
                        Err(ExecutionError::NonZero { output, .. }) => {
                            CnfToCustom::CommandResponse {
                                stdout: String::from_utf8(output.stdout)
                                    .map_err(Error::InvalidUtf8)?,
                                stderr: String::from_utf8(output.stderr)
                                    .map_err(Error::InvalidUtf8)?,
                                exit_code: output.status.code().unwrap_or(256),
                            }
                        }
                        _ => return output.map(|_| vec![]).map_err(ProviderError::from),
                    };
                    let mut response = serde_json::to_vec(&message)
                        .with_context(|| format!("failed to send response to provider {}", self))
                        .map_err(ProviderError::ApplicationError)?;
                    response.push(MESSAGE_TERMINATOR[0]);
                    response.push(MESSAGE_TERMINATOR[1]);
                    stdin
                        .write_all(&response)
                        .await
                        .map_err(BrokenStdinError)
                        .map_err(Error::BrokenStdin)?;
                }
                CustomToCnf::Results(results) => {
                    result = results;
                    break;
                }
                CustomToCnf::Error(error) => {
                    return Err(Error::Child(error).into());
                }
            }
        }

        stdin
            .shutdown()
            .await
            .with_context(|| format!("failed to close stdin of custom provider '{}'", self.name))?;
        if let Err(e) = child
            .wait()
            .await
            .with_context(|| format!("child process of '{}' terminated unexpectedly", self))
        {
            error!("{:#?}", e);
        };

        Ok(result)
    }
}