1use serde::de::DeserializeOwned;
2use std::cell::RefCell;
3use std::io::BufRead;
4use std::io::ErrorKind;
5use std::io::Read;
6use std::io::Write;
7use std::path::Path;
8use std::path::PathBuf;
9use std::process::Child;
10use std::process::ChildStderr;
11use std::process::Command;
12use std::process::Stdio;
13use std::rc::Rc;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::oneshot;
17use tokio_util::sync::CancellationToken;
18
19use super::PLUGIN_SCHEMA_VERSION;
20use super::messages::CheckConfigUpdatesMessageBody;
21use super::messages::CheckConfigUpdatesResponseBody;
22use super::messages::FormatMessageBody;
23use super::messages::HostFormatMessageBody;
24use super::messages::MessageBody;
25use super::messages::ProcessPluginMessage;
26use super::messages::RegisterConfigMessageBody;
27use super::messages::ResponseBody;
28use crate::async_runtime::DropGuardAction;
29use crate::async_runtime::LocalBoxFuture;
30use crate::communication::AtomicFlag;
31use crate::communication::IdGenerator;
32use crate::communication::MessageReader;
33use crate::communication::MessageWriter;
34use crate::communication::RcIdStore;
35use crate::communication::SingleThreadMessageWriter;
36use crate::configuration::ConfigKeyMap;
37use crate::configuration::ConfigurationDiagnostic;
38use crate::configuration::GlobalConfiguration;
39use crate::plugins::ConfigChange;
40use crate::plugins::CriticalFormatError;
41use crate::plugins::FileMatchingInfo;
42use crate::plugins::FormatConfigId;
43use crate::plugins::FormatError;
44use crate::plugins::FormatRange;
45use crate::plugins::FormatResult;
46use crate::plugins::HostFormatRequest;
47use crate::plugins::NullCancellationToken;
48use crate::plugins::PluginInfo;
49use crate::plugins::error_to_string;
50
51type Result<T> = std::result::Result<T, FormatError>;
52
53type DprintCancellationToken = Arc<dyn super::super::CancellationToken>;
54
55#[derive(Debug, thiserror::Error)]
57enum SchemaVersionError {
58 #[error("Failed asking for schema version: {0}")]
59 Ask(std::io::Error),
60 #[error("Failed flushing schema version request: {0}")]
61 Flush(std::io::Error),
62 #[error("Could not read success response: {0}")]
63 ReadAcknowledgement(std::io::Error),
64 #[error("Plugin response was unexpected ({0}).")]
65 UnexpectedAcknowledgement(u32),
66 #[error("Could not read schema version: {0}")]
67 ReadVersion(std::io::Error),
68}
69
70#[derive(Debug, thiserror::Error)]
72enum CommunicatorError {
73 #[error("Error starting {executable} with args [{args}]. {error}")]
74 StartProcess { executable: String, args: String, error: std::io::Error },
75 #[error("Failed plugin schema verification. This may indicate you are using an old version of the dprint CLI or plugin and should upgrade")]
76 SchemaVerification(#[from] SchemaVersionError),
77 #[error(
78 "This plugin is too old to run in the dprint CLI and you will need to manually upgrade it (version was {actual}, but expected {expected}).\n\nUpgrade instructions: https://github.com/dprint/dprint/issues/731"
79 )]
80 PluginTooOld { actual: u32, expected: u32 },
81 #[error("Your dprint CLI is too old to run this plugin (version was {actual}, but expected {expected}). Try running: dprint upgrade")]
82 CliTooOld { actual: u32, expected: u32 },
83 #[error("Error waiting on message ({message_id}). {error}")]
84 WaitOnMessage {
85 message_id: u32,
86 error: tokio::sync::oneshot::error::RecvError,
87 },
88 #[error("Unexpected data channel for success response: {0}")]
89 UnexpectedDataForSuccess(u32),
90 #[error("Unexpected format channel for success response: {0}")]
91 UnexpectedFormatForSuccess(u32),
92 #[error("Unexpected success channel for data response: {0}")]
93 UnexpectedSuccessForData(u32),
94 #[error("Unexpected format channel for data response: {0}")]
95 UnexpectedFormatForData(u32),
96 #[error("Unexpected success channel for format response: {0}")]
97 UnexpectedSuccessForFormat(u32),
98 #[error("Unexpected data channel for format response: {0}")]
99 UnexpectedDataForFormat(u32),
100 #[error("Unknown message kind: {0}")]
101 UnknownMessageKind(u32),
102 #[error("Could not find host format callback for message id: {0}")]
103 HostFormatCallbackNotFound(u32),
104}
105
106impl From<CommunicatorError> for FormatError {
107 fn from(err: CommunicatorError) -> Self {
108 FormatError::new(err)
109 }
110}
111
112pub type HostFormatCallback = Rc<dyn Fn(HostFormatRequest) -> LocalBoxFuture<'static, FormatResult>>;
113
114pub struct ProcessPluginCommunicatorFormatRequest {
115 pub file_path: PathBuf,
116 pub file_bytes: Vec<u8>,
117 pub range: FormatRange,
118 pub config_id: FormatConfigId,
119 pub override_config: ConfigKeyMap,
120 pub on_host_format: HostFormatCallback,
121 pub token: DprintCancellationToken,
122}
123
124enum MessageResponseChannel {
125 Acknowledgement(oneshot::Sender<Result<()>>),
126 Data(oneshot::Sender<Result<Vec<u8>>>),
127 Format(oneshot::Sender<Result<Option<Vec<u8>>>>),
128}
129
130struct Context {
131 stdin_writer: SingleThreadMessageWriter<ProcessPluginMessage>,
132 shutdown_flag: Arc<AtomicFlag>,
133 id_generator: IdGenerator,
134 messages: RcIdStore<MessageResponseChannel>,
135 format_request_tokens: RcIdStore<Arc<CancellationToken>>,
136 host_format_callbacks: RcIdStore<HostFormatCallback>,
137}
138
139pub struct ProcessPluginCommunicator {
141 child: RefCell<Option<Child>>,
142 context: Rc<Context>,
143}
144
145impl Drop for ProcessPluginCommunicator {
146 fn drop(&mut self) {
147 self.kill();
148 }
149}
150
151impl ProcessPluginCommunicator {
152 pub async fn new(executable_file_path: &Path, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
153 ProcessPluginCommunicator::new_internal(executable_file_path, false, on_std_err).await
154 }
155
156 pub async fn new_with_init(executable_file_path: &Path, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
158 ProcessPluginCommunicator::new_internal(executable_file_path, true, on_std_err).await
159 }
160
161 async fn new_internal(executable_file_path: &Path, is_init: bool, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
162 let mut args = vec!["--parent-pid".to_string(), std::process::id().to_string()];
163 if is_init {
164 args.push("--init".to_string());
165 }
166
167 let shutdown_flag = Arc::new(AtomicFlag::default());
168 let mut child = Command::new(executable_file_path)
169 .args(&args)
170 .stdin(Stdio::piped())
171 .stderr(Stdio::piped())
172 .stdout(Stdio::piped())
173 .spawn()
174 .map_err(|err| CommunicatorError::StartProcess {
175 executable: executable_file_path.display().to_string(),
176 args: args.join(" "),
177 error: err,
178 })?;
179
180 let stderr = child.stderr.take().unwrap();
182 crate::async_runtime::spawn_blocking({
183 let shutdown_flag = shutdown_flag.clone();
184 let on_std_err = on_std_err.clone();
185 move || {
186 std_err_redirect(shutdown_flag, stderr, on_std_err);
187 }
188 });
189
190 let mut stdout_reader = MessageReader::new(child.stdout.take().unwrap());
192 let mut stdin_writer = MessageWriter::new(child.stdin.take().unwrap());
193
194 let (mut stdout_reader, stdin_writer, schema_version) = crate::async_runtime::spawn_blocking(move || {
195 let schema_version = get_plugin_schema_version(&mut stdout_reader, &mut stdin_writer).map_err(CommunicatorError::SchemaVerification)?;
196 Ok::<_, FormatError>((stdout_reader, stdin_writer, schema_version))
197 })
198 .await??;
199
200 if schema_version != PLUGIN_SCHEMA_VERSION {
201 let _ = child.kill();
203 let err = if schema_version < PLUGIN_SCHEMA_VERSION {
204 CommunicatorError::PluginTooOld {
205 actual: schema_version,
206 expected: PLUGIN_SCHEMA_VERSION,
207 }
208 } else {
209 CommunicatorError::CliTooOld {
210 actual: schema_version,
211 expected: PLUGIN_SCHEMA_VERSION,
212 }
213 };
214 return Err(err.into());
215 }
216
217 let stdin_writer = SingleThreadMessageWriter::for_stdin(stdin_writer);
218 let context = Rc::new(Context {
219 id_generator: Default::default(),
220 shutdown_flag,
221 stdin_writer,
222 messages: Default::default(),
223 format_request_tokens: Default::default(),
224 host_format_callbacks: Default::default(),
225 });
226
227 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
229 crate::async_runtime::spawn_blocking({
230 let shutdown_flag = context.shutdown_flag.clone();
231 let on_std_err = on_std_err.clone();
232 move || {
233 loop {
234 match ProcessPluginMessage::read(&mut stdout_reader) {
235 Ok(message) => {
236 if tx.send(message).is_err() {
237 break; }
239 }
240 Err(err) if err.kind() == ErrorKind::BrokenPipe => {
241 break;
242 }
243 Err(err) => {
244 if !shutdown_flag.is_raised() {
245 on_std_err(format!("Error reading stdout message: {:#}", err));
246 }
247 break;
248 }
249 }
250 }
251 }
252 });
253 crate::async_runtime::spawn({
254 let context = context.clone();
255 async move {
256 while let Some(message) = rx.recv().await {
257 if let Err(err) = handle_stdout_message(message, &context) {
258 if !context.shutdown_flag.is_raised() {
259 on_std_err(format!("Error reading stdout message: {:#}", err));
260 }
261 break;
262 }
263 }
264 context.messages.take_all();
266 }
267 });
268
269 Ok(Self {
270 child: RefCell::new(Some(child)),
271 context,
272 })
273 }
274
275 pub async fn shutdown(&self) {
277 if self.context.shutdown_flag.raise() {
278 tokio::select! {
280 _ = self.send_with_acknowledgement(MessageBody::Close) => {}
284 _ = tokio::time::sleep(Duration::from_millis(250)) => {
285 self.kill();
286 }
287 }
288 } else {
289 self.kill();
290 }
291 }
292
293 pub fn kill(&self) {
294 self.context.shutdown_flag.raise();
295 if let Some(mut child) = self.child.borrow_mut().take() {
296 let _ignore = child.kill();
297 }
298 }
299
300 pub async fn register_config(&self, config_id: FormatConfigId, global_config: &GlobalConfiguration, plugin_config: &ConfigKeyMap) -> Result<()> {
301 let global_config = serde_json::to_vec(global_config)?;
302 let plugin_config = serde_json::to_vec(plugin_config)?;
303 self
304 .send_with_acknowledgement(MessageBody::RegisterConfig(RegisterConfigMessageBody {
305 config_id,
306 global_config,
307 plugin_config,
308 }))
309 .await?;
310 Ok(())
311 }
312
313 pub async fn release_config(&self, config_id: FormatConfigId) -> Result<()> {
314 self.send_with_acknowledgement(MessageBody::ReleaseConfig(config_id)).await?;
315 Ok(())
316 }
317
318 pub async fn ask_is_alive(&self) -> bool {
319 self.send_with_acknowledgement(MessageBody::IsAlive).await.is_ok()
320 }
321
322 pub async fn plugin_info(&self) -> Result<PluginInfo> {
323 self.send_receiving_data(MessageBody::GetPluginInfo).await
324 }
325
326 pub async fn license_text(&self) -> Result<String> {
327 self.send_receiving_string(MessageBody::GetLicenseText).await
328 }
329
330 pub async fn resolved_config(&self, config_id: FormatConfigId) -> Result<String> {
331 self.send_receiving_string(MessageBody::GetResolvedConfig(config_id)).await
332 }
333
334 pub async fn file_matching_info(&self, config_id: FormatConfigId) -> Result<FileMatchingInfo> {
335 self.send_receiving_data(MessageBody::GetFileMatchingInfo(config_id)).await
336 }
337
338 pub async fn config_diagnostics(&self, config_id: FormatConfigId) -> Result<Vec<ConfigurationDiagnostic>> {
339 self.send_receiving_data(MessageBody::GetConfigDiagnostics(config_id)).await
340 }
341
342 pub async fn check_config_updates(&self, message: &CheckConfigUpdatesMessageBody) -> Result<Vec<ConfigChange>> {
343 let bytes = serde_json::to_vec(&message)?;
344 let response: CheckConfigUpdatesResponseBody = self.send_receiving_data(MessageBody::CheckConfigUpdates(bytes)).await?;
345 Ok(response.changes)
346 }
347
348 pub async fn format_text(&self, request: ProcessPluginCommunicatorFormatRequest) -> FormatResult {
349 let (tx, rx) = oneshot::channel::<Result<Option<Vec<u8>>>>();
350
351 let message_id = self.context.id_generator.next();
352 let store_guard = self.context.host_format_callbacks.store_with_guard(message_id, request.on_host_format);
353 let maybe_result = self
354 .send_message_with_id(
355 message_id,
356 MessageBody::Format(FormatMessageBody {
357 file_path: request.file_path,
358 file_bytes: request.file_bytes,
359 range: request.range,
360 config_id: request.config_id,
361 override_config: serde_json::to_vec(&request.override_config).unwrap(),
362 }),
363 MessageResponseChannel::Format(tx),
364 rx,
365 request.token.clone(),
366 )
367 .await;
368
369 drop(store_guard); if request.token.is_cancelled() {
372 Ok(None)
373 } else {
374 match maybe_result {
375 Ok(result) => result,
376 Err(err) => Err(CriticalFormatError(err).into()),
377 }
378 }
379 }
380
381 pub async fn is_process_alive(&self) -> bool {
383 if self.context.shutdown_flag.is_raised() {
384 false
385 } else {
386 self.ask_is_alive().await
387 }
388 }
389
390 async fn send_with_acknowledgement(&self, body: MessageBody) -> Result<()> {
391 let (tx, rx) = oneshot::channel::<Result<()>>();
392 self
393 .send_message(body, MessageResponseChannel::Acknowledgement(tx), rx, Arc::new(NullCancellationToken))
394 .await?
395 }
396
397 async fn send_receiving_string(&self, body: MessageBody) -> Result<String> {
398 let data = self.send_receiving_bytes(body).await??;
399 Ok(String::from_utf8(data)?)
400 }
401
402 async fn send_receiving_data<T: DeserializeOwned>(&self, body: MessageBody) -> Result<T> {
403 let data = self.send_receiving_bytes(body).await??;
404 Ok(serde_json::from_slice(&data)?)
405 }
406
407 async fn send_receiving_bytes(&self, body: MessageBody) -> Result<Result<Vec<u8>>> {
408 let (tx, rx) = oneshot::channel::<Result<Vec<u8>>>();
409 self
410 .send_message(body, MessageResponseChannel::Data(tx), rx, Arc::new(NullCancellationToken))
411 .await
412 }
413
414 async fn send_message<T: Default>(
415 &self,
416 body: MessageBody,
417 response_channel: MessageResponseChannel,
418 receiver: oneshot::Receiver<Result<T>>,
419 token: Arc<dyn super::super::CancellationToken>,
420 ) -> Result<Result<T>> {
421 let message_id = self.context.id_generator.next();
422 self.send_message_with_id(message_id, body, response_channel, receiver, token).await
423 }
424
425 async fn send_message_with_id<T: Default>(
426 &self,
427 message_id: u32,
428 body: MessageBody,
429 response_channel: MessageResponseChannel,
430 receiver: oneshot::Receiver<Result<T>>,
431 token: Arc<dyn super::super::CancellationToken>,
432 ) -> Result<Result<T>> {
433 let mut drop_guard = DropGuardAction::new(|| {
434 self.context.messages.take(message_id);
436 let _ = self.context.stdin_writer.send(ProcessPluginMessage {
438 id: self.context.id_generator.next(),
439 body: MessageBody::CancelFormat(message_id),
440 });
441 });
442
443 self.context.messages.store(message_id, response_channel);
444 self.context.stdin_writer.send(ProcessPluginMessage { id: message_id, body })?;
445 tokio::select! {
446 _ = token.wait_cancellation() => {
447 drop(drop_guard); Ok(Ok(Default::default()))
449 }
450 response = receiver => {
451 drop_guard.forget(); match response {
453 Ok(data) => Ok(data),
454 Err(err) => Err(CommunicatorError::WaitOnMessage { message_id, error: err }.into()),
455 }
456 }
457 }
458 }
459}
460
461fn get_plugin_schema_version<TRead: Read + Unpin, TWrite: Write + Unpin>(
462 reader: &mut MessageReader<TRead>,
463 writer: &mut MessageWriter<TWrite>,
464) -> std::result::Result<u32, SchemaVersionError> {
465 writer.send_u32(0).map_err(SchemaVersionError::Ask)?; writer.flush().map_err(SchemaVersionError::Flush)?;
468 let acknowledgement_response = reader.read_u32().map_err(SchemaVersionError::ReadAcknowledgement)?;
469 if acknowledgement_response != 0 {
470 return Err(SchemaVersionError::UnexpectedAcknowledgement(acknowledgement_response));
471 }
472 reader.read_u32().map_err(SchemaVersionError::ReadVersion)
473}
474
475fn std_err_redirect(shutdown_flag: Arc<AtomicFlag>, stderr: ChildStderr, on_std_err: impl Fn(String) + Send + Sync + 'static) {
476 let reader = std::io::BufReader::new(stderr);
477 for line in reader.lines() {
478 match line {
479 Ok(line) => on_std_err(line),
480 Err(err) => {
481 if shutdown_flag.is_raised() || err.kind() == ErrorKind::BrokenPipe {
482 return;
483 } else {
484 on_std_err(format!("Error reading line from process plugin stderr. {:#}", err));
485 }
486 }
487 }
488 }
489}
490
491fn handle_stdout_message(message: ProcessPluginMessage, context: &Rc<Context>) -> Result<()> {
492 match message.body {
493 MessageBody::Success(message_id) => match context.messages.take(message_id) {
494 Some(MessageResponseChannel::Acknowledgement(channel)) => {
495 let _ignore = channel.send(Ok(()));
496 }
497 Some(MessageResponseChannel::Data(channel)) => {
498 let _ignore = channel.send(Err(CommunicatorError::UnexpectedDataForSuccess(message_id).into()));
499 }
500 Some(MessageResponseChannel::Format(channel)) => {
501 let _ignore = channel.send(Err(CommunicatorError::UnexpectedFormatForSuccess(message_id).into()));
502 }
503 None => {}
504 },
505 MessageBody::DataResponse(response) => match context.messages.take(response.message_id) {
506 Some(MessageResponseChannel::Acknowledgement(channel)) => {
507 let _ignore = channel.send(Err(CommunicatorError::UnexpectedSuccessForData(response.message_id).into()));
508 }
509 Some(MessageResponseChannel::Data(channel)) => {
510 let _ignore = channel.send(Ok(response.data));
511 }
512 Some(MessageResponseChannel::Format(channel)) => {
513 let _ignore = channel.send(Err(CommunicatorError::UnexpectedFormatForData(response.message_id).into()));
514 }
515 None => {}
516 },
517 MessageBody::Error(response) => {
518 let err: FormatError = String::from_utf8_lossy(&response.data).into_owned().into();
519 match context.messages.take(response.message_id) {
520 Some(MessageResponseChannel::Acknowledgement(channel)) => {
521 let _ignore = channel.send(Err(err));
522 }
523 Some(MessageResponseChannel::Data(channel)) => {
524 let _ignore = channel.send(Err(err));
525 }
526 Some(MessageResponseChannel::Format(channel)) => {
527 let _ignore = channel.send(Err(err));
528 }
529 None => {}
530 }
531 }
532 MessageBody::FormatResponse(response) => match context.messages.take(response.message_id) {
533 Some(MessageResponseChannel::Acknowledgement(channel)) => {
534 let _ignore = channel.send(Err(CommunicatorError::UnexpectedSuccessForFormat(response.message_id).into()));
535 }
536 Some(MessageResponseChannel::Data(channel)) => {
537 let _ignore = channel.send(Err(CommunicatorError::UnexpectedDataForFormat(response.message_id).into()));
538 }
539 Some(MessageResponseChannel::Format(channel)) => {
540 let _ignore = channel.send(Ok(response.data));
541 }
542 None => {}
543 },
544 MessageBody::CancelFormat(message_id) => {
545 if let Some(token) = context.format_request_tokens.take(message_id) {
546 token.cancel();
547 }
548 context.host_format_callbacks.take(message_id);
549 }
551 MessageBody::HostFormat(body) => {
552 let context = context.clone();
555 crate::async_runtime::spawn(async move {
556 let result = host_format(context.clone(), message.id, body).await;
557
558 let _ignore = context.stdin_writer.send(ProcessPluginMessage {
561 id: context.id_generator.next(),
562 body: match result {
563 Ok(result) => MessageBody::FormatResponse(ResponseBody {
564 message_id: message.id,
565 data: result,
566 }),
567 Err(err) => MessageBody::Error(ResponseBody {
568 message_id: message.id,
569 data: error_to_string(&err).into_bytes(),
570 }),
571 },
572 });
573 });
574 }
575 MessageBody::IsAlive => {
576 let _ = context.stdin_writer.send(ProcessPluginMessage {
578 id: context.id_generator.next(),
579 body: MessageBody::Success(message.id),
580 });
581 }
582 MessageBody::Format(_)
583 | MessageBody::Close
584 | MessageBody::GetPluginInfo
585 | MessageBody::GetLicenseText
586 | MessageBody::RegisterConfig(_)
587 | MessageBody::ReleaseConfig(_)
588 | MessageBody::GetConfigDiagnostics(_)
589 | MessageBody::GetFileMatchingInfo(_)
590 | MessageBody::GetResolvedConfig(_)
591 | MessageBody::CheckConfigUpdates(_) => {
592 let _ = context.stdin_writer.send(ProcessPluginMessage {
593 id: context.id_generator.next(),
594 body: MessageBody::Error(ResponseBody {
595 message_id: message.id,
596 data: "Unsupported plugin to CLI message.".as_bytes().to_vec(),
597 }),
598 });
599 }
600 MessageBody::Unknown(message_kind) => {
603 return Err(CommunicatorError::UnknownMessageKind(message_kind).into());
604 }
605 }
606
607 Ok(())
608}
609
610async fn host_format(context: Rc<Context>, message_id: u32, body: HostFormatMessageBody) -> FormatResult {
611 let Some(callback) = context.host_format_callbacks.get_cloned(body.original_message_id) else {
612 return FormatResult::Err(CommunicatorError::HostFormatCallbackNotFound(body.original_message_id).into());
613 };
614
615 let token = Arc::new(CancellationToken::new());
616 let store_guard = context.format_request_tokens.store_with_guard(message_id, token.clone());
617 let result = callback(HostFormatRequest {
618 file_path: body.file_path,
619 file_bytes: body.file_text,
620 range: body.range,
621 override_config: serde_json::from_slice(&body.override_config).unwrap(),
622 token,
623 })
624 .await;
625 drop(store_guard); result
627}