1use std::{
2 cmp::Ordering,
3 collections::HashMap,
4 env,
5 ffi::OsString,
6 ops::Deref,
7 panic::AssertUnwindSafe,
8 path::Path,
9 sync::mpsc::{self, TrySendError},
10 thread,
11};
12
13use nu_engine::documentation::{FormatterValue, HelpStyle, get_flags_section};
14use nu_plugin_core::{
15 ClientCommunicationIo, CommunicationMode, InterfaceManager, PluginEncoder, PluginRead,
16 PluginWrite,
17};
18use nu_plugin_protocol::{CallInfo, CustomValueOp, PluginCustomValue, PluginInput, PluginOutput};
19use nu_protocol::{
20 CustomValue, IntoSpanned, LabeledError, PipelineData, PluginMetadata, ShellError, Span,
21 Spanned, Value, ast::Operator, casing::Casing,
22};
23use thiserror::Error;
24
25use self::{command::render_examples, interface::ReceivedPluginCall};
26
27mod command;
28mod interface;
29
30pub use command::{PluginCommand, SimplePluginCommand, create_plugin_signature};
31pub use interface::{EngineInterface, EngineInterfaceManager};
32
33#[allow(dead_code)]
37pub(crate) const OUTPUT_BUFFER_SIZE: usize = 16384;
38
39pub trait Plugin: Sync {
98 fn version(&self) -> String;
114
115 fn commands(&self) -> Vec<Box<dyn PluginCommand<Plugin = Self>>>;
123
124 fn custom_value_to_base_value(
129 &self,
130 engine: &EngineInterface,
131 custom_value: Spanned<Box<dyn CustomValue>>,
132 ) -> Result<Value, LabeledError> {
133 let _ = engine;
134 custom_value
135 .item
136 .to_base_value(custom_value.span)
137 .map_err(LabeledError::from)
138 }
139
140 fn custom_value_follow_path_int(
145 &self,
146 engine: &EngineInterface,
147 custom_value: Spanned<Box<dyn CustomValue>>,
148 index: Spanned<usize>,
149 optional: bool,
150 ) -> Result<Value, LabeledError> {
151 let _ = engine;
152 custom_value
153 .item
154 .follow_path_int(custom_value.span, index.item, index.span, optional)
155 .map_err(LabeledError::from)
156 }
157
158 fn custom_value_follow_path_string(
163 &self,
164 engine: &EngineInterface,
165 custom_value: Spanned<Box<dyn CustomValue>>,
166 column_name: Spanned<String>,
167 optional: bool,
168 casing: Casing,
169 ) -> Result<Value, LabeledError> {
170 let _ = engine;
171 custom_value
172 .item
173 .follow_path_string(
174 custom_value.span,
175 column_name.item,
176 column_name.span,
177 optional,
178 casing,
179 )
180 .map_err(LabeledError::from)
181 }
182
183 fn custom_value_partial_cmp(
192 &self,
193 engine: &EngineInterface,
194 custom_value: Box<dyn CustomValue>,
195 other_value: Value,
196 ) -> Result<Option<Ordering>, LabeledError> {
197 let _ = engine;
198 Ok(custom_value.partial_cmp(&other_value))
199 }
200
201 fn custom_value_operation(
206 &self,
207 engine: &EngineInterface,
208 left: Spanned<Box<dyn CustomValue>>,
209 operator: Spanned<Operator>,
210 right: Value,
211 ) -> Result<Value, LabeledError> {
212 let _ = engine;
213 left.item
214 .operation(left.span, operator.item, operator.span, &right)
215 .map_err(LabeledError::from)
216 }
217
218 fn custom_value_save(
223 &self,
224 engine: &EngineInterface,
225 value: Spanned<Box<dyn CustomValue>>,
226 path: Spanned<&Path>,
227 save_call_span: Span,
228 ) -> Result<(), LabeledError> {
229 let _ = engine;
230 value
231 .item
232 .save(path, value.span, save_call_span)
233 .map_err(LabeledError::from)
234 }
235
236 fn custom_value_dropped(
249 &self,
250 engine: &EngineInterface,
251 custom_value: Box<dyn CustomValue>,
252 ) -> Result<(), LabeledError> {
253 let _ = (engine, custom_value);
254 Ok(())
255 }
256}
257
258pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) {
277 let args: Vec<OsString> = env::args_os().skip(1).collect();
278
279 let exe = std::env::current_exe().ok();
281
282 let plugin_name: String = exe
283 .as_ref()
284 .and_then(|path| path.file_stem())
285 .map(|stem| stem.to_string_lossy().into_owned())
286 .map(|stem| {
287 stem.strip_prefix("nu_plugin_")
288 .map(|s| s.to_owned())
289 .unwrap_or(stem)
290 })
291 .unwrap_or_else(|| "(unknown)".into());
292
293 if args.is_empty() || args[0] == "-h" || args[0] == "--help" {
294 print_help(plugin, encoder);
295 std::process::exit(0)
296 }
297
298 let mode = if args[0] == "--stdio" && args.len() == 1 {
300 CommunicationMode::Stdio
302 } else if args[0] == "--local-socket" && args.len() == 2 {
303 #[cfg(feature = "local-socket")]
304 {
305 CommunicationMode::LocalSocket((&args[1]).into())
306 }
307 #[cfg(not(feature = "local-socket"))]
308 {
309 eprintln!("{plugin_name}: local socket mode is not supported");
310 std::process::exit(1);
311 }
312 } else {
313 eprintln!(
314 "{}: This plugin must be run from within Nushell. See `plugin add --help` for details \
315 on how to use plugins.",
316 env::current_exe()
317 .map(|path| path.display().to_string())
318 .unwrap_or_else(|_| "plugin".into())
319 );
320 eprintln!(
321 "If you are running from Nushell, this plugin may be incompatible with the \
322 version of nushell you are using."
323 );
324 std::process::exit(1)
325 };
326
327 let encoder_clone = encoder.clone();
328
329 let result = match mode.connect_as_client() {
330 Ok(ClientCommunicationIo::Stdio(stdin, mut stdout)) => {
331 tell_nushell_encoding(&mut stdout, &encoder).expect("failed to tell nushell encoding");
332 serve_plugin_io(
333 plugin,
334 &plugin_name,
335 move || (stdin.lock(), encoder_clone),
336 move || (stdout, encoder),
337 )
338 }
339 #[cfg(feature = "local-socket")]
340 Ok(ClientCommunicationIo::LocalSocket {
341 read_in,
342 mut write_out,
343 }) => {
344 use std::io::{BufReader, BufWriter};
345 use std::sync::Mutex;
346
347 tell_nushell_encoding(&mut write_out, &encoder)
348 .expect("failed to tell nushell encoding");
349
350 let read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, read_in);
351 let write = Mutex::new(BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, write_out));
352 serve_plugin_io(
353 plugin,
354 &plugin_name,
355 move || (read, encoder_clone),
356 move || (write, encoder),
357 )
358 }
359 Err(err) => {
360 eprintln!("{plugin_name}: failed to connect: {err:?}");
361 std::process::exit(1);
362 }
363 };
364
365 match result {
366 Ok(()) => (),
367 Err(ServePluginError::UnreportedError(err)) => {
369 eprintln!("Plugin `{plugin_name}` error: {err}");
370 std::process::exit(1);
371 }
372 Err(_) => std::process::exit(1),
373 }
374}
375
376fn tell_nushell_encoding(
377 writer: &mut impl std::io::Write,
378 encoder: &impl PluginEncoder,
379) -> Result<(), std::io::Error> {
380 let encoding = encoder.name();
385 let length = encoding.len() as u8;
386 let mut encoding_content: Vec<u8> = encoding.as_bytes().to_vec();
387 encoding_content.insert(0, length);
388 writer.write_all(&encoding_content)?;
389 writer.flush()
390}
391
392#[derive(Debug, Error)]
394pub enum ServePluginError {
395 #[error("{0}")]
397 UnreportedError(#[source] ShellError),
398 #[error("{0}")]
400 ReportedError(#[source] ShellError),
401 #[error("{0}")]
403 Incompatible(#[source] ShellError),
404 #[error("{0}")]
406 IOError(#[source] ShellError),
407 #[error("{0}")]
409 ThreadSpawnError(#[source] std::io::Error),
410 #[error("a panic occurred in a plugin thread")]
412 Panicked,
413}
414
415impl From<ShellError> for ServePluginError {
416 fn from(error: ShellError) -> Self {
417 match error {
418 ShellError::Io(_) => ServePluginError::IOError(error),
419 ShellError::PluginFailedToLoad { .. } => ServePluginError::Incompatible(error),
420 _ => ServePluginError::UnreportedError(error),
421 }
422 }
423}
424
425trait TryToReport {
427 type T;
428 fn try_to_report(self, engine: &EngineInterface) -> Result<Self::T, ServePluginError>;
429}
430
431impl<T, E> TryToReport for Result<T, E>
432where
433 E: Into<ServePluginError>,
434{
435 type T = T;
436 fn try_to_report(self, engine: &EngineInterface) -> Result<T, ServePluginError> {
437 self.map_err(|e| match e.into() {
438 ServePluginError::UnreportedError(err) => {
439 if engine.write_response(Err(err.clone())).is_ok() {
440 ServePluginError::ReportedError(err)
441 } else {
442 ServePluginError::UnreportedError(err)
443 }
444 }
445 other => other,
446 })
447 }
448}
449
450#[doc(hidden)]
457pub fn serve_plugin_io<I, O>(
458 plugin: &impl Plugin,
459 plugin_name: &str,
460 input: impl FnOnce() -> I + Send + 'static,
461 output: impl FnOnce() -> O + Send + 'static,
462) -> Result<(), ServePluginError>
463where
464 I: PluginRead<PluginInput> + 'static,
465 O: PluginWrite<PluginOutput> + 'static,
466{
467 let (error_tx, error_rx) = mpsc::channel();
468
469 let mut commands: HashMap<String, _> = HashMap::new();
471
472 for command in plugin.commands() {
473 if let Some(previous) = commands.insert(command.name().into(), command) {
474 eprintln!(
475 "Plugin `{plugin_name}` warning: command `{}` shadowed by another command with the \
476 same name. Check your commands' `name()` methods",
477 previous.name()
478 );
479 }
480 }
481
482 let mut manager = EngineInterfaceManager::new(output());
483 let call_receiver = manager
484 .take_plugin_call_receiver()
485 .expect("take_plugin_call_receiver returned None");
487
488 let interface = manager.get_interface();
490
491 interface.hello()?;
493
494 {
495 let error_tx = error_tx.clone();
497 std::thread::Builder::new()
498 .name("engine interface reader".into())
499 .spawn(move || {
500 if let Err(err) = manager.consume_all(input()) {
502 let _ = error_tx.send(ServePluginError::from(err));
503 }
504 })
505 .map_err(ServePluginError::ThreadSpawnError)?;
506 }
507
508 thread::scope(|scope| {
510 let run = |engine, call_info| {
511 let unwind_result = std::panic::catch_unwind(AssertUnwindSafe(|| {
514 let CallInfo { name, call, input } = call_info;
515 let result = if let Some(command) = commands.get(&name) {
516 command.run(plugin, &engine, &call, input)
517 } else {
518 Err(
519 LabeledError::new(format!("Plugin command not found: `{name}`"))
520 .with_label(
521 format!("plugin `{plugin_name}` doesn't have this command"),
522 call.head,
523 ),
524 )
525 };
526 let write_result = engine
527 .write_response(result)
528 .and_then(|writer| writer.write())
529 .try_to_report(&engine);
530 if let Err(err) = write_result {
531 let _ = error_tx.send(err);
532 }
533 }));
534 if unwind_result.is_err() {
535 std::process::exit(1);
537 }
538 };
539
540 let (run_tx, run_rx) = mpsc::sync_channel(0);
542 thread::Builder::new()
543 .name("plugin runner (primary)".into())
544 .spawn_scoped(scope, move || {
545 for (engine, call) in run_rx {
546 run(engine, call);
547 }
548 })
549 .map_err(ServePluginError::ThreadSpawnError)?;
550
551 for plugin_call in call_receiver {
552 if let Ok(error) = error_rx.try_recv() {
554 return Err(error);
555 }
556
557 match plugin_call {
558 ReceivedPluginCall::Metadata { engine } => {
560 engine
561 .write_metadata(PluginMetadata::new().with_version(plugin.version()))
562 .try_to_report(&engine)?;
563 }
564 ReceivedPluginCall::Signature { engine } => {
566 let sigs = commands
567 .values()
568 .map(|command| create_plugin_signature(command.deref()))
569 .map(|mut sig| {
570 render_examples(plugin, &engine, &mut sig.examples)?;
571 Ok(sig)
572 })
573 .collect::<Result<Vec<_>, ShellError>>()
574 .try_to_report(&engine)?;
575 engine.write_signature(sigs).try_to_report(&engine)?;
576 }
577 ReceivedPluginCall::Run { engine, call } => {
579 match run_tx.try_send((engine, call)) {
581 Ok(()) => (),
582 Err(TrySendError::Full((engine, call)))
584 | Err(TrySendError::Disconnected((engine, call))) => {
585 thread::Builder::new()
586 .name("plugin runner (secondary)".into())
587 .spawn_scoped(scope, move || run(engine, call))
588 .map_err(ServePluginError::ThreadSpawnError)?;
589 }
590 }
591 }
592 ReceivedPluginCall::CustomValueOp {
594 engine,
595 custom_value,
596 op,
597 } => {
598 custom_value_op(plugin, &engine, custom_value, op).try_to_report(&engine)?;
599 }
600 }
601 }
602
603 Ok::<_, ServePluginError>(())
604 })?;
605
606 drop(interface);
608
609 if let Ok(err) = error_rx.try_recv() {
611 Err(err)
612 } else {
613 Ok(())
614 }
615}
616
617fn custom_value_op(
618 plugin: &impl Plugin,
619 engine: &EngineInterface,
620 custom_value: Spanned<PluginCustomValue>,
621 op: CustomValueOp,
622) -> Result<(), ShellError> {
623 let local_value = custom_value
624 .item
625 .deserialize_to_custom_value(custom_value.span)?
626 .into_spanned(custom_value.span);
627 match op {
628 CustomValueOp::ToBaseValue => {
629 let result = plugin
630 .custom_value_to_base_value(engine, local_value)
631 .map(|value| PipelineData::value(value, None));
632 engine
633 .write_response(result)
634 .and_then(|writer| writer.write())
635 }
636 CustomValueOp::FollowPathInt { index, optional } => {
637 let result = plugin
638 .custom_value_follow_path_int(engine, local_value, index, optional)
639 .map(|value| PipelineData::value(value, None));
640 engine
641 .write_response(result)
642 .and_then(|writer| writer.write())
643 }
644 CustomValueOp::FollowPathString {
645 column_name,
646 optional,
647 casing,
648 } => {
649 let result = plugin
650 .custom_value_follow_path_string(engine, local_value, column_name, optional, casing)
651 .map(|value| PipelineData::value(value, None));
652 engine
653 .write_response(result)
654 .and_then(|writer| writer.write())
655 }
656 CustomValueOp::PartialCmp(mut other_value) => {
657 PluginCustomValue::deserialize_custom_values_in(&mut other_value)?;
658 match plugin.custom_value_partial_cmp(engine, local_value.item, other_value) {
659 Ok(ordering) => engine.write_ordering(ordering),
660 Err(err) => engine
661 .write_response(Err(err))
662 .and_then(|writer| writer.write()),
663 }
664 }
665 CustomValueOp::Operation(operator, mut right) => {
666 PluginCustomValue::deserialize_custom_values_in(&mut right)?;
667 let result = plugin
668 .custom_value_operation(engine, local_value, operator, right)
669 .map(|value| PipelineData::value(value, None));
670 engine
671 .write_response(result)
672 .and_then(|writer| writer.write())
673 }
674 CustomValueOp::Save {
675 path,
676 save_call_span,
677 } => {
678 let path = Spanned {
679 item: path.item.as_path(),
680 span: path.span,
681 };
682 let result = plugin.custom_value_save(engine, local_value, path, save_call_span);
683 engine.write_ok(result)
684 }
685 CustomValueOp::Dropped => {
686 let result = plugin
687 .custom_value_dropped(engine, local_value.item)
688 .map(|_| PipelineData::empty());
689 engine
690 .write_response(result)
691 .and_then(|writer| writer.write())
692 }
693 }
694}
695
696fn print_help(plugin: &impl Plugin, encoder: impl PluginEncoder) {
697 use std::fmt::Write;
698
699 println!("Nushell Plugin");
700 println!("Encoder: {}", encoder.name());
701 println!("Version: {}", plugin.version());
702
703 let exe = std::env::current_exe().ok();
705 let plugin_name: String = exe
706 .as_ref()
707 .map(|stem| stem.to_string_lossy().into_owned())
708 .unwrap_or_else(|| "(unknown)".into());
709 println!("Plugin file path: {plugin_name}");
710
711 let mut help = String::new();
712 let help_style = HelpStyle::default();
713
714 plugin.commands().into_iter().for_each(|command| {
715 let signature = command.signature();
716 let res = write!(help, "\nCommand: {}", command.name())
717 .and_then(|_| writeln!(help, "\nDescription:\n > {}", command.description()))
718 .and_then(|_| {
719 if !command.extra_description().is_empty() {
720 writeln!(
721 help,
722 "\nExtra description:\n > {}",
723 command.extra_description()
724 )
725 } else {
726 Ok(())
727 }
728 })
729 .and_then(|_| {
730 let flags = get_flags_section(&signature, &help_style, |v| match v {
731 FormatterValue::DefaultValue(value) => format!("{value:#?}"),
732 FormatterValue::CodeString(text) => text.to_string(),
733 });
734 write!(help, "{flags}")
735 })
736 .and_then(|_| writeln!(help, "\nParameters:"))
737 .and_then(|_| {
738 signature
739 .required_positional
740 .iter()
741 .try_for_each(|positional| {
742 writeln!(
743 help,
744 " {} <{}>: {}",
745 positional.name, positional.shape, positional.desc
746 )
747 })
748 })
749 .and_then(|_| {
750 signature
751 .optional_positional
752 .iter()
753 .try_for_each(|positional| {
754 writeln!(
755 help,
756 " (optional) {} <{}>: {}",
757 positional.name, positional.shape, positional.desc
758 )
759 })
760 })
761 .and_then(|_| {
762 if let Some(rest_positional) = &signature.rest_positional {
763 writeln!(
764 help,
765 " ...{} <{}>: {}",
766 rest_positional.name, rest_positional.shape, rest_positional.desc
767 )
768 } else {
769 Ok(())
770 }
771 })
772 .and_then(|_| writeln!(help, "======================"));
773
774 if res.is_err() {
775 println!("{res:?}")
776 }
777 });
778
779 println!("{help}")
780}