1use anyhow::anyhow;
2use anyhow::bail;
3use anyhow::Context as AnyhowContext;
4use anyhow::Result;
5use serde::de::DeserializeOwned;
6use std::cell::RefCell;
7use std::io::BufRead;
8use std::io::ErrorKind;
9use std::io::Read;
10use std::io::Write;
11use std::path::Path;
12use std::path::PathBuf;
13use std::process::Child;
14use std::process::ChildStderr;
15use std::process::Command;
16use std::process::Stdio;
17use std::rc::Rc;
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::oneshot;
21use tokio_util::sync::CancellationToken;
22
23use super::messages::CheckConfigUpdatesMessageBody;
24use super::messages::CheckConfigUpdatesResponseBody;
25use super::messages::FormatMessageBody;
26use super::messages::HostFormatMessageBody;
27use super::messages::MessageBody;
28use super::messages::ProcessPluginMessage;
29use super::messages::RegisterConfigMessageBody;
30use super::messages::ResponseBody;
31use super::PLUGIN_SCHEMA_VERSION;
32use crate::async_runtime::DropGuardAction;
33use crate::async_runtime::LocalBoxFuture;
34use crate::communication::AtomicFlag;
35use crate::communication::IdGenerator;
36use crate::communication::MessageReader;
37use crate::communication::MessageWriter;
38use crate::communication::RcIdStore;
39use crate::communication::SingleThreadMessageWriter;
40use crate::configuration::ConfigKeyMap;
41use crate::configuration::ConfigurationDiagnostic;
42use crate::configuration::GlobalConfiguration;
43use crate::plugins::ConfigChange;
44use crate::plugins::CriticalFormatError;
45use crate::plugins::FileMatchingInfo;
46use crate::plugins::FormatConfigId;
47use crate::plugins::FormatRange;
48use crate::plugins::FormatResult;
49use crate::plugins::HostFormatRequest;
50use crate::plugins::NullCancellationToken;
51use crate::plugins::PluginInfo;
52
53type DprintCancellationToken = Arc<dyn super::super::CancellationToken>;
54
55pub type HostFormatCallback = Rc<dyn Fn(HostFormatRequest) -> LocalBoxFuture<'static, FormatResult>>;
56
57pub struct ProcessPluginCommunicatorFormatRequest {
58 pub file_path: PathBuf,
59 pub file_bytes: Vec<u8>,
60 pub range: FormatRange,
61 pub config_id: FormatConfigId,
62 pub override_config: ConfigKeyMap,
63 pub on_host_format: HostFormatCallback,
64 pub token: DprintCancellationToken,
65}
66
67enum MessageResponseChannel {
68 Acknowledgement(oneshot::Sender<Result<()>>),
69 Data(oneshot::Sender<Result<Vec<u8>>>),
70 Format(oneshot::Sender<Result<Option<Vec<u8>>>>),
71}
72
73struct Context {
74 stdin_writer: SingleThreadMessageWriter<ProcessPluginMessage>,
75 shutdown_flag: Arc<AtomicFlag>,
76 id_generator: IdGenerator,
77 messages: RcIdStore<MessageResponseChannel>,
78 format_request_tokens: RcIdStore<Arc<CancellationToken>>,
79 host_format_callbacks: RcIdStore<HostFormatCallback>,
80}
81
82pub struct ProcessPluginCommunicator {
84 child: RefCell<Option<Child>>,
85 context: Rc<Context>,
86}
87
88impl Drop for ProcessPluginCommunicator {
89 fn drop(&mut self) {
90 self.kill();
91 }
92}
93
94impl ProcessPluginCommunicator {
95 pub async fn new(executable_file_path: &Path, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
96 ProcessPluginCommunicator::new_internal(executable_file_path, false, on_std_err).await
97 }
98
99 pub async fn new_with_init(executable_file_path: &Path, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
101 ProcessPluginCommunicator::new_internal(executable_file_path, true, on_std_err).await
102 }
103
104 async fn new_internal(executable_file_path: &Path, is_init: bool, on_std_err: impl Fn(String) + Clone + Send + Sync + 'static) -> Result<Self> {
105 let mut args = vec!["--parent-pid".to_string(), std::process::id().to_string()];
106 if is_init {
107 args.push("--init".to_string());
108 }
109
110 let shutdown_flag = Arc::new(AtomicFlag::default());
111 let mut child = Command::new(executable_file_path)
112 .args(&args)
113 .stdin(Stdio::piped())
114 .stderr(Stdio::piped())
115 .stdout(Stdio::piped())
116 .spawn()
117 .map_err(|err| anyhow!("Error starting {} with args [{}]. {:#}", executable_file_path.display(), args.join(" "), err))?;
118
119 let stderr = child.stderr.take().unwrap();
121 crate::async_runtime::spawn_blocking({
122 let shutdown_flag = shutdown_flag.clone();
123 let on_std_err = on_std_err.clone();
124 move || {
125 std_err_redirect(shutdown_flag, stderr, on_std_err);
126 }
127 });
128
129 let mut stdout_reader = MessageReader::new(child.stdout.take().unwrap());
131 let mut stdin_writer = MessageWriter::new(child.stdin.take().unwrap());
132
133 let (mut stdout_reader, stdin_writer, schema_version) = crate::async_runtime::spawn_blocking(move || {
134 let schema_version = get_plugin_schema_version(&mut stdout_reader, &mut stdin_writer)
135 .context("Failed plugin schema verification. This may indicate you are using an old version of the dprint CLI or plugin and should upgrade")?;
136 Ok::<_, anyhow::Error>((stdout_reader, stdin_writer, schema_version))
137 })
138 .await??;
139
140 if schema_version != PLUGIN_SCHEMA_VERSION {
141 let _ = child.kill();
143 if schema_version < PLUGIN_SCHEMA_VERSION {
144 bail!(
145 "This plugin is too old to run in the dprint CLI and you will need to manually upgrade it (version was {}, but expected {}).\n\nUpgrade instructions: https://github.com/dprint/dprint/issues/731",
146 schema_version,
147 PLUGIN_SCHEMA_VERSION
148 );
149 } else {
150 bail!(
151 "Your dprint CLI is too old to run this plugin (version was {}, but expected {}). Try running: dprint upgrade",
152 schema_version,
153 PLUGIN_SCHEMA_VERSION
154 );
155 }
156 }
157
158 let stdin_writer = SingleThreadMessageWriter::for_stdin(stdin_writer);
159 let context = Rc::new(Context {
160 id_generator: Default::default(),
161 shutdown_flag,
162 stdin_writer,
163 messages: Default::default(),
164 format_request_tokens: Default::default(),
165 host_format_callbacks: Default::default(),
166 });
167
168 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
170 crate::async_runtime::spawn_blocking({
171 let shutdown_flag = context.shutdown_flag.clone();
172 let on_std_err = on_std_err.clone();
173 move || {
174 loop {
175 match ProcessPluginMessage::read(&mut stdout_reader) {
176 Ok(message) => {
177 if tx.send(message).is_err() {
178 break; }
180 }
181 Err(err) if err.kind() == ErrorKind::BrokenPipe => {
182 break;
183 }
184 Err(err) => {
185 if !shutdown_flag.is_raised() {
186 on_std_err(format!("Error reading stdout message: {:#}", err));
187 }
188 break;
189 }
190 }
191 }
192 }
193 });
194 crate::async_runtime::spawn({
195 let context = context.clone();
196 async move {
197 while let Some(message) = rx.recv().await {
198 if let Err(err) = handle_stdout_message(message, &context) {
199 if !context.shutdown_flag.is_raised() {
200 on_std_err(format!("Error reading stdout message: {:#}", err));
201 }
202 break;
203 }
204 }
205 context.messages.take_all();
207 }
208 });
209
210 Ok(Self {
211 child: RefCell::new(Some(child)),
212 context,
213 })
214 }
215
216 pub async fn shutdown(&self) {
218 if self.context.shutdown_flag.raise() {
219 tokio::select! {
221 _ = self.send_with_acknowledgement(MessageBody::Close) => {}
225 _ = tokio::time::sleep(Duration::from_millis(250)) => {
226 self.kill();
227 }
228 }
229 } else {
230 self.kill();
231 }
232 }
233
234 pub fn kill(&self) {
235 self.context.shutdown_flag.raise();
236 if let Some(mut child) = self.child.borrow_mut().take() {
237 let _ignore = child.kill();
238 }
239 }
240
241 pub async fn register_config(&self, config_id: FormatConfigId, global_config: &GlobalConfiguration, plugin_config: &ConfigKeyMap) -> Result<()> {
242 let global_config = serde_json::to_vec(global_config)?;
243 let plugin_config = serde_json::to_vec(plugin_config)?;
244 self
245 .send_with_acknowledgement(MessageBody::RegisterConfig(RegisterConfigMessageBody {
246 config_id,
247 global_config,
248 plugin_config,
249 }))
250 .await?;
251 Ok(())
252 }
253
254 pub async fn release_config(&self, config_id: FormatConfigId) -> Result<()> {
255 self.send_with_acknowledgement(MessageBody::ReleaseConfig(config_id)).await?;
256 Ok(())
257 }
258
259 pub async fn ask_is_alive(&self) -> bool {
260 self.send_with_acknowledgement(MessageBody::IsAlive).await.is_ok()
261 }
262
263 pub async fn plugin_info(&self) -> Result<PluginInfo> {
264 self.send_receiving_data(MessageBody::GetPluginInfo).await
265 }
266
267 pub async fn license_text(&self) -> Result<String> {
268 self.send_receiving_string(MessageBody::GetLicenseText).await
269 }
270
271 pub async fn resolved_config(&self, config_id: FormatConfigId) -> Result<String> {
272 self.send_receiving_string(MessageBody::GetResolvedConfig(config_id)).await
273 }
274
275 pub async fn file_matching_info(&self, config_id: FormatConfigId) -> Result<FileMatchingInfo> {
276 self.send_receiving_data(MessageBody::GetFileMatchingInfo(config_id)).await
277 }
278
279 pub async fn config_diagnostics(&self, config_id: FormatConfigId) -> Result<Vec<ConfigurationDiagnostic>> {
280 self.send_receiving_data(MessageBody::GetConfigDiagnostics(config_id)).await
281 }
282
283 pub async fn check_config_updates(&self, message: &CheckConfigUpdatesMessageBody) -> Result<Vec<ConfigChange>> {
284 let bytes = serde_json::to_vec(&message)?;
285 let response: CheckConfigUpdatesResponseBody = self.send_receiving_data(MessageBody::CheckConfigUpdates(bytes)).await?;
286 Ok(response.changes)
287 }
288
289 pub async fn format_text(&self, request: ProcessPluginCommunicatorFormatRequest) -> FormatResult {
290 let (tx, rx) = oneshot::channel::<Result<Option<Vec<u8>>>>();
291
292 let message_id = self.context.id_generator.next();
293 let store_guard = self.context.host_format_callbacks.store_with_guard(message_id, request.on_host_format);
294 let maybe_result = self
295 .send_message_with_id(
296 message_id,
297 MessageBody::Format(FormatMessageBody {
298 file_path: request.file_path,
299 file_bytes: request.file_bytes,
300 range: request.range,
301 config_id: request.config_id,
302 override_config: serde_json::to_vec(&request.override_config).unwrap(),
303 }),
304 MessageResponseChannel::Format(tx),
305 rx,
306 request.token.clone(),
307 )
308 .await;
309
310 drop(store_guard); if request.token.is_cancelled() {
313 Ok(None)
314 } else {
315 match maybe_result {
316 Ok(result) => result,
317 Err(err) => Err(CriticalFormatError(err).into()),
318 }
319 }
320 }
321
322 pub async fn is_process_alive(&self) -> bool {
324 if self.context.shutdown_flag.is_raised() {
325 false
326 } else {
327 self.ask_is_alive().await
328 }
329 }
330
331 async fn send_with_acknowledgement(&self, body: MessageBody) -> Result<()> {
332 let (tx, rx) = oneshot::channel::<Result<()>>();
333 self
334 .send_message(body, MessageResponseChannel::Acknowledgement(tx), rx, Arc::new(NullCancellationToken))
335 .await?
336 }
337
338 async fn send_receiving_string(&self, body: MessageBody) -> Result<String> {
339 let data = self.send_receiving_bytes(body).await??;
340 Ok(String::from_utf8(data)?)
341 }
342
343 async fn send_receiving_data<T: DeserializeOwned>(&self, body: MessageBody) -> Result<T> {
344 let data = self.send_receiving_bytes(body).await??;
345 Ok(serde_json::from_slice(&data)?)
346 }
347
348 async fn send_receiving_bytes(&self, body: MessageBody) -> Result<Result<Vec<u8>>> {
349 let (tx, rx) = oneshot::channel::<Result<Vec<u8>>>();
350 self
351 .send_message(body, MessageResponseChannel::Data(tx), rx, Arc::new(NullCancellationToken))
352 .await
353 }
354
355 async fn send_message<T: Default>(
356 &self,
357 body: MessageBody,
358 response_channel: MessageResponseChannel,
359 receiver: oneshot::Receiver<Result<T>>,
360 token: Arc<dyn super::super::CancellationToken>,
361 ) -> Result<Result<T>> {
362 let message_id = self.context.id_generator.next();
363 self.send_message_with_id(message_id, body, response_channel, receiver, token).await
364 }
365
366 async fn send_message_with_id<T: Default>(
367 &self,
368 message_id: u32,
369 body: MessageBody,
370 response_channel: MessageResponseChannel,
371 receiver: oneshot::Receiver<Result<T>>,
372 token: Arc<dyn super::super::CancellationToken>,
373 ) -> Result<Result<T>> {
374 let mut drop_guard = DropGuardAction::new(|| {
375 self.context.messages.take(message_id);
377 let _ = self.context.stdin_writer.send(ProcessPluginMessage {
379 id: self.context.id_generator.next(),
380 body: MessageBody::CancelFormat(message_id),
381 });
382 });
383
384 self.context.messages.store(message_id, response_channel);
385 self.context.stdin_writer.send(ProcessPluginMessage { id: message_id, body })?;
386 tokio::select! {
387 _ = token.wait_cancellation() => {
388 drop(drop_guard); Ok(Ok(Default::default()))
390 }
391 response = receiver => {
392 drop_guard.forget(); match response {
394 Ok(data) => Ok(data),
395 Err(err) => {
396 bail!("Error waiting on message ({}). {:#}", message_id, err)
397 }
398 }
399 }
400 }
401 }
402}
403
404fn get_plugin_schema_version<TRead: Read + Unpin, TWrite: Write + Unpin>(reader: &mut MessageReader<TRead>, writer: &mut MessageWriter<TWrite>) -> Result<u32> {
405 writer.send_u32(0).context("Failed asking for schema version.")?; writer.flush().context("Failed flushing schema version request.")?;
408 let acknowledgement_response = reader.read_u32().context("Could not read success response.")?;
409 if acknowledgement_response != 0 {
410 bail!("Plugin response was unexpected ({acknowledgement_response}).");
411 }
412 reader.read_u32().context("Could not read schema version.")
413}
414
415fn std_err_redirect(shutdown_flag: Arc<AtomicFlag>, stderr: ChildStderr, on_std_err: impl Fn(String) + Send + Sync + 'static) {
416 let reader = std::io::BufReader::new(stderr);
417 for line in reader.lines() {
418 match line {
419 Ok(line) => on_std_err(line),
420 Err(err) => {
421 if shutdown_flag.is_raised() || err.kind() == ErrorKind::BrokenPipe {
422 return;
423 } else {
424 on_std_err(format!("Error reading line from process plugin stderr. {:#}", err));
425 }
426 }
427 }
428 }
429}
430
431fn handle_stdout_message(message: ProcessPluginMessage, context: &Rc<Context>) -> Result<()> {
432 match message.body {
433 MessageBody::Success(message_id) => match context.messages.take(message_id) {
434 Some(MessageResponseChannel::Acknowledgement(channel)) => {
435 let _ignore = channel.send(Ok(()));
436 }
437 Some(MessageResponseChannel::Data(channel)) => {
438 let _ignore = channel.send(Err(anyhow!("Unexpected data channel for success response: {}", message_id)));
439 }
440 Some(MessageResponseChannel::Format(channel)) => {
441 let _ignore = channel.send(Err(anyhow!("Unexpected format channel for success response: {}", message_id)));
442 }
443 None => {}
444 },
445 MessageBody::DataResponse(response) => match context.messages.take(response.message_id) {
446 Some(MessageResponseChannel::Acknowledgement(channel)) => {
447 let _ignore = channel.send(Err(anyhow!("Unexpected success channel for data response: {}", response.message_id)));
448 }
449 Some(MessageResponseChannel::Data(channel)) => {
450 let _ignore = channel.send(Ok(response.data));
451 }
452 Some(MessageResponseChannel::Format(channel)) => {
453 let _ignore = channel.send(Err(anyhow!("Unexpected format channel for data response: {}", response.message_id)));
454 }
455 None => {}
456 },
457 MessageBody::Error(response) => {
458 let err = anyhow!("{}", String::from_utf8_lossy(&response.data));
459 match context.messages.take(response.message_id) {
460 Some(MessageResponseChannel::Acknowledgement(channel)) => {
461 let _ignore = channel.send(Err(err));
462 }
463 Some(MessageResponseChannel::Data(channel)) => {
464 let _ignore = channel.send(Err(err));
465 }
466 Some(MessageResponseChannel::Format(channel)) => {
467 let _ignore = channel.send(Err(err));
468 }
469 None => {}
470 }
471 }
472 MessageBody::FormatResponse(response) => match context.messages.take(response.message_id) {
473 Some(MessageResponseChannel::Acknowledgement(channel)) => {
474 let _ignore = channel.send(Err(anyhow!("Unexpected success channel for format response: {}", response.message_id)));
475 }
476 Some(MessageResponseChannel::Data(channel)) => {
477 let _ignore = channel.send(Err(anyhow!("Unexpected data channel for format response: {}", response.message_id)));
478 }
479 Some(MessageResponseChannel::Format(channel)) => {
480 let _ignore = channel.send(Ok(response.data));
481 }
482 None => {}
483 },
484 MessageBody::CancelFormat(message_id) => {
485 if let Some(token) = context.format_request_tokens.take(message_id) {
486 token.cancel();
487 }
488 context.host_format_callbacks.take(message_id);
489 }
491 MessageBody::HostFormat(body) => {
492 let context = context.clone();
495 crate::async_runtime::spawn(async move {
496 let result = host_format(context.clone(), message.id, body).await;
497
498 let _ignore = context.stdin_writer.send(ProcessPluginMessage {
501 id: context.id_generator.next(),
502 body: match result {
503 Ok(result) => MessageBody::FormatResponse(ResponseBody {
504 message_id: message.id,
505 data: result,
506 }),
507 Err(err) => MessageBody::Error(ResponseBody {
508 message_id: message.id,
509 data: format!("{:#}", err).into_bytes(),
510 }),
511 },
512 });
513 });
514 }
515 MessageBody::IsAlive => {
516 let _ = context.stdin_writer.send(ProcessPluginMessage {
518 id: context.id_generator.next(),
519 body: MessageBody::Success(message.id),
520 });
521 }
522 MessageBody::Format(_)
523 | MessageBody::Close
524 | MessageBody::GetPluginInfo
525 | MessageBody::GetLicenseText
526 | MessageBody::RegisterConfig(_)
527 | MessageBody::ReleaseConfig(_)
528 | MessageBody::GetConfigDiagnostics(_)
529 | MessageBody::GetFileMatchingInfo(_)
530 | MessageBody::GetResolvedConfig(_)
531 | MessageBody::CheckConfigUpdates(_) => {
532 let _ = context.stdin_writer.send(ProcessPluginMessage {
533 id: context.id_generator.next(),
534 body: MessageBody::Error(ResponseBody {
535 message_id: message.id,
536 data: "Unsupported plugin to CLI message.".as_bytes().to_vec(),
537 }),
538 });
539 }
540 MessageBody::Unknown(message_kind) => {
543 bail!("Unknown message kind: {}", message_kind);
544 }
545 }
546
547 Ok(())
548}
549
550async fn host_format(context: Rc<Context>, message_id: u32, body: HostFormatMessageBody) -> FormatResult {
551 let Some(callback) = context.host_format_callbacks.get_cloned(body.original_message_id) else {
552 return FormatResult::Err(anyhow!("Could not find host format callback for message id: {}", body.original_message_id));
553 };
554
555 let token = Arc::new(CancellationToken::new());
556 let store_guard = context.format_request_tokens.store_with_guard(message_id, token.clone());
557 let result = callback(HostFormatRequest {
558 file_path: body.file_path,
559 file_bytes: body.file_text,
560 range: body.range,
561 override_config: serde_json::from_slice(&body.override_config).unwrap(),
562 token,
563 })
564 .await;
565 drop(store_guard); result
567}