nu_command/filters/
tee.rs

1use nu_engine::{command_prelude::*, get_eval_block_with_early_return};
2#[cfg(feature = "os")]
3use nu_protocol::process::ChildPipe;
4use nu_protocol::{
5    byte_stream::copy_with_signals, engine::Closure, report_shell_error, shell_error::io::IoError,
6    ByteStream, ByteStreamSource, OutDest, PipelineMetadata, Signals,
7};
8use std::{
9    io::{self, Read, Write},
10    sync::{
11        mpsc::{self, Sender},
12        Arc,
13    },
14    thread::{self, JoinHandle},
15};
16
17#[derive(Clone)]
18pub struct Tee;
19
20impl Command for Tee {
21    fn name(&self) -> &str {
22        "tee"
23    }
24
25    fn description(&self) -> &str {
26        "Copy a stream to another command in parallel."
27    }
28
29    fn extra_description(&self) -> &str {
30        r#"This is useful for doing something else with a stream while still continuing to
31use it in your pipeline."#
32    }
33
34    fn signature(&self) -> Signature {
35        Signature::build("tee")
36            .input_output_type(Type::Any, Type::Any)
37            .switch(
38                "stderr",
39                "For external commands: copy the standard error stream instead.",
40                Some('e'),
41            )
42            .required(
43                "closure",
44                SyntaxShape::Closure(None),
45                "The other command to send the stream to.",
46            )
47            .category(Category::Filters)
48    }
49
50    fn examples(&self) -> Vec<Example> {
51        vec![
52            Example {
53                example: "http get http://example.org/ | tee { save example.html }",
54                description: "Save a webpage to a file while also printing it",
55                result: None,
56            },
57            Example {
58                example:
59                    "nu -c 'print -e error; print ok' | tee --stderr { save error.log } | complete",
60                description: "Save error messages from an external command to a file without \
61                    redirecting them",
62                result: None,
63            },
64            Example {
65                example: "1..100 | tee { each { print } } | math sum | wrap sum",
66                description: "Print numbers and their sum",
67                result: None,
68            },
69            Example {
70                example: "10000 | tee { 1..$in | print } | $in * 5",
71                description: "Do something with a value on another thread, while also passing through the value",
72                result: Some(Value::test_int(50000)),
73            }
74        ]
75    }
76
77    fn run(
78        &self,
79        engine_state: &EngineState,
80        stack: &mut Stack,
81        call: &Call,
82        input: PipelineData,
83    ) -> Result<PipelineData, ShellError> {
84        let head = call.head;
85        let from_io_error = IoError::factory(head, None);
86        let use_stderr = call.has_flag(engine_state, stack, "stderr")?;
87
88        let closure: Spanned<Closure> = call.req(engine_state, stack, 0)?;
89        let closure_span = closure.span;
90        let closure = closure.item;
91
92        let engine_state_arc = Arc::new(engine_state.clone());
93
94        let mut eval_block = {
95            let closure_engine_state = engine_state_arc.clone();
96            let mut closure_stack = stack
97                .captures_to_stack_preserve_out_dest(closure.captures)
98                .reset_pipes();
99            let eval_block_with_early_return = get_eval_block_with_early_return(engine_state);
100
101            move |input| {
102                let result = eval_block_with_early_return(
103                    &closure_engine_state,
104                    &mut closure_stack,
105                    closure_engine_state.get_block(closure.block_id),
106                    input,
107                );
108                // Make sure to drain any iterator produced to avoid unexpected behavior
109                result.and_then(|data| data.drain().map(|_| ()))
110            }
111        };
112
113        // Convert values that can be represented as streams into streams. Streams can pass errors
114        // through later, so if we treat string/binary/list as a stream instead, it's likely that
115        // we can get the error back to the original thread.
116        let span = input.span().unwrap_or(head);
117        let input = input
118            .try_into_stream(engine_state)
119            .unwrap_or_else(|original_input| original_input);
120
121        if let PipelineData::ByteStream(stream, metadata) = input {
122            let type_ = stream.type_();
123
124            let info = StreamInfo {
125                span,
126                signals: engine_state.signals().clone(),
127                type_,
128                metadata: metadata.clone(),
129            };
130
131            match stream.into_source() {
132                ByteStreamSource::Read(read) => {
133                    if use_stderr {
134                        return stderr_misuse(span, head);
135                    }
136
137                    let tee_thread = spawn_tee(info, eval_block)?;
138                    let tee = IoTee::new(read, tee_thread);
139
140                    Ok(PipelineData::ByteStream(
141                        ByteStream::read(tee, span, engine_state.signals().clone(), type_),
142                        metadata,
143                    ))
144                }
145                ByteStreamSource::File(file) => {
146                    if use_stderr {
147                        return stderr_misuse(span, head);
148                    }
149
150                    let tee_thread = spawn_tee(info, eval_block)?;
151                    let tee = IoTee::new(file, tee_thread);
152
153                    Ok(PipelineData::ByteStream(
154                        ByteStream::read(tee, span, engine_state.signals().clone(), type_),
155                        metadata,
156                    ))
157                }
158                #[cfg(feature = "os")]
159                ByteStreamSource::Child(mut child) => {
160                    let stderr_thread = if use_stderr {
161                        let stderr_thread = if let Some(stderr) = child.stderr.take() {
162                            let tee_thread = spawn_tee(info.clone(), eval_block)?;
163                            let tee = IoTee::new(stderr, tee_thread);
164                            match stack.stderr() {
165                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
166                                    child.stderr = Some(ChildPipe::Tee(Box::new(tee)));
167                                    Ok(None)
168                                }
169                                OutDest::Null => copy_on_thread(tee, io::sink(), &info).map(Some),
170                                OutDest::Print | OutDest::Inherit => {
171                                    copy_on_thread(tee, io::stderr(), &info).map(Some)
172                                }
173                                OutDest::File(file) => {
174                                    copy_on_thread(tee, file.clone(), &info).map(Some)
175                                }
176                            }?
177                        } else {
178                            None
179                        };
180
181                        if let Some(stdout) = child.stdout.take() {
182                            match stack.stdout() {
183                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
184                                    child.stdout = Some(stdout);
185                                    Ok(())
186                                }
187                                OutDest::Null => copy_pipe(stdout, io::sink(), &info),
188                                OutDest::Print | OutDest::Inherit => {
189                                    copy_pipe(stdout, io::stdout(), &info)
190                                }
191                                OutDest::File(file) => copy_pipe(stdout, file.as_ref(), &info),
192                            }?;
193                        }
194
195                        stderr_thread
196                    } else {
197                        let stderr_thread = if let Some(stderr) = child.stderr.take() {
198                            let info = info.clone();
199                            match stack.stderr() {
200                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
201                                    child.stderr = Some(stderr);
202                                    Ok(None)
203                                }
204                                OutDest::Null => {
205                                    copy_pipe_on_thread(stderr, io::sink(), &info).map(Some)
206                                }
207                                OutDest::Print | OutDest::Inherit => {
208                                    copy_pipe_on_thread(stderr, io::stderr(), &info).map(Some)
209                                }
210                                OutDest::File(file) => {
211                                    copy_pipe_on_thread(stderr, file.clone(), &info).map(Some)
212                                }
213                            }?
214                        } else {
215                            None
216                        };
217
218                        if let Some(stdout) = child.stdout.take() {
219                            let tee_thread = spawn_tee(info.clone(), eval_block)?;
220                            let tee = IoTee::new(stdout, tee_thread);
221                            match stack.stdout() {
222                                OutDest::Pipe | OutDest::PipeSeparate | OutDest::Value => {
223                                    child.stdout = Some(ChildPipe::Tee(Box::new(tee)));
224                                    Ok(())
225                                }
226                                OutDest::Null => copy(tee, io::sink(), &info),
227                                OutDest::Print | OutDest::Inherit => copy(tee, io::stdout(), &info),
228                                OutDest::File(file) => copy(tee, file.as_ref(), &info),
229                            }?;
230                        }
231
232                        stderr_thread
233                    };
234
235                    if child.stdout.is_some() || child.stderr.is_some() {
236                        Ok(PipelineData::ByteStream(
237                            ByteStream::child(*child, span),
238                            metadata,
239                        ))
240                    } else {
241                        if let Some(thread) = stderr_thread {
242                            thread.join().unwrap_or_else(|_| Err(panic_error()))?;
243                        }
244                        child.wait()?;
245                        Ok(PipelineData::Empty)
246                    }
247                }
248            }
249        } else {
250            if use_stderr {
251                return stderr_misuse(input.span().unwrap_or(head), head);
252            }
253
254            let metadata = input.metadata();
255            let metadata_clone = metadata.clone();
256
257            if matches!(input, PipelineData::ListStream(..)) {
258                // Only use the iterator implementation on lists / list streams. We want to be able
259                // to preserve errors as much as possible, and only the stream implementations can
260                // really do that
261                let signals = engine_state.signals().clone();
262
263                Ok(tee(input.into_iter(), move |rx| {
264                    let input = rx.into_pipeline_data_with_metadata(span, signals, metadata_clone);
265                    eval_block(input)
266                })
267                .map_err(&from_io_error)?
268                .map(move |result| result.unwrap_or_else(|err| Value::error(err, closure_span)))
269                .into_pipeline_data_with_metadata(
270                    span,
271                    engine_state.signals().clone(),
272                    metadata,
273                ))
274            } else {
275                // Otherwise, we can spawn a thread with the input value, but we have nowhere to
276                // send an error to other than just trying to print it to stderr.
277                let value = input.into_value(span)?;
278                let value_clone = value.clone();
279                tee_once(engine_state_arc, move || {
280                    eval_block(value_clone.into_pipeline_data_with_metadata(metadata_clone))
281                })
282                .map_err(&from_io_error)?;
283                Ok(value.into_pipeline_data_with_metadata(metadata))
284            }
285        }
286    }
287
288    fn pipe_redirection(&self) -> (Option<OutDest>, Option<OutDest>) {
289        (Some(OutDest::PipeSeparate), Some(OutDest::PipeSeparate))
290    }
291}
292
293fn panic_error() -> ShellError {
294    ShellError::NushellFailed {
295        msg: "A panic occurred on a thread spawned by `tee`".into(),
296    }
297}
298
299/// Copies the iterator to a channel on another thread. If an error is produced on that thread,
300/// it is embedded in the resulting iterator as an `Err` as soon as possible. When the iterator
301/// finishes, it waits for the other thread to finish, also handling any error produced at that
302/// point.
303fn tee<T>(
304    input: impl Iterator<Item = T>,
305    with_cloned_stream: impl FnOnce(mpsc::Receiver<T>) -> Result<(), ShellError> + Send + 'static,
306) -> Result<impl Iterator<Item = Result<T, ShellError>>, std::io::Error>
307where
308    T: Clone + Send + 'static,
309{
310    // For sending the values to the other thread
311    let (tx, rx) = mpsc::channel();
312
313    let mut thread = Some(
314        thread::Builder::new()
315            .name("tee".into())
316            .spawn(move || with_cloned_stream(rx))?,
317    );
318
319    let mut iter = input.into_iter();
320    let mut tx = Some(tx);
321
322    Ok(std::iter::from_fn(move || {
323        if thread.as_ref().is_some_and(|t| t.is_finished()) {
324            // Check for an error from the other thread
325            let result = thread
326                .take()
327                .expect("thread was taken early")
328                .join()
329                .unwrap_or_else(|_| Err(panic_error()));
330            if let Err(err) = result {
331                // Embed the error early
332                return Some(Err(err));
333            }
334        }
335
336        // Get a value from the iterator
337        if let Some(value) = iter.next() {
338            // Send a copy, ignoring any error if the channel is closed
339            let _ = tx.as_ref().map(|tx| tx.send(value.clone()));
340            Some(Ok(value))
341        } else {
342            // Close the channel so the stream ends for the other thread
343            drop(tx.take());
344            // Wait for the other thread, and embed any error produced
345            thread.take().and_then(|t| {
346                t.join()
347                    .unwrap_or_else(|_| Err(panic_error()))
348                    .err()
349                    .map(Err)
350            })
351        }
352    }))
353}
354
355/// "tee" for a single value. No stream handling, just spawns a thread, printing any resulting error
356fn tee_once(
357    engine_state: Arc<EngineState>,
358    on_thread: impl FnOnce() -> Result<(), ShellError> + Send + 'static,
359) -> Result<JoinHandle<()>, std::io::Error> {
360    thread::Builder::new().name("tee".into()).spawn(move || {
361        if let Err(err) = on_thread() {
362            report_shell_error(&engine_state, &err);
363        }
364    })
365}
366
367fn stderr_misuse<T>(span: Span, head: Span) -> Result<T, ShellError> {
368    Err(ShellError::UnsupportedInput {
369        msg: "--stderr can only be used on external commands".into(),
370        input: "the input to `tee` is not an external command".into(),
371        msg_span: head,
372        input_span: span,
373    })
374}
375
376struct IoTee<R: Read> {
377    reader: R,
378    sender: Option<Sender<Vec<u8>>>,
379    thread: Option<JoinHandle<Result<(), ShellError>>>,
380}
381
382impl<R: Read> IoTee<R> {
383    fn new(reader: R, tee: TeeThread) -> Self {
384        Self {
385            reader,
386            sender: Some(tee.sender),
387            thread: Some(tee.thread),
388        }
389    }
390}
391
392impl<R: Read> Read for IoTee<R> {
393    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
394        if let Some(thread) = self.thread.take() {
395            if thread.is_finished() {
396                if let Err(err) = thread.join().unwrap_or_else(|_| Err(panic_error())) {
397                    return Err(io::Error::other(err));
398                }
399            } else {
400                self.thread = Some(thread)
401            }
402        }
403        let len = self.reader.read(buf)?;
404        if len == 0 {
405            self.sender = None;
406            if let Some(thread) = self.thread.take() {
407                if let Err(err) = thread.join().unwrap_or_else(|_| Err(panic_error())) {
408                    return Err(io::Error::other(err));
409                }
410            }
411        } else if let Some(sender) = self.sender.as_mut() {
412            if sender.send(buf[..len].to_vec()).is_err() {
413                self.sender = None;
414            }
415        }
416        Ok(len)
417    }
418}
419
420struct TeeThread {
421    sender: Sender<Vec<u8>>,
422    thread: JoinHandle<Result<(), ShellError>>,
423}
424
425fn spawn_tee(
426    info: StreamInfo,
427    mut eval_block: impl FnMut(PipelineData) -> Result<(), ShellError> + Send + 'static,
428) -> Result<TeeThread, ShellError> {
429    let (sender, receiver) = mpsc::channel();
430
431    let thread = thread::Builder::new()
432        .name("tee".into())
433        .spawn(move || {
434            // We use Signals::empty() here because we assume there already is a Signals on the other side
435            let stream = ByteStream::from_iter(
436                receiver.into_iter(),
437                info.span,
438                Signals::empty(),
439                info.type_,
440            );
441            eval_block(PipelineData::ByteStream(stream, info.metadata))
442        })
443        .map_err(|err| {
444            IoError::new_with_additional_context(err.kind(), info.span, None, "Could not spawn tee")
445        })?;
446
447    Ok(TeeThread { sender, thread })
448}
449
450#[derive(Clone)]
451struct StreamInfo {
452    span: Span,
453    signals: Signals,
454    type_: ByteStreamType,
455    metadata: Option<PipelineMetadata>,
456}
457
458fn copy(src: impl Read, dest: impl Write, info: &StreamInfo) -> Result<(), ShellError> {
459    copy_with_signals(src, dest, info.span, &info.signals)?;
460    Ok(())
461}
462
463#[cfg(feature = "os")]
464fn copy_pipe(pipe: ChildPipe, dest: impl Write, info: &StreamInfo) -> Result<(), ShellError> {
465    match pipe {
466        ChildPipe::Pipe(pipe) => copy(pipe, dest, info),
467        ChildPipe::Tee(tee) => copy(tee, dest, info),
468    }
469}
470
471fn copy_on_thread(
472    src: impl Read + Send + 'static,
473    dest: impl Write + Send + 'static,
474    info: &StreamInfo,
475) -> Result<JoinHandle<Result<(), ShellError>>, ShellError> {
476    let span = info.span;
477    let signals = info.signals.clone();
478    thread::Builder::new()
479        .name("stderr copier".into())
480        .spawn(move || {
481            copy_with_signals(src, dest, span, &signals)?;
482            Ok(())
483        })
484        .map_err(|err| {
485            IoError::new_with_additional_context(
486                err.kind(),
487                span,
488                None,
489                "Could not spawn stderr copier",
490            )
491            .into()
492        })
493}
494
495#[cfg(feature = "os")]
496fn copy_pipe_on_thread(
497    pipe: ChildPipe,
498    dest: impl Write + Send + 'static,
499    info: &StreamInfo,
500) -> Result<JoinHandle<Result<(), ShellError>>, ShellError> {
501    match pipe {
502        ChildPipe::Pipe(pipe) => copy_on_thread(pipe, dest, info),
503        ChildPipe::Tee(tee) => copy_on_thread(tee, dest, info),
504    }
505}
506
507#[test]
508fn tee_copies_values_to_other_thread_and_passes_them_through() {
509    let (tx, rx) = mpsc::channel();
510
511    let expected_values = vec![1, 2, 3, 4];
512
513    let my_result = tee(expected_values.clone().into_iter(), move |rx| {
514        for val in rx {
515            let _ = tx.send(val);
516        }
517        Ok(())
518    })
519    .expect("io error")
520    .collect::<Result<Vec<i32>, ShellError>>()
521    .expect("should not produce error");
522
523    assert_eq!(expected_values, my_result);
524
525    let other_threads_result = rx.into_iter().collect::<Vec<_>>();
526
527    assert_eq!(expected_values, other_threads_result);
528}
529
530#[test]
531fn tee_forwards_errors_back_immediately() {
532    use std::time::Duration;
533    let slow_input = (0..100).inspect(|_| std::thread::sleep(Duration::from_millis(1)));
534    let iter = tee(slow_input, |_| {
535        Err(ShellError::Io(IoError::new_with_additional_context(
536            std::io::ErrorKind::Other,
537            Span::test_data(),
538            None,
539            "test",
540        )))
541    })
542    .expect("io error");
543    for result in iter {
544        if let Ok(val) = result {
545            // should not make it to the end
546            assert!(val < 99, "the error did not come early enough");
547        } else {
548            // got the error
549            return;
550        }
551    }
552    panic!("never received the error");
553}
554
555#[test]
556fn tee_waits_for_the_other_thread() {
557    use std::sync::{
558        atomic::{AtomicBool, Ordering},
559        Arc,
560    };
561    use std::time::Duration;
562    let waited = Arc::new(AtomicBool::new(false));
563    let waited_clone = waited.clone();
564    let iter = tee(0..100, move |_| {
565        std::thread::sleep(Duration::from_millis(10));
566        waited_clone.store(true, Ordering::Relaxed);
567        Err(ShellError::Io(IoError::new_with_additional_context(
568            std::io::ErrorKind::Other,
569            Span::test_data(),
570            None,
571            "test",
572        )))
573    })
574    .expect("io error");
575    let last = iter.last();
576    assert!(waited.load(Ordering::Relaxed), "failed to wait");
577    assert!(
578        last.is_some_and(|res| res.is_err()),
579        "failed to return error from wait"
580    );
581}