Skip to main content

pipe_io/
sink.rs

1//! The [`Sink`] trait and built-in sink adapters.
2
3use crate::error::StageError;
4use crate::source::Infallible;
5
6#[cfg(feature = "std")]
7use alloc::vec::Vec;
8use core::marker::PhantomData;
9
10/// Terminal consumer at the tail of a pipeline.
11pub trait Sink {
12    /// Type of item this sink accepts.
13    type Item;
14    /// Error type the sink can return.
15    type Error: StageError;
16
17    /// Write a single item.
18    ///
19    /// # Errors
20    ///
21    /// Returns `Err(Self::Error)` on write failure. The driver wraps
22    /// this in [`crate::Error::Sink`].
23    fn write(&mut self, item: Self::Item) -> Result<(), Self::Error>;
24
25    /// Flush any buffered output. Default impl does nothing.
26    ///
27    /// # Errors
28    ///
29    /// Returns `Err(Self::Error)` on flush failure.
30    fn flush(&mut self) -> Result<(), Self::Error> {
31        Ok(())
32    }
33
34    /// Release any resources. Default impl does nothing.
35    ///
36    /// # Errors
37    ///
38    /// Returns `Err(Self::Error)` on shutdown failure.
39    fn close(&mut self) -> Result<(), Self::Error> {
40        Ok(())
41    }
42}
43
44/// Sink that discards every item.
45///
46/// # Example
47///
48/// ```
49/// use pipe_io::sink::{NullSink, Sink};
50///
51/// let mut s: NullSink<u32> = NullSink::new();
52/// s.write(42).unwrap();
53/// ```
54pub struct NullSink<T> {
55    _marker: PhantomData<fn(T)>,
56}
57
58impl<T> NullSink<T> {
59    /// Construct a new null sink.
60    #[must_use]
61    pub const fn new() -> Self {
62        Self {
63            _marker: PhantomData,
64        }
65    }
66}
67
68impl<T> Default for NullSink<T> {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl<T: 'static> Sink for NullSink<T> {
75    type Item = T;
76    type Error = Infallible;
77
78    fn write(&mut self, _item: Self::Item) -> Result<(), Self::Error> {
79        Ok(())
80    }
81}
82
83/// Sink that pushes every item into an internal `Vec`. Useful in tests.
84///
85/// The collected items live behind a [`SharedHandle`] returned by
86/// [`VecSink::handle`]. The handle is `Send + Sync` (via internal
87/// `Arc<Mutex<_>>`) and outlives the sink, so callers can drop the
88/// sink and still read the result.
89#[cfg(feature = "std")]
90#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
91pub struct VecSink<T> {
92    inner: SharedHandle<T>,
93}
94
95#[cfg(feature = "std")]
96impl<T> VecSink<T> {
97    /// Construct a new vec sink.
98    #[must_use]
99    pub fn new() -> Self {
100        Self {
101            inner: SharedHandle::new(),
102        }
103    }
104
105    /// Return a cloneable handle to the underlying storage. Items
106    /// written to the sink appear in the handle.
107    #[must_use]
108    pub fn handle(&self) -> SharedHandle<T> {
109        self.inner.clone()
110    }
111}
112
113#[cfg(feature = "std")]
114impl<T> Default for VecSink<T> {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120#[cfg(feature = "std")]
121impl<T: Send + 'static> Sink for VecSink<T> {
122    type Item = T;
123    type Error = Infallible;
124
125    fn write(&mut self, item: Self::Item) -> Result<(), Self::Error> {
126        self.inner.push(item);
127        Ok(())
128    }
129}
130
131/// Sink adapter over a closure.
132///
133/// # Example
134///
135/// ```
136/// use pipe_io::sink::{FnSink, Sink};
137///
138/// let mut total = 0u64;
139/// {
140///     let mut s = FnSink::new(|item: u64| -> Result<(), &'static str> {
141///         total += item;
142///         Ok(())
143///     });
144///     s.write(1).unwrap();
145///     s.write(2).unwrap();
146///     s.write(3).unwrap();
147/// }
148/// assert_eq!(total, 6);
149/// ```
150pub struct FnSink<F, T, E> {
151    func: F,
152    _marker: PhantomData<fn(T) -> Result<(), E>>,
153}
154
155impl<F, T, E> FnSink<F, T, E>
156where
157    F: FnMut(T) -> Result<(), E>,
158{
159    /// Wrap a closure into a sink.
160    pub fn new(func: F) -> Self {
161        Self {
162            func,
163            _marker: PhantomData,
164        }
165    }
166}
167
168impl<F, T, E> Sink for FnSink<F, T, E>
169where
170    F: FnMut(T) -> Result<(), E>,
171    E: StageError,
172    T: 'static,
173{
174    type Item = T;
175    type Error = E;
176
177    fn write(&mut self, item: Self::Item) -> Result<(), Self::Error> {
178        (self.func)(item)
179    }
180}
181
182/// Sink adapter over an [`std::sync::mpsc::SyncSender`] (bounded).
183#[cfg(feature = "std")]
184#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
185pub struct ChannelSink<T> {
186    tx: std::sync::mpsc::SyncSender<T>,
187}
188
189#[cfg(feature = "std")]
190impl<T> ChannelSink<T> {
191    /// Wrap a sender into a sink.
192    #[must_use]
193    pub fn new(tx: std::sync::mpsc::SyncSender<T>) -> Self {
194        Self { tx }
195    }
196}
197
198#[cfg(feature = "std")]
199impl<T: 'static + Send> Sink for ChannelSink<T> {
200    type Item = T;
201    type Error = ChannelSinkError;
202
203    fn write(&mut self, item: Self::Item) -> Result<(), Self::Error> {
204        self.tx
205            .send(item)
206            .map_err(|_| ChannelSinkError::Disconnected)
207    }
208}
209
210/// Error returned by [`ChannelSink`] when the receiver has been dropped.
211#[cfg(feature = "std")]
212#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
213#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
214pub enum ChannelSinkError {
215    /// The receiving end of the channel has been dropped.
216    Disconnected,
217}
218
219#[cfg(feature = "std")]
220impl core::fmt::Display for ChannelSinkError {
221    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
222        match self {
223            Self::Disconnected => f.write_str("channel sink disconnected"),
224        }
225    }
226}
227
228#[cfg(feature = "std")]
229impl std::error::Error for ChannelSinkError {}
230
231/// Sink that line-writes `String` items into any [`std::io::Write`].
232/// Each item is followed by `\n`. Upstream stages can convert via
233/// `.map(|x| x.to_string())` for any [`core::fmt::Display`] type.
234#[cfg(feature = "std")]
235#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
236pub struct WriterSink<W: std::io::Write> {
237    writer: W,
238}
239
240#[cfg(feature = "std")]
241impl<W: std::io::Write> WriterSink<W> {
242    /// Wrap a writer into a sink.
243    pub fn new(writer: W) -> Self {
244        Self { writer }
245    }
246
247    /// Consume the sink and return the wrapped writer.
248    pub fn into_inner(self) -> W {
249        self.writer
250    }
251}
252
253#[cfg(feature = "std")]
254impl<W> Sink for WriterSink<W>
255where
256    W: std::io::Write + Send + 'static,
257{
258    type Item = alloc::string::String;
259    type Error = std::io::Error;
260
261    fn write(&mut self, item: Self::Item) -> Result<(), Self::Error> {
262        writeln!(self.writer, "{item}")
263    }
264
265    fn flush(&mut self) -> Result<(), Self::Error> {
266        std::io::Write::flush(&mut self.writer)
267    }
268}
269
270// ---------------------------------------------------------------------
271// SharedHandle: thread-safe (std) or single-threaded (no_std) storage.
272// ---------------------------------------------------------------------
273
274/// Cloneable handle to the items collected by a [`VecSink`].
275#[cfg(feature = "std")]
276pub struct SharedHandle<T> {
277    inner: std::sync::Arc<std::sync::Mutex<Vec<T>>>,
278}
279
280#[cfg(feature = "std")]
281impl<T> SharedHandle<T> {
282    fn new() -> Self {
283        Self {
284            inner: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
285        }
286    }
287
288    fn push(&self, item: T) {
289        self.inner
290            .lock()
291            .expect("VecSink mutex poisoned")
292            .push(item);
293    }
294
295    /// Drain the collected items.
296    pub fn take(&self) -> Vec<T> {
297        let mut guard = self.inner.lock().expect("VecSink mutex poisoned");
298        core::mem::take(&mut *guard)
299    }
300
301    /// Number of items currently buffered.
302    pub fn len(&self) -> usize {
303        self.inner.lock().expect("VecSink mutex poisoned").len()
304    }
305
306    /// True if no items are currently buffered.
307    pub fn is_empty(&self) -> bool {
308        self.len() == 0
309    }
310}
311
312#[cfg(feature = "std")]
313impl<T> Clone for SharedHandle<T> {
314    fn clone(&self) -> Self {
315        Self {
316            inner: self.inner.clone(),
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    #[cfg(feature = "std")]
325    use alloc::vec;
326
327    #[test]
328    fn null_sink_discards() {
329        let mut s: NullSink<i32> = NullSink::new();
330        s.write(1).unwrap();
331        s.write(2).unwrap();
332        s.flush().unwrap();
333        s.close().unwrap();
334    }
335
336    #[cfg(feature = "std")]
337    #[test]
338    fn vec_sink_collects() {
339        let mut s = VecSink::<i32>::new();
340        let h = s.handle();
341        s.write(1).unwrap();
342        s.write(2).unwrap();
343        s.write(3).unwrap();
344        assert_eq!(h.len(), 3);
345        assert_eq!(h.take(), vec![1, 2, 3]);
346        assert!(h.is_empty());
347    }
348
349    #[test]
350    fn fn_sink_invokes_closure() {
351        let mut count = 0u32;
352        {
353            let mut s: FnSink<_, u32, &'static str> = FnSink::new(|n: u32| {
354                count += n;
355                Ok(())
356            });
357            s.write(2).unwrap();
358            s.write(3).unwrap();
359        }
360        assert_eq!(count, 5);
361    }
362
363    #[cfg(feature = "std")]
364    #[test]
365    fn channel_sink_sends_until_disconnect() {
366        let (tx, rx) = std::sync::mpsc::sync_channel::<i32>(4);
367        let mut s = ChannelSink::new(tx);
368        s.write(1).unwrap();
369        s.write(2).unwrap();
370        drop(rx);
371        assert!(s.write(3).is_err());
372    }
373
374    #[cfg(feature = "std")]
375    #[test]
376    fn writer_sink_line_writes() {
377        use std::io::Cursor;
378        let buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
379        let mut s = WriterSink::<Cursor<Vec<u8>>>::new(buf);
380        s.write("alpha".into()).unwrap();
381        s.write("beta".into()).unwrap();
382        let cur = s.into_inner();
383        let body = String::from_utf8(cur.into_inner()).unwrap();
384        assert_eq!(body, "alpha\nbeta\n");
385    }
386}