Skip to main content

dprint_core/plugins/process/
message_processor.rs

1use serde::Serialize;
2use std::io::Read;
3use std::io::Write;
4use std::rc::Rc;
5use std::sync::Arc;
6use tokio_util::sync::CancellationToken;
7
8use super::PLUGIN_SCHEMA_VERSION;
9use super::context::ProcessContext;
10use super::context::StoredConfig;
11use super::messages::CheckConfigUpdatesMessageBody;
12use super::messages::CheckConfigUpdatesResponseBody;
13use super::messages::HostFormatMessageBody;
14use super::messages::MessageBody;
15use super::messages::ProcessPluginMessage;
16use super::messages::ResponseBody;
17use super::utils::setup_exit_process_panic_hook;
18
19use crate::async_runtime::FutureExt;
20use crate::async_runtime::LocalBoxFuture;
21use crate::communication::MessageReader;
22use crate::communication::MessageWriter;
23use crate::communication::SingleThreadMessageWriter;
24use crate::configuration::ConfigKeyMap;
25use crate::configuration::GlobalConfiguration;
26use crate::plugins::AsyncPluginHandler;
27use crate::plugins::FormatConfigId;
28use crate::plugins::FormatError;
29use crate::plugins::FormatRequest;
30use crate::plugins::FormatResult;
31use crate::plugins::HostFormatRequest;
32use crate::plugins::error_to_string;
33
34type Result<T> = std::result::Result<T, FormatError>;
35
36/// The detailed reason the schema establishment handshake failed.
37#[derive(Debug, thiserror::Error)]
38enum SchemaEstablishmentError {
39  #[error("Expected a schema version request of `0`.")]
40  UnexpectedRequest,
41  #[error(transparent)]
42  Io(#[from] std::io::Error),
43}
44
45/// Errors that can occur while processing messages for a process plugin.
46#[derive(Debug, thiserror::Error)]
47enum MessageProcessorError {
48  #[error("Failed estabilishing schema.")]
49  SchemaEstablishment(#[from] SchemaEstablishmentError),
50  #[error("Did not find configuration for id: {0}")]
51  ConfigNotFound(FormatConfigId),
52  #[error("Could not deserialize the check config updates message body.")]
53  DeserializeCheckConfigUpdates(#[source] serde_json::Error),
54  #[error("Cannot host format with a plugin.")]
55  CannotHostFormat,
56}
57
58impl From<MessageProcessorError> for FormatError {
59  fn from(err: MessageProcessorError) -> Self {
60    FormatError::new(err)
61  }
62}
63
64/// Handles the process' messages based on the provided handler.
65pub async fn handle_process_stdio_messages<THandler: AsyncPluginHandler>(handler: THandler) -> Result<()> {
66  // ensure all process plugins exit on panic on any tokio task
67  setup_exit_process_panic_hook();
68
69  // estabilish the schema
70  let (mut stdin_reader, stdout_writer) = crate::async_runtime::spawn_blocking(move || {
71    let mut stdin_reader = MessageReader::new(std::io::stdin());
72    let mut stdout_writer = MessageWriter::new(std::io::stdout());
73
74    schema_establishment_phase(&mut stdin_reader, &mut stdout_writer).map_err(MessageProcessorError::SchemaEstablishment)?;
75    Ok::<_, FormatError>((stdin_reader, stdout_writer))
76  })
77  .await??;
78
79  // now start reading messages
80  let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<std::io::Result<ProcessPluginMessage>>();
81  crate::async_runtime::spawn_blocking(move || {
82    loop {
83      let message_result = ProcessPluginMessage::read(&mut stdin_reader);
84      let is_err = message_result.is_err();
85      if tx.send(message_result).is_err() {
86        return; // disconnected
87      }
88      if is_err {
89        return; // shut down
90      }
91    }
92  });
93
94  crate::async_runtime::spawn(async move {
95    let handler = Rc::new(handler);
96    let stdout_message_writer = SingleThreadMessageWriter::for_stdout(stdout_writer);
97    let context: Rc<ProcessContext<THandler::Configuration>> = Rc::new(ProcessContext::new(stdout_message_writer));
98
99    // read messages over stdin
100    loop {
101      let message = match rx.recv().await {
102        Some(message_result) => message_result?,
103        None => return Ok(()), // disconnected
104      };
105
106      match message.body {
107        MessageBody::Close => {
108          handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
109          return Ok(());
110        }
111        MessageBody::IsAlive => {
112          handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
113        }
114        MessageBody::GetPluginInfo => {
115          handle_message(&context, message.id, || {
116            let plugin_info = handler.plugin_info();
117            let data = serde_json::to_vec(&plugin_info)?;
118            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
119          });
120        }
121        MessageBody::GetLicenseText => {
122          handle_message(&context, message.id, || {
123            let data = handler.license_text().into_bytes();
124            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
125          });
126        }
127        MessageBody::RegisterConfig(body) => {
128          handle_async_message(
129            &context,
130            message.id,
131            async {
132              let global_config: GlobalConfiguration = serde_json::from_slice(&body.global_config)?;
133              let config_map: ConfigKeyMap = serde_json::from_slice(&body.plugin_config)?;
134              let result = handler.resolve_config(config_map.clone(), global_config.clone()).await;
135              context.configs.store(
136                body.config_id.as_raw(),
137                Rc::new(StoredConfig {
138                  config: Arc::new(result.config),
139                  file_matching: result.file_matching,
140                  diagnostics: Rc::new(result.diagnostics),
141                  config_map,
142                  global_config,
143                }),
144              );
145              Ok(MessageBody::Success(message.id))
146            }
147            .boxed_local(),
148          )
149          .await;
150        }
151        MessageBody::ReleaseConfig(config_id) => {
152          handle_message(&context, message.id, || {
153            context.configs.take(config_id.as_raw());
154            Ok(MessageBody::Success(message.id))
155          });
156        }
157        MessageBody::GetConfigDiagnostics(config_id) => {
158          handle_message(&context, message.id, || {
159            let diagnostics = context
160              .configs
161              .get_cloned(config_id.as_raw())
162              .map(|c| c.diagnostics.clone())
163              .unwrap_or_default();
164            let data = serde_json::to_vec(&*diagnostics)?;
165            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
166          });
167        }
168        MessageBody::GetFileMatchingInfo(config_id) => {
169          handle_message(&context, message.id, || {
170            let data = match context.configs.get_cloned(config_id.as_raw()) {
171              Some(config) => serde_json::to_vec(&config.file_matching)?,
172              None => return Err(MessageProcessorError::ConfigNotFound(config_id).into()),
173            };
174            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
175          });
176        }
177        MessageBody::GetResolvedConfig(config_id) => {
178          handle_message(&context, message.id, || {
179            let data = match context.configs.get_cloned(config_id.as_raw()) {
180              Some(config) => serde_json::to_vec(&*config.config)?,
181              None => return Err(MessageProcessorError::ConfigNotFound(config_id).into()),
182            };
183            Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
184          });
185        }
186        MessageBody::CheckConfigUpdates(body_bytes) => {
187          handle_async_message(
188            &context,
189            message.id,
190            async {
191              let message_body =
192                serde_json::from_slice::<CheckConfigUpdatesMessageBody>(&body_bytes).map_err(MessageProcessorError::DeserializeCheckConfigUpdates)?;
193              let changes = handler.check_config_updates(message_body).await?;
194              let response = CheckConfigUpdatesResponseBody { changes };
195              let data = serde_json::to_vec(&response)?;
196              Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
197            }
198            .boxed_local(),
199          )
200          .await;
201        }
202        MessageBody::Format(body) => {
203          // now parse
204          let token = Arc::new(CancellationToken::new());
205          let request = FormatRequest {
206            file_path: body.file_path,
207            range: body.range,
208            config_id: body.config_id,
209            config: match context.configs.get_cloned(body.config_id.as_raw()) {
210              Some(config) => {
211                if body.override_config.is_empty() {
212                  config.config.clone()
213                } else {
214                  let mut config_map = config.config_map.clone();
215                  let override_config_map: ConfigKeyMap = serde_json::from_slice(&body.override_config)?;
216                  for (key, value) in override_config_map {
217                    config_map.insert(key, value);
218                  }
219                  let result = handler.resolve_config(config_map, config.global_config.clone()).await;
220                  Arc::new(result.config)
221                }
222              }
223              None => {
224                send_error_response(&context, message.id, MessageProcessorError::ConfigNotFound(body.config_id).into());
225                continue;
226              }
227            },
228            file_bytes: body.file_bytes,
229            token: token.clone(),
230          };
231
232          // start the task
233          let context = context.clone();
234          let handler = handler.clone();
235          let token_storage_guard = context.cancellation_tokens.store_with_owned_guard(message.id, token.clone());
236          crate::async_runtime::spawn(async move {
237            let original_message_id = message.id;
238            let result = handler
239              .format(request, {
240                let context = context.clone();
241                move |request| host_format(&context, original_message_id, request)
242              })
243              .await;
244            drop(token_storage_guard);
245            if !token.is_cancelled() {
246              let body = match result {
247                Ok(text) => MessageBody::FormatResponse(ResponseBody {
248                  message_id: message.id,
249                  data: text,
250                }),
251                Err(err) => MessageBody::Error(ResponseBody {
252                  message_id: message.id,
253                  data: error_to_string(&err).into_bytes(),
254                }),
255              };
256              send_response_body(&context, body)
257            }
258          });
259        }
260        MessageBody::CancelFormat(message_id) => {
261          if let Some(token) = context.cancellation_tokens.take(message_id) {
262            token.cancel();
263          }
264        }
265        MessageBody::Error(body) => {
266          let text = String::from_utf8_lossy(&body.data);
267          if let Some(sender) = context.format_host_senders.take(body.message_id) {
268            sender.send(Err(text.into_owned().into())).unwrap();
269          } else {
270            #[allow(clippy::print_stderr)]
271            {
272              eprintln!("Received error from CLI. {}", text);
273            }
274          }
275        }
276        MessageBody::FormatResponse(body) => {
277          if let Some(sender) = context.format_host_senders.take(body.message_id) {
278            sender.send(Ok(body.data)).unwrap();
279          }
280        }
281        MessageBody::Success(_) | MessageBody::DataResponse(_) => {
282          // ignore
283        }
284        MessageBody::HostFormat(_) => {
285          send_error_response(&context, message.id, MessageProcessorError::CannotHostFormat.into());
286        }
287        MessageBody::Unknown(message_kind) => panic!("Received unknown message kind: {}", message_kind),
288      }
289    }
290  })
291  .await
292  .unwrap()
293}
294
295fn host_format<TConfiguration: Serialize + Clone + Send + Sync>(
296  context: &ProcessContext<TConfiguration>,
297  original_message_id: u32,
298  request: HostFormatRequest,
299) -> LocalBoxFuture<'static, FormatResult> {
300  let (tx, rx) = tokio::sync::oneshot::channel::<FormatResult>();
301  let id = context.id_generator.next();
302  context.format_host_senders.store(id, tx);
303
304  context
305    .stdout_writer
306    .send(ProcessPluginMessage {
307      id,
308      body: MessageBody::HostFormat(HostFormatMessageBody {
309        original_message_id,
310        file_path: request.file_path,
311        file_text: request.file_bytes,
312        range: request.range,
313        override_config: serde_json::to_vec(&request.override_config).unwrap(),
314      }),
315    })
316    .unwrap_or_else(|err| panic!("Error sending host format response: {:#}", err));
317
318  let token = request.token;
319  let stdout_writer = context.stdout_writer.clone();
320  let id_generator = context.id_generator.clone();
321  let original_message_id = id;
322
323  async move {
324    tokio::select! {
325      _ = token.wait_cancellation() => {
326        // send a cancellation to the host
327        stdout_writer.send(ProcessPluginMessage {
328          id: id_generator.next(),
329          body: MessageBody::CancelFormat(original_message_id),
330        }).unwrap_or_else(|err| panic!("Error sending host format cancellation: {:#}", err));
331
332        // return no change
333        Ok(None)
334      }
335      value = rx => {
336        match value {
337          Ok(Ok(Some(value))) => Ok(Some(value)),
338          Ok(Ok(None)) => Ok(None),
339          Ok(Err(err)) => Err(err),
340          // means the rx was closed, so just ignore
341          Err(err) => Err(err.into()),
342        }
343      }
344    }
345  }
346  .boxed_local()
347}
348
349fn handle_message<TConfiguration: Serialize + Clone + Send + Sync>(
350  context: &ProcessContext<TConfiguration>,
351  original_message_id: u32,
352  action: impl FnOnce() -> Result<MessageBody>,
353) {
354  match action() {
355    Ok(body) => send_response_body(context, body),
356    Err(err) => send_error_response(context, original_message_id, err),
357  };
358}
359
360async fn handle_async_message<TConfiguration: Serialize + Clone + Send + Sync>(
361  context: &ProcessContext<TConfiguration>,
362  original_message_id: u32,
363  action: LocalBoxFuture<'_, Result<MessageBody>>,
364) {
365  match action.await {
366    Ok(body) => send_response_body(context, body),
367    Err(err) => send_error_response(context, original_message_id, err),
368  };
369}
370
371fn send_error_response<TConfiguration: Serialize + Clone + Send + Sync>(context: &ProcessContext<TConfiguration>, original_message_id: u32, err: FormatError) {
372  let body = MessageBody::Error(ResponseBody {
373    message_id: original_message_id,
374    data: error_to_string(&err).into_bytes(),
375  });
376  send_response_body(context, body)
377}
378
379fn send_response_body<TConfiguration: Serialize + Clone + Send + Sync>(context: &ProcessContext<TConfiguration>, body: MessageBody) {
380  let message = ProcessPluginMessage {
381    id: context.id_generator.next(),
382    body,
383  };
384  if let Err(err) = context.stdout_writer.send(message) {
385    panic!("Receiver dropped. {:#}", err);
386  }
387}
388
389/// For backwards compatibility asking for the schema version.
390fn schema_establishment_phase<TRead: Read + Unpin, TWrite: Write + Unpin>(
391  stdin: &mut MessageReader<TRead>,
392  stdout: &mut MessageWriter<TWrite>,
393) -> std::result::Result<(), SchemaEstablishmentError> {
394  // 1. An initial `0` (4 bytes) is sent asking for the schema version.
395  if stdin.read_u32()? != 0 {
396    return Err(SchemaEstablishmentError::UnexpectedRequest);
397  }
398
399  // 2. The client responds with `0` (4 bytes) for success
400  stdout.send_u32(0)?;
401  // 3. Then 4 bytes for the schema version
402  stdout.send_u32(PLUGIN_SCHEMA_VERSION)?;
403  stdout.flush()?;
404
405  Ok(())
406}