dprint_core/plugins/process/
message_processor.rs1use serde::Serialize;
2use std::io::Read;
3use std::io::Write;
4use std::rc::Rc;
5use std::sync::Arc;
6use tokio_util::sync::CancellationToken;
7
8use super::PLUGIN_SCHEMA_VERSION;
9use super::context::ProcessContext;
10use super::context::StoredConfig;
11use super::messages::CheckConfigUpdatesMessageBody;
12use super::messages::CheckConfigUpdatesResponseBody;
13use super::messages::HostFormatMessageBody;
14use super::messages::MessageBody;
15use super::messages::ProcessPluginMessage;
16use super::messages::ResponseBody;
17use super::utils::setup_exit_process_panic_hook;
18
19use crate::async_runtime::FutureExt;
20use crate::async_runtime::LocalBoxFuture;
21use crate::communication::MessageReader;
22use crate::communication::MessageWriter;
23use crate::communication::SingleThreadMessageWriter;
24use crate::configuration::ConfigKeyMap;
25use crate::configuration::GlobalConfiguration;
26use crate::plugins::AsyncPluginHandler;
27use crate::plugins::FormatConfigId;
28use crate::plugins::FormatError;
29use crate::plugins::FormatRequest;
30use crate::plugins::FormatResult;
31use crate::plugins::HostFormatRequest;
32use crate::plugins::error_to_string;
33
34type Result<T> = std::result::Result<T, FormatError>;
35
36#[derive(Debug, thiserror::Error)]
38enum SchemaEstablishmentError {
39 #[error("Expected a schema version request of `0`.")]
40 UnexpectedRequest,
41 #[error(transparent)]
42 Io(#[from] std::io::Error),
43}
44
45#[derive(Debug, thiserror::Error)]
47enum MessageProcessorError {
48 #[error("Failed estabilishing schema.")]
49 SchemaEstablishment(#[from] SchemaEstablishmentError),
50 #[error("Did not find configuration for id: {0}")]
51 ConfigNotFound(FormatConfigId),
52 #[error("Could not deserialize the check config updates message body.")]
53 DeserializeCheckConfigUpdates(#[source] serde_json::Error),
54 #[error("Cannot host format with a plugin.")]
55 CannotHostFormat,
56}
57
58impl From<MessageProcessorError> for FormatError {
59 fn from(err: MessageProcessorError) -> Self {
60 FormatError::new(err)
61 }
62}
63
64pub async fn handle_process_stdio_messages<THandler: AsyncPluginHandler>(handler: THandler) -> Result<()> {
66 setup_exit_process_panic_hook();
68
69 let (mut stdin_reader, stdout_writer) = crate::async_runtime::spawn_blocking(move || {
71 let mut stdin_reader = MessageReader::new(std::io::stdin());
72 let mut stdout_writer = MessageWriter::new(std::io::stdout());
73
74 schema_establishment_phase(&mut stdin_reader, &mut stdout_writer).map_err(MessageProcessorError::SchemaEstablishment)?;
75 Ok::<_, FormatError>((stdin_reader, stdout_writer))
76 })
77 .await??;
78
79 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<std::io::Result<ProcessPluginMessage>>();
81 crate::async_runtime::spawn_blocking(move || {
82 loop {
83 let message_result = ProcessPluginMessage::read(&mut stdin_reader);
84 let is_err = message_result.is_err();
85 if tx.send(message_result).is_err() {
86 return; }
88 if is_err {
89 return; }
91 }
92 });
93
94 crate::async_runtime::spawn(async move {
95 let handler = Rc::new(handler);
96 let stdout_message_writer = SingleThreadMessageWriter::for_stdout(stdout_writer);
97 let context: Rc<ProcessContext<THandler::Configuration>> = Rc::new(ProcessContext::new(stdout_message_writer));
98
99 loop {
101 let message = match rx.recv().await {
102 Some(message_result) => message_result?,
103 None => return Ok(()), };
105
106 match message.body {
107 MessageBody::Close => {
108 handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
109 return Ok(());
110 }
111 MessageBody::IsAlive => {
112 handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
113 }
114 MessageBody::GetPluginInfo => {
115 handle_message(&context, message.id, || {
116 let plugin_info = handler.plugin_info();
117 let data = serde_json::to_vec(&plugin_info)?;
118 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
119 });
120 }
121 MessageBody::GetLicenseText => {
122 handle_message(&context, message.id, || {
123 let data = handler.license_text().into_bytes();
124 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
125 });
126 }
127 MessageBody::RegisterConfig(body) => {
128 handle_async_message(
129 &context,
130 message.id,
131 async {
132 let global_config: GlobalConfiguration = serde_json::from_slice(&body.global_config)?;
133 let config_map: ConfigKeyMap = serde_json::from_slice(&body.plugin_config)?;
134 let result = handler.resolve_config(config_map.clone(), global_config.clone()).await;
135 context.configs.store(
136 body.config_id.as_raw(),
137 Rc::new(StoredConfig {
138 config: Arc::new(result.config),
139 file_matching: result.file_matching,
140 diagnostics: Rc::new(result.diagnostics),
141 config_map,
142 global_config,
143 }),
144 );
145 Ok(MessageBody::Success(message.id))
146 }
147 .boxed_local(),
148 )
149 .await;
150 }
151 MessageBody::ReleaseConfig(config_id) => {
152 handle_message(&context, message.id, || {
153 context.configs.take(config_id.as_raw());
154 Ok(MessageBody::Success(message.id))
155 });
156 }
157 MessageBody::GetConfigDiagnostics(config_id) => {
158 handle_message(&context, message.id, || {
159 let diagnostics = context
160 .configs
161 .get_cloned(config_id.as_raw())
162 .map(|c| c.diagnostics.clone())
163 .unwrap_or_default();
164 let data = serde_json::to_vec(&*diagnostics)?;
165 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
166 });
167 }
168 MessageBody::GetFileMatchingInfo(config_id) => {
169 handle_message(&context, message.id, || {
170 let data = match context.configs.get_cloned(config_id.as_raw()) {
171 Some(config) => serde_json::to_vec(&config.file_matching)?,
172 None => return Err(MessageProcessorError::ConfigNotFound(config_id).into()),
173 };
174 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
175 });
176 }
177 MessageBody::GetResolvedConfig(config_id) => {
178 handle_message(&context, message.id, || {
179 let data = match context.configs.get_cloned(config_id.as_raw()) {
180 Some(config) => serde_json::to_vec(&*config.config)?,
181 None => return Err(MessageProcessorError::ConfigNotFound(config_id).into()),
182 };
183 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
184 });
185 }
186 MessageBody::CheckConfigUpdates(body_bytes) => {
187 handle_async_message(
188 &context,
189 message.id,
190 async {
191 let message_body =
192 serde_json::from_slice::<CheckConfigUpdatesMessageBody>(&body_bytes).map_err(MessageProcessorError::DeserializeCheckConfigUpdates)?;
193 let changes = handler.check_config_updates(message_body).await?;
194 let response = CheckConfigUpdatesResponseBody { changes };
195 let data = serde_json::to_vec(&response)?;
196 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
197 }
198 .boxed_local(),
199 )
200 .await;
201 }
202 MessageBody::Format(body) => {
203 let token = Arc::new(CancellationToken::new());
205 let request = FormatRequest {
206 file_path: body.file_path,
207 range: body.range,
208 config_id: body.config_id,
209 config: match context.configs.get_cloned(body.config_id.as_raw()) {
210 Some(config) => {
211 if body.override_config.is_empty() {
212 config.config.clone()
213 } else {
214 let mut config_map = config.config_map.clone();
215 let override_config_map: ConfigKeyMap = serde_json::from_slice(&body.override_config)?;
216 for (key, value) in override_config_map {
217 config_map.insert(key, value);
218 }
219 let result = handler.resolve_config(config_map, config.global_config.clone()).await;
220 Arc::new(result.config)
221 }
222 }
223 None => {
224 send_error_response(&context, message.id, MessageProcessorError::ConfigNotFound(body.config_id).into());
225 continue;
226 }
227 },
228 file_bytes: body.file_bytes,
229 token: token.clone(),
230 };
231
232 let context = context.clone();
234 let handler = handler.clone();
235 let token_storage_guard = context.cancellation_tokens.store_with_owned_guard(message.id, token.clone());
236 crate::async_runtime::spawn(async move {
237 let original_message_id = message.id;
238 let result = handler
239 .format(request, {
240 let context = context.clone();
241 move |request| host_format(&context, original_message_id, request)
242 })
243 .await;
244 drop(token_storage_guard);
245 if !token.is_cancelled() {
246 let body = match result {
247 Ok(text) => MessageBody::FormatResponse(ResponseBody {
248 message_id: message.id,
249 data: text,
250 }),
251 Err(err) => MessageBody::Error(ResponseBody {
252 message_id: message.id,
253 data: error_to_string(&err).into_bytes(),
254 }),
255 };
256 send_response_body(&context, body)
257 }
258 });
259 }
260 MessageBody::CancelFormat(message_id) => {
261 if let Some(token) = context.cancellation_tokens.take(message_id) {
262 token.cancel();
263 }
264 }
265 MessageBody::Error(body) => {
266 let text = String::from_utf8_lossy(&body.data);
267 if let Some(sender) = context.format_host_senders.take(body.message_id) {
268 sender.send(Err(text.into_owned().into())).unwrap();
269 } else {
270 #[allow(clippy::print_stderr)]
271 {
272 eprintln!("Received error from CLI. {}", text);
273 }
274 }
275 }
276 MessageBody::FormatResponse(body) => {
277 if let Some(sender) = context.format_host_senders.take(body.message_id) {
278 sender.send(Ok(body.data)).unwrap();
279 }
280 }
281 MessageBody::Success(_) | MessageBody::DataResponse(_) => {
282 }
284 MessageBody::HostFormat(_) => {
285 send_error_response(&context, message.id, MessageProcessorError::CannotHostFormat.into());
286 }
287 MessageBody::Unknown(message_kind) => panic!("Received unknown message kind: {}", message_kind),
288 }
289 }
290 })
291 .await
292 .unwrap()
293}
294
295fn host_format<TConfiguration: Serialize + Clone + Send + Sync>(
296 context: &ProcessContext<TConfiguration>,
297 original_message_id: u32,
298 request: HostFormatRequest,
299) -> LocalBoxFuture<'static, FormatResult> {
300 let (tx, rx) = tokio::sync::oneshot::channel::<FormatResult>();
301 let id = context.id_generator.next();
302 context.format_host_senders.store(id, tx);
303
304 context
305 .stdout_writer
306 .send(ProcessPluginMessage {
307 id,
308 body: MessageBody::HostFormat(HostFormatMessageBody {
309 original_message_id,
310 file_path: request.file_path,
311 file_text: request.file_bytes,
312 range: request.range,
313 override_config: serde_json::to_vec(&request.override_config).unwrap(),
314 }),
315 })
316 .unwrap_or_else(|err| panic!("Error sending host format response: {:#}", err));
317
318 let token = request.token;
319 let stdout_writer = context.stdout_writer.clone();
320 let id_generator = context.id_generator.clone();
321 let original_message_id = id;
322
323 async move {
324 tokio::select! {
325 _ = token.wait_cancellation() => {
326 stdout_writer.send(ProcessPluginMessage {
328 id: id_generator.next(),
329 body: MessageBody::CancelFormat(original_message_id),
330 }).unwrap_or_else(|err| panic!("Error sending host format cancellation: {:#}", err));
331
332 Ok(None)
334 }
335 value = rx => {
336 match value {
337 Ok(Ok(Some(value))) => Ok(Some(value)),
338 Ok(Ok(None)) => Ok(None),
339 Ok(Err(err)) => Err(err),
340 Err(err) => Err(err.into()),
342 }
343 }
344 }
345 }
346 .boxed_local()
347}
348
349fn handle_message<TConfiguration: Serialize + Clone + Send + Sync>(
350 context: &ProcessContext<TConfiguration>,
351 original_message_id: u32,
352 action: impl FnOnce() -> Result<MessageBody>,
353) {
354 match action() {
355 Ok(body) => send_response_body(context, body),
356 Err(err) => send_error_response(context, original_message_id, err),
357 };
358}
359
360async fn handle_async_message<TConfiguration: Serialize + Clone + Send + Sync>(
361 context: &ProcessContext<TConfiguration>,
362 original_message_id: u32,
363 action: LocalBoxFuture<'_, Result<MessageBody>>,
364) {
365 match action.await {
366 Ok(body) => send_response_body(context, body),
367 Err(err) => send_error_response(context, original_message_id, err),
368 };
369}
370
371fn send_error_response<TConfiguration: Serialize + Clone + Send + Sync>(context: &ProcessContext<TConfiguration>, original_message_id: u32, err: FormatError) {
372 let body = MessageBody::Error(ResponseBody {
373 message_id: original_message_id,
374 data: error_to_string(&err).into_bytes(),
375 });
376 send_response_body(context, body)
377}
378
379fn send_response_body<TConfiguration: Serialize + Clone + Send + Sync>(context: &ProcessContext<TConfiguration>, body: MessageBody) {
380 let message = ProcessPluginMessage {
381 id: context.id_generator.next(),
382 body,
383 };
384 if let Err(err) = context.stdout_writer.send(message) {
385 panic!("Receiver dropped. {:#}", err);
386 }
387}
388
389fn schema_establishment_phase<TRead: Read + Unpin, TWrite: Write + Unpin>(
391 stdin: &mut MessageReader<TRead>,
392 stdout: &mut MessageWriter<TWrite>,
393) -> std::result::Result<(), SchemaEstablishmentError> {
394 if stdin.read_u32()? != 0 {
396 return Err(SchemaEstablishmentError::UnexpectedRequest);
397 }
398
399 stdout.send_u32(0)?;
401 stdout.send_u32(PLUGIN_SCHEMA_VERSION)?;
403 stdout.flush()?;
404
405 Ok(())
406}