Skip to main content

nu_plugin_core/interface/stream/
mod.rs

1use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
2use nu_protocol::{ShellError, Span, Value, shell_error::generic::GenericError};
3use std::{
4    collections::{BTreeMap, btree_map},
5    iter::FusedIterator,
6    marker::PhantomData,
7    sync::{Arc, Condvar, Mutex, MutexGuard, Weak, mpsc},
8};
9
10#[cfg(test)]
11mod tests;
12
13/// Receives messages from a stream read from input by a [`StreamManager`].
14///
15/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
16/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
17/// through `Ok(Some)`.
18///
19/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
20/// and all further calls to `next()` return `None`.
21///
22/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
23/// and `TryFrom<StreamData>` to convert it to the correct type.
24///
25/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
26/// it sends [`StreamMessage::Drop`].
27#[derive(Debug)]
28pub struct StreamReader<T, W>
29where
30    W: WriteStreamMessage,
31{
32    id: StreamId,
33    receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
34    writer: W,
35    /// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
36    /// even though we're actually receiving dynamic data.
37    marker: PhantomData<fn() -> T>,
38}
39
40impl<T, W> StreamReader<T, W>
41where
42    T: TryFrom<StreamData, Error = ShellError>,
43    W: WriteStreamMessage,
44{
45    /// Create a new StreamReader from parts
46    fn new(
47        id: StreamId,
48        receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
49        writer: W,
50    ) -> StreamReader<T, W> {
51        StreamReader {
52            id,
53            receiver: Some(receiver),
54            writer,
55            marker: PhantomData,
56        }
57    }
58
59    /// Receive a message from the channel, or return an error if:
60    ///
61    /// * the channel couldn't be received from
62    /// * an error was sent on the channel
63    /// * the message received couldn't be converted to `T`
64    pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
65        let connection_lost = || {
66            ShellError::Generic(GenericError::new_internal(
67                "Stream ended unexpectedly",
68                "connection lost before explicit end of stream",
69            ))
70        };
71
72        if let Some(ref rx) = self.receiver {
73            // Try to receive a message first
74            let msg = match rx.try_recv() {
75                Ok(msg) => msg?,
76                Err(mpsc::TryRecvError::Empty) => {
77                    // The receiver doesn't have any messages waiting for us. It's possible that the
78                    // other side hasn't seen our acknowledgements. Let's flush the writer and then
79                    // wait
80                    self.writer.flush()?;
81                    rx.recv().map_err(|_| connection_lost())??
82                }
83                Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
84            };
85
86            if let Some(data) = msg {
87                // Acknowledge the message
88                self.writer
89                    .write_stream_message(StreamMessage::Ack(self.id))?;
90                // Try to convert it into the correct type
91                Ok(Some(data.try_into()?))
92            } else {
93                // Remove the receiver, so that future recv() calls always return Ok(None)
94                self.receiver = None;
95                Ok(None)
96            }
97        } else {
98            // Closed already
99            Ok(None)
100        }
101    }
102}
103
104impl<T, W> Iterator for StreamReader<T, W>
105where
106    T: FromShellError + TryFrom<StreamData, Error = ShellError>,
107    W: WriteStreamMessage,
108{
109    type Item = T;
110
111    fn next(&mut self) -> Option<T> {
112        // Converting the error to the value here makes the implementation a lot easier
113        match self.recv() {
114            Ok(option) => option,
115            Err(err) => {
116                // Drop the receiver so we don't keep returning errors
117                self.receiver = None;
118                Some(T::from_shell_error(err))
119            }
120        }
121    }
122}
123
124// Guaranteed not to return anything after the end
125impl<T, W> FusedIterator for StreamReader<T, W>
126where
127    T: FromShellError + TryFrom<StreamData, Error = ShellError>,
128    W: WriteStreamMessage,
129{
130}
131
132impl<T, W> Drop for StreamReader<T, W>
133where
134    W: WriteStreamMessage,
135{
136    fn drop(&mut self) {
137        if let Err(err) = self
138            .writer
139            .write_stream_message(StreamMessage::Drop(self.id))
140            .and_then(|_| self.writer.flush())
141        {
142            log::warn!("Failed to send message to drop stream: {err}");
143        }
144    }
145}
146
147/// Values that can contain a `ShellError` to signal an error has occurred.
148pub trait FromShellError {
149    fn from_shell_error(err: ShellError) -> Self;
150}
151
152// For List streams.
153// Note: Span::unknown() is unavoidable here because this is called from Iterator::next(),
154// which has no span context. The ShellError itself carries its own span information.
155impl FromShellError for Value {
156    fn from_shell_error(err: ShellError) -> Self {
157        Value::error(err, Span::unknown())
158    }
159}
160
161// For Raw streams, mostly.
162impl<T> FromShellError for Result<T, ShellError> {
163    fn from_shell_error(err: ShellError) -> Self {
164        Err(err)
165    }
166}
167
168/// Writes messages to a stream, with flow control.
169///
170/// The `signal` contained
171#[derive(Debug)]
172pub struct StreamWriter<W: WriteStreamMessage> {
173    id: StreamId,
174    signal: Arc<StreamWriterSignal>,
175    writer: W,
176    ended: bool,
177}
178
179impl<W> StreamWriter<W>
180where
181    W: WriteStreamMessage,
182{
183    fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
184        StreamWriter {
185            id,
186            signal,
187            writer,
188            ended: false,
189        }
190    }
191
192    /// Check if the stream was dropped from the other end. Recommended to do this before calling
193    /// [`.write()`](Self::write), especially in a loop.
194    pub fn is_dropped(&self) -> Result<bool, ShellError> {
195        self.signal.is_dropped()
196    }
197
198    /// Write a single piece of data to the stream.
199    ///
200    /// Error if something failed with the write, or if [`.end()`](Self::end) was already called
201    /// previously.
202    pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
203        if !self.ended {
204            self.writer
205                .write_stream_message(StreamMessage::Data(self.id, data.into()))?;
206            // Flush after each data message to ensure they do predictably appear on the other side
207            // when they're generated
208            //
209            // TODO: make the buffering configurable, as this is a factor for performance
210            self.writer.flush()?;
211            // This implements flow control, so we don't write too many messages:
212            if !self.signal.notify_sent()? {
213                self.signal.wait_for_drain()
214            } else {
215                Ok(())
216            }
217        } else {
218            Err(ShellError::Generic(
219                GenericError::new_internal(
220                    "Wrote to a stream after it ended",
221                    format!(
222                        "tried to write to stream {} after it was already ended",
223                        self.id
224                    ),
225                )
226                .with_help("this may be a bug in the nu-plugin crate"),
227            ))
228        }
229    }
230
231    /// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
232    /// still call [`.end()`](Self::end).
233    ///
234    /// If the stream is dropped from the other end, the iterator will not be fully consumed, and
235    /// writing will terminate.
236    ///
237    /// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
238    /// the stream from the other side.
239    pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
240    where
241        T: Into<StreamData>,
242    {
243        // Check before starting
244        if self.is_dropped()? {
245            return Ok(false);
246        }
247
248        for item in data {
249            // Check again after each item is consumed from the iterator, just in case the iterator
250            // takes a while to produce a value
251            if self.is_dropped()? {
252                return Ok(false);
253            }
254            self.write(item)?;
255        }
256        Ok(true)
257    }
258
259    /// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
260    /// error.
261    pub fn end(&mut self) -> Result<(), ShellError> {
262        if !self.ended {
263            // Set the flag first so we don't double-report in the Drop
264            self.ended = true;
265            self.writer
266                .write_stream_message(StreamMessage::End(self.id))?;
267            self.writer.flush()
268        } else {
269            Ok(())
270        }
271    }
272}
273
274impl<W> Drop for StreamWriter<W>
275where
276    W: WriteStreamMessage,
277{
278    fn drop(&mut self) {
279        // Make sure we ended the stream
280        if let Err(err) = self.end() {
281            log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
282        }
283    }
284}
285
286/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
287/// A key part of managing stream lifecycle and flow control.
288#[derive(Debug)]
289pub struct StreamWriterSignal {
290    mutex: Mutex<StreamWriterSignalState>,
291    change_cond: Condvar,
292}
293
294#[derive(Debug)]
295pub struct StreamWriterSignalState {
296    /// Stream has been dropped and consumer is no longer interested in any messages.
297    dropped: bool,
298    /// Number of messages that have been sent without acknowledgement.
299    unacknowledged: i32,
300    /// Max number of messages to send before waiting for acknowledgement.
301    high_pressure_mark: i32,
302}
303
304impl StreamWriterSignal {
305    /// Create a new signal.
306    ///
307    /// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
308    /// `notify_acknowledge()` is called by another thread enough times to bring the number of
309    /// unacknowledged sent messages below that threshold.
310    fn new(high_pressure_mark: i32) -> StreamWriterSignal {
311        assert!(high_pressure_mark > 0);
312
313        StreamWriterSignal {
314            mutex: Mutex::new(StreamWriterSignalState {
315                dropped: false,
316                unacknowledged: 0,
317                high_pressure_mark,
318            }),
319            change_cond: Condvar::new(),
320        }
321    }
322
323    fn lock(&self) -> Result<MutexGuard<'_, StreamWriterSignalState>, ShellError> {
324        self.mutex.lock().map_err(|_| ShellError::NushellFailed {
325            msg: "StreamWriterSignal mutex poisoned due to panic".into(),
326        })
327    }
328
329    /// True if the stream was dropped and the consumer is no longer interested in it. Indicates
330    /// that no more messages should be sent, other than `End`.
331    pub fn is_dropped(&self) -> Result<bool, ShellError> {
332        Ok(self.lock()?.dropped)
333    }
334
335    /// Notify the writers that the stream has been dropped, so they can stop writing.
336    pub fn set_dropped(&self) -> Result<(), ShellError> {
337        let mut state = self.lock()?;
338        state.dropped = true;
339        // Unblock the writers so they can terminate
340        self.change_cond.notify_all();
341        Ok(())
342    }
343
344    /// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
345    /// or `Ok(false)` if the high pressure mark has been reached and
346    /// [`.wait_for_drain()`](Self::wait_for_drain) should be called to block.
347    pub fn notify_sent(&self) -> Result<bool, ShellError> {
348        let mut state = self.lock()?;
349        state.unacknowledged =
350            state
351                .unacknowledged
352                .checked_add(1)
353                .ok_or_else(|| ShellError::NushellFailed {
354                    msg: "Overflow in counter: too many unacknowledged messages".into(),
355                })?;
356
357        Ok(state.unacknowledged < state.high_pressure_mark)
358    }
359
360    /// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
361    pub fn wait_for_drain(&self) -> Result<(), ShellError> {
362        let mut state = self.lock()?;
363        while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
364            state = self
365                .change_cond
366                .wait(state)
367                .map_err(|_| ShellError::NushellFailed {
368                    msg: "StreamWriterSignal mutex poisoned due to panic".into(),
369                })?;
370        }
371        Ok(())
372    }
373
374    /// Notify the writers that a message has been acknowledged, so they can continue to write
375    /// if they were waiting.
376    pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
377        let mut state = self.lock()?;
378        state.unacknowledged =
379            state
380                .unacknowledged
381                .checked_sub(1)
382                .ok_or_else(|| ShellError::NushellFailed {
383                    msg: "Underflow in counter: too many message acknowledgements".into(),
384                })?;
385        // Unblock the writer
386        self.change_cond.notify_one();
387        Ok(())
388    }
389}
390
391/// A sink for a [`StreamMessage`]
392pub trait WriteStreamMessage {
393    fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
394    fn flush(&mut self) -> Result<(), ShellError>;
395}
396
397#[derive(Debug, Default)]
398struct StreamManagerState {
399    reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
400    writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
401}
402
403impl StreamManagerState {
404    /// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
405    fn lock(
406        state: &Mutex<StreamManagerState>,
407    ) -> Result<MutexGuard<'_, StreamManagerState>, ShellError> {
408        state.lock().map_err(|_| ShellError::NushellFailed {
409            msg: "StreamManagerState mutex poisoned due to a panic".into(),
410        })
411    }
412}
413
414#[derive(Debug)]
415pub struct StreamManager {
416    state: Arc<Mutex<StreamManagerState>>,
417}
418
419impl StreamManager {
420    /// Create a new StreamManager.
421    pub fn new() -> StreamManager {
422        StreamManager {
423            state: Default::default(),
424        }
425    }
426
427    fn lock(&self) -> Result<MutexGuard<'_, StreamManagerState>, ShellError> {
428        StreamManagerState::lock(&self.state)
429    }
430
431    /// Create a new handle to the StreamManager for registering streams.
432    pub fn get_handle(&self) -> StreamManagerHandle {
433        StreamManagerHandle {
434            state: Arc::downgrade(&self.state),
435        }
436    }
437
438    /// Process a stream message, and update internal state accordingly.
439    pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
440        let mut state = self.lock()?;
441        match message {
442            StreamMessage::Data(id, data) => {
443                if let Some(sender) = state.reading_streams.get(&id) {
444                    // We should ignore the error on send. This just means the reader has dropped,
445                    // but it will have sent a Drop message to the other side, and we will receive
446                    // an End message at which point we can remove the channel.
447                    let _ = sender.send(Ok(Some(data)));
448                    Ok(())
449                } else {
450                    Err(ShellError::PluginFailedToDecode {
451                        msg: format!("received Data for unknown stream {id}"),
452                    })
453                }
454            }
455            StreamMessage::End(id) => {
456                if let Some(sender) = state.reading_streams.remove(&id) {
457                    // We should ignore the error on the send, because the reader might have dropped
458                    // already
459                    let _ = sender.send(Ok(None));
460                    Ok(())
461                } else {
462                    Err(ShellError::PluginFailedToDecode {
463                        msg: format!("received End for unknown stream {id}"),
464                    })
465                }
466            }
467            StreamMessage::Drop(id) => {
468                if let Some(signal) = state.writing_streams.remove(&id)
469                    && let Some(signal) = signal.upgrade()
470                {
471                    // This will wake blocked writers so they can stop writing, so it's ok
472                    signal.set_dropped()?;
473                }
474                // It's possible that the stream has already finished writing and we don't have it
475                // anymore, so we fall through to Ok
476                Ok(())
477            }
478            StreamMessage::Ack(id) => {
479                if let Some(signal) = state.writing_streams.get(&id) {
480                    if let Some(signal) = signal.upgrade() {
481                        // This will wake up a blocked writer
482                        signal.notify_acknowledged()?;
483                    } else {
484                        // We know it doesn't exist, so might as well remove it
485                        state.writing_streams.remove(&id);
486                    }
487                }
488                // It's possible that the stream has already finished writing and we don't have it
489                // anymore, so we fall through to Ok
490                Ok(())
491            }
492        }
493    }
494
495    /// Broadcast an error to all stream readers. This is useful for error propagation.
496    pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
497        let state = self.lock()?;
498        for channel in state.reading_streams.values() {
499            // Ignore send errors.
500            let _ = channel.send(Err(error.clone()));
501        }
502        Ok(())
503    }
504
505    // If the `StreamManager` is dropped, we should let all of the stream writers know that they
506    // won't be able to write anymore. We don't need to do anything about the readers though
507    // because they'll know when the `Sender` is dropped automatically
508    fn drop_all_writers(&self) -> Result<(), ShellError> {
509        let mut state = self.lock()?;
510        let writers = std::mem::take(&mut state.writing_streams);
511        for (_, signal) in writers {
512            if let Some(signal) = signal.upgrade() {
513                // more important that we send to all than handling an error
514                let _ = signal.set_dropped();
515            }
516        }
517        Ok(())
518    }
519}
520
521impl Default for StreamManager {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527impl Drop for StreamManager {
528    fn drop(&mut self) {
529        if let Err(err) = self.drop_all_writers() {
530            log::warn!("error during Drop for StreamManager: {err}")
531        }
532    }
533}
534
535/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
536///
537/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
538/// a [`StreamWriter`].
539#[derive(Debug, Clone)]
540pub struct StreamManagerHandle {
541    state: Weak<Mutex<StreamManagerState>>,
542}
543
544impl StreamManagerHandle {
545    /// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
546    /// first try to upgrade to a strong reference and then lock. This function wraps those two
547    /// operations together, handling errors appropriately.
548    fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
549    where
550        F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
551    {
552        let upgraded = self
553            .state
554            .upgrade()
555            .ok_or_else(|| ShellError::NushellFailed {
556                msg: "StreamManager is no longer alive".into(),
557            })?;
558        let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
559            msg: "StreamManagerState mutex poisoned due to a panic".into(),
560        })?;
561        f(guard)
562    }
563
564    /// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
565    /// on the values received. A [`StreamMessage`] writer is required for writing control messages
566    /// back to the producer.
567    pub fn read_stream<T, W>(
568        &self,
569        id: StreamId,
570        writer: W,
571    ) -> Result<StreamReader<T, W>, ShellError>
572    where
573        T: TryFrom<StreamData, Error = ShellError>,
574        W: WriteStreamMessage,
575    {
576        let (tx, rx) = mpsc::channel();
577        self.with_lock(|mut state| {
578            // Must be exclusive
579            if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
580                e.insert(tx);
581                Ok(())
582            } else {
583                Err(ShellError::Generic(
584                    GenericError::new_internal(
585                        format!("Failed to acquire reader for stream {id}"),
586                        "tried to get a reader for a stream that's already being read",
587                    )
588                    .with_help("this may be a bug in the nu-plugin crate"),
589                ))
590            }
591        })?;
592        Ok(StreamReader::new(id, rx, writer))
593    }
594
595    /// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
596    /// data to the stream.
597    ///
598    /// The `high_pressure_mark` value controls how many messages can be written without receiving
599    /// an acknowledgement before any further attempts to write will wait for the consumer to
600    /// acknowledge them. This prevents overwhelming the reader.
601    pub fn write_stream<W>(
602        &self,
603        id: StreamId,
604        writer: W,
605        high_pressure_mark: i32,
606    ) -> Result<StreamWriter<W>, ShellError>
607    where
608        W: WriteStreamMessage,
609    {
610        let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
611        self.with_lock(|mut state| {
612            // Remove dead writing streams
613            state
614                .writing_streams
615                .retain(|_, signal| signal.strong_count() > 0);
616            // Must be exclusive
617            if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
618                e.insert(Arc::downgrade(&signal));
619                Ok(())
620            } else {
621                Err(ShellError::Generic(
622                    GenericError::new_internal(
623                        format!("Failed to acquire writer for stream {id}"),
624                        "tried to get a writer for a stream that's already being written",
625                    )
626                    .with_help("this may be a bug in the nu-plugin crate"),
627                ))
628            }
629        })?;
630        Ok(StreamWriter::new(id, signal, writer))
631    }
632}