1use futures::FutureExt;
2use serde::{Deserialize, Serialize};
3use std::{error::Error, fmt, marker::PhantomData, sync::Mutex};
4
5use super::{
6 super::{
7 RemoteSendError, SendErrorExt,
8 base::{self, PortDeserializer, PortSerializer},
9 },
10 Receiver, Ref,
11 receiver::RecvError,
12};
13use crate::{RemoteSend, chmux, codec};
14
15#[derive(Clone, Debug, Serialize, Deserialize)]
17pub enum SendError {
18 Closed,
20 RemoteSend(base::SendErrorKind),
22 RemoteConnect(chmux::ConnectError),
24 RemoteListen(chmux::ListenerError),
26 RemoteForward,
28}
29
30impl SendError {
31 pub fn is_closed(&self) -> bool {
33 matches!(self, Self::Closed)
34 }
35
36 pub fn is_disconnected(&self) -> bool {
38 !matches!(self, Self::RemoteSend(base::SendErrorKind::Serialize(_)))
39 }
40
41 pub fn is_final(&self) -> bool {
43 match self {
44 Self::RemoteSend(err) => err.is_final(),
45 Self::Closed | Self::RemoteConnect(_) | Self::RemoteListen(_) | Self::RemoteForward => true,
46 }
47 }
48
49 pub fn is_item_specific(&self) -> bool {
51 matches!(self, Self::RemoteSend(err) if err.is_item_specific())
52 }
53}
54
55impl SendErrorExt for SendError {
56 fn is_closed(&self) -> bool {
57 self.is_closed()
58 }
59
60 fn is_disconnected(&self) -> bool {
61 self.is_disconnected()
62 }
63
64 fn is_final(&self) -> bool {
65 self.is_final()
66 }
67
68 fn is_item_specific(&self) -> bool {
69 self.is_item_specific()
70 }
71}
72
73impl fmt::Display for SendError {
74 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75 match self {
76 Self::Closed => write!(f, "channel is closed"),
77 Self::RemoteSend(err) => write!(f, "send error: {err}"),
78 Self::RemoteConnect(err) => write!(f, "connect error: {err}"),
79 Self::RemoteListen(err) => write!(f, "listen error: {err}"),
80 Self::RemoteForward => write!(f, "forwarding error"),
81 }
82 }
83}
84
85impl Error for SendError {}
86
87impl From<RemoteSendError> for SendError {
88 fn from(err: RemoteSendError) -> Self {
89 match err {
90 RemoteSendError::Send(err) => Self::RemoteSend(err),
91 RemoteSendError::Connect(err) => Self::RemoteConnect(err),
92 RemoteSendError::Listen(err) => Self::RemoteListen(err),
93 RemoteSendError::Forward => Self::RemoteForward,
94 RemoteSendError::Closed => Self::Closed,
95 }
96 }
97}
98
99pub struct Sender<T, Codec = codec::Default> {
103 inner: Option<SenderInner<T, Codec>>,
104 successor_tx: Mutex<Option<tokio::sync::oneshot::Sender<SenderInner<T, Codec>>>>,
105}
106
107impl<T, Codec> fmt::Debug for Sender<T, Codec> {
108 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
109 f.debug_struct("Sender").finish()
110 }
111}
112
113pub(crate) struct SenderInner<T, Codec> {
114 tx: tokio::sync::watch::Sender<Result<T, RecvError>>,
115 remote_send_err_tx: tokio::sync::mpsc::UnboundedSender<RemoteSendError>,
116 remote_send_err_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<RemoteSendError>>,
117 current_err: Mutex<Option<RemoteSendError>>,
118 max_item_size: usize,
119 _codec: PhantomData<Codec>,
120}
121
122#[derive(Serialize, Deserialize)]
124pub(crate) struct TransportedSender<T, Codec> {
125 port: u32,
127 data: Result<T, RecvError>,
129 codec: PhantomData<Codec>,
131 #[serde(default = "default_max_item_size")]
133 max_item_size: u64,
134}
135
136const fn default_max_item_size() -> u64 {
137 u64::MAX
138}
139
140impl<T, Codec> Sender<T, Codec>
141where
142 T: Send + 'static,
143{
144 pub(crate) fn new(
146 tx: tokio::sync::watch::Sender<Result<T, RecvError>>,
147 remote_send_err_tx: tokio::sync::mpsc::UnboundedSender<RemoteSendError>,
148 remote_send_err_rx: tokio::sync::mpsc::UnboundedReceiver<RemoteSendError>, max_item_size: usize,
149 ) -> Self {
150 let inner = SenderInner {
151 tx,
152 remote_send_err_tx,
153 remote_send_err_rx: Mutex::new(remote_send_err_rx),
154 current_err: Mutex::new(None),
155 max_item_size,
156 _codec: PhantomData,
157 };
158 Self { inner: Some(inner), successor_tx: Mutex::new(None) }
159 }
160
161 pub fn send(&self, value: T) -> Result<(), SendError> {
170 match self.inner.as_ref().unwrap().tx.send(Ok(value)) {
171 Ok(()) => Ok(()),
172 Err(_) => match self.error() {
173 Some(err) => Err(err),
174 None => Err(SendError::Closed),
175 },
176 }
177 }
178
179 pub fn send_modify<F>(&self, func: F)
187 where
188 F: FnOnce(&mut T),
189 {
190 self.inner.as_ref().unwrap().tx.send_modify(move |v| func(v.as_mut().unwrap()))
191 }
192
193 pub fn send_replace(&self, value: T) -> T {
199 self.inner.as_ref().unwrap().tx.send_replace(Ok(value)).unwrap()
200 }
201
202 pub fn borrow(&self) -> Ref<'_, T> {
204 Ref(self.inner.as_ref().unwrap().tx.borrow())
205 }
206
207 pub async fn closed(&self) {
209 self.inner.as_ref().unwrap().tx.closed().await
210 }
211
212 pub fn is_closed(&self) -> bool {
214 self.inner.as_ref().unwrap().tx.is_closed()
215 }
216
217 pub fn subscribe(&self) -> Receiver<T, Codec> {
219 let inner = self.inner.as_ref().unwrap();
220 Receiver::new(inner.tx.subscribe(), inner.remote_send_err_tx.clone(), None)
221 }
222
223 fn update_error(&self) {
224 let inner = self.inner.as_ref().unwrap();
225 let mut current_err = inner.current_err.lock().unwrap();
226 if current_err.is_some() {
227 return;
228 }
229
230 let mut remote_send_err_rx = inner.remote_send_err_rx.lock().unwrap();
231 if let Ok(err) = remote_send_err_rx.try_recv() {
232 *current_err = Some(err);
233 }
234 }
235
236 pub fn error(&self) -> Option<SendError> {
242 self.update_error();
243
244 let inner = self.inner.as_ref().unwrap();
245 let current_err = inner.current_err.lock().unwrap();
246 current_err.clone().map(|err| err.into())
247 }
248
249 pub fn clear_error(&mut self) {
251 self.update_error();
252
253 let inner = self.inner.as_ref().unwrap();
254 let mut current_err = inner.current_err.lock().unwrap();
255 *current_err = None;
256 }
257
258 pub fn check(&mut self) -> Result<(), SendError> {
271 while let Some(err) = self.error() {
272 if err.is_item_specific() {
273 return Err(err);
274 }
275 self.clear_error();
276 }
277 Ok(())
278 }
279
280 pub fn max_item_size(&self) -> usize {
282 self.inner.as_ref().unwrap().max_item_size
283 }
284
285 pub fn set_max_item_size(&mut self, max_item_size: usize) {
287 self.inner.as_mut().unwrap().max_item_size = max_item_size;
288 }
289}
290
291impl<T, Codec> Drop for Sender<T, Codec> {
292 fn drop(&mut self) {
293 if let Some(successor_tx) = self.successor_tx.lock().unwrap().take() {
294 let _ = successor_tx.send(self.inner.take().unwrap());
295 }
296 }
297}
298
299impl<T, Codec> Serialize for Sender<T, Codec>
300where
301 T: RemoteSend + Sync + Clone,
302 Codec: codec::Codec,
303{
304 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
306 where
307 S: serde::Serializer,
308 {
309 let max_item_size = self.max_item_size();
310
311 let (successor_tx, successor_rx) = tokio::sync::oneshot::channel();
313 *self.successor_tx.lock().unwrap() = Some(successor_tx);
314
315 let port = PortSerializer::connect(move |connect| {
316 async move {
317 let SenderInner { tx, remote_send_err_rx, current_err, .. } = match successor_rx.await {
319 Ok(inner) => inner,
320 Err(_) => return,
321 };
322 let remote_send_err_rx = remote_send_err_rx.into_inner().unwrap();
323 let current_err = current_err.into_inner().unwrap();
324
325 let (raw_tx, raw_rx) = match connect.await {
327 Ok(tx_rx) => tx_rx,
328 Err(err) => {
329 let _ = tx.send(Err(RecvError::RemoteConnect(err)));
330 return;
331 }
332 };
333
334 super::recv_impl::<T, Codec>(tx, raw_tx, raw_rx, remote_send_err_rx, current_err, max_item_size)
335 .await;
336 }
337 .boxed()
338 })?;
339
340 let data = self.inner.as_ref().unwrap().tx.borrow().clone();
342 let transported = TransportedSender::<T, Codec> {
343 port,
344 data,
345 max_item_size: max_item_size.try_into().unwrap_or(u64::MAX),
346 codec: PhantomData,
347 };
348 transported.serialize(serializer)
349 }
350}
351
352impl<'de, T, Codec> Deserialize<'de> for Sender<T, Codec>
353where
354 T: RemoteSend + Sync + Clone,
355 Codec: codec::Codec,
356{
357 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
359 where
360 D: serde::Deserializer<'de>,
361 {
362 let TransportedSender { port, data, max_item_size, .. } =
364 TransportedSender::<T, Codec>::deserialize(deserializer)?;
365 let max_item_size = usize::try_from(max_item_size).unwrap_or(usize::MAX);
366 if data.is_err() {
367 return Err(serde::de::Error::custom("received watch data with error"));
368 }
369
370 let (tx, rx) = tokio::sync::watch::channel(data);
372 let (remote_send_err_tx, remote_send_err_rx) = tokio::sync::mpsc::unbounded_channel();
373 let remote_send_err_tx2 = remote_send_err_tx.clone();
374
375 PortDeserializer::accept(port, move |local_port, request| {
377 async move {
378 let (raw_tx, raw_rx) = match request.accept_from(local_port).await {
380 Ok(tx_rx) => tx_rx,
381 Err(err) => {
382 let _ = remote_send_err_tx.send(RemoteSendError::Listen(err));
383 return;
384 }
385 };
386
387 super::send_impl::<T, Codec>(rx, raw_tx, raw_rx, remote_send_err_tx, max_item_size).await;
388 }
389 .boxed()
390 })?;
391
392 Ok(Self::new(tx, remote_send_err_tx2, remote_send_err_rx, max_item_size))
393 }
394}