Skip to main content

nu_command/filters/
tee.rs

1use nu_engine::{command_prelude::*, get_eval_block_with_early_return};
2#[cfg(feature = "os")]
3use nu_protocol::process::ChildPipe;
4#[cfg(test)]
5use nu_protocol::shell_error;
6use nu_protocol::{
7    ByteStream, ByteStreamSource, OutDest, PipelineMetadata, Signals,
8    byte_stream::copy_with_signals, engine::Closure, report_shell_error, shell_error::io::IoError,
9};
10use std::{
11    io::{self, Read, Write},
12    sync::{
13        Arc,
14        mpsc::{self, Sender},
15    },
16    thread::{self, JoinHandle},
17};
18
19#[derive(Clone)]
20pub struct Tee;
21
22impl Command for Tee {
23    fn name(&self) -> &str {
24        "tee"
25    }
26
27    fn description(&self) -> &str {
28        "Copy a stream to another command in parallel."
29    }
30
31    fn extra_description(&self) -> &str {
32        r#"This is useful for doing something else with a stream while still continuing to
33use it in your pipeline."#
34    }
35
36    fn signature(&self) -> Signature {
37        Signature::build("tee")
38            .input_output_type(Type::Any, Type::Any)
39            .switch(
40                "stderr",
41                "For external commands: copy the standard error stream instead.",
42                Some('e'),
43            )
44            .required(
45                "closure",
46                SyntaxShape::Closure(None),
47                "The other command to send the stream to.",
48            )
49            .category(Category::Filters)
50    }
51
52    fn examples(&self) -> Vec<Example<'_>> {
53        vec![
54            Example {
55                example: "http get http://example.org/ | tee { save example.html }",
56                description: "Save a webpage to a file while also printing it.",
57                result: None,
58            },
59            Example {
60                example: "nu -c 'print -e error; print ok' | tee --stderr { save error.log } | complete",
61                description: "Save error messages from an external command to a file without redirecting them.",
62                result: None,
63            },
64            Example {
65                example: "1..100 | tee { each { print } } | math sum | wrap sum",
66                description: "Print numbers and their sum.",
67                result: None,
68            },
69            Example {
70                example: "10000 | tee { 1..$in | print } | $in * 5",
71                description: "Do something with a value on another thread, while also passing through the value.",
72                result: Some(Value::test_int(50000)),
73            },
74        ]
75    }
76
77    fn run(
78        &self,
79        engine_state: &EngineState,
80        stack: &mut Stack,
81        call: &Call,
82        input: PipelineData,
83    ) -> Result<PipelineData, ShellError> {
84        let head = call.head;
85        let from_io_error = IoError::factory(head, None);
86        let use_stderr = call.has_flag(engine_state, stack, "stderr")?;
87
88        let closure: Spanned<Closure> = call.req(engine_state, stack, 0)?;
89        let closure_span = closure.span;
90        let closure = closure.item;
91
92        let engine_state_arc = Arc::new(engine_state.clone());
93
94        let mut eval_block = {
95            let closure_engine_state = engine_state_arc.clone();
96            let mut closure_stack = stack
97                .captures_to_stack_preserve_out_dest(closure.captures)
98                .reset_pipes();
99            let eval_block_with_early_return = get_eval_block_with_early_return(engine_state);
100
101            move |input| {
102                let result = eval_block_with_early_return(
103                    &closure_engine_state,
104                    &mut closure_stack,
105                    closure_engine_state.get_block(closure.block_id),
106                    input,
107                )
108                .map(|p| p.body);
109                // Make sure to drain any iterator produced to avoid unexpected behavior
110                result.and_then(|data| data.drain().map(|_| ()))
111            }
112        };
113
114        // Convert values that can be represented as streams into streams. Streams can pass errors
115        // through later, so if we treat string/binary/list as a stream instead, it's likely that
116        // we can get the error back to the original thread.
117        let span = input.span().unwrap_or(head);
118        let input = input
119            .try_into_stream(engine_state)
120            .unwrap_or_else(|original_input| original_input);
121
122        if let PipelineData::ByteStream(stream, metadata) = input {
123            let type_ = stream.type_();
124
125            let info = StreamInfo {
126                span,
127                signals: engine_state.signals().clone(),
128                type_,
129                metadata: metadata.clone(),
130            };
131
132            match stream.into_source() {
133                ByteStreamSource::Read(read) => {
134                    if use_stderr {
135                        return stderr_misuse(span, head);
136                    }
137
138                    let tee_thread = spawn_tee(info, eval_block)?;
139                    let tee = IoTee::new(read, tee_thread);
140
141                    Ok(PipelineData::byte_stream(
142                        ByteStream::read(tee, span, engine_state.signals().clone(), type_),
143                        metadata,
144                    ))
145                }
146                ByteStreamSource::File(file) => {
147                    if use_stderr {
148                        return stderr_misuse(span, head);
149                    }
150
151                    let tee_thread = spawn_tee(info, eval_block)?;
152                    let tee = IoTee::new(file, tee_thread);
153
154                    Ok(PipelineData::byte_stream(
155                        ByteStream::read(tee, span, engine_state.signals().clone(), type_),
156                        metadata,
157                    ))
158                }
159                #[cfg(feature = "os")]
160                ByteStreamSource::Child(mut child) => {
161                    let stderr_thread = if use_stderr {
162                        let stderr_thread = if let Some(stderr) = child.stderr.take() {
163                            let tee_thread = spawn_tee(info.clone(), eval_block)?;
164                            let tee = IoTee::new(stderr, tee_thread);
165                            match stack.stderr() {
166                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
167                                    child.stderr = Some(ChildPipe::Tee(Box::new(tee)));
168                                    Ok(None)
169                                }
170                                OutDest::Null => copy_on_thread(tee, io::sink(), &info).map(Some),
171                                OutDest::Print | OutDest::Inherit => {
172                                    copy_on_thread(tee, io::stderr(), &info).map(Some)
173                                }
174                                OutDest::File(file) => {
175                                    copy_on_thread(tee, file.clone(), &info).map(Some)
176                                }
177                            }?
178                        } else {
179                            None
180                        };
181
182                        if let Some(stdout) = child.stdout.take() {
183                            match stack.stdout() {
184                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
185                                    child.stdout = Some(stdout);
186                                    Ok(())
187                                }
188                                OutDest::Null => copy_pipe(stdout, io::sink(), &info),
189                                OutDest::Print | OutDest::Inherit => {
190                                    copy_pipe(stdout, io::stdout(), &info)
191                                }
192                                OutDest::File(file) => copy_pipe(stdout, file.as_ref(), &info),
193                            }?;
194                        }
195
196                        stderr_thread
197                    } else {
198                        let stderr_thread = if let Some(stderr) = child.stderr.take() {
199                            let info = info.clone();
200                            match stack.stderr() {
201                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
202                                    child.stderr = Some(stderr);
203                                    Ok(None)
204                                }
205                                OutDest::Null => {
206                                    copy_pipe_on_thread(stderr, io::sink(), &info).map(Some)
207                                }
208                                OutDest::Print | OutDest::Inherit => {
209                                    copy_pipe_on_thread(stderr, io::stderr(), &info).map(Some)
210                                }
211                                OutDest::File(file) => {
212                                    copy_pipe_on_thread(stderr, file.clone(), &info).map(Some)
213                                }
214                            }?
215                        } else {
216                            None
217                        };
218
219                        if let Some(stdout) = child.stdout.take() {
220                            let tee_thread = spawn_tee(info.clone(), eval_block)?;
221                            let tee = IoTee::new(stdout, tee_thread);
222                            match stack.stdout() {
223                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
224                                    child.stdout = Some(ChildPipe::Tee(Box::new(tee)));
225                                    Ok(())
226                                }
227                                OutDest::Null => copy(tee, io::sink(), &info),
228                                OutDest::Print | OutDest::Inherit => copy(tee, io::stdout(), &info),
229                                OutDest::File(file) => copy(tee, file.as_ref(), &info),
230                            }?;
231                        }
232
233                        stderr_thread
234                    };
235
236                    if child.stdout.is_some() || child.stderr.is_some() {
237                        Ok(PipelineData::byte_stream(
238                            ByteStream::child(*child, span),
239                            metadata,
240                        ))
241                    } else {
242                        if let Some(thread) = stderr_thread {
243                            thread.join().unwrap_or_else(|_| Err(panic_error()))?;
244                        }
245                        child.wait()?;
246                        Ok(PipelineData::empty())
247                    }
248                }
249            }
250        } else {
251            if use_stderr {
252                return stderr_misuse(input.span().unwrap_or(head), head);
253            }
254
255            let metadata = input.metadata();
256            let metadata_clone = metadata.clone();
257
258            if let PipelineData::ListStream(..) = input {
259                // Only use the iterator implementation on lists / list streams. We want to be able
260                // to preserve errors as much as possible, and only the stream implementations can
261                // really do that
262                let signals = engine_state.signals().clone();
263
264                Ok(tee(input.into_iter(), move |rx| {
265                    let input = rx.into_pipeline_data_with_metadata(span, signals, metadata_clone);
266                    eval_block(input)
267                })
268                .map_err(&from_io_error)?
269                .map(move |result| result.unwrap_or_else(|err| Value::error(err, closure_span)))
270                .into_pipeline_data_with_metadata(
271                    span,
272                    engine_state.signals().clone(),
273                    metadata,
274                ))
275            } else {
276                // Otherwise, we can spawn a thread with the input value, but we have nowhere to
277                // send an error to other than just trying to print it to stderr.
278                let value = input.into_value(span)?;
279                let value_clone = value.clone();
280                tee_once(stack.clone(), engine_state_arc, move || {
281                    eval_block(value_clone.into_pipeline_data_with_metadata(metadata_clone))
282                })
283                .map_err(&from_io_error)?;
284                Ok(value.into_pipeline_data_with_metadata(metadata))
285            }
286        }
287    }
288
289    fn pipe_redirection(&self) -> (Option<OutDest>, Option<OutDest>) {
290        (Some(OutDest::PipeSeparate), Some(OutDest::PipeSeparate))
291    }
292}
293
294fn panic_error() -> ShellError {
295    ShellError::NushellFailed {
296        msg: "A panic occurred on a thread spawned by `tee`".into(),
297    }
298}
299
300/// Copies the iterator to a channel on another thread. If an error is produced on that thread,
301/// it is embedded in the resulting iterator as an `Err` as soon as possible. When the iterator
302/// finishes, it waits for the other thread to finish, also handling any error produced at that
303/// point.
304fn tee<T>(
305    input: impl Iterator<Item = T>,
306    with_cloned_stream: impl FnOnce(mpsc::Receiver<T>) -> Result<(), ShellError> + Send + 'static,
307) -> Result<impl Iterator<Item = Result<T, ShellError>>, std::io::Error>
308where
309    T: Clone + Send + 'static,
310{
311    // For sending the values to the other thread
312    let (tx, rx) = mpsc::channel();
313
314    let mut thread = Some(
315        thread::Builder::new()
316            .name("tee".into())
317            .spawn(move || with_cloned_stream(rx))?,
318    );
319
320    let mut iter = input.into_iter();
321    let mut tx = Some(tx);
322
323    Ok(std::iter::from_fn(move || {
324        if thread.as_ref().is_some_and(|t| t.is_finished()) {
325            // Check for an error from the other thread
326            let result = thread
327                .take()
328                .expect("thread was taken early")
329                .join()
330                .unwrap_or_else(|_| Err(panic_error()));
331            if let Err(err) = result {
332                // Embed the error early
333                return Some(Err(err));
334            }
335        }
336
337        // Get a value from the iterator
338        if let Some(value) = iter.next() {
339            // Send a copy, ignoring any error if the channel is closed
340            let _ = tx.as_ref().map(|tx| tx.send(value.clone()));
341            Some(Ok(value))
342        } else {
343            // Close the channel so the stream ends for the other thread
344            drop(tx.take());
345            // Wait for the other thread, and embed any error produced
346            thread.take().and_then(|t| {
347                t.join()
348                    .unwrap_or_else(|_| Err(panic_error()))
349                    .err()
350                    .map(Err)
351            })
352        }
353    }))
354}
355
356/// "tee" for a single value. No stream handling, just spawns a thread, printing any resulting error
357fn tee_once(
358    stack: Stack,
359    engine_state: Arc<EngineState>,
360    on_thread: impl FnOnce() -> Result<(), ShellError> + Send + 'static,
361) -> Result<JoinHandle<()>, std::io::Error> {
362    thread::Builder::new().name("tee".into()).spawn(move || {
363        if let Err(err) = on_thread() {
364            report_shell_error(Some(&stack), &engine_state, &err);
365        }
366    })
367}
368
369fn stderr_misuse<T>(span: Span, head: Span) -> Result<T, ShellError> {
370    Err(ShellError::UnsupportedInput {
371        msg: "--stderr can only be used on external commands".into(),
372        input: "the input to `tee` is not an external command".into(),
373        msg_span: head,
374        input_span: span,
375    })
376}
377
378struct IoTee<R: Read> {
379    reader: R,
380    sender: Option<Sender<Vec<u8>>>,
381    thread: Option<JoinHandle<Result<(), ShellError>>>,
382}
383
384impl<R: Read> IoTee<R> {
385    fn new(reader: R, tee: TeeThread) -> Self {
386        Self {
387            reader,
388            sender: Some(tee.sender),
389            thread: Some(tee.thread),
390        }
391    }
392}
393
394impl<R: Read> Read for IoTee<R> {
395    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
396        if let Some(thread) = self.thread.take() {
397            if thread.is_finished() {
398                if let Err(err) = thread.join().unwrap_or_else(|_| Err(panic_error())) {
399                    return Err(io::Error::other(err));
400                }
401            } else {
402                self.thread = Some(thread)
403            }
404        }
405        let len = self.reader.read(buf)?;
406        if len == 0 {
407            self.sender = None;
408            if let Some(thread) = self.thread.take()
409                && let Err(err) = thread.join().unwrap_or_else(|_| Err(panic_error()))
410            {
411                return Err(io::Error::other(err));
412            }
413        } else if let Some(sender) = self.sender.as_mut()
414            && sender.send(buf[..len].to_vec()).is_err()
415        {
416            self.sender = None;
417        }
418        Ok(len)
419    }
420}
421
422struct TeeThread {
423    sender: Sender<Vec<u8>>,
424    thread: JoinHandle<Result<(), ShellError>>,
425}
426
427fn spawn_tee(
428    info: StreamInfo,
429    mut eval_block: impl FnMut(PipelineData) -> Result<(), ShellError> + Send + 'static,
430) -> Result<TeeThread, ShellError> {
431    let (sender, receiver) = mpsc::channel();
432
433    let thread = thread::Builder::new()
434        .name("tee".into())
435        .spawn(move || {
436            // We use Signals::empty() here because we assume there already is a Signals on the other side
437            let stream = ByteStream::from_iter(
438                receiver.into_iter(),
439                info.span,
440                Signals::empty(),
441                info.type_,
442            );
443            eval_block(PipelineData::byte_stream(stream, info.metadata))
444        })
445        .map_err(|err| {
446            IoError::new_with_additional_context(err, info.span, None, "Could not spawn tee")
447        })?;
448
449    Ok(TeeThread { sender, thread })
450}
451
452#[derive(Clone)]
453struct StreamInfo {
454    span: Span,
455    signals: Signals,
456    type_: ByteStreamType,
457    metadata: Option<PipelineMetadata>,
458}
459
460fn copy(src: impl Read, dest: impl Write, info: &StreamInfo) -> Result<(), ShellError> {
461    copy_with_signals(src, dest, info.span, &info.signals)?;
462    Ok(())
463}
464
465#[cfg(feature = "os")]
466fn copy_pipe(pipe: ChildPipe, dest: impl Write, info: &StreamInfo) -> Result<(), ShellError> {
467    match pipe {
468        ChildPipe::Pipe(pipe) => copy(pipe, dest, info),
469        ChildPipe::Tee(tee) => copy(tee, dest, info),
470    }
471}
472
473fn copy_on_thread(
474    src: impl Read + Send + 'static,
475    dest: impl Write + Send + 'static,
476    info: &StreamInfo,
477) -> Result<JoinHandle<Result<(), ShellError>>, ShellError> {
478    let span = info.span;
479    let signals = info.signals.clone();
480    thread::Builder::new()
481        .name("stderr copier".into())
482        .spawn(move || {
483            copy_with_signals(src, dest, span, &signals)?;
484            Ok(())
485        })
486        .map_err(|err| {
487            IoError::new_with_additional_context(err, span, None, "Could not spawn stderr copier")
488                .into()
489        })
490}
491
492#[cfg(feature = "os")]
493fn copy_pipe_on_thread(
494    pipe: ChildPipe,
495    dest: impl Write + Send + 'static,
496    info: &StreamInfo,
497) -> Result<JoinHandle<Result<(), ShellError>>, ShellError> {
498    match pipe {
499        ChildPipe::Pipe(pipe) => copy_on_thread(pipe, dest, info),
500        ChildPipe::Tee(tee) => copy_on_thread(tee, dest, info),
501    }
502}
503
504#[test]
505fn tee_copies_values_to_other_thread_and_passes_them_through() {
506    let (tx, rx) = mpsc::channel();
507
508    let expected_values = vec![1, 2, 3, 4];
509
510    let my_result = tee(expected_values.clone().into_iter(), move |rx| {
511        for val in rx {
512            let _ = tx.send(val);
513        }
514        Ok(())
515    })
516    .expect("io error")
517    .collect::<Result<Vec<i32>, ShellError>>()
518    .expect("should not produce error");
519
520    assert_eq!(expected_values, my_result);
521
522    let other_threads_result = rx.into_iter().collect::<Vec<_>>();
523
524    assert_eq!(expected_values, other_threads_result);
525}
526
527#[test]
528fn tee_forwards_errors_back_immediately() {
529    use std::time::Duration;
530    let slow_input = (0..100).inspect(|_| std::thread::sleep(Duration::from_millis(1)));
531    let iter = tee(slow_input, |_| {
532        Err(ShellError::Io(IoError::new_with_additional_context(
533            shell_error::io::ErrorKind::from_std(std::io::ErrorKind::Other),
534            Span::test_data(),
535            None,
536            "test",
537        )))
538    })
539    .expect("io error");
540    for result in iter {
541        if let Ok(val) = result {
542            // should not make it to the end
543            assert!(val < 99, "the error did not come early enough");
544        } else {
545            // got the error
546            return;
547        }
548    }
549    panic!("never received the error");
550}
551
552#[test]
553fn tee_waits_for_the_other_thread() {
554    use std::sync::{
555        Arc,
556        atomic::{AtomicBool, Ordering},
557    };
558    use std::time::Duration;
559    let waited = Arc::new(AtomicBool::new(false));
560    let waited_clone = waited.clone();
561    let iter = tee(0..100, move |_| {
562        std::thread::sleep(Duration::from_millis(10));
563        waited_clone.store(true, Ordering::Relaxed);
564        Err(ShellError::Io(IoError::new_with_additional_context(
565            shell_error::io::ErrorKind::from_std(std::io::ErrorKind::Other),
566            Span::test_data(),
567            None,
568            "test",
569        )))
570    })
571    .expect("io error");
572    let last = iter.last();
573    assert!(waited.load(Ordering::Relaxed), "failed to wait");
574    assert!(
575        last.is_some_and(|res| res.is_err()),
576        "failed to return error from wait"
577    );
578}