1use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
2use nu_protocol::{ShellError, Span, Value};
3use std::{
4 collections::{BTreeMap, btree_map},
5 iter::FusedIterator,
6 marker::PhantomData,
7 sync::{Arc, Condvar, Mutex, MutexGuard, Weak, mpsc},
8};
9
10#[cfg(test)]
11mod tests;
12
13#[derive(Debug)]
28pub struct StreamReader<T, W>
29where
30 W: WriteStreamMessage,
31{
32 id: StreamId,
33 receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
34 writer: W,
35 marker: PhantomData<fn() -> T>,
38}
39
40impl<T, W> StreamReader<T, W>
41where
42 T: TryFrom<StreamData, Error = ShellError>,
43 W: WriteStreamMessage,
44{
45 fn new(
47 id: StreamId,
48 receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
49 writer: W,
50 ) -> StreamReader<T, W> {
51 StreamReader {
52 id,
53 receiver: Some(receiver),
54 writer,
55 marker: PhantomData,
56 }
57 }
58
59 pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
65 let connection_lost = || ShellError::GenericError {
66 error: "Stream ended unexpectedly".into(),
67 msg: "connection lost before explicit end of stream".into(),
68 span: None,
69 help: None,
70 inner: vec![],
71 };
72
73 if let Some(ref rx) = self.receiver {
74 let msg = match rx.try_recv() {
76 Ok(msg) => msg?,
77 Err(mpsc::TryRecvError::Empty) => {
78 self.writer.flush()?;
82 rx.recv().map_err(|_| connection_lost())??
83 }
84 Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
85 };
86
87 if let Some(data) = msg {
88 self.writer
90 .write_stream_message(StreamMessage::Ack(self.id))?;
91 Ok(Some(data.try_into()?))
93 } else {
94 self.receiver = None;
96 Ok(None)
97 }
98 } else {
99 Ok(None)
101 }
102 }
103}
104
105impl<T, W> Iterator for StreamReader<T, W>
106where
107 T: FromShellError + TryFrom<StreamData, Error = ShellError>,
108 W: WriteStreamMessage,
109{
110 type Item = T;
111
112 fn next(&mut self) -> Option<T> {
113 match self.recv() {
115 Ok(option) => option,
116 Err(err) => {
117 self.receiver = None;
119 Some(T::from_shell_error(err))
120 }
121 }
122 }
123}
124
125impl<T, W> FusedIterator for StreamReader<T, W>
127where
128 T: FromShellError + TryFrom<StreamData, Error = ShellError>,
129 W: WriteStreamMessage,
130{
131}
132
133impl<T, W> Drop for StreamReader<T, W>
134where
135 W: WriteStreamMessage,
136{
137 fn drop(&mut self) {
138 if let Err(err) = self
139 .writer
140 .write_stream_message(StreamMessage::Drop(self.id))
141 .and_then(|_| self.writer.flush())
142 {
143 log::warn!("Failed to send message to drop stream: {err}");
144 }
145 }
146}
147
148pub trait FromShellError {
150 fn from_shell_error(err: ShellError) -> Self;
151}
152
153impl FromShellError for Value {
155 fn from_shell_error(err: ShellError) -> Self {
156 Value::error(err, Span::unknown())
157 }
158}
159
160impl<T> FromShellError for Result<T, ShellError> {
162 fn from_shell_error(err: ShellError) -> Self {
163 Err(err)
164 }
165}
166
167#[derive(Debug)]
171pub struct StreamWriter<W: WriteStreamMessage> {
172 id: StreamId,
173 signal: Arc<StreamWriterSignal>,
174 writer: W,
175 ended: bool,
176}
177
178impl<W> StreamWriter<W>
179where
180 W: WriteStreamMessage,
181{
182 fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
183 StreamWriter {
184 id,
185 signal,
186 writer,
187 ended: false,
188 }
189 }
190
191 pub fn is_dropped(&self) -> Result<bool, ShellError> {
194 self.signal.is_dropped()
195 }
196
197 pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
202 if !self.ended {
203 self.writer
204 .write_stream_message(StreamMessage::Data(self.id, data.into()))?;
205 self.writer.flush()?;
210 if !self.signal.notify_sent()? {
212 self.signal.wait_for_drain()
213 } else {
214 Ok(())
215 }
216 } else {
217 Err(ShellError::GenericError {
218 error: "Wrote to a stream after it ended".into(),
219 msg: format!(
220 "tried to write to stream {} after it was already ended",
221 self.id
222 ),
223 span: None,
224 help: Some("this may be a bug in the nu-plugin crate".into()),
225 inner: vec![],
226 })
227 }
228 }
229
230 pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
239 where
240 T: Into<StreamData>,
241 {
242 if self.is_dropped()? {
244 return Ok(false);
245 }
246
247 for item in data {
248 if self.is_dropped()? {
251 return Ok(false);
252 }
253 self.write(item)?;
254 }
255 Ok(true)
256 }
257
258 pub fn end(&mut self) -> Result<(), ShellError> {
261 if !self.ended {
262 self.ended = true;
264 self.writer
265 .write_stream_message(StreamMessage::End(self.id))?;
266 self.writer.flush()
267 } else {
268 Ok(())
269 }
270 }
271}
272
273impl<W> Drop for StreamWriter<W>
274where
275 W: WriteStreamMessage,
276{
277 fn drop(&mut self) {
278 if let Err(err) = self.end() {
280 log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
281 }
282 }
283}
284
285#[derive(Debug)]
288pub struct StreamWriterSignal {
289 mutex: Mutex<StreamWriterSignalState>,
290 change_cond: Condvar,
291}
292
293#[derive(Debug)]
294pub struct StreamWriterSignalState {
295 dropped: bool,
297 unacknowledged: i32,
299 high_pressure_mark: i32,
301}
302
303impl StreamWriterSignal {
304 fn new(high_pressure_mark: i32) -> StreamWriterSignal {
310 assert!(high_pressure_mark > 0);
311
312 StreamWriterSignal {
313 mutex: Mutex::new(StreamWriterSignalState {
314 dropped: false,
315 unacknowledged: 0,
316 high_pressure_mark,
317 }),
318 change_cond: Condvar::new(),
319 }
320 }
321
322 fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
323 self.mutex.lock().map_err(|_| ShellError::NushellFailed {
324 msg: "StreamWriterSignal mutex poisoned due to panic".into(),
325 })
326 }
327
328 pub fn is_dropped(&self) -> Result<bool, ShellError> {
331 Ok(self.lock()?.dropped)
332 }
333
334 pub fn set_dropped(&self) -> Result<(), ShellError> {
336 let mut state = self.lock()?;
337 state.dropped = true;
338 self.change_cond.notify_all();
340 Ok(())
341 }
342
343 pub fn notify_sent(&self) -> Result<bool, ShellError> {
347 let mut state = self.lock()?;
348 state.unacknowledged =
349 state
350 .unacknowledged
351 .checked_add(1)
352 .ok_or_else(|| ShellError::NushellFailed {
353 msg: "Overflow in counter: too many unacknowledged messages".into(),
354 })?;
355
356 Ok(state.unacknowledged < state.high_pressure_mark)
357 }
358
359 pub fn wait_for_drain(&self) -> Result<(), ShellError> {
361 let mut state = self.lock()?;
362 while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
363 state = self
364 .change_cond
365 .wait(state)
366 .map_err(|_| ShellError::NushellFailed {
367 msg: "StreamWriterSignal mutex poisoned due to panic".into(),
368 })?;
369 }
370 Ok(())
371 }
372
373 pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
376 let mut state = self.lock()?;
377 state.unacknowledged =
378 state
379 .unacknowledged
380 .checked_sub(1)
381 .ok_or_else(|| ShellError::NushellFailed {
382 msg: "Underflow in counter: too many message acknowledgements".into(),
383 })?;
384 self.change_cond.notify_one();
386 Ok(())
387 }
388}
389
390pub trait WriteStreamMessage {
392 fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
393 fn flush(&mut self) -> Result<(), ShellError>;
394}
395
396#[derive(Debug, Default)]
397struct StreamManagerState {
398 reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
399 writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
400}
401
402impl StreamManagerState {
403 fn lock(
405 state: &Mutex<StreamManagerState>,
406 ) -> Result<MutexGuard<StreamManagerState>, ShellError> {
407 state.lock().map_err(|_| ShellError::NushellFailed {
408 msg: "StreamManagerState mutex poisoned due to a panic".into(),
409 })
410 }
411}
412
413#[derive(Debug)]
414pub struct StreamManager {
415 state: Arc<Mutex<StreamManagerState>>,
416}
417
418impl StreamManager {
419 pub fn new() -> StreamManager {
421 StreamManager {
422 state: Default::default(),
423 }
424 }
425
426 fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
427 StreamManagerState::lock(&self.state)
428 }
429
430 pub fn get_handle(&self) -> StreamManagerHandle {
432 StreamManagerHandle {
433 state: Arc::downgrade(&self.state),
434 }
435 }
436
437 pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
439 let mut state = self.lock()?;
440 match message {
441 StreamMessage::Data(id, data) => {
442 if let Some(sender) = state.reading_streams.get(&id) {
443 let _ = sender.send(Ok(Some(data)));
447 Ok(())
448 } else {
449 Err(ShellError::PluginFailedToDecode {
450 msg: format!("received Data for unknown stream {id}"),
451 })
452 }
453 }
454 StreamMessage::End(id) => {
455 if let Some(sender) = state.reading_streams.remove(&id) {
456 let _ = sender.send(Ok(None));
459 Ok(())
460 } else {
461 Err(ShellError::PluginFailedToDecode {
462 msg: format!("received End for unknown stream {id}"),
463 })
464 }
465 }
466 StreamMessage::Drop(id) => {
467 if let Some(signal) = state.writing_streams.remove(&id) {
468 if let Some(signal) = signal.upgrade() {
469 signal.set_dropped()?;
471 }
472 }
473 Ok(())
476 }
477 StreamMessage::Ack(id) => {
478 if let Some(signal) = state.writing_streams.get(&id) {
479 if let Some(signal) = signal.upgrade() {
480 signal.notify_acknowledged()?;
482 } else {
483 state.writing_streams.remove(&id);
485 }
486 }
487 Ok(())
490 }
491 }
492 }
493
494 pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
496 let state = self.lock()?;
497 for channel in state.reading_streams.values() {
498 let _ = channel.send(Err(error.clone()));
500 }
501 Ok(())
502 }
503
504 fn drop_all_writers(&self) -> Result<(), ShellError> {
508 let mut state = self.lock()?;
509 let writers = std::mem::take(&mut state.writing_streams);
510 for (_, signal) in writers {
511 if let Some(signal) = signal.upgrade() {
512 let _ = signal.set_dropped();
514 }
515 }
516 Ok(())
517 }
518}
519
520impl Default for StreamManager {
521 fn default() -> Self {
522 Self::new()
523 }
524}
525
526impl Drop for StreamManager {
527 fn drop(&mut self) {
528 if let Err(err) = self.drop_all_writers() {
529 log::warn!("error during Drop for StreamManager: {err}")
530 }
531 }
532}
533
534#[derive(Debug, Clone)]
539pub struct StreamManagerHandle {
540 state: Weak<Mutex<StreamManagerState>>,
541}
542
543impl StreamManagerHandle {
544 fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
548 where
549 F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
550 {
551 let upgraded = self
552 .state
553 .upgrade()
554 .ok_or_else(|| ShellError::NushellFailed {
555 msg: "StreamManager is no longer alive".into(),
556 })?;
557 let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
558 msg: "StreamManagerState mutex poisoned due to a panic".into(),
559 })?;
560 f(guard)
561 }
562
563 pub fn read_stream<T, W>(
567 &self,
568 id: StreamId,
569 writer: W,
570 ) -> Result<StreamReader<T, W>, ShellError>
571 where
572 T: TryFrom<StreamData, Error = ShellError>,
573 W: WriteStreamMessage,
574 {
575 let (tx, rx) = mpsc::channel();
576 self.with_lock(|mut state| {
577 if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
579 e.insert(tx);
580 Ok(())
581 } else {
582 Err(ShellError::GenericError {
583 error: format!("Failed to acquire reader for stream {id}"),
584 msg: "tried to get a reader for a stream that's already being read".into(),
585 span: None,
586 help: Some("this may be a bug in the nu-plugin crate".into()),
587 inner: vec![],
588 })
589 }
590 })?;
591 Ok(StreamReader::new(id, rx, writer))
592 }
593
594 pub fn write_stream<W>(
601 &self,
602 id: StreamId,
603 writer: W,
604 high_pressure_mark: i32,
605 ) -> Result<StreamWriter<W>, ShellError>
606 where
607 W: WriteStreamMessage,
608 {
609 let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
610 self.with_lock(|mut state| {
611 state
613 .writing_streams
614 .retain(|_, signal| signal.strong_count() > 0);
615 if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
617 e.insert(Arc::downgrade(&signal));
618 Ok(())
619 } else {
620 Err(ShellError::GenericError {
621 error: format!("Failed to acquire writer for stream {id}"),
622 msg: "tried to get a writer for a stream that's already being written".into(),
623 span: None,
624 help: Some("this may be a bug in the nu-plugin crate".into()),
625 inner: vec![],
626 })
627 }
628 })?;
629 Ok(StreamWriter::new(id, signal, writer))
630 }
631}