cnf_lib/provider/
custom.rs1use crate::provider::prelude::*;
46use futures::StreamExt;
47use logerr::LoggableError;
48use serde_derive::{Deserialize, Serialize};
49use std::process::Stdio;
50use tokio::{io::AsyncWriteExt, process::Command};
51use tokio_util::{
52 bytes::{Buf, BytesMut},
53 codec::Decoder,
54};
55
56#[derive(Debug, ThisError)]
61#[non_exhaustive]
62pub enum Error {
63 #[error("failed to capture stdin of spawned child process")]
64 NoStdin,
65
66 #[error("failed to capture stdout of spawned child process")]
67 NoStdout,
68
69 #[error("failed to parse command output into string")]
70 InvalidUtf8(#[from] std::string::FromUtf8Error),
71
72 #[error("failed to write message to child process")]
73 BrokenStdin(#[from] BrokenStdinError),
74
75 #[error("failed to read message from child process")]
76 BrokenStdout(#[from] BrokenStdoutError),
77
78 #[error("failed to decode message from custom provider")]
79 Decode(#[from] DecoderError),
80
81 #[error("{0}")]
83 Child(String),
84}
85
86#[derive(Debug, ThisError)]
87pub enum DecoderError {
88 #[error(transparent)]
89 Deserialize(#[from] serde_json::Error),
90
91 #[error(transparent)]
92 StdIo(#[from] std::io::Error),
93}
94
95impl From<Error> for ProviderError {
96 fn from(value: Error) -> Self {
97 Self::ApplicationError(anyhow::Error::new(value))
98 }
99}
100
101#[derive(Debug, ThisError)]
102#[error(transparent)]
103pub struct BrokenStdinError(#[from] std::io::Error);
104
105#[derive(Debug, ThisError)]
106#[error(transparent)]
107pub struct BrokenStdoutError(#[from] std::io::Error);
108
109pub const MESSAGE_TERMINATOR: [u8; 2] = [0x07, b'\n'];
119
120#[derive(Default, Debug, PartialEq, Clone, Deserialize, Serialize)]
121pub struct Custom {
122 pub name: String,
124 pub command: String,
126 pub args: Vec<String>,
128}
129
130impl fmt::Display for Custom {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 write!(f, "custom ({})", self.name)
133 }
134}
135
136#[derive(Debug, Serialize)]
138#[serde(rename_all = "kebab-case")]
139enum CnfToCustom {
140 CommandResponse {
141 stdout: String,
142 stderr: String,
143 exit_code: i32,
144 },
145}
146
147#[derive(Debug, Deserialize)]
149#[serde(rename_all = "kebab-case")]
150pub enum CustomToCnf {
151 Execute(CommandLine),
152
153 Results(Vec<Candidate>),
154
155 Error(String),
156}
157
158#[derive(Debug, Default)]
160struct MessageDecoder {}
161
162impl MessageDecoder {
163 pub fn new() -> Self {
164 Self::default()
165 }
166}
167
168impl Decoder for MessageDecoder {
169 type Item = CustomToCnf;
170 type Error = DecoderError;
171
172 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
173 let Some(frame_end_pos) = src
174 .windows(MESSAGE_TERMINATOR.len())
175 .position(|window| window == MESSAGE_TERMINATOR)
176 else {
177 return Ok(None);
178 };
179
180 let message = &src[0..frame_end_pos].to_vec();
181 src.advance(frame_end_pos + 1);
182
183 let msg =
184 serde_json::from_slice::<CustomToCnf>(message).map_err(DecoderError::Deserialize)?;
185 Ok(Some(msg))
186 }
187}
188
189#[async_trait]
190impl IsProvider for Custom {
191 async fn search_internal(
192 &self,
193 command: &str,
194 target_env: Arc<Environment>,
195 ) -> ProviderResult<Vec<Candidate>> {
196 let mut result: Vec<Candidate> = vec![];
197
198 let mut child = Command::new(&self.command)
199 .args(&self.args)
200 .arg(command)
201 .kill_on_drop(true)
202 .stdin(Stdio::piped())
203 .stdout(Stdio::piped())
204 .stderr(Stdio::null())
205 .spawn()
206 .map_err(|e| match e.kind() {
207 std::io::ErrorKind::NotFound => {
208 ProviderError::Requirements(self.command.to_string())
209 }
210 _ => ProviderError::ApplicationError(anyhow::Error::new(e)),
211 })?;
212
213 let mut stdin = child.stdin.take().ok_or(Error::NoStdin)?;
214 let mut stdout = tokio_util::codec::FramedRead::new(
215 child.stdout.take().ok_or(Error::NoStdout)?,
216 MessageDecoder::new(),
217 );
218
219 while let Some(message) = stdout.next().await {
222 match message.map_err(Error::Decode)? {
223 CustomToCnf::Execute(commandline) => {
224 let output = target_env.output_of(commandline).await;
225
226 let message = match output {
227 Ok(stdout) => CnfToCustom::CommandResponse {
228 stdout,
229 stderr: "".to_string(),
230 exit_code: 0,
231 },
232 Err(ExecutionError::NonZero { output, .. }) => {
233 CnfToCustom::CommandResponse {
234 stdout: String::from_utf8(output.stdout)
235 .map_err(Error::InvalidUtf8)?,
236 stderr: String::from_utf8(output.stderr)
237 .map_err(Error::InvalidUtf8)?,
238 exit_code: output.status.code().unwrap_or(256),
239 }
240 }
241 _ => return output.map(|_| vec![]).map_err(ProviderError::from),
242 };
243 let mut response = serde_json::to_vec(&message)
244 .with_context(|| format!("failed to send response to provider {}", self))
245 .map_err(ProviderError::ApplicationError)?;
246 response.push(MESSAGE_TERMINATOR[0]);
247 response.push(MESSAGE_TERMINATOR[1]);
248 stdin
249 .write_all(&response)
250 .await
251 .map_err(BrokenStdinError)
252 .map_err(Error::BrokenStdin)?;
253 }
254 CustomToCnf::Results(results) => {
255 result = results;
256 break;
257 }
258 CustomToCnf::Error(error) => {
259 return Err(Error::Child(error).into());
260 }
261 }
262 }
263
264 stdin
265 .shutdown()
266 .await
267 .with_context(|| format!("failed to close stdin of custom provider '{}'", self.name))?;
268 let _ = child
269 .wait()
270 .await
271 .with_context(|| format!("child process of '{}' terminated unexpectedly", self))
272 .to_log();
273
274 Ok(result)
275 }
276}