nu_plugin_core/interface/stream/
mod.rs

1use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
2use nu_protocol::{ShellError, Span, Value};
3use std::{
4    collections::{BTreeMap, btree_map},
5    iter::FusedIterator,
6    marker::PhantomData,
7    sync::{Arc, Condvar, Mutex, MutexGuard, Weak, mpsc},
8};
9
10#[cfg(test)]
11mod tests;
12
13/// Receives messages from a stream read from input by a [`StreamManager`].
14///
15/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
16/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
17/// through `Ok(Some)`.
18///
19/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
20/// and all further calls to `next()` return `None`.
21///
22/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
23/// and `TryFrom<StreamData>` to convert it to the correct type.
24///
25/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
26/// it sends [`StreamMessage::Drop`].
27#[derive(Debug)]
28pub struct StreamReader<T, W>
29where
30    W: WriteStreamMessage,
31{
32    id: StreamId,
33    receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
34    writer: W,
35    /// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
36    /// even though we're actually receiving dynamic data.
37    marker: PhantomData<fn() -> T>,
38}
39
40impl<T, W> StreamReader<T, W>
41where
42    T: TryFrom<StreamData, Error = ShellError>,
43    W: WriteStreamMessage,
44{
45    /// Create a new StreamReader from parts
46    fn new(
47        id: StreamId,
48        receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
49        writer: W,
50    ) -> StreamReader<T, W> {
51        StreamReader {
52            id,
53            receiver: Some(receiver),
54            writer,
55            marker: PhantomData,
56        }
57    }
58
59    /// Receive a message from the channel, or return an error if:
60    ///
61    /// * the channel couldn't be received from
62    /// * an error was sent on the channel
63    /// * the message received couldn't be converted to `T`
64    pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
65        let connection_lost = || ShellError::GenericError {
66            error: "Stream ended unexpectedly".into(),
67            msg: "connection lost before explicit end of stream".into(),
68            span: None,
69            help: None,
70            inner: vec![],
71        };
72
73        if let Some(ref rx) = self.receiver {
74            // Try to receive a message first
75            let msg = match rx.try_recv() {
76                Ok(msg) => msg?,
77                Err(mpsc::TryRecvError::Empty) => {
78                    // The receiver doesn't have any messages waiting for us. It's possible that the
79                    // other side hasn't seen our acknowledgements. Let's flush the writer and then
80                    // wait
81                    self.writer.flush()?;
82                    rx.recv().map_err(|_| connection_lost())??
83                }
84                Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
85            };
86
87            if let Some(data) = msg {
88                // Acknowledge the message
89                self.writer
90                    .write_stream_message(StreamMessage::Ack(self.id))?;
91                // Try to convert it into the correct type
92                Ok(Some(data.try_into()?))
93            } else {
94                // Remove the receiver, so that future recv() calls always return Ok(None)
95                self.receiver = None;
96                Ok(None)
97            }
98        } else {
99            // Closed already
100            Ok(None)
101        }
102    }
103}
104
105impl<T, W> Iterator for StreamReader<T, W>
106where
107    T: FromShellError + TryFrom<StreamData, Error = ShellError>,
108    W: WriteStreamMessage,
109{
110    type Item = T;
111
112    fn next(&mut self) -> Option<T> {
113        // Converting the error to the value here makes the implementation a lot easier
114        match self.recv() {
115            Ok(option) => option,
116            Err(err) => {
117                // Drop the receiver so we don't keep returning errors
118                self.receiver = None;
119                Some(T::from_shell_error(err))
120            }
121        }
122    }
123}
124
125// Guaranteed not to return anything after the end
126impl<T, W> FusedIterator for StreamReader<T, W>
127where
128    T: FromShellError + TryFrom<StreamData, Error = ShellError>,
129    W: WriteStreamMessage,
130{
131}
132
133impl<T, W> Drop for StreamReader<T, W>
134where
135    W: WriteStreamMessage,
136{
137    fn drop(&mut self) {
138        if let Err(err) = self
139            .writer
140            .write_stream_message(StreamMessage::Drop(self.id))
141            .and_then(|_| self.writer.flush())
142        {
143            log::warn!("Failed to send message to drop stream: {err}");
144        }
145    }
146}
147
148/// Values that can contain a `ShellError` to signal an error has occurred.
149pub trait FromShellError {
150    fn from_shell_error(err: ShellError) -> Self;
151}
152
153// For List streams.
154impl FromShellError for Value {
155    fn from_shell_error(err: ShellError) -> Self {
156        Value::error(err, Span::unknown())
157    }
158}
159
160// For Raw streams, mostly.
161impl<T> FromShellError for Result<T, ShellError> {
162    fn from_shell_error(err: ShellError) -> Self {
163        Err(err)
164    }
165}
166
167/// Writes messages to a stream, with flow control.
168///
169/// The `signal` contained
170#[derive(Debug)]
171pub struct StreamWriter<W: WriteStreamMessage> {
172    id: StreamId,
173    signal: Arc<StreamWriterSignal>,
174    writer: W,
175    ended: bool,
176}
177
178impl<W> StreamWriter<W>
179where
180    W: WriteStreamMessage,
181{
182    fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
183        StreamWriter {
184            id,
185            signal,
186            writer,
187            ended: false,
188        }
189    }
190
191    /// Check if the stream was dropped from the other end. Recommended to do this before calling
192    /// [`.write()`](Self::write), especially in a loop.
193    pub fn is_dropped(&self) -> Result<bool, ShellError> {
194        self.signal.is_dropped()
195    }
196
197    /// Write a single piece of data to the stream.
198    ///
199    /// Error if something failed with the write, or if [`.end()`](Self::end) was already called
200    /// previously.
201    pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
202        if !self.ended {
203            self.writer
204                .write_stream_message(StreamMessage::Data(self.id, data.into()))?;
205            // Flush after each data message to ensure they do predictably appear on the other side
206            // when they're generated
207            //
208            // TODO: make the buffering configurable, as this is a factor for performance
209            self.writer.flush()?;
210            // This implements flow control, so we don't write too many messages:
211            if !self.signal.notify_sent()? {
212                self.signal.wait_for_drain()
213            } else {
214                Ok(())
215            }
216        } else {
217            Err(ShellError::GenericError {
218                error: "Wrote to a stream after it ended".into(),
219                msg: format!(
220                    "tried to write to stream {} after it was already ended",
221                    self.id
222                ),
223                span: None,
224                help: Some("this may be a bug in the nu-plugin crate".into()),
225                inner: vec![],
226            })
227        }
228    }
229
230    /// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
231    /// still call [`.end()`](Self::end).
232    ///
233    /// If the stream is dropped from the other end, the iterator will not be fully consumed, and
234    /// writing will terminate.
235    ///
236    /// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
237    /// the stream from the other side.
238    pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
239    where
240        T: Into<StreamData>,
241    {
242        // Check before starting
243        if self.is_dropped()? {
244            return Ok(false);
245        }
246
247        for item in data {
248            // Check again after each item is consumed from the iterator, just in case the iterator
249            // takes a while to produce a value
250            if self.is_dropped()? {
251                return Ok(false);
252            }
253            self.write(item)?;
254        }
255        Ok(true)
256    }
257
258    /// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
259    /// error.
260    pub fn end(&mut self) -> Result<(), ShellError> {
261        if !self.ended {
262            // Set the flag first so we don't double-report in the Drop
263            self.ended = true;
264            self.writer
265                .write_stream_message(StreamMessage::End(self.id))?;
266            self.writer.flush()
267        } else {
268            Ok(())
269        }
270    }
271}
272
273impl<W> Drop for StreamWriter<W>
274where
275    W: WriteStreamMessage,
276{
277    fn drop(&mut self) {
278        // Make sure we ended the stream
279        if let Err(err) = self.end() {
280            log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
281        }
282    }
283}
284
285/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
286/// A key part of managing stream lifecycle and flow control.
287#[derive(Debug)]
288pub struct StreamWriterSignal {
289    mutex: Mutex<StreamWriterSignalState>,
290    change_cond: Condvar,
291}
292
293#[derive(Debug)]
294pub struct StreamWriterSignalState {
295    /// Stream has been dropped and consumer is no longer interested in any messages.
296    dropped: bool,
297    /// Number of messages that have been sent without acknowledgement.
298    unacknowledged: i32,
299    /// Max number of messages to send before waiting for acknowledgement.
300    high_pressure_mark: i32,
301}
302
303impl StreamWriterSignal {
304    /// Create a new signal.
305    ///
306    /// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
307    /// `notify_acknowledge()` is called by another thread enough times to bring the number of
308    /// unacknowledged sent messages below that threshold.
309    fn new(high_pressure_mark: i32) -> StreamWriterSignal {
310        assert!(high_pressure_mark > 0);
311
312        StreamWriterSignal {
313            mutex: Mutex::new(StreamWriterSignalState {
314                dropped: false,
315                unacknowledged: 0,
316                high_pressure_mark,
317            }),
318            change_cond: Condvar::new(),
319        }
320    }
321
322    fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
323        self.mutex.lock().map_err(|_| ShellError::NushellFailed {
324            msg: "StreamWriterSignal mutex poisoned due to panic".into(),
325        })
326    }
327
328    /// True if the stream was dropped and the consumer is no longer interested in it. Indicates
329    /// that no more messages should be sent, other than `End`.
330    pub fn is_dropped(&self) -> Result<bool, ShellError> {
331        Ok(self.lock()?.dropped)
332    }
333
334    /// Notify the writers that the stream has been dropped, so they can stop writing.
335    pub fn set_dropped(&self) -> Result<(), ShellError> {
336        let mut state = self.lock()?;
337        state.dropped = true;
338        // Unblock the writers so they can terminate
339        self.change_cond.notify_all();
340        Ok(())
341    }
342
343    /// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
344    /// or `Ok(false)` if the high pressure mark has been reached and
345    /// [`.wait_for_drain()`](Self::wait_for_drain) should be called to block.
346    pub fn notify_sent(&self) -> Result<bool, ShellError> {
347        let mut state = self.lock()?;
348        state.unacknowledged =
349            state
350                .unacknowledged
351                .checked_add(1)
352                .ok_or_else(|| ShellError::NushellFailed {
353                    msg: "Overflow in counter: too many unacknowledged messages".into(),
354                })?;
355
356        Ok(state.unacknowledged < state.high_pressure_mark)
357    }
358
359    /// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
360    pub fn wait_for_drain(&self) -> Result<(), ShellError> {
361        let mut state = self.lock()?;
362        while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
363            state = self
364                .change_cond
365                .wait(state)
366                .map_err(|_| ShellError::NushellFailed {
367                    msg: "StreamWriterSignal mutex poisoned due to panic".into(),
368                })?;
369        }
370        Ok(())
371    }
372
373    /// Notify the writers that a message has been acknowledged, so they can continue to write
374    /// if they were waiting.
375    pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
376        let mut state = self.lock()?;
377        state.unacknowledged =
378            state
379                .unacknowledged
380                .checked_sub(1)
381                .ok_or_else(|| ShellError::NushellFailed {
382                    msg: "Underflow in counter: too many message acknowledgements".into(),
383                })?;
384        // Unblock the writer
385        self.change_cond.notify_one();
386        Ok(())
387    }
388}
389
390/// A sink for a [`StreamMessage`]
391pub trait WriteStreamMessage {
392    fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
393    fn flush(&mut self) -> Result<(), ShellError>;
394}
395
396#[derive(Debug, Default)]
397struct StreamManagerState {
398    reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
399    writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
400}
401
402impl StreamManagerState {
403    /// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
404    fn lock(
405        state: &Mutex<StreamManagerState>,
406    ) -> Result<MutexGuard<StreamManagerState>, ShellError> {
407        state.lock().map_err(|_| ShellError::NushellFailed {
408            msg: "StreamManagerState mutex poisoned due to a panic".into(),
409        })
410    }
411}
412
413#[derive(Debug)]
414pub struct StreamManager {
415    state: Arc<Mutex<StreamManagerState>>,
416}
417
418impl StreamManager {
419    /// Create a new StreamManager.
420    pub fn new() -> StreamManager {
421        StreamManager {
422            state: Default::default(),
423        }
424    }
425
426    fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
427        StreamManagerState::lock(&self.state)
428    }
429
430    /// Create a new handle to the StreamManager for registering streams.
431    pub fn get_handle(&self) -> StreamManagerHandle {
432        StreamManagerHandle {
433            state: Arc::downgrade(&self.state),
434        }
435    }
436
437    /// Process a stream message, and update internal state accordingly.
438    pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
439        let mut state = self.lock()?;
440        match message {
441            StreamMessage::Data(id, data) => {
442                if let Some(sender) = state.reading_streams.get(&id) {
443                    // We should ignore the error on send. This just means the reader has dropped,
444                    // but it will have sent a Drop message to the other side, and we will receive
445                    // an End message at which point we can remove the channel.
446                    let _ = sender.send(Ok(Some(data)));
447                    Ok(())
448                } else {
449                    Err(ShellError::PluginFailedToDecode {
450                        msg: format!("received Data for unknown stream {id}"),
451                    })
452                }
453            }
454            StreamMessage::End(id) => {
455                if let Some(sender) = state.reading_streams.remove(&id) {
456                    // We should ignore the error on the send, because the reader might have dropped
457                    // already
458                    let _ = sender.send(Ok(None));
459                    Ok(())
460                } else {
461                    Err(ShellError::PluginFailedToDecode {
462                        msg: format!("received End for unknown stream {id}"),
463                    })
464                }
465            }
466            StreamMessage::Drop(id) => {
467                if let Some(signal) = state.writing_streams.remove(&id) {
468                    if let Some(signal) = signal.upgrade() {
469                        // This will wake blocked writers so they can stop writing, so it's ok
470                        signal.set_dropped()?;
471                    }
472                }
473                // It's possible that the stream has already finished writing and we don't have it
474                // anymore, so we fall through to Ok
475                Ok(())
476            }
477            StreamMessage::Ack(id) => {
478                if let Some(signal) = state.writing_streams.get(&id) {
479                    if let Some(signal) = signal.upgrade() {
480                        // This will wake up a blocked writer
481                        signal.notify_acknowledged()?;
482                    } else {
483                        // We know it doesn't exist, so might as well remove it
484                        state.writing_streams.remove(&id);
485                    }
486                }
487                // It's possible that the stream has already finished writing and we don't have it
488                // anymore, so we fall through to Ok
489                Ok(())
490            }
491        }
492    }
493
494    /// Broadcast an error to all stream readers. This is useful for error propagation.
495    pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
496        let state = self.lock()?;
497        for channel in state.reading_streams.values() {
498            // Ignore send errors.
499            let _ = channel.send(Err(error.clone()));
500        }
501        Ok(())
502    }
503
504    // If the `StreamManager` is dropped, we should let all of the stream writers know that they
505    // won't be able to write anymore. We don't need to do anything about the readers though
506    // because they'll know when the `Sender` is dropped automatically
507    fn drop_all_writers(&self) -> Result<(), ShellError> {
508        let mut state = self.lock()?;
509        let writers = std::mem::take(&mut state.writing_streams);
510        for (_, signal) in writers {
511            if let Some(signal) = signal.upgrade() {
512                // more important that we send to all than handling an error
513                let _ = signal.set_dropped();
514            }
515        }
516        Ok(())
517    }
518}
519
520impl Default for StreamManager {
521    fn default() -> Self {
522        Self::new()
523    }
524}
525
526impl Drop for StreamManager {
527    fn drop(&mut self) {
528        if let Err(err) = self.drop_all_writers() {
529            log::warn!("error during Drop for StreamManager: {err}")
530        }
531    }
532}
533
534/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
535///
536/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
537/// a [`StreamWriter`].
538#[derive(Debug, Clone)]
539pub struct StreamManagerHandle {
540    state: Weak<Mutex<StreamManagerState>>,
541}
542
543impl StreamManagerHandle {
544    /// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
545    /// first try to upgrade to a strong reference and then lock. This function wraps those two
546    /// operations together, handling errors appropriately.
547    fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
548    where
549        F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
550    {
551        let upgraded = self
552            .state
553            .upgrade()
554            .ok_or_else(|| ShellError::NushellFailed {
555                msg: "StreamManager is no longer alive".into(),
556            })?;
557        let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
558            msg: "StreamManagerState mutex poisoned due to a panic".into(),
559        })?;
560        f(guard)
561    }
562
563    /// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
564    /// on the values received. A [`StreamMessage`] writer is required for writing control messages
565    /// back to the producer.
566    pub fn read_stream<T, W>(
567        &self,
568        id: StreamId,
569        writer: W,
570    ) -> Result<StreamReader<T, W>, ShellError>
571    where
572        T: TryFrom<StreamData, Error = ShellError>,
573        W: WriteStreamMessage,
574    {
575        let (tx, rx) = mpsc::channel();
576        self.with_lock(|mut state| {
577            // Must be exclusive
578            if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
579                e.insert(tx);
580                Ok(())
581            } else {
582                Err(ShellError::GenericError {
583                    error: format!("Failed to acquire reader for stream {id}"),
584                    msg: "tried to get a reader for a stream that's already being read".into(),
585                    span: None,
586                    help: Some("this may be a bug in the nu-plugin crate".into()),
587                    inner: vec![],
588                })
589            }
590        })?;
591        Ok(StreamReader::new(id, rx, writer))
592    }
593
594    /// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
595    /// data to the stream.
596    ///
597    /// The `high_pressure_mark` value controls how many messages can be written without receiving
598    /// an acknowledgement before any further attempts to write will wait for the consumer to
599    /// acknowledge them. This prevents overwhelming the reader.
600    pub fn write_stream<W>(
601        &self,
602        id: StreamId,
603        writer: W,
604        high_pressure_mark: i32,
605    ) -> Result<StreamWriter<W>, ShellError>
606    where
607        W: WriteStreamMessage,
608    {
609        let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
610        self.with_lock(|mut state| {
611            // Remove dead writing streams
612            state
613                .writing_streams
614                .retain(|_, signal| signal.strong_count() > 0);
615            // Must be exclusive
616            if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
617                e.insert(Arc::downgrade(&signal));
618                Ok(())
619            } else {
620                Err(ShellError::GenericError {
621                    error: format!("Failed to acquire writer for stream {id}"),
622                    msg: "tried to get a writer for a stream that's already being written".into(),
623                    span: None,
624                    help: Some("this may be a bug in the nu-plugin crate".into()),
625                    inner: vec![],
626                })
627            }
628        })?;
629        Ok(StreamWriter::new(id, signal, writer))
630    }
631}