dprint_core/plugins/process/
communicator.rs

1use anyhow::anyhow;
2use anyhow::bail;
3use anyhow::Context as AnyhowContext;
4use anyhow::Result;
5use serde::de::DeserializeOwned;
6use std::cell::RefCell;
7use std::io::BufRead;
8use std::io::ErrorKind;
9use std::io::Read;
10use std::io::Write;
11use std::path::Path;
12use std::path::PathBuf;
13use std::process::Child;
14use std::process::ChildStderr;
15use std::process::Command;
16use std::process::Stdio;
17use std::rc::Rc;
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::oneshot;
21use tokio_util::sync::CancellationToken;
22
23use super::messages::CheckConfigUpdatesMessageBody;
24use super::messages::CheckConfigUpdatesResponseBody;
25use super::messages::FormatMessageBody;
26use super::messages::HostFormatMessageBody;
27use super::messages::MessageBody;
28use super::messages::ProcessPluginMessage;
29use super::messages::RegisterConfigMessageBody;
30use super::messages::ResponseBody;
31use super::PLUGIN_SCHEMA_VERSION;
32use crate::async_runtime::DropGuardAction;
33use crate::async_runtime::LocalBoxFuture;
34use crate::communication::AtomicFlag;
35use crate::communication::IdGenerator;
36use crate::communication::MessageReader;
37use crate::communication::MessageWriter;
38use crate::communication::RcIdStore;
39use crate::communication::SingleThreadMessageWriter;
40use crate::configuration::ConfigKeyMap;
41use crate::configuration::ConfigurationDiagnostic;
42use crate::configuration::GlobalConfiguration;
43use crate::plugins::ConfigChange;
44use crate::plugins::CriticalFormatError;
45use crate::plugins::FileMatchingInfo;
46use crate::plugins::FormatConfigId;
47use crate::plugins::FormatRange;
48use crate::plugins::FormatResult;
49use crate::plugins::HostFormatRequest;
50use crate::plugins::NullCancellationToken;
51use crate::plugins::PluginInfo;
52
53type DprintCancellationToken = Arc<dyn super::super::CancellationToken>;
54
55pub type HostFormatCallback = Rc<dyn Fn(HostFormatRequest) -> LocalBoxFuture<'static, FormatResult>>;
56
57pub struct ProcessPluginCommunicatorFormatRequest {
58  pub file_path: PathBuf,
59  pub file_bytes: Vec<u8>,
60  pub range: FormatRange,
61  pub config_id: FormatConfigId,
62  pub override_config: ConfigKeyMap,
63  pub on_host_format: HostFormatCallback,
64  pub token: DprintCancellationToken,
65}
66
67enum MessageResponseChannel {
68  Acknowledgement(oneshot::Sender<Result<()>>),
69  Data(oneshot::Sender<Result<Vec<u8>>>),
70  Format(oneshot::Sender<Result<Option<Vec<u8>>>>),
71}
72
73struct Context {
74  stdin_writer: SingleThreadMessageWriter<ProcessPluginMessage>,
75  shutdown_flag: Arc<AtomicFlag>,
76  id_generator: IdGenerator,
77  messages: RcIdStore<MessageResponseChannel>,
78  format_request_tokens: RcIdStore<Arc<CancellationToken>>,
79  host_format_callbacks: RcIdStore<HostFormatCallback>,
80}
81
82/// Communicates with a process plugin.
83pub struct ProcessPluginCommunicator {
84  child: RefCell<Option<Child>>,
85  context: Rc<Context>,
86}
87
88impl Drop for ProcessPluginCommunicator {
89  fn drop(&mut self) {
90    self.kill();
91  }
92}
93
94impl ProcessPluginCommunicator {
95  pub async fn new(executable_file_path: &Path, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
96    ProcessPluginCommunicator::new_internal(executable_file_path, false, on_std_err).await
97  }
98
99  /// Provides the `--init` CLI flag to tell the process plugin to do any initialization necessary
100  pub async fn new_with_init(executable_file_path: &Path, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
101    ProcessPluginCommunicator::new_internal(executable_file_path, true, on_std_err).await
102  }
103
104  async fn new_internal(executable_file_path: &Path, is_init: bool, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
105    let mut args = vec!["--parent-pid".to_string(), std::process::id().to_string()];
106    if is_init {
107      args.push("--init".to_string());
108    }
109
110    let shutdown_flag = Arc::new(AtomicFlag::default());
111    let mut child = Command::new(executable_file_path)
112      .args(&args)
113      .stdin(Stdio::piped())
114      .stderr(Stdio::piped())
115      .stdout(Stdio::piped())
116      .spawn()
117      .map_err(|err| anyhow!("Error starting {} with args [{}]. {:#}", executable_file_path.display(), args.join(" "), err))?;
118
119    // read and output stderr prefixed
120    let stderr = child.stderr.take().unwrap();
121    crate::async_runtime::spawn_blocking({
122      let shutdown_flag = shutdown_flag.clone();
123      let on_std_err = on_std_err.clone();
124      move || {
125        std_err_redirect(shutdown_flag, stderr, on_std_err);
126      }
127    });
128
129    // verify the schema version
130    let mut stdout_reader = MessageReader::new(child.stdout.take().unwrap());
131    let mut stdin_writer = MessageWriter::new(child.stdin.take().unwrap());
132
133    let (mut stdout_reader, stdin_writer, schema_version) = crate::async_runtime::spawn_blocking(move || {
134      let schema_version = get_plugin_schema_version(&mut stdout_reader, &mut stdin_writer)
135        .context("Failed plugin schema verification. This may indicate you are using an old version of the dprint CLI or plugin and should upgrade")?;
136      Ok::<_, anyhow::Error>((stdout_reader, stdin_writer, schema_version))
137    })
138    .await??;
139
140    if schema_version != PLUGIN_SCHEMA_VERSION {
141      // kill the child to prevent it from ouputting to stderr
142      let _ = child.kill();
143      if schema_version < PLUGIN_SCHEMA_VERSION {
144        bail!(
145          "This plugin is too old to run in the dprint CLI and you will need to manually upgrade it (version was {}, but expected {}).\n\nUpgrade instructions: https://github.com/dprint/dprint/issues/731",
146          schema_version,
147          PLUGIN_SCHEMA_VERSION
148        );
149      } else {
150        bail!(
151          "Your dprint CLI is too old to run this plugin (version was {}, but expected {}). Try running: dprint upgrade",
152          schema_version,
153          PLUGIN_SCHEMA_VERSION
154        );
155      }
156    }
157
158    let stdin_writer = SingleThreadMessageWriter::for_stdin(stdin_writer);
159    let context = Rc::new(Context {
160      id_generator: Default::default(),
161      shutdown_flag,
162      stdin_writer,
163      messages: Default::default(),
164      format_request_tokens: Default::default(),
165      host_format_callbacks: Default::default(),
166    });
167
168    // read from stdout
169    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
170    crate::async_runtime::spawn_blocking({
171      let shutdown_flag = context.shutdown_flag.clone();
172      let on_std_err = on_std_err.clone();
173      move || {
174        loop {
175          match ProcessPluginMessage::read(&mut stdout_reader) {
176            Ok(message) => {
177              if tx.send(message).is_err() {
178                break; // closed
179              }
180            }
181            Err(err) if err.kind() == ErrorKind::BrokenPipe => {
182              break;
183            }
184            Err(err) => {
185              if !shutdown_flag.is_raised() {
186                on_std_err(format!("Error reading stdout message: {:#}", err));
187              }
188              break;
189            }
190          }
191        }
192      }
193    });
194    crate::async_runtime::spawn({
195      let context = context.clone();
196      async move {
197        while let Some(message) = rx.recv().await {
198          if let Err(err) = handle_stdout_message(message, &context) {
199            if !context.shutdown_flag.is_raised() {
200              on_std_err(format!("Error reading stdout message: {:#}", err));
201            }
202            break;
203          }
204        }
205        // clear out all the messages
206        context.messages.take_all();
207      }
208    });
209
210    Ok(Self {
211      child: RefCell::new(Some(child)),
212      context,
213    })
214  }
215
216  /// Perform a graceful shutdown.
217  pub async fn shutdown(&self) {
218    if self.context.shutdown_flag.raise() {
219      // attempt to exit nicely
220      tokio::select! {
221        // we wait for acknowledgement in order to give the process
222        // plugin a chance to clean up (ex. in case it has spawned
223        // any processes it needs to kill or something like that)
224        _ = self.send_with_acknowledgement(MessageBody::Close) => {}
225        _ = tokio::time::sleep(Duration::from_millis(250)) => {
226          self.kill();
227        }
228      }
229    } else {
230      self.kill();
231    }
232  }
233
234  pub fn kill(&self) {
235    self.context.shutdown_flag.raise();
236    if let Some(mut child) = self.child.borrow_mut().take() {
237      let _ignore = child.kill();
238    }
239  }
240
241  pub async fn register_config(&self, config_id: FormatConfigId, global_config: &GlobalConfiguration, plugin_config: &ConfigKeyMap) -> Result<()> {
242    let global_config = serde_json::to_vec(global_config)?;
243    let plugin_config = serde_json::to_vec(plugin_config)?;
244    self
245      .send_with_acknowledgement(MessageBody::RegisterConfig(RegisterConfigMessageBody {
246        config_id,
247        global_config,
248        plugin_config,
249      }))
250      .await?;
251    Ok(())
252  }
253
254  pub async fn release_config(&self, config_id: FormatConfigId) -> Result<()> {
255    self.send_with_acknowledgement(MessageBody::ReleaseConfig(config_id)).await?;
256    Ok(())
257  }
258
259  pub async fn ask_is_alive(&self) -> bool {
260    self.send_with_acknowledgement(MessageBody::IsAlive).await.is_ok()
261  }
262
263  pub async fn plugin_info(&self) -> Result<PluginInfo> {
264    self.send_receiving_data(MessageBody::GetPluginInfo).await
265  }
266
267  pub async fn license_text(&self) -> Result<String> {
268    self.send_receiving_string(MessageBody::GetLicenseText).await
269  }
270
271  pub async fn resolved_config(&self, config_id: FormatConfigId) -> Result<String> {
272    self.send_receiving_string(MessageBody::GetResolvedConfig(config_id)).await
273  }
274
275  pub async fn file_matching_info(&self, config_id: FormatConfigId) -> Result<FileMatchingInfo> {
276    self.send_receiving_data(MessageBody::GetFileMatchingInfo(config_id)).await
277  }
278
279  pub async fn config_diagnostics(&self, config_id: FormatConfigId) -> Result<Vec<ConfigurationDiagnostic>> {
280    self.send_receiving_data(MessageBody::GetConfigDiagnostics(config_id)).await
281  }
282
283  pub async fn check_config_updates(&self, message: &CheckConfigUpdatesMessageBody) -> Result<Vec<ConfigChange>> {
284    let bytes = serde_json::to_vec(&message)?;
285    let response: CheckConfigUpdatesResponseBody = self.send_receiving_data(MessageBody::CheckConfigUpdates(bytes)).await?;
286    Ok(response.changes)
287  }
288
289  pub async fn format_text(&self, request: ProcessPluginCommunicatorFormatRequest) -> FormatResult {
290    let (tx, rx) = oneshot::channel::<Result<Option<Vec<u8>>>>();
291
292    let message_id = self.context.id_generator.next();
293    let store_guard = self.context.host_format_callbacks.store_with_guard(message_id, request.on_host_format);
294    let maybe_result = self
295      .send_message_with_id(
296        message_id,
297        MessageBody::Format(FormatMessageBody {
298          file_path: request.file_path,
299          file_bytes: request.file_bytes,
300          range: request.range,
301          config_id: request.config_id,
302          override_config: serde_json::to_vec(&request.override_config).unwrap(),
303        }),
304        MessageResponseChannel::Format(tx),
305        rx,
306        request.token.clone(),
307      )
308      .await;
309
310    drop(store_guard); // explicit for clarity
311
312    if request.token.is_cancelled() {
313      Ok(None)
314    } else {
315      match maybe_result {
316        Ok(result) => result,
317        Err(err) => Err(CriticalFormatError(err).into()),
318      }
319    }
320  }
321
322  /// Checks if the process is functioning.
323  pub async fn is_process_alive(&self) -> bool {
324    if self.context.shutdown_flag.is_raised() {
325      false
326    } else {
327      self.ask_is_alive().await
328    }
329  }
330
331  async fn send_with_acknowledgement(&self, body: MessageBody) -> Result<()> {
332    let (tx, rx) = oneshot::channel::<Result<()>>();
333    self
334      .send_message(body, MessageResponseChannel::Acknowledgement(tx), rx, Arc::new(NullCancellationToken))
335      .await?
336  }
337
338  async fn send_receiving_string(&self, body: MessageBody) -> Result<String> {
339    let data = self.send_receiving_bytes(body).await??;
340    Ok(String::from_utf8(data)?)
341  }
342
343  async fn send_receiving_data<T: DeserializeOwned>(&self, body: MessageBody) -> Result<T> {
344    let data = self.send_receiving_bytes(body).await??;
345    Ok(serde_json::from_slice(&data)?)
346  }
347
348  async fn send_receiving_bytes(&self, body: MessageBody) -> Result<Result<Vec<u8>>> {
349    let (tx, rx) = oneshot::channel::<Result<Vec<u8>>>();
350    self
351      .send_message(body, MessageResponseChannel::Data(tx), rx, Arc::new(NullCancellationToken))
352      .await
353  }
354
355  async fn send_message<T: Default>(
356    &self,
357    body: MessageBody,
358    response_channel: MessageResponseChannel,
359    receiver: oneshot::Receiver<Result<T>>,
360    token: Arc<dyn super::super::CancellationToken>,
361  ) -> Result<Result<T>> {
362    let message_id = self.context.id_generator.next();
363    self.send_message_with_id(message_id, body, response_channel, receiver, token).await
364  }
365
366  async fn send_message_with_id<T: Default>(
367    &self,
368    message_id: u32,
369    body: MessageBody,
370    response_channel: MessageResponseChannel,
371    receiver: oneshot::Receiver<Result<T>>,
372    token: Arc<dyn super::super::CancellationToken>,
373  ) -> Result<Result<T>> {
374    let mut drop_guard = DropGuardAction::new(|| {
375      // clear up memory
376      self.context.messages.take(message_id);
377      // send cancellation to the client
378      let _ = self.context.stdin_writer.send(ProcessPluginMessage {
379        id: self.context.id_generator.next(),
380        body: MessageBody::CancelFormat(message_id),
381      });
382    });
383
384    self.context.messages.store(message_id, response_channel);
385    self.context.stdin_writer.send(ProcessPluginMessage { id: message_id, body })?;
386    tokio::select! {
387      _ = token.wait_cancellation() => {
388        drop(drop_guard); // explicit
389        Ok(Ok(Default::default()))
390      }
391      response = receiver => {
392        drop_guard.forget(); // we completed, so don't run the drop guard
393        match response {
394          Ok(data) => Ok(data),
395          Err(err) => {
396            bail!("Error waiting on message ({}). {:#}", message_id, err)
397          }
398        }
399      }
400    }
401  }
402}
403
404fn get_plugin_schema_version<TRead: Read + Unpin, TWrite: Write + Unpin>(reader: &mut MessageReader<TRead>, writer: &mut MessageWriter<TWrite>) -> Result<u32> {
405  // since this is the setup, use a lot of contexts to find exactly where it failed
406  writer.send_u32(0).context("Failed asking for schema version.")?; // ask for schema version
407  writer.flush().context("Failed flushing schema version request.")?;
408  let acknowledgement_response = reader.read_u32().context("Could not read success response.")?;
409  if acknowledgement_response != 0 {
410    bail!("Plugin response was unexpected ({acknowledgement_response}).");
411  }
412  reader.read_u32().context("Could not read schema version.")
413}
414
415fn std_err_redirect(shutdown_flag: Arc<AtomicFlag>, stderr: ChildStderr, on_std_err: impl Fn(String) + Send + Sync + 'static) {
416  let reader = std::io::BufReader::new(stderr);
417  for line in reader.lines() {
418    match line {
419      Ok(line) => on_std_err(line),
420      Err(err) => {
421        if shutdown_flag.is_raised() || err.kind() == ErrorKind::BrokenPipe {
422          return;
423        } else {
424          on_std_err(format!("Error reading line from process plugin stderr. {:#}", err));
425        }
426      }
427    }
428  }
429}
430
431fn handle_stdout_message(message: ProcessPluginMessage, context: &Rc<Context>) -> Result<()> {
432  match message.body {
433    MessageBody::Success(message_id) => match context.messages.take(message_id) {
434      Some(MessageResponseChannel::Acknowledgement(channel)) => {
435        let _ignore = channel.send(Ok(()));
436      }
437      Some(MessageResponseChannel::Data(channel)) => {
438        let _ignore = channel.send(Err(anyhow!("Unexpected data channel for success response: {}", message_id)));
439      }
440      Some(MessageResponseChannel::Format(channel)) => {
441        let _ignore = channel.send(Err(anyhow!("Unexpected format channel for success response: {}", message_id)));
442      }
443      None => {}
444    },
445    MessageBody::DataResponse(response) => match context.messages.take(response.message_id) {
446      Some(MessageResponseChannel::Acknowledgement(channel)) => {
447        let _ignore = channel.send(Err(anyhow!("Unexpected success channel for data response: {}", response.message_id)));
448      }
449      Some(MessageResponseChannel::Data(channel)) => {
450        let _ignore = channel.send(Ok(response.data));
451      }
452      Some(MessageResponseChannel::Format(channel)) => {
453        let _ignore = channel.send(Err(anyhow!("Unexpected format channel for data response: {}", response.message_id)));
454      }
455      None => {}
456    },
457    MessageBody::Error(response) => {
458      let err = anyhow!("{}", String::from_utf8_lossy(&response.data));
459      match context.messages.take(response.message_id) {
460        Some(MessageResponseChannel::Acknowledgement(channel)) => {
461          let _ignore = channel.send(Err(err));
462        }
463        Some(MessageResponseChannel::Data(channel)) => {
464          let _ignore = channel.send(Err(err));
465        }
466        Some(MessageResponseChannel::Format(channel)) => {
467          let _ignore = channel.send(Err(err));
468        }
469        None => {}
470      }
471    }
472    MessageBody::FormatResponse(response) => match context.messages.take(response.message_id) {
473      Some(MessageResponseChannel::Acknowledgement(channel)) => {
474        let _ignore = channel.send(Err(anyhow!("Unexpected success channel for format response: {}", response.message_id)));
475      }
476      Some(MessageResponseChannel::Data(channel)) => {
477        let _ignore = channel.send(Err(anyhow!("Unexpected data channel for format response: {}", response.message_id)));
478      }
479      Some(MessageResponseChannel::Format(channel)) => {
480        let _ignore = channel.send(Ok(response.data));
481      }
482      None => {}
483    },
484    MessageBody::CancelFormat(message_id) => {
485      if let Some(token) = context.format_request_tokens.take(message_id) {
486        token.cancel();
487      }
488      context.host_format_callbacks.take(message_id);
489      // do not clear from context.messages here because the cancellation will do that
490    }
491    MessageBody::HostFormat(body) => {
492      // spawn a task to do the host formatting, then send a message back to the
493      // plugin with the result
494      let context = context.clone();
495      crate::async_runtime::spawn(async move {
496        let result = host_format(context.clone(), message.id, body).await;
497
498        // ignore failure, as this means that the process shut down
499        // at which point handling would have occurred elsewhere
500        let _ignore = context.stdin_writer.send(ProcessPluginMessage {
501          id: context.id_generator.next(),
502          body: match result {
503            Ok(result) => MessageBody::FormatResponse(ResponseBody {
504              message_id: message.id,
505              data: result,
506            }),
507            Err(err) => MessageBody::Error(ResponseBody {
508              message_id: message.id,
509              data: format!("{:#}", err).into_bytes(),
510            }),
511          },
512        });
513      });
514    }
515    MessageBody::IsAlive => {
516      // the CLI is not documented as supporting this, but we might as well respond
517      let _ = context.stdin_writer.send(ProcessPluginMessage {
518        id: context.id_generator.next(),
519        body: MessageBody::Success(message.id),
520      });
521    }
522    MessageBody::Format(_)
523    | MessageBody::Close
524    | MessageBody::GetPluginInfo
525    | MessageBody::GetLicenseText
526    | MessageBody::RegisterConfig(_)
527    | MessageBody::ReleaseConfig(_)
528    | MessageBody::GetConfigDiagnostics(_)
529    | MessageBody::GetFileMatchingInfo(_)
530    | MessageBody::GetResolvedConfig(_)
531    | MessageBody::CheckConfigUpdates(_) => {
532      let _ = context.stdin_writer.send(ProcessPluginMessage {
533        id: context.id_generator.next(),
534        body: MessageBody::Error(ResponseBody {
535          message_id: message.id,
536          data: "Unsupported plugin to CLI message.".as_bytes().to_vec(),
537        }),
538      });
539    }
540    // If encountered, process plugin should exit and
541    // the CLI should kill the process plugin.
542    MessageBody::Unknown(message_kind) => {
543      bail!("Unknown message kind: {}", message_kind);
544    }
545  }
546
547  Ok(())
548}
549
550async fn host_format(context: Rc<Context>, message_id: u32, body: HostFormatMessageBody) -> FormatResult {
551  let Some(callback) = context.host_format_callbacks.get_cloned(body.original_message_id) else {
552    return FormatResult::Err(anyhow!("Could not find host format callback for message id: {}", body.original_message_id));
553  };
554
555  let token = Arc::new(CancellationToken::new());
556  let store_guard = context.format_request_tokens.store_with_guard(message_id, token.clone());
557  let result = callback(HostFormatRequest {
558    file_path: body.file_path,
559    file_bytes: body.file_text,
560    range: body.range,
561    override_config: serde_json::from_slice(&body.override_config).unwrap(),
562    token,
563  })
564  .await;
565  drop(store_guard); // explicit for clarity
566  result
567}