use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
use nu_protocol::{ShellError, Span, Value, shell_error::generic::GenericError};
use std::{
collections::{BTreeMap, btree_map},
iter::FusedIterator,
marker::PhantomData,
sync::{Arc, Condvar, Mutex, MutexGuard, Weak, mpsc},
};
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct StreamReader<T, W>
where
W: WriteStreamMessage,
{
id: StreamId,
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
writer: W,
marker: PhantomData<fn() -> T>,
}
impl<T, W> StreamReader<T, W>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
fn new(
id: StreamId,
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
writer: W,
) -> StreamReader<T, W> {
StreamReader {
id,
receiver: Some(receiver),
writer,
marker: PhantomData,
}
}
pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
let connection_lost = || {
ShellError::Generic(GenericError::new_internal(
"Stream ended unexpectedly",
"connection lost before explicit end of stream",
))
};
if let Some(ref rx) = self.receiver {
let msg = match rx.try_recv() {
Ok(msg) => msg?,
Err(mpsc::TryRecvError::Empty) => {
self.writer.flush()?;
rx.recv().map_err(|_| connection_lost())??
}
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
};
if let Some(data) = msg {
self.writer
.write_stream_message(StreamMessage::Ack(self.id))?;
Ok(Some(data.try_into()?))
} else {
self.receiver = None;
Ok(None)
}
} else {
Ok(None)
}
}
}
impl<T, W> Iterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
type Item = T;
fn next(&mut self) -> Option<T> {
match self.recv() {
Ok(option) => option,
Err(err) => {
self.receiver = None;
Some(T::from_shell_error(err))
}
}
}
}
impl<T, W> FusedIterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
}
impl<T, W> Drop for StreamReader<T, W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
if let Err(err) = self
.writer
.write_stream_message(StreamMessage::Drop(self.id))
.and_then(|_| self.writer.flush())
{
log::warn!("Failed to send message to drop stream: {err}");
}
}
}
pub trait FromShellError {
fn from_shell_error(err: ShellError) -> Self;
}
impl FromShellError for Value {
fn from_shell_error(err: ShellError) -> Self {
Value::error(err, Span::unknown())
}
}
impl<T> FromShellError for Result<T, ShellError> {
fn from_shell_error(err: ShellError) -> Self {
Err(err)
}
}
#[derive(Debug)]
pub struct StreamWriter<W: WriteStreamMessage> {
id: StreamId,
signal: Arc<StreamWriterSignal>,
writer: W,
ended: bool,
}
impl<W> StreamWriter<W>
where
W: WriteStreamMessage,
{
fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
StreamWriter {
id,
signal,
writer,
ended: false,
}
}
pub fn is_dropped(&self) -> Result<bool, ShellError> {
self.signal.is_dropped()
}
pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
if !self.ended {
self.writer
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
self.writer.flush()?;
if !self.signal.notify_sent()? {
self.signal.wait_for_drain()
} else {
Ok(())
}
} else {
Err(ShellError::Generic(
GenericError::new_internal(
"Wrote to a stream after it ended",
format!(
"tried to write to stream {} after it was already ended",
self.id
),
)
.with_help("this may be a bug in the nu-plugin crate"),
))
}
}
pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
where
T: Into<StreamData>,
{
if self.is_dropped()? {
return Ok(false);
}
for item in data {
if self.is_dropped()? {
return Ok(false);
}
self.write(item)?;
}
Ok(true)
}
pub fn end(&mut self) -> Result<(), ShellError> {
if !self.ended {
self.ended = true;
self.writer
.write_stream_message(StreamMessage::End(self.id))?;
self.writer.flush()
} else {
Ok(())
}
}
}
impl<W> Drop for StreamWriter<W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
if let Err(err) = self.end() {
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
}
}
}
#[derive(Debug)]
pub struct StreamWriterSignal {
mutex: Mutex<StreamWriterSignalState>,
change_cond: Condvar,
}
#[derive(Debug)]
pub struct StreamWriterSignalState {
dropped: bool,
unacknowledged: i32,
high_pressure_mark: i32,
}
impl StreamWriterSignal {
fn new(high_pressure_mark: i32) -> StreamWriterSignal {
assert!(high_pressure_mark > 0);
StreamWriterSignal {
mutex: Mutex::new(StreamWriterSignalState {
dropped: false,
unacknowledged: 0,
high_pressure_mark,
}),
change_cond: Condvar::new(),
}
}
fn lock(&self) -> Result<MutexGuard<'_, StreamWriterSignalState>, ShellError> {
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})
}
pub fn is_dropped(&self) -> Result<bool, ShellError> {
Ok(self.lock()?.dropped)
}
pub fn set_dropped(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.dropped = true;
self.change_cond.notify_all();
Ok(())
}
pub fn notify_sent(&self) -> Result<bool, ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_add(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Overflow in counter: too many unacknowledged messages".into(),
})?;
Ok(state.unacknowledged < state.high_pressure_mark)
}
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
state = self
.change_cond
.wait(state)
.map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})?;
}
Ok(())
}
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_sub(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Underflow in counter: too many message acknowledgements".into(),
})?;
self.change_cond.notify_one();
Ok(())
}
}
pub trait WriteStreamMessage {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
fn flush(&mut self) -> Result<(), ShellError>;
}
#[derive(Debug, Default)]
struct StreamManagerState {
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
}
impl StreamManagerState {
fn lock(
state: &Mutex<StreamManagerState>,
) -> Result<MutexGuard<'_, StreamManagerState>, ShellError> {
state.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})
}
}
#[derive(Debug)]
pub struct StreamManager {
state: Arc<Mutex<StreamManagerState>>,
}
impl StreamManager {
pub fn new() -> StreamManager {
StreamManager {
state: Default::default(),
}
}
fn lock(&self) -> Result<MutexGuard<'_, StreamManagerState>, ShellError> {
StreamManagerState::lock(&self.state)
}
pub fn get_handle(&self) -> StreamManagerHandle {
StreamManagerHandle {
state: Arc::downgrade(&self.state),
}
}
pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
let mut state = self.lock()?;
match message {
StreamMessage::Data(id, data) => {
if let Some(sender) = state.reading_streams.get(&id) {
let _ = sender.send(Ok(Some(data)));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received Data for unknown stream {id}"),
})
}
}
StreamMessage::End(id) => {
if let Some(sender) = state.reading_streams.remove(&id) {
let _ = sender.send(Ok(None));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received End for unknown stream {id}"),
})
}
}
StreamMessage::Drop(id) => {
if let Some(signal) = state.writing_streams.remove(&id)
&& let Some(signal) = signal.upgrade()
{
signal.set_dropped()?;
}
Ok(())
}
StreamMessage::Ack(id) => {
if let Some(signal) = state.writing_streams.get(&id) {
if let Some(signal) = signal.upgrade() {
signal.notify_acknowledged()?;
} else {
state.writing_streams.remove(&id);
}
}
Ok(())
}
}
}
pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
let state = self.lock()?;
for channel in state.reading_streams.values() {
let _ = channel.send(Err(error.clone()));
}
Ok(())
}
fn drop_all_writers(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
let writers = std::mem::take(&mut state.writing_streams);
for (_, signal) in writers {
if let Some(signal) = signal.upgrade() {
let _ = signal.set_dropped();
}
}
Ok(())
}
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for StreamManager {
fn drop(&mut self) {
if let Err(err) = self.drop_all_writers() {
log::warn!("error during Drop for StreamManager: {err}")
}
}
}
#[derive(Debug, Clone)]
pub struct StreamManagerHandle {
state: Weak<Mutex<StreamManagerState>>,
}
impl StreamManagerHandle {
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
where
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
{
let upgraded = self
.state
.upgrade()
.ok_or_else(|| ShellError::NushellFailed {
msg: "StreamManager is no longer alive".into(),
})?;
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})?;
f(guard)
}
pub fn read_stream<T, W>(
&self,
id: StreamId,
writer: W,
) -> Result<StreamReader<T, W>, ShellError>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
let (tx, rx) = mpsc::channel();
self.with_lock(|mut state| {
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
e.insert(tx);
Ok(())
} else {
Err(ShellError::Generic(
GenericError::new_internal(
format!("Failed to acquire reader for stream {id}"),
"tried to get a reader for a stream that's already being read",
)
.with_help("this may be a bug in the nu-plugin crate"),
))
}
})?;
Ok(StreamReader::new(id, rx, writer))
}
pub fn write_stream<W>(
&self,
id: StreamId,
writer: W,
high_pressure_mark: i32,
) -> Result<StreamWriter<W>, ShellError>
where
W: WriteStreamMessage,
{
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
self.with_lock(|mut state| {
state
.writing_streams
.retain(|_, signal| signal.strong_count() > 0);
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
e.insert(Arc::downgrade(&signal));
Ok(())
} else {
Err(ShellError::Generic(
GenericError::new_internal(
format!("Failed to acquire writer for stream {id}"),
"tried to get a writer for a stream that's already being written",
)
.with_help("this may be a bug in the nu-plugin crate"),
))
}
})?;
Ok(StreamWriter::new(id, signal, writer))
}
}