use std::{
collections::{btree_map, BTreeMap},
sync::{mpsc, Arc},
};
use nu_protocol::{
IntoInterruptiblePipelineData, ListStream, PipelineData, PluginSignature, ShellError, Spanned,
Value,
};
use crate::{
plugin::{context::PluginExecutionContext, PluginIdentity},
protocol::{
CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue,
PluginInput, PluginOutput, ProtocolInfo,
},
sequence::Sequence,
};
use super::{
stream::{StreamManager, StreamManagerHandle},
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
};
#[cfg(test)]
mod tests;
#[derive(Debug)]
enum ReceivedPluginCallMessage {
Response(PluginCallResponse<PipelineData>),
Error(ShellError),
}
#[derive(Clone)]
pub(crate) struct Context(Arc<dyn PluginExecutionContext>);
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Context")
}
}
impl std::ops::Deref for Context {
type Target = dyn PluginExecutionContext;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
struct PluginInterfaceState {
identity: Arc<PluginIdentity>,
plugin_call_id_sequence: Sequence,
stream_id_sequence: Sequence,
plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallSubscription)>,
writer: Box<dyn PluginWrite<PluginInput>>,
}
impl std::fmt::Debug for PluginInterfaceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginInterfaceState")
.field("identity", &self.identity)
.field("plugin_call_id_sequence", &self.plugin_call_id_sequence)
.field("stream_id_sequence", &self.stream_id_sequence)
.field(
"plugin_call_subscription_sender",
&self.plugin_call_subscription_sender,
)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
struct PluginCallSubscription {
sender: mpsc::Sender<ReceivedPluginCallMessage>,
context: Option<Context>,
}
#[derive(Debug)]
pub(crate) struct PluginInterfaceManager {
state: Arc<PluginInterfaceState>,
stream_manager: StreamManager,
protocol_info: Option<ProtocolInfo>,
plugin_call_subscriptions: BTreeMap<PluginCallId, PluginCallSubscription>,
plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>,
}
impl PluginInterfaceManager {
pub(crate) fn new(
identity: Arc<PluginIdentity>,
writer: impl PluginWrite<PluginInput> + 'static,
) -> PluginInterfaceManager {
let (subscription_tx, subscription_rx) = mpsc::channel();
PluginInterfaceManager {
state: Arc::new(PluginInterfaceState {
identity,
plugin_call_id_sequence: Sequence::default(),
stream_id_sequence: Sequence::default(),
plugin_call_subscription_sender: subscription_tx,
writer: Box::new(writer),
}),
stream_manager: StreamManager::new(),
protocol_info: None,
plugin_call_subscriptions: BTreeMap::new(),
plugin_call_subscription_receiver: subscription_rx,
}
}
fn receive_plugin_call_subscriptions(&mut self) {
while let Ok((id, subscription)) = self.plugin_call_subscription_receiver.try_recv() {
if let btree_map::Entry::Vacant(e) = self.plugin_call_subscriptions.entry(id) {
e.insert(subscription);
} else {
log::warn!("Duplicate plugin call ID ignored: {id}");
}
}
}
fn get_context(&mut self, id: PluginCallId) -> Result<Option<Context>, ShellError> {
self.receive_plugin_call_subscriptions();
self.plugin_call_subscriptions
.get(&id)
.map(|sub| sub.context.clone())
.ok_or_else(|| ShellError::PluginFailedToDecode {
msg: format!("Unknown plugin call ID: {id}"),
})
}
fn send_plugin_call_response(
&mut self,
id: PluginCallId,
response: PluginCallResponse<PipelineData>,
) -> Result<(), ShellError> {
self.receive_plugin_call_subscriptions();
if let Some(subscription) = self.plugin_call_subscriptions.remove(&id) {
if subscription
.sender
.send(ReceivedPluginCallMessage::Response(response))
.is_err()
{
log::warn!("Received a plugin call response for id={id}, but the caller hung up");
}
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("Unknown plugin call ID: {id}"),
})
}
}
pub(crate) fn is_finished(&self) -> bool {
Arc::strong_count(&self.state) < 2
}
pub(crate) fn consume_all(
&mut self,
mut reader: impl PluginRead<PluginOutput>,
) -> Result<(), ShellError> {
while let Some(msg) = reader.read().transpose() {
if self.is_finished() {
break;
}
if let Err(err) = msg.and_then(|msg| self.consume(msg)) {
let _ = self.stream_manager.broadcast_read_error(err.clone());
self.receive_plugin_call_subscriptions();
for subscription in
std::mem::take(&mut self.plugin_call_subscriptions).into_values()
{
let _ = subscription
.sender
.send(ReceivedPluginCallMessage::Error(err.clone()));
}
return Err(err);
}
}
Ok(())
}
}
impl InterfaceManager for PluginInterfaceManager {
type Interface = PluginInterface;
type Input = PluginOutput;
fn get_interface(&self) -> Self::Interface {
PluginInterface {
state: self.state.clone(),
stream_manager_handle: self.stream_manager.get_handle(),
}
}
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
log::trace!("from plugin: {:?}", input);
match input {
PluginOutput::Hello(info) => {
let local_info = ProtocolInfo::default();
if local_info.is_compatible_with(&info)? {
self.protocol_info = Some(info);
Ok(())
} else {
self.protocol_info = None;
Err(ShellError::PluginFailedToLoad {
msg: format!(
"Plugin is compiled for nushell version {}, \
which is not compatible with version {}",
info.version, local_info.version
),
})
}
}
_ if self.protocol_info.is_none() => {
Err(ShellError::PluginFailedToLoad {
msg: "Failed to receive initial Hello message. \
This plugin might be too old"
.into(),
})
}
PluginOutput::Stream(message) => self.consume_stream_message(message),
PluginOutput::CallResponse(id, response) => {
let response = match response {
PluginCallResponse::Error(err) => PluginCallResponse::Error(err),
PluginCallResponse::Signature(sigs) => PluginCallResponse::Signature(sigs),
PluginCallResponse::PipelineData(data) => {
let exec_context = self.get_context(id)?;
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
match self.read_pipeline_data(data, ctrlc) {
Ok(data) => PluginCallResponse::PipelineData(data),
Err(err) => PluginCallResponse::Error(err.into()),
}
}
};
self.send_plugin_call_response(id, response)
}
}
}
fn stream_manager(&self) -> &StreamManager {
&self.stream_manager
}
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
match data {
PipelineData::Value(ref mut value, _) => {
PluginCustomValue::add_source(value, &self.state.identity);
Ok(data)
}
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => {
let identity = self.state.identity.clone();
Ok(stream
.map(move |mut value| {
PluginCustomValue::add_source(&mut value, &identity);
value
})
.into_pipeline_data_with_metadata(meta, ctrlc))
}
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct PluginInterface {
state: Arc<PluginInterfaceState>,
stream_manager_handle: StreamManagerHandle,
}
impl PluginInterface {
pub(crate) fn hello(&self) -> Result<(), ShellError> {
self.write(PluginInput::Hello(ProtocolInfo::default()))?;
self.flush()
}
pub(crate) fn goodbye(&self) -> Result<(), ShellError> {
self.write(PluginInput::Goodbye)?;
self.flush()
}
fn write_plugin_call(
&self,
call: PluginCall<PipelineData>,
context: Option<Context>,
) -> Result<
(
PipelineDataWriter<Self>,
mpsc::Receiver<ReceivedPluginCallMessage>,
),
ShellError,
> {
let id = self.state.plugin_call_id_sequence.next()?;
let (tx, rx) = mpsc::channel();
let (call, writer) = match call {
PluginCall::Signature => (PluginCall::Signature, Default::default()),
PluginCall::CustomValueOp(value, op) => {
(PluginCall::CustomValueOp(value, op), Default::default())
}
PluginCall::Run(CallInfo {
name,
call,
input,
config,
}) => {
let (header, writer) = self.init_write_pipeline_data(input)?;
(
PluginCall::Run(CallInfo {
name,
call,
input: header,
config,
}),
writer,
)
}
};
self.state
.plugin_call_subscription_sender
.send((
id,
PluginCallSubscription {
sender: tx,
context,
},
))
.map_err(|_| ShellError::NushellFailed {
msg: "PluginInterfaceManager hung up and is no longer accepting plugin calls"
.into(),
})?;
self.write(PluginInput::Call(id, call))?;
self.flush()?;
Ok((writer, rx))
}
fn receive_plugin_call_response(
&self,
rx: mpsc::Receiver<ReceivedPluginCallMessage>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
if let Ok(msg) = rx.recv() {
match msg {
ReceivedPluginCallMessage::Response(resp) => Ok(resp),
ReceivedPluginCallMessage::Error(err) => Err(err),
}
} else {
Err(ShellError::PluginFailedToDecode {
msg: "Failed to receive response to plugin call".into(),
})
}
}
fn plugin_call(
&self,
call: PluginCall<PipelineData>,
context: &Option<Context>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
let (writer, rx) = self.write_plugin_call(call, context.clone())?;
writer.write_background()?;
self.receive_plugin_call_response(rx)
}
pub(crate) fn get_signature(&self) -> Result<Vec<PluginSignature>, ShellError> {
match self.plugin_call(PluginCall::Signature, &None)? {
PluginCallResponse::Signature(sigs) => Ok(sigs),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response to plugin Signature call".into(),
}),
}
}
pub(crate) fn run(
&self,
call: CallInfo<PipelineData>,
context: Arc<impl PluginExecutionContext + 'static>,
) -> Result<PipelineData, ShellError> {
let context = Some(Context(context));
match self.plugin_call(PluginCall::Run(call), &context)? {
PluginCallResponse::PipelineData(data) => Ok(data),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response to plugin Run call".into(),
}),
}
}
pub(crate) fn custom_value_to_base_value(
&self,
value: Spanned<PluginCustomValue>,
) -> Result<Value, ShellError> {
let span = value.span;
let call = PluginCall::CustomValueOp(value, CustomValueOp::ToBaseValue);
match self.plugin_call(call, &None)? {
PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response to plugin CustomValueOp::ToBaseValue call"
.into(),
}),
}
}
}
impl Interface for PluginInterface {
type Output = PluginInput;
fn write(&self, input: PluginInput) -> Result<(), ShellError> {
log::trace!("to plugin: {:?}", input);
self.state.writer.write(&input)
}
fn flush(&self) -> Result<(), ShellError> {
self.state.writer.flush()
}
fn stream_id_sequence(&self) -> &Sequence {
&self.state.stream_id_sequence
}
fn stream_manager_handle(&self) -> &StreamManagerHandle {
&self.stream_manager_handle
}
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
match data {
PipelineData::Value(mut value, meta) => {
PluginCustomValue::verify_source(&mut value, &self.state.identity)?;
Ok(PipelineData::Value(value, meta))
}
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => {
let identity = self.state.identity.clone();
Ok(stream
.map(move |mut value| {
match PluginCustomValue::verify_source(&mut value, &identity) {
Ok(()) => value,
Err(err) => Value::error(err, value.span()),
}
})
.into_pipeline_data_with_metadata(meta, ctrlc))
}
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
}
}
}
impl Drop for PluginInterface {
fn drop(&mut self) {
if Arc::strong_count(&self.state) < 3 {
if let Err(err) = self.goodbye() {
log::warn!("Error during plugin Goodbye: {err}");
}
}
}
}