dprint_core/plugins/process/
message_processor.rs1use anyhow::anyhow;
2use anyhow::bail;
3use anyhow::Context;
4use anyhow::Result;
5use serde::Serialize;
6use std::io::Read;
7use std::io::Write;
8use std::rc::Rc;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12use super::context::ProcessContext;
13use super::context::StoredConfig;
14use super::messages::CheckConfigUpdatesMessageBody;
15use super::messages::CheckConfigUpdatesResponseBody;
16use super::messages::HostFormatMessageBody;
17use super::messages::MessageBody;
18use super::messages::ProcessPluginMessage;
19use super::messages::ResponseBody;
20use super::utils::setup_exit_process_panic_hook;
21use super::PLUGIN_SCHEMA_VERSION;
22
23use crate::async_runtime::FutureExt;
24use crate::async_runtime::LocalBoxFuture;
25use crate::communication::MessageReader;
26use crate::communication::MessageWriter;
27use crate::communication::SingleThreadMessageWriter;
28use crate::configuration::ConfigKeyMap;
29use crate::configuration::GlobalConfiguration;
30use crate::plugins::AsyncPluginHandler;
31use crate::plugins::FormatRequest;
32use crate::plugins::FormatResult;
33use crate::plugins::HostFormatRequest;
34
35pub async fn handle_process_stdio_messages<THandler: AsyncPluginHandler>(handler: THandler) -> Result<()> {
37 setup_exit_process_panic_hook();
39
40 let (mut stdin_reader, stdout_writer) = crate::async_runtime::spawn_blocking(move || {
42 let mut stdin_reader = MessageReader::new(std::io::stdin());
43 let mut stdout_writer = MessageWriter::new(std::io::stdout());
44
45 schema_establishment_phase(&mut stdin_reader, &mut stdout_writer).context("Failed estabilishing schema.")?;
46 Ok::<_, anyhow::Error>((stdin_reader, stdout_writer))
47 })
48 .await??;
49
50 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<std::io::Result<ProcessPluginMessage>>();
52 crate::async_runtime::spawn_blocking(move || loop {
53 let message_result = ProcessPluginMessage::read(&mut stdin_reader);
54 let is_err = message_result.is_err();
55 if tx.send(message_result).is_err() {
56 return; }
58 if is_err {
59 return; }
61 });
62
63 crate::async_runtime::spawn(async move {
64 let handler = Rc::new(handler);
65 let stdout_message_writer = SingleThreadMessageWriter::for_stdout(stdout_writer);
66 let context: Rc<ProcessContext<THandler::Configuration>> = Rc::new(ProcessContext::new(stdout_message_writer));
67
68 loop {
70 let message = match rx.recv().await {
71 Some(message_result) => message_result?,
72 None => return Ok(()), };
74
75 match message.body {
76 MessageBody::Close => {
77 handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
78 return Ok(());
79 }
80 MessageBody::IsAlive => {
81 handle_message(&context, message.id, || Ok(MessageBody::Success(message.id)));
82 }
83 MessageBody::GetPluginInfo => {
84 handle_message(&context, message.id, || {
85 let plugin_info = handler.plugin_info();
86 let data = serde_json::to_vec(&plugin_info)?;
87 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
88 });
89 }
90 MessageBody::GetLicenseText => {
91 handle_message(&context, message.id, || {
92 let data = handler.license_text().into_bytes();
93 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
94 });
95 }
96 MessageBody::RegisterConfig(body) => {
97 handle_async_message(
98 &context,
99 message.id,
100 async {
101 let global_config: GlobalConfiguration = serde_json::from_slice(&body.global_config)?;
102 let config_map: ConfigKeyMap = serde_json::from_slice(&body.plugin_config)?;
103 let result = handler.resolve_config(config_map.clone(), global_config.clone()).await;
104 context.configs.store(
105 body.config_id.as_raw(),
106 Rc::new(StoredConfig {
107 config: Arc::new(result.config),
108 file_matching: result.file_matching,
109 diagnostics: Rc::new(result.diagnostics),
110 config_map,
111 global_config,
112 }),
113 );
114 Ok(MessageBody::Success(message.id))
115 }
116 .boxed_local(),
117 )
118 .await;
119 }
120 MessageBody::ReleaseConfig(config_id) => {
121 handle_message(&context, message.id, || {
122 context.configs.take(config_id.as_raw());
123 Ok(MessageBody::Success(message.id))
124 });
125 }
126 MessageBody::GetConfigDiagnostics(config_id) => {
127 handle_message(&context, message.id, || {
128 let diagnostics = context
129 .configs
130 .get_cloned(config_id.as_raw())
131 .map(|c| c.diagnostics.clone())
132 .unwrap_or_default();
133 let data = serde_json::to_vec(&*diagnostics)?;
134 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
135 });
136 }
137 MessageBody::GetFileMatchingInfo(config_id) => {
138 handle_message(&context, message.id, || {
139 let data = match context.configs.get_cloned(config_id.as_raw()) {
140 Some(config) => serde_json::to_vec(&config.file_matching)?,
141 None => bail!("Did not find configuration for id: {}", config_id),
142 };
143 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
144 });
145 }
146 MessageBody::GetResolvedConfig(config_id) => {
147 handle_message(&context, message.id, || {
148 let data = match context.configs.get_cloned(config_id.as_raw()) {
149 Some(config) => serde_json::to_vec(&*config.config)?,
150 None => bail!("Did not find configuration for id: {}", config_id),
151 };
152 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
153 });
154 }
155 MessageBody::CheckConfigUpdates(body_bytes) => {
156 handle_async_message(
157 &context,
158 message.id,
159 async {
160 let message_body = serde_json::from_slice::<CheckConfigUpdatesMessageBody>(&body_bytes)
161 .with_context(|| "Could not deserialize the check config updates message body.".to_string())?;
162 let changes = handler.check_config_updates(message_body).await?;
163 let response = CheckConfigUpdatesResponseBody { changes };
164 let data = serde_json::to_vec(&response)?;
165 Ok(MessageBody::DataResponse(ResponseBody { message_id: message.id, data }))
166 }
167 .boxed_local(),
168 )
169 .await;
170 }
171 MessageBody::Format(body) => {
172 let token = Arc::new(CancellationToken::new());
174 let request = FormatRequest {
175 file_path: body.file_path,
176 range: body.range,
177 config_id: body.config_id,
178 config: match context.configs.get_cloned(body.config_id.as_raw()) {
179 Some(config) => {
180 if body.override_config.is_empty() {
181 config.config.clone()
182 } else {
183 let mut config_map = config.config_map.clone();
184 let override_config_map: ConfigKeyMap = serde_json::from_slice(&body.override_config)?;
185 for (key, value) in override_config_map {
186 config_map.insert(key, value);
187 }
188 let result = handler.resolve_config(config_map, config.global_config.clone()).await;
189 Arc::new(result.config)
190 }
191 }
192 None => {
193 send_error_response(&context, message.id, anyhow!("Did not find configuration for id: {}", body.config_id));
194 continue;
195 }
196 },
197 file_bytes: body.file_bytes,
198 token: token.clone(),
199 };
200
201 let context = context.clone();
203 let handler = handler.clone();
204 let token_storage_guard = context.cancellation_tokens.store_with_owned_guard(message.id, token.clone());
205 crate::async_runtime::spawn(async move {
206 let original_message_id = message.id;
207 let result = handler
208 .format(request, {
209 let context = context.clone();
210 move |request| host_format(&context, original_message_id, request)
211 })
212 .await;
213 drop(token_storage_guard);
214 if !token.is_cancelled() {
215 let body = match result {
216 Ok(text) => MessageBody::FormatResponse(ResponseBody {
217 message_id: message.id,
218 data: text,
219 }),
220 Err(err) => MessageBody::Error(ResponseBody {
221 message_id: message.id,
222 data: format!("{:#}", err).into_bytes(),
223 }),
224 };
225 send_response_body(&context, body)
226 }
227 });
228 }
229 MessageBody::CancelFormat(message_id) => {
230 if let Some(token) = context.cancellation_tokens.take(message_id) {
231 token.cancel();
232 }
233 }
234 MessageBody::Error(body) => {
235 let text = String::from_utf8_lossy(&body.data);
236 if let Some(sender) = context.format_host_senders.take(body.message_id) {
237 sender.send(Err(anyhow!("{}", text))).unwrap();
238 } else {
239 #[allow(clippy::print_stderr)]
240 {
241 eprintln!("Received error from CLI. {}", text);
242 }
243 }
244 }
245 MessageBody::FormatResponse(body) => {
246 if let Some(sender) = context.format_host_senders.take(body.message_id) {
247 sender.send(Ok(body.data)).unwrap();
248 }
249 }
250 MessageBody::Success(_) | MessageBody::DataResponse(_) => {
251 }
253 MessageBody::HostFormat(_) => {
254 send_error_response(&context, message.id, anyhow!("Cannot host format with a plugin."));
255 }
256 MessageBody::Unknown(message_kind) => panic!("Received unknown message kind: {}", message_kind),
257 }
258 }
259 })
260 .await
261 .unwrap()
262}
263
264fn host_format<TConfiguration: Serialize + Clone + Send + Sync>(
265 context: &ProcessContext<TConfiguration>,
266 original_message_id: u32,
267 request: HostFormatRequest,
268) -> LocalBoxFuture<'static, FormatResult> {
269 let (tx, rx) = tokio::sync::oneshot::channel::<FormatResult>();
270 let id = context.id_generator.next();
271 context.format_host_senders.store(id, tx);
272
273 context
274 .stdout_writer
275 .send(ProcessPluginMessage {
276 id,
277 body: MessageBody::HostFormat(HostFormatMessageBody {
278 original_message_id,
279 file_path: request.file_path,
280 file_text: request.file_bytes,
281 range: request.range,
282 override_config: serde_json::to_vec(&request.override_config).unwrap(),
283 }),
284 })
285 .unwrap_or_else(|err| panic!("Error sending host format response: {:#}", err));
286
287 let token = request.token;
288 let stdout_writer = context.stdout_writer.clone();
289 let id_generator = context.id_generator.clone();
290 let original_message_id = id;
291
292 async move {
293 tokio::select! {
294 _ = token.wait_cancellation() => {
295 stdout_writer.send(ProcessPluginMessage {
297 id: id_generator.next(),
298 body: MessageBody::CancelFormat(original_message_id),
299 }).unwrap_or_else(|err| panic!("Error sending host format cancellation: {:#}", err));
300
301 Ok(None)
303 }
304 value = rx => {
305 match value {
306 Ok(Ok(Some(value))) => Ok(Some(value)),
307 Ok(Ok(None)) => Ok(None),
308 Ok(Err(err)) => Err(err),
309 Err(err) => Err(err.into()),
311 }
312 }
313 }
314 }
315 .boxed_local()
316}
317
318fn handle_message<TConfiguration: Serialize + Clone + Send + Sync>(
319 context: &ProcessContext<TConfiguration>,
320 original_message_id: u32,
321 action: impl FnOnce() -> Result<MessageBody>,
322) {
323 match action() {
324 Ok(body) => send_response_body(context, body),
325 Err(err) => send_error_response(context, original_message_id, err),
326 };
327}
328
329async fn handle_async_message<'a, TConfiguration: Serialize + Clone + Send + Sync>(
330 context: &ProcessContext<TConfiguration>,
331 original_message_id: u32,
332 action: LocalBoxFuture<'a, Result<MessageBody>>,
333) {
334 match action.await {
335 Ok(body) => send_response_body(context, body),
336 Err(err) => send_error_response(context, original_message_id, err),
337 };
338}
339
340fn send_error_response<TConfiguration: Serialize + Clone + Send + Sync>(
341 context: &ProcessContext<TConfiguration>,
342 original_message_id: u32,
343 err: anyhow::Error,
344) {
345 let body = MessageBody::Error(ResponseBody {
346 message_id: original_message_id,
347 data: format!("{:#}", err).into_bytes(),
348 });
349 send_response_body(context, body)
350}
351
352fn send_response_body<TConfiguration: Serialize + Clone + Send + Sync>(context: &ProcessContext<TConfiguration>, body: MessageBody) {
353 let message = ProcessPluginMessage {
354 id: context.id_generator.next(),
355 body,
356 };
357 if let Err(err) = context.stdout_writer.send(message) {
358 panic!("Receiver dropped. {:#}", err);
359 }
360}
361
362fn schema_establishment_phase<TRead: Read + Unpin, TWrite: Write + Unpin>(stdin: &mut MessageReader<TRead>, stdout: &mut MessageWriter<TWrite>) -> Result<()> {
364 if stdin.read_u32()? != 0 {
366 bail!("Expected a schema version request of `0`.");
367 }
368
369 stdout.send_u32(0)?;
371 stdout.send_u32(PLUGIN_SCHEMA_VERSION)?;
373 stdout.flush()?;
374
375 Ok(())
376}