dprint_core/plugins/process/
message_processor.rs

1use anyhow::anyhow;
2use anyhow::bail;
3use anyhow::Context;
4use anyhow::Result;
5use serde::Serialize;
6use std::io::Read;
7use std::io::Write;
8use std::rc::Rc;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12use super::context::ProcessContext;
13use super::context::StoredConfig;
14use super::messages::CheckConfigUpdatesMessageBody;
15use super::messages::CheckConfigUpdatesResponseBody;
16use super::messages::HostFormatMessageBody;
17use super::messages::MessageBody;
18use super::messages::ProcessPluginMessage;
19use super::messages::ResponseBody;
20use super::utils::setup_exit_process_panic_hook;
21use super::PLUGIN_SCHEMA_VERSION;
22
23use crate::async_runtime::FutureExt;
24use crate::async_runtime::LocalBoxFuture;
25use crate::communication::MessageReader;
26use crate::communication::MessageWriter;
27use crate::communication::SingleThreadMessageWriter;
28use crate::configuration::ConfigKeyMap;
29use crate::configuration::GlobalConfiguration;
30use crate::plugins::AsyncPluginHandler;
31use crate::plugins::FormatRequest;
32use crate::plugins::FormatResult;
33use crate::plugins::HostFormatRequest;
34
35/// Handles the process' messages based on the provided handler.
36pub async fn handle_process_stdio_messages<THandler: AsyncPluginHandler>(handler: THandler) -> Result<()> {
37  // ensure all process plugins exit on panic on any tokio task
38  setup_exit_process_panic_hook();
39
40  // estabilish the schema
41  let (mut stdin_reader, stdout_writer) = crate::async_runtime::spawn_blocking(move || {
42    let mut stdin_reader = MessageReader::new(std::io::stdin());
43    let mut stdout_writer = MessageWriter::new(std::io::stdout());
44
45    schema_establishment_phase(&mut stdin_reader, &mut stdout_writer).context("Failed estabilishing schema.")?;
46    Ok::<_, anyhow::Error>((stdin_reader, stdout_writer))
47  })
48  .await??;
49
50  // now start reading messages
51  let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<std::io::Result<ProcessPluginMessage>>();
52  crate::async_runtime::spawn_blocking(move || loop {
53    let message_result = ProcessPluginMessage::read(&mut stdin_reader);
54    let is_err = message_result.is_err();
55    if tx.send(message_result).is_err() {
56      return; // disconnected
57    }
58    if is_err {
59      return; // shut down
60    }
61  });
62
63  crate::async_runtime::spawn(async move {
64    let handler = Rc::new(handler);
65    let stdout_message_writer = SingleThreadMessageWriter::for_stdout(stdout_writer);
66    let context: Rc<ProcessContext<THandler::Configuration>> = Rc::new(ProcessContext::new(stdout_message_writer));
67
68    // read messages over stdin
69    loop {
70      let message = match rx.recv().await {
71        Some(message_result) => message_result?,
72        None => return Ok(()), // disconnected
73      };
74
75      match message.body {
76        MessageBody::Close => {
77          handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
78          return Ok(());
79        }
80        MessageBody::IsAlive => {
81          handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
82        }
83        MessageBody::GetPluginInfo => {
84          handle_message(&context, message.id, || {
85            let plugin_info = handler.plugin_info();
86            let data = serde_json::to_vec(&plugin_info)?;
87            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
88          });
89        }
90        MessageBody::GetLicenseText => {
91          handle_message(&context, message.id, || {
92            let data = handler.license_text().into_bytes();
93            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
94          });
95        }
96        MessageBody::RegisterConfig(body) => {
97          handle_async_message(
98            &context,
99            message.id,
100            async {
101              let global_config: GlobalConfiguration = serde_json::from_slice(&body.global_config)?;
102              let config_map: ConfigKeyMap = serde_json::from_slice(&body.plugin_config)?;
103              let result = handler.resolve_config(config_map.clone(), global_config.clone()).await;
104              context.configs.store(
105                body.config_id.as_raw(),
106                Rc::new(StoredConfig {
107                  config: Arc::new(result.config),
108                  file_matching: result.file_matching,
109                  diagnostics: Rc::new(result.diagnostics),
110                  config_map,
111                  global_config,
112                }),
113              );
114              Ok(MessageBody::Success(message.id))
115            }
116            .boxed_local(),
117          )
118          .await;
119        }
120        MessageBody::ReleaseConfig(config_id) => {
121          handle_message(&context, message.id, || {
122            context.configs.take(config_id.as_raw());
123            Ok(MessageBody::Success(message.id))
124          });
125        }
126        MessageBody::GetConfigDiagnostics(config_id) => {
127          handle_message(&context, message.id, || {
128            let diagnostics = context
129              .configs
130              .get_cloned(config_id.as_raw())
131              .map(|c| c.diagnostics.clone())
132              .unwrap_or_default();
133            let data = serde_json::to_vec(&*diagnostics)?;
134            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
135          });
136        }
137        MessageBody::GetFileMatchingInfo(config_id) => {
138          handle_message(&context, message.id, || {
139            let data = match context.configs.get_cloned(config_id.as_raw()) {
140              Some(config) => serde_json::to_vec(&config.file_matching)?,
141              None => bail!("Did not find configuration for id: {}", config_id),
142            };
143            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
144          });
145        }
146        MessageBody::GetResolvedConfig(config_id) => {
147          handle_message(&context, message.id, || {
148            let data = match context.configs.get_cloned(config_id.as_raw()) {
149              Some(config) => serde_json::to_vec(&*config.config)?,
150              None => bail!("Did not find configuration for id: {}", config_id),
151            };
152            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
153          });
154        }
155        MessageBody::CheckConfigUpdates(body_bytes) => {
156          handle_async_message(
157            &context,
158            message.id,
159            async {
160              let message_body = serde_json::from_slice::<CheckConfigUpdatesMessageBody>(&body_bytes)
161                .with_context(|| "Could not deserialize the check config updates message body.".to_string())?;
162              let changes = handler.check_config_updates(message_body).await?;
163              let response = CheckConfigUpdatesResponseBody { changes };
164              let data = serde_json::to_vec(&response)?;
165              Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
166            }
167            .boxed_local(),
168          )
169          .await;
170        }
171        MessageBody::Format(body) => {
172          // now parse
173          let token = Arc::new(CancellationToken::new());
174          let request = FormatRequest {
175            file_path: body.file_path,
176            range: body.range,
177            config_id: body.config_id,
178            config: match context.configs.get_cloned(body.config_id.as_raw()) {
179              Some(config) => {
180                if body.override_config.is_empty() {
181                  config.config.clone()
182                } else {
183                  let mut config_map = config.config_map.clone();
184                  let override_config_map: ConfigKeyMap = serde_json::from_slice(&body.override_config)?;
185                  for (key, value) in override_config_map {
186                    config_map.insert(key, value);
187                  }
188                  let result = handler.resolve_config(config_map, config.global_config.clone()).await;
189                  Arc::new(result.config)
190                }
191              }
192              None => {
193                send_error_response(&context, message.id, anyhow!("Did not find configuration for id: {}", body.config_id));
194                continue;
195              }
196            },
197            file_bytes: body.file_bytes,
198            token: token.clone(),
199          };
200
201          // start the task
202          let context = context.clone();
203          let handler = handler.clone();
204          let token_storage_guard = context.cancellation_tokens.store_with_owned_guard(message.id, token.clone());
205          crate::async_runtime::spawn(async move {
206            let original_message_id = message.id;
207            let result = handler
208              .format(request, {
209                let context = context.clone();
210                move |request| host_format(&context, original_message_id, request)
211              })
212              .await;
213            drop(token_storage_guard);
214            if !token.is_cancelled() {
215              let body = match result {
216                Ok(text) => MessageBody::FormatResponse(ResponseBody {
217                  message_id: message.id,
218                  data: text,
219                }),
220                Err(err) => MessageBody::Error(ResponseBody {
221                  message_id: message.id,
222                  data: format!("{:#}", err).into_bytes(),
223                }),
224              };
225              send_response_body(&context, body)
226            }
227          });
228        }
229        MessageBody::CancelFormat(message_id) => {
230          if let Some(token) = context.cancellation_tokens.take(message_id) {
231            token.cancel();
232          }
233        }
234        MessageBody::Error(body) => {
235          let text = String::from_utf8_lossy(&body.data);
236          if let Some(sender) = context.format_host_senders.take(body.message_id) {
237            sender.send(Err(anyhow!("{}", text))).unwrap();
238          } else {
239            #[allow(clippy::print_stderr)]
240            {
241              eprintln!("Received error from CLI. {}", text);
242            }
243          }
244        }
245        MessageBody::FormatResponse(body) => {
246          if let Some(sender) = context.format_host_senders.take(body.message_id) {
247            sender.send(Ok(body.data)).unwrap();
248          }
249        }
250        MessageBody::Success(_) | MessageBody::DataResponse(_) => {
251          // ignore
252        }
253        MessageBody::HostFormat(_) => {
254          send_error_response(&context, message.id, anyhow!("Cannot host format with a plugin."));
255        }
256        MessageBody::Unknown(message_kind) => panic!("Received unknown message kind: {}", message_kind),
257      }
258    }
259  })
260  .await
261  .unwrap()
262}
263
264fn host_format<TConfiguration: Serialize + Clone + Send + Sync>(
265  context: &ProcessContext<TConfiguration>,
266  original_message_id: u32,
267  request: HostFormatRequest,
268) -> LocalBoxFuture<'static, FormatResult> {
269  let (tx, rx) = tokio::sync::oneshot::channel::<FormatResult>();
270  let id = context.id_generator.next();
271  context.format_host_senders.store(id, tx);
272
273  context
274    .stdout_writer
275    .send(ProcessPluginMessage {
276      id,
277      body: MessageBody::HostFormat(HostFormatMessageBody {
278        original_message_id,
279        file_path: request.file_path,
280        file_text: request.file_bytes,
281        range: request.range,
282        override_config: serde_json::to_vec(&request.override_config).unwrap(),
283      }),
284    })
285    .unwrap_or_else(|err| panic!("Error sending host format response: {:#}", err));
286
287  let token = request.token;
288  let stdout_writer = context.stdout_writer.clone();
289  let id_generator = context.id_generator.clone();
290  let original_message_id = id;
291
292  async move {
293    tokio::select! {
294      _ = token.wait_cancellation() => {
295        // send a cancellation to the host
296        stdout_writer.send(ProcessPluginMessage {
297          id: id_generator.next(),
298          body: MessageBody::CancelFormat(original_message_id),
299        }).unwrap_or_else(|err| panic!("Error sending host format cancellation: {:#}", err));
300
301        // return no change
302        Ok(None)
303      }
304      value = rx => {
305        match value {
306          Ok(Ok(Some(value))) => Ok(Some(value)),
307          Ok(Ok(None)) => Ok(None),
308          Ok(Err(err)) => Err(err),
309          // means the rx was closed, so just ignore
310          Err(err) => Err(err.into()),
311        }
312      }
313    }
314  }
315  .boxed_local()
316}
317
318fn handle_message<TConfiguration: Serialize + Clone + Send + Sync>(
319  context: &ProcessContext<TConfiguration>,
320  original_message_id: u32,
321  action: impl FnOnce() -> Result<MessageBody>,
322) {
323  match action() {
324    Ok(body) => send_response_body(context, body),
325    Err(err) => send_error_response(context, original_message_id, err),
326  };
327}
328
329async fn handle_async_message<'a, TConfiguration: Serialize + Clone + Send + Sync>(
330  context: &ProcessContext<TConfiguration>,
331  original_message_id: u32,
332  action: LocalBoxFuture<'a, Result<MessageBody>>,
333) {
334  match action.await {
335    Ok(body) => send_response_body(context, body),
336    Err(err) => send_error_response(context, original_message_id, err),
337  };
338}
339
340fn send_error_response<TConfiguration: Serialize + Clone + Send + Sync>(
341  context: &ProcessContext<TConfiguration>,
342  original_message_id: u32,
343  err: anyhow::Error,
344) {
345  let body = MessageBody::Error(ResponseBody {
346    message_id: original_message_id,
347    data: format!("{:#}", err).into_bytes(),
348  });
349  send_response_body(context, body)
350}
351
352fn send_response_body<TConfiguration: Serialize + Clone + Send + Sync>(context: &ProcessContext<TConfiguration>, body: MessageBody) {
353  let message = ProcessPluginMessage {
354    id: context.id_generator.next(),
355    body,
356  };
357  if let Err(err) = context.stdout_writer.send(message) {
358    panic!("Receiver dropped. {:#}", err);
359  }
360}
361
362/// For backwards compatibility asking for the schema version.
363fn schema_establishment_phase<TRead: Read + Unpin, TWrite: Write + Unpin>(stdin: &mut MessageReader<TRead>, stdout: &mut MessageWriter<TWrite>) -> Result<()> {
364  // 1. An initial `0` (4 bytes) is sent asking for the schema version.
365  if stdin.read_u32()? != 0 {
366    bail!("Expected a schema version request of `0`.");
367  }
368
369  // 2. The client responds with `0` (4 bytes) for success
370  stdout.send_u32(0)?;
371  // 3. Then 4 bytes for the schema version
372  stdout.send_u32(PLUGIN_SCHEMA_VERSION)?;
373  stdout.flush()?;
374
375  Ok(())
376}