nu_plugin_core/interface/stream/
mod.rs1use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
2use nu_protocol::{ShellError, Span, Value, shell_error::generic::GenericError};
3use std::{
4 collections::{BTreeMap, btree_map},
5 iter::FusedIterator,
6 marker::PhantomData,
7 sync::{Arc, Condvar, Mutex, MutexGuard, Weak, mpsc},
8};
9
10#[cfg(test)]
11mod tests;
12
13#[derive(Debug)]
28pub struct StreamReader<T, W>
29where
30 W: WriteStreamMessage,
31{
32 id: StreamId,
33 receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
34 writer: W,
35 marker: PhantomData<fn() -> T>,
38}
39
40impl<T, W> StreamReader<T, W>
41where
42 T: TryFrom<StreamData, Error = ShellError>,
43 W: WriteStreamMessage,
44{
45 fn new(
47 id: StreamId,
48 receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
49 writer: W,
50 ) -> StreamReader<T, W> {
51 StreamReader {
52 id,
53 receiver: Some(receiver),
54 writer,
55 marker: PhantomData,
56 }
57 }
58
59 pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
65 let connection_lost = || {
66 ShellError::Generic(GenericError::new_internal(
67 "Stream ended unexpectedly",
68 "connection lost before explicit end of stream",
69 ))
70 };
71
72 if let Some(ref rx) = self.receiver {
73 let msg = match rx.try_recv() {
75 Ok(msg) => msg?,
76 Err(mpsc::TryRecvError::Empty) => {
77 self.writer.flush()?;
81 rx.recv().map_err(|_| connection_lost())??
82 }
83 Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
84 };
85
86 if let Some(data) = msg {
87 self.writer
89 .write_stream_message(StreamMessage::Ack(self.id))?;
90 Ok(Some(data.try_into()?))
92 } else {
93 self.receiver = None;
95 Ok(None)
96 }
97 } else {
98 Ok(None)
100 }
101 }
102}
103
104impl<T, W> Iterator for StreamReader<T, W>
105where
106 T: FromShellError + TryFrom<StreamData, Error = ShellError>,
107 W: WriteStreamMessage,
108{
109 type Item = T;
110
111 fn next(&mut self) -> Option<T> {
112 match self.recv() {
114 Ok(option) => option,
115 Err(err) => {
116 self.receiver = None;
118 Some(T::from_shell_error(err))
119 }
120 }
121 }
122}
123
124impl<T, W> FusedIterator for StreamReader<T, W>
126where
127 T: FromShellError + TryFrom<StreamData, Error = ShellError>,
128 W: WriteStreamMessage,
129{
130}
131
132impl<T, W> Drop for StreamReader<T, W>
133where
134 W: WriteStreamMessage,
135{
136 fn drop(&mut self) {
137 if let Err(err) = self
138 .writer
139 .write_stream_message(StreamMessage::Drop(self.id))
140 .and_then(|_| self.writer.flush())
141 {
142 log::warn!("Failed to send message to drop stream: {err}");
143 }
144 }
145}
146
147pub trait FromShellError {
149 fn from_shell_error(err: ShellError) -> Self;
150}
151
152impl FromShellError for Value {
156 fn from_shell_error(err: ShellError) -> Self {
157 Value::error(err, Span::unknown())
158 }
159}
160
161impl<T> FromShellError for Result<T, ShellError> {
163 fn from_shell_error(err: ShellError) -> Self {
164 Err(err)
165 }
166}
167
168#[derive(Debug)]
172pub struct StreamWriter<W: WriteStreamMessage> {
173 id: StreamId,
174 signal: Arc<StreamWriterSignal>,
175 writer: W,
176 ended: bool,
177}
178
179impl<W> StreamWriter<W>
180where
181 W: WriteStreamMessage,
182{
183 fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
184 StreamWriter {
185 id,
186 signal,
187 writer,
188 ended: false,
189 }
190 }
191
192 pub fn is_dropped(&self) -> Result<bool, ShellError> {
195 self.signal.is_dropped()
196 }
197
198 pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
203 if !self.ended {
204 self.writer
205 .write_stream_message(StreamMessage::Data(self.id, data.into()))?;
206 self.writer.flush()?;
211 if !self.signal.notify_sent()? {
213 self.signal.wait_for_drain()
214 } else {
215 Ok(())
216 }
217 } else {
218 Err(ShellError::Generic(
219 GenericError::new_internal(
220 "Wrote to a stream after it ended",
221 format!(
222 "tried to write to stream {} after it was already ended",
223 self.id
224 ),
225 )
226 .with_help("this may be a bug in the nu-plugin crate"),
227 ))
228 }
229 }
230
231 pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
240 where
241 T: Into<StreamData>,
242 {
243 if self.is_dropped()? {
245 return Ok(false);
246 }
247
248 for item in data {
249 if self.is_dropped()? {
252 return Ok(false);
253 }
254 self.write(item)?;
255 }
256 Ok(true)
257 }
258
259 pub fn end(&mut self) -> Result<(), ShellError> {
262 if !self.ended {
263 self.ended = true;
265 self.writer
266 .write_stream_message(StreamMessage::End(self.id))?;
267 self.writer.flush()
268 } else {
269 Ok(())
270 }
271 }
272}
273
274impl<W> Drop for StreamWriter<W>
275where
276 W: WriteStreamMessage,
277{
278 fn drop(&mut self) {
279 if let Err(err) = self.end() {
281 log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
282 }
283 }
284}
285
286#[derive(Debug)]
289pub struct StreamWriterSignal {
290 mutex: Mutex<StreamWriterSignalState>,
291 change_cond: Condvar,
292}
293
294#[derive(Debug)]
295pub struct StreamWriterSignalState {
296 dropped: bool,
298 unacknowledged: i32,
300 high_pressure_mark: i32,
302}
303
304impl StreamWriterSignal {
305 fn new(high_pressure_mark: i32) -> StreamWriterSignal {
311 assert!(high_pressure_mark > 0);
312
313 StreamWriterSignal {
314 mutex: Mutex::new(StreamWriterSignalState {
315 dropped: false,
316 unacknowledged: 0,
317 high_pressure_mark,
318 }),
319 change_cond: Condvar::new(),
320 }
321 }
322
323 fn lock(&self) -> Result<MutexGuard<'_, StreamWriterSignalState>, ShellError> {
324 self.mutex.lock().map_err(|_| ShellError::NushellFailed {
325 msg: "StreamWriterSignal mutex poisoned due to panic".into(),
326 })
327 }
328
329 pub fn is_dropped(&self) -> Result<bool, ShellError> {
332 Ok(self.lock()?.dropped)
333 }
334
335 pub fn set_dropped(&self) -> Result<(), ShellError> {
337 let mut state = self.lock()?;
338 state.dropped = true;
339 self.change_cond.notify_all();
341 Ok(())
342 }
343
344 pub fn notify_sent(&self) -> Result<bool, ShellError> {
348 let mut state = self.lock()?;
349 state.unacknowledged =
350 state
351 .unacknowledged
352 .checked_add(1)
353 .ok_or_else(|| ShellError::NushellFailed {
354 msg: "Overflow in counter: too many unacknowledged messages".into(),
355 })?;
356
357 Ok(state.unacknowledged < state.high_pressure_mark)
358 }
359
360 pub fn wait_for_drain(&self) -> Result<(), ShellError> {
362 let mut state = self.lock()?;
363 while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
364 state = self
365 .change_cond
366 .wait(state)
367 .map_err(|_| ShellError::NushellFailed {
368 msg: "StreamWriterSignal mutex poisoned due to panic".into(),
369 })?;
370 }
371 Ok(())
372 }
373
374 pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
377 let mut state = self.lock()?;
378 state.unacknowledged =
379 state
380 .unacknowledged
381 .checked_sub(1)
382 .ok_or_else(|| ShellError::NushellFailed {
383 msg: "Underflow in counter: too many message acknowledgements".into(),
384 })?;
385 self.change_cond.notify_one();
387 Ok(())
388 }
389}
390
391pub trait WriteStreamMessage {
393 fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
394 fn flush(&mut self) -> Result<(), ShellError>;
395}
396
397#[derive(Debug, Default)]
398struct StreamManagerState {
399 reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
400 writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
401}
402
403impl StreamManagerState {
404 fn lock(
406 state: &Mutex<StreamManagerState>,
407 ) -> Result<MutexGuard<'_, StreamManagerState>, ShellError> {
408 state.lock().map_err(|_| ShellError::NushellFailed {
409 msg: "StreamManagerState mutex poisoned due to a panic".into(),
410 })
411 }
412}
413
414#[derive(Debug)]
415pub struct StreamManager {
416 state: Arc<Mutex<StreamManagerState>>,
417}
418
419impl StreamManager {
420 pub fn new() -> StreamManager {
422 StreamManager {
423 state: Default::default(),
424 }
425 }
426
427 fn lock(&self) -> Result<MutexGuard<'_, StreamManagerState>, ShellError> {
428 StreamManagerState::lock(&self.state)
429 }
430
431 pub fn get_handle(&self) -> StreamManagerHandle {
433 StreamManagerHandle {
434 state: Arc::downgrade(&self.state),
435 }
436 }
437
438 pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
440 let mut state = self.lock()?;
441 match message {
442 StreamMessage::Data(id, data) => {
443 if let Some(sender) = state.reading_streams.get(&id) {
444 let _ = sender.send(Ok(Some(data)));
448 Ok(())
449 } else {
450 Err(ShellError::PluginFailedToDecode {
451 msg: format!("received Data for unknown stream {id}"),
452 })
453 }
454 }
455 StreamMessage::End(id) => {
456 if let Some(sender) = state.reading_streams.remove(&id) {
457 let _ = sender.send(Ok(None));
460 Ok(())
461 } else {
462 Err(ShellError::PluginFailedToDecode {
463 msg: format!("received End for unknown stream {id}"),
464 })
465 }
466 }
467 StreamMessage::Drop(id) => {
468 if let Some(signal) = state.writing_streams.remove(&id)
469 && let Some(signal) = signal.upgrade()
470 {
471 signal.set_dropped()?;
473 }
474 Ok(())
477 }
478 StreamMessage::Ack(id) => {
479 if let Some(signal) = state.writing_streams.get(&id) {
480 if let Some(signal) = signal.upgrade() {
481 signal.notify_acknowledged()?;
483 } else {
484 state.writing_streams.remove(&id);
486 }
487 }
488 Ok(())
491 }
492 }
493 }
494
495 pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
497 let state = self.lock()?;
498 for channel in state.reading_streams.values() {
499 let _ = channel.send(Err(error.clone()));
501 }
502 Ok(())
503 }
504
505 fn drop_all_writers(&self) -> Result<(), ShellError> {
509 let mut state = self.lock()?;
510 let writers = std::mem::take(&mut state.writing_streams);
511 for (_, signal) in writers {
512 if let Some(signal) = signal.upgrade() {
513 let _ = signal.set_dropped();
515 }
516 }
517 Ok(())
518 }
519}
520
521impl Default for StreamManager {
522 fn default() -> Self {
523 Self::new()
524 }
525}
526
527impl Drop for StreamManager {
528 fn drop(&mut self) {
529 if let Err(err) = self.drop_all_writers() {
530 log::warn!("error during Drop for StreamManager: {err}")
531 }
532 }
533}
534
535#[derive(Debug, Clone)]
540pub struct StreamManagerHandle {
541 state: Weak<Mutex<StreamManagerState>>,
542}
543
544impl StreamManagerHandle {
545 fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
549 where
550 F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
551 {
552 let upgraded = self
553 .state
554 .upgrade()
555 .ok_or_else(|| ShellError::NushellFailed {
556 msg: "StreamManager is no longer alive".into(),
557 })?;
558 let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
559 msg: "StreamManagerState mutex poisoned due to a panic".into(),
560 })?;
561 f(guard)
562 }
563
564 pub fn read_stream<T, W>(
568 &self,
569 id: StreamId,
570 writer: W,
571 ) -> Result<StreamReader<T, W>, ShellError>
572 where
573 T: TryFrom<StreamData, Error = ShellError>,
574 W: WriteStreamMessage,
575 {
576 let (tx, rx) = mpsc::channel();
577 self.with_lock(|mut state| {
578 if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
580 e.insert(tx);
581 Ok(())
582 } else {
583 Err(ShellError::Generic(
584 GenericError::new_internal(
585 format!("Failed to acquire reader for stream {id}"),
586 "tried to get a reader for a stream that's already being read",
587 )
588 .with_help("this may be a bug in the nu-plugin crate"),
589 ))
590 }
591 })?;
592 Ok(StreamReader::new(id, rx, writer))
593 }
594
595 pub fn write_stream<W>(
602 &self,
603 id: StreamId,
604 writer: W,
605 high_pressure_mark: i32,
606 ) -> Result<StreamWriter<W>, ShellError>
607 where
608 W: WriteStreamMessage,
609 {
610 let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
611 self.with_lock(|mut state| {
612 state
614 .writing_streams
615 .retain(|_, signal| signal.strong_count() > 0);
616 if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
618 e.insert(Arc::downgrade(&signal));
619 Ok(())
620 } else {
621 Err(ShellError::Generic(
622 GenericError::new_internal(
623 format!("Failed to acquire writer for stream {id}"),
624 "tried to get a writer for a stream that's already being written",
625 )
626 .with_help("this may be a bug in the nu-plugin crate"),
627 ))
628 }
629 })?;
630 Ok(StreamWriter::new(id, signal, writer))
631 }
632}