Skip to main content

cnf_lib/provider/
custom.rs

1// SPDX-License-Identifier: GPL-3.0-or-later
2// SPDX-FileCopyrightText: (C) 2023 Andreas Hartmann <hartan@7x.de>
3// This file is part of cnf-lib, available at <https://gitlab.com/hartang/rust/cnf>
4
5//! # Custom provider
6//!
7//! Adds integration to execute arbitrary commands as extra command providers. The main purpose of
8//! this mechanism is to allow user-written scripts in any language to be used to enhance `cnf`s
9//! default experience.
10//!
11//! Communication between `cnf` and custom providers takes place by passing messages via
12//! stdin/stdout. The messages are JSON-formatted and must be terminated with a `BEL` char (ASCII
13//! 7, 0x07), followed by a `\n` (newline) char (see [`MESSAGE_TERMINATOR`]). The search term is
14//! passed as the first and only CLI argument to the custom provider.
15//!
16//! Child processes can run without restriction, there are no timeouts or similar measures. Cleanly
17//! exiting from the custom provider executable can happen in one of two ways:
18//!
19//! - By sending a `CustomToCnf::Results` message with valid results to display
20//! - By sending a `CustomToCnf::Error` message with an error description
21//!
22//! Any other type of exiting (regular exit, exit by signal) is recognized and will cause the
23//! provider to report an appropriate error message.
24//!
25//!
26//! ## Configuring custom providers
27//!
28//! Custom providers are registered through the application config file. On Linux systems, it is
29//! found under `$XDG_CONFIG_DIR/cnf` (usually `~/.config/cnf`). Registering a provider looks like
30//! this:
31//!
32//! ```yml
33//! custom_providers:
34//!     # A pretty name, displayed in the application
35//!   - name: "cnf_fd (Bash)"
36//!     # Main command to execute
37//!     command: "/home/foo/Downloads/cnf_fd.sh"
38//!     # Any additional arguments you may need
39//!     args: []
40//! ```
41//!
42//! You can configure an arbitrary amount of providers, just copy the snippet above and add more
43//! list entries!
44
45use crate::provider::prelude::*;
46use futures::StreamExt;
47use serde_derive::{Deserialize, Serialize};
48use std::process::Stdio;
49use tokio::{io::AsyncWriteExt, process::Command};
50use tokio_util::{
51    bytes::{Buf, BytesMut},
52    codec::Decoder,
53};
54
55/// Error variants for custom providers.
56///
57/// These are returned to the application as `ProviderError::ApplicationError` variant and can be
58/// accessed with `anyhow`s `downcast_*` functions.
59#[derive(Debug, ThisError, Display)]
60#[non_exhaustive]
61pub enum Error {
62    /// failed to capture stdin of spawned child process
63    NoStdin,
64
65    /// failed to capture stdout of spawned child process
66    NoStdout,
67
68    /// failed to parse command output into string
69    InvalidUtf8(#[from] std::string::FromUtf8Error),
70
71    /// failed to write message to child process
72    BrokenStdin(#[from] BrokenStdinError),
73
74    /// failed to read message from child process
75    BrokenStdout(#[from] BrokenStdoutError),
76
77    /// failed to decode message from custom provider
78    Decode(#[from] DecoderError),
79
80    /// unexpected error from custom provider
81    Child(String),
82}
83
84/// Errors from decoding provider messages.
85#[derive(Debug, ThisError, Display)]
86pub enum DecoderError {
87    /// failed to deserialize message for decoding
88    Deserialize(#[from] serde_json::Error),
89
90    /// got unexpected I/O error while reading message for decoding
91    StdIo(#[from] std::io::Error),
92}
93
94impl From<Error> for ProviderError {
95    fn from(value: Error) -> Self {
96        Self::ApplicationError(anyhow::Error::new(value))
97    }
98}
99
100/// failed to read data from stdin
101#[derive(Debug, ThisError, Display)]
102pub struct BrokenStdinError(#[from] std::io::Error);
103
104/// failed to write data to stdout
105#[derive(Debug, ThisError, Display)]
106pub struct BrokenStdoutError(#[from] std::io::Error);
107
108/// Terminating byte sequence for messages passed between CNF and custom plugin.
109///
110/// We use this two-byte sequence for the following reasons:
111///
112/// 1. This sequence seems to be reasonably unlikely in regular shell output
113/// 2. The terminating `\n` makes sure that messages can be received even by languages which have
114///    no trivial way to read raw (unbuffered) stdin
115/// 3. We can distinguish between a newline as part of the payload and the message termination
116///    (under the assumption that the message payload doesn't hold the exact terminating sequence).
117pub const MESSAGE_TERMINATOR: [u8; 2] = [0x07, b'\n'];
118
119/// Provider for user-provided custom solutions.
120#[derive(Default, Debug, PartialEq, Clone, Deserialize, Serialize)]
121pub struct Custom {
122    /// A human-readable name/short identifier
123    pub name: String,
124    /// The main command to execute
125    pub command: String,
126    /// Additional arguments to provide
127    pub args: Vec<String>,
128}
129
130impl fmt::Display for Custom {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(f, "custom ({})", self.name)
133    }
134}
135
136/// Messages from `cnf` to custom provider
137#[derive(Debug, Serialize)]
138#[serde(rename_all = "kebab-case")]
139enum CnfToCustom {
140    CommandResponse {
141        stdout: String,
142        stderr: String,
143        exit_code: i32,
144    },
145}
146
147/// Messages from custom provider to `cnf`
148#[derive(Debug, Deserialize)]
149#[serde(rename_all = "kebab-case")]
150pub enum CustomToCnf {
151    /// A command that `cnf` should execute.
152    Execute(CommandLine),
153
154    /// List of results offered by this provider.
155    Results(Vec<Candidate>),
156
157    /// Error from the provider.
158    Error(String),
159}
160
161/// Message decoder to convert raw bytes into message frames.
162#[derive(Debug, Default)]
163struct MessageDecoder {}
164
165impl MessageDecoder {
166    /// Create a new instance.
167    fn new() -> Self {
168        Self::default()
169    }
170}
171
172impl Decoder for MessageDecoder {
173    type Item = CustomToCnf;
174    type Error = DecoderError;
175
176    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
177        let Some(frame_end_pos) = src
178            .windows(MESSAGE_TERMINATOR.len())
179            .position(|window| window == MESSAGE_TERMINATOR)
180        else {
181            return Ok(None);
182        };
183
184        let message = &src[0..frame_end_pos].to_vec();
185        src.advance(frame_end_pos + 1);
186
187        let msg =
188            serde_json::from_slice::<CustomToCnf>(message).map_err(DecoderError::Deserialize)?;
189        Ok(Some(msg))
190    }
191}
192
193#[async_trait]
194impl IsProvider for Custom {
195    async fn search_internal(
196        &self,
197        command: &str,
198        target_env: Arc<Environment>,
199    ) -> ProviderResult<Vec<Candidate>> {
200        let mut result: Vec<Candidate> = vec![];
201
202        let mut child = Command::new(&self.command)
203            .args(&self.args)
204            .arg(command)
205            .kill_on_drop(true)
206            .stdin(Stdio::piped())
207            .stdout(Stdio::piped())
208            .stderr(Stdio::null())
209            .spawn()
210            .map_err(|e| match e.kind() {
211                std::io::ErrorKind::NotFound => {
212                    ProviderError::Requirements(self.command.to_string())
213                }
214                _ => ProviderError::ApplicationError(anyhow::Error::new(e)),
215            })?;
216
217        let mut stdin = child.stdin.take().ok_or(Error::NoStdin)?;
218        let mut stdout = tokio_util::codec::FramedRead::new(
219            child.stdout.take().ok_or(Error::NoStdout)?,
220            MessageDecoder::new(),
221        );
222
223        while let Some(message) = stdout.next().await {
224            match message.map_err(Error::Decode)? {
225                CustomToCnf::Execute(commandline) => {
226                    let output = target_env.output_of(commandline).await;
227
228                    let message = match output {
229                        Ok(stdout) => CnfToCustom::CommandResponse {
230                            stdout,
231                            stderr: "".to_string(),
232                            exit_code: 0,
233                        },
234                        Err(ExecutionError::NonZero { output, .. }) => {
235                            CnfToCustom::CommandResponse {
236                                stdout: String::from_utf8(output.stdout)
237                                    .map_err(Error::InvalidUtf8)?,
238                                stderr: String::from_utf8(output.stderr)
239                                    .map_err(Error::InvalidUtf8)?,
240                                exit_code: output.status.code().unwrap_or(256),
241                            }
242                        }
243                        _ => return output.map(|_| vec![]).map_err(ProviderError::from),
244                    };
245                    let mut response = serde_json::to_vec(&message)
246                        .with_context(|| format!("failed to send response to provider {}", self))
247                        .map_err(ProviderError::ApplicationError)?;
248                    response.push(MESSAGE_TERMINATOR[0]);
249                    response.push(MESSAGE_TERMINATOR[1]);
250                    stdin
251                        .write_all(&response)
252                        .await
253                        .map_err(BrokenStdinError)
254                        .map_err(Error::BrokenStdin)?;
255                }
256                CustomToCnf::Results(results) => {
257                    result = results;
258                    break;
259                }
260                CustomToCnf::Error(error) => {
261                    return Err(Error::Child(error).into());
262                }
263            }
264        }
265
266        stdin
267            .shutdown()
268            .await
269            .with_context(|| format!("failed to close stdin of custom provider '{}'", self.name))?;
270        if let Err(e) = child
271            .wait()
272            .await
273            .with_context(|| format!("child process of '{}' terminated unexpectedly", self))
274        {
275            error!("{:#?}", e);
276        };
277
278        Ok(result)
279    }
280}