1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
62#![forbid(unsafe_code)]
63
64use std::io::{self, Write};
65#[cfg(feature = "async-runtime-tokio")]
66use std::pin::{pin, Pin};
67#[cfg(feature = "async-runtime-tokio")]
68use std::task::{Context, Poll};
69
70use chksum_core::Hash;
71#[cfg(feature = "async-runtime-tokio")]
72use tokio::io::{AsyncWrite, AsyncWriteExt};
73
74pub fn new<H>(inner: impl Write) -> Writer<impl Write, H>
76where
77 H: Hash,
78{
79 Writer::new(inner)
80}
81
82pub fn with_hash<H>(inner: impl Write, hash: H) -> Writer<impl Write, H>
84where
85 H: Hash,
86{
87 Writer::with_hash(inner, hash)
88}
89
90#[cfg(feature = "async-runtime-tokio")]
91pub fn async_new<H>(inner: impl AsyncWrite) -> AsyncWriter<impl AsyncWrite, H>
93where
94 H: Hash,
95{
96 AsyncWriter::new(inner)
97}
98
99#[cfg(feature = "async-runtime-tokio")]
100pub fn async_with_hash<H>(inner: impl AsyncWrite, hash: H) -> AsyncWriter<impl AsyncWrite, H>
102where
103 H: Hash,
104{
105 AsyncWriter::with_hash(inner, hash)
106}
107
108#[derive(Clone, Debug, PartialEq, Eq)]
110pub struct Writer<W, H>
111where
112 W: Write,
113 H: Hash,
114{
115 inner: W,
116 hash: H,
117}
118
119impl<W, H> Writer<W, H>
120where
121 W: Write,
122 H: Hash,
123{
124 pub fn new(inner: W) -> Self {
126 let hash = H::default();
127 Self::with_hash(inner, hash)
128 }
129
130 #[must_use]
132 pub const fn with_hash(inner: W, hash: H) -> Self {
133 Self { inner, hash }
134 }
135
136 #[must_use]
138 pub fn into_inner(self) -> W {
139 let Self { inner, .. } = self;
140 inner
141 }
142
143 #[must_use]
145 pub fn digest(&self) -> H::Digest {
146 self.hash.digest()
147 }
148}
149
150impl<W, H> Write for Writer<W, H>
151where
152 W: Write,
153 H: Hash,
154{
155 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
156 let n = self.inner.write(buf)?;
157 self.hash.update(&buf[..n]);
158 Ok(n)
159 }
160
161 fn flush(&mut self) -> io::Result<()> {
162 self.inner.flush()
163 }
164}
165
166#[cfg(feature = "async-runtime-tokio")]
168#[derive(Clone, Debug, PartialEq, Eq)]
169pub struct AsyncWriter<W, H>
170where
171 W: AsyncWriteExt,
172 H: Hash,
173{
174 inner: W,
175 hash: H,
176}
177
178#[cfg(feature = "async-runtime-tokio")]
179impl<W, H> AsyncWriter<W, H>
180where
181 W: AsyncWriteExt,
182 H: Hash,
183{
184 pub fn new(inner: W) -> Self {
186 let hash = H::default();
187 Self::with_hash(inner, hash)
188 }
189
190 #[must_use]
192 pub const fn with_hash(inner: W, hash: H) -> Self {
193 Self { inner, hash }
194 }
195
196 #[must_use]
198 pub fn into_inner(self) -> W {
199 let Self { inner, .. } = self;
200 inner
201 }
202
203 #[must_use]
205 pub fn digest(&self) -> H::Digest {
206 self.hash.digest()
207 }
208}
209
210#[cfg(feature = "async-runtime-tokio")]
211impl<W, H> AsyncWrite for AsyncWriter<W, H>
212where
213 W: AsyncWrite + Unpin,
214 H: Hash + Unpin,
215{
216 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
217 let Self { inner, hash } = self.get_mut();
218 match pin!(inner).poll_write(cx, buf) {
219 Poll::Ready(Ok(n)) => {
220 hash.update(&buf[..n]);
221 Poll::Ready(Ok(n))
222 },
223 poll => poll,
224 }
225 }
226
227 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
228 let Self { inner, .. } = self.get_mut();
229 pin!(inner).poll_flush(cx)
230 }
231
232 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
233 let Self { inner, .. } = self.get_mut();
234 pin!(inner).poll_shutdown(cx)
235 }
236}