snocat 0.7.0

Streaming Network Overlay Connection Arbitration Tunnel
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license OR Apache 2.0
use std::{fmt::Debug, sync::Arc};

use futures::{
  future::{BoxFuture, FutureExt},
  Future,
};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing_futures::Instrument;

use crate::util::tunnel_stream::TunnelStream;

use super::{traits::ServiceRegistry, tunnel::Tunnel, RouteAddress, Service, ServiceError};

/// Identifies the SNOCAT protocol over a stream
pub const SNOCAT_NEGOTIATION_MAGIC: &[u8; 4] = &[0x4e, 0x59, 0x41, 0x4e]; // UTF-8 "NYAN"

#[derive(thiserror::Error, Debug)]
pub enum NegotiationError<ApplicationError> {
  #[error("Stream read failed")]
  ReadError,
  #[error("Stream write failed")]
  WriteError,
  #[error("Protocol violated by remote")]
  ProtocolViolation,
  #[error("Protocol refused")]
  Refused,
  #[error("Protocol version not supported")]
  UnsupportedProtocolVersion,
  #[error("Service version not supported")]
  UnsupportedServiceVersion,
  #[error("Negotiation application error: {0:?}")]
  ApplicationError(ApplicationError),
  #[error("Negotiation fatal error: {0:?}")]
  FatalError(ApplicationError),
}

impl<ApplicationError> NegotiationError<ApplicationError> {
  pub fn map_err<F, TErr>(self, f: F) -> NegotiationError<F::Output>
  where
    F: FnOnce(ApplicationError) -> TErr,
  {
    match self {
      NegotiationError::ReadError => NegotiationError::ReadError,
      NegotiationError::WriteError => NegotiationError::WriteError,
      NegotiationError::ProtocolViolation => NegotiationError::ProtocolViolation,
      NegotiationError::Refused => NegotiationError::Refused,
      NegotiationError::UnsupportedProtocolVersion => NegotiationError::UnsupportedProtocolVersion,
      NegotiationError::UnsupportedServiceVersion => NegotiationError::UnsupportedServiceVersion,
      NegotiationError::ApplicationError(e) => NegotiationError::ApplicationError(f(e)),
      NegotiationError::FatalError(e) => NegotiationError::FatalError(f(e)),
    }
  }

  pub fn err_into<TNewErr: From<ApplicationError>>(self) -> NegotiationError<TNewErr> {
    self.map_err(TNewErr::from)
  }
}

impl<SourceError: Into<OutError>, OutError> From<NegotiationError<SourceError>>
  for ServiceError<OutError>
{
  fn from(e: NegotiationError<SourceError>) -> Self {
    match e {
      NegotiationError::ReadError => ServiceError::UnexpectedEnd,
      NegotiationError::WriteError => ServiceError::UnexpectedEnd,
      NegotiationError::ProtocolViolation => ServiceError::IllegalResponse,
      NegotiationError::Refused => ServiceError::Refused,
      NegotiationError::UnsupportedProtocolVersion => ServiceError::Refused,
      NegotiationError::UnsupportedServiceVersion => ServiceError::Refused,
      NegotiationError::ApplicationError(e) => ServiceError::InternalError(e.into()),
      NegotiationError::FatalError(e) => ServiceError::InternalError(e.into()),
    }
  }
}

/// Write future to send our magic and version to the remote,
/// returning an error if writes are refused by the stream.
#[tracing::instrument(level = tracing::Level::TRACE, err, skip(stream))]
async fn write_magic_and_version<S: AsyncWrite + Send + Unpin, AE: Debug>(
  mut stream: S,
  protocol_version: u8,
) -> Result<S, NegotiationError<AE>> {
  stream
    .write_all(SNOCAT_NEGOTIATION_MAGIC)
    .await
    .map_err(|_| NegotiationError::WriteError)?;
  stream
    .write_u8(protocol_version)
    .await
    .map_err(|_| NegotiationError::WriteError)?;
  stream
    .flush()
    .await
    .map_err(|_| NegotiationError::WriteError)?;
  Result::<S, NegotiationError<AE>>::Ok(stream)
}

// Note: Protocol v0 is symmetric until negotiation handshake completes
fn protocol_magic<'a, S: TunnelStream + Send + 'a, AE: Debug + 'a>(
  stream: S,
  protocol_version: u8,
) -> impl Future<Output = Result<u8, NegotiationError<AE>>> + 'a {
  let (mut read, write) = tokio::io::split(stream);
  // Write future to send our magic and version to the remote,
  // returning an error if writes are refused by the stream.
  let send_magic = write_magic_and_version(write, protocol_version);
  // Read future to get the magic from the remote, returning an error on magic mismatch
  let read_magic = async {
    let mut remote_magic = [0u8; 4];
    let remote_magic_len = read
      .read_exact(&mut remote_magic)
      .await
      .map_err(|_| NegotiationError::ProtocolViolation)?;
    if remote_magic_len < remote_magic.len() || &remote_magic != SNOCAT_NEGOTIATION_MAGIC {
      tracing::trace!("magic mismatch");
      return Err(NegotiationError::ProtocolViolation);
    }
    tracing::trace!("magic matched expectation");
    Result::<_, NegotiationError<AE>>::Ok(read)
  };

  async move {
    let (read, write) = futures::future::try_join(read_magic, send_magic).await?;
    let mut stream = read.unsplit(write);
    let remote_version = stream
      .read_u8()
      .await
      .map_err(|_| NegotiationError::ReadError)?;
    Ok(remote_version)
  }
  .instrument(tracing::trace_span!(
    stringify!(protocol_magic),
    ?protocol_version
  ))
}

pub struct NegotiationClient;

impl NegotiationClient {
  pub fn new() -> Self {
    Self {}
  }

  pub fn negotiate<'stream, S, AE: Debug + 'stream>(
    self,
    addr: RouteAddress,
    mut link: S,
  ) -> impl Future<Output = Result<S, NegotiationError<AE>>> + 'stream
  where
    S: TunnelStream + Send + 'stream,
    for<'a> &'a mut S: TunnelStream + Send,
  {
    const LOCAL_PROTOCOL_VERSION: u8 = 0;
    let negotiation_span = tracing::trace_span!("protocol_negotiation_client", addr=?addr);
    async move {
      // Absolute most-basic negotiation protocol - sends the address in a frame and waits for 0u8-or-fail

      tracing::trace!("performing negotiation protocol handshake");
      let remote_version = protocol_magic::<&mut S, AE>(&mut link, LOCAL_PROTOCOL_VERSION).await?;
      // TODO: Consider adding a confirmation for negotiation protocol acceptance here

      // TODO: support multiple versions of negotiation mechanism
      if remote_version > 0 {
        // We don't support anything beyond this basic protocol yet
        tracing::trace!(
          version = remote_version,
          "unsupported remote protocol version"
        );
        return Err(NegotiationError::UnsupportedProtocolVersion);
      }

      // TODO: support service/client protocol versioning (May require Service/Client cooperation)

      tracing::trace!("writing address");
      // Write address to the remote, and see if the requested protocol is supported
      crate::util::framed::write_frame(&mut link, &addr.into_bytes())
        .await
        .map_err(|_| NegotiationError::WriteError)?;

      tracing::trace!("awaiting remote protocol service acceptance");
      // Await acceptance of address by a service, or refusal if none are compatible
      let accepted = link
        .read_u8()
        .await
        .map_err(|_| NegotiationError::ReadError)?;
      if accepted > 0 {
        // For v0, this byte doesn't carry any useful info beyond accepted or not
        tracing::trace!(
          code = accepted,
          "address refused by remote protocol services"
        );
        Err(NegotiationError::Refused)
      } else {
        tracing::trace!("address accepted by remote protocol services");
        Ok(link)
      }
    }
    .instrument(negotiation_span)
  }
}

pub struct NegotiationService<ServiceRegistry: ?Sized> {
  service_registry: Arc<ServiceRegistry>,
}

pub type ArcService<TServiceError> =
  Arc<dyn Service<Error = TServiceError> + Send + Sync + 'static>;

impl<R: ?Sized> NegotiationService<R> {
  pub fn new(service_registry: Arc<R>) -> Self {
    Self { service_registry }
  }
}

impl<R> NegotiationService<R>
where
  R: ServiceRegistry + Send + Sync + ?Sized,
{
  /// Performs negotiation, returning the stream if successful
  ///
  /// If the negotiation task is dropped, the stream is dropped in an indeterminate state.
  /// In scenarios involving an owned stream, this will drop the stream, otherwise the
  /// other end of the stream may be at an unknown point in the protocol. As such, any
  /// timeout mechanism here must not expect to resume the stream after a ref drop.
  pub fn negotiate<'stream, S, TTunnel>(
    &self,
    mut link: S,
    tunnel: TTunnel,
  ) -> BoxFuture<
    'stream,
    Result<
      (S, RouteAddress, ArcService<<R as ServiceRegistry>::Error>),
      NegotiationError<anyhow::Error>,
    >,
  >
  where
    R: 'stream,
    S: TunnelStream + Send + 'stream,
    for<'a> &'a mut S: TunnelStream + Send,
    TTunnel: Tunnel + 'static,
  {
    const CURRENT_PROTOCOL_VERSION: u8 = 0u8;
    let service_registry = Arc::clone(&self.service_registry);
    let tunnel_id = *tunnel.id();
    async move {
      tracing::trace!("performing negotiation protocol handshake");
      let remote_version = protocol_magic(&mut link, CURRENT_PROTOCOL_VERSION).await?;
      // TODO: Consider adding a confirmation for negotiation protocol acceptance here

      if remote_version > 0 {
        // This should map to multiple supported versions, where possible
        tracing::trace!(
          version = remote_version,
          "unsupported remote protocol version"
        );
        return Err(NegotiationError::UnsupportedProtocolVersion);
      }

      let addr: RouteAddress = crate::util::framed::read_frame(&mut link, Some(2048))
        .await
        .map_err(|_| NegotiationError::ProtocolViolation) // Address must be sent as a frame in v0
        // Addresses must be valid UTF-8
        .and_then(|raw| String::from_utf8(raw).map_err(|_| NegotiationError::ProtocolViolation))
        // Addresses must be legal SlashAddrs
        .and_then(|raw| raw.parse().map_err(|_| NegotiationError::ProtocolViolation))?;

      tracing::trace!("searching service registry for address handlers");
      let found = service_registry.find_service(&addr, &(Arc::new(tunnel) as Arc<_>));

      match found {
        None => {
          // Write refusal
          // v0 calls for a non-zero u8 to be written to the stream to refuse an address
          tracing::trace!(?addr, "refusing address");
          link
            .write_u8(1)
            .await
            .map_err(|_| NegotiationError::WriteError)?;
          Err(NegotiationError::Refused)
        }
        Some(service) => {
          // Write acceptance
          // v0 calls for a 0u8 to be written to the stream to accept an address
          tracing::trace!("accepting address");
          link
            .write_u8(0)
            .await
            .map_err(|_| NegotiationError::WriteError)?;
          Ok((link, addr, service))
        }
      }
    }
    .instrument(tracing::trace_span!("protocol_negotiation_service", source_tunnel=?tunnel_id))
    .boxed()
  }
}

#[cfg(test)]
mod tests {
  use futures::{FutureExt, TryStreamExt};
  use std::{sync::Arc, time::Duration};
  use tokio::time::timeout;

  use super::{ArcService, NegotiationClient, NegotiationError, NegotiationService};
  use crate::common::protocol::{
    traits::ServiceRegistry,
    tunnel::{
      duplex::EntangledTunnels, ArcTunnel, Tunnel, TunnelDownlink, TunnelIncomingType, TunnelUplink,
    },
    Service,
  };

  struct TestServiceRegistry {
    services: Vec<ArcService<<Self as ServiceRegistry>::Error>>,
  }

  impl ServiceRegistry for TestServiceRegistry {
    type Error = anyhow::Error;

    fn find_service(
      self: std::sync::Arc<Self>,
      addr: &crate::common::protocol::RouteAddress,
      tunnel: &ArcTunnel,
    ) -> Option<std::sync::Arc<dyn Service<Error = Self::Error> + Send + Sync + 'static>> {
      self
        .services
        .iter()
        .find(|s| s.accepts(addr, tunnel))
        .map(Arc::clone)
    }
  }

  struct NoOpServiceAcceptAll;

  impl Service for NoOpServiceAcceptAll {
    type Error = anyhow::Error;

    fn accepts(&self, _addr: &crate::common::protocol::RouteAddress, _tunnel: &ArcTunnel) -> bool {
      true
    }

    fn handle(
      &'_ self,
      _addr: crate::common::protocol::RouteAddress,
      _stream: Box<dyn crate::util::tunnel_stream::TunnelStream + Send + 'static>,
      _tunnel: ArcTunnel,
    ) -> futures::future::BoxFuture<
      '_,
      Result<(), crate::common::protocol::ServiceError<Self::Error>>,
    > {
      futures::future::ready(Ok(())).boxed()
    }
  }

  /// Test that negotiation between client and server sends an address successfully
  #[tokio::test]
  async fn negotiate() {
    let collector = tracing_subscriber::fmt()
      .pretty()
      .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
      .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
      .finish();
    tracing::subscriber::with_default(collector, || async move {
      const TEST_ADDR: &str = "/test/addr";
      let service_registry = TestServiceRegistry {
        services: vec![Arc::new(NoOpServiceAcceptAll)],
      };
      let EntangledTunnels {
        connector,
        listener,
      } = super::super::tunnel::duplex::channel();

      let service = NegotiationService::new(Arc::new(service_registry));
      let client = NegotiationClient::new();

      let client_future = async move {
        let client_stream = connector
          .open_link()
          .await
          .expect("Must open client stream");
        let _stream = client
          .negotiate(
            TEST_ADDR.parse().expect("Illegal test address"),
            client_stream,
          )
          .await?;
        Result::<_, NegotiationError<anyhow::Error>>::Ok(())
      };

      let server_future = async move {
        // server
        let server_stream = listener
          .downlink()
          .await
          .expect("Must successfully fetch server downlink")
          .as_stream()
          .try_next()
          .await
          .expect("Must fetch next connection");
        let server_stream = match server_stream {
          Some(TunnelIncomingType::BiStream(s)) => s,
          #[allow(unreachable_patterns)]
          Some(_other) => unreachable!("Non-bistream opened to the test server"),
          None => panic!("No stream was opened to the test server"),
        };
        let (_stream, addr, service) = service.negotiate(server_stream, listener).await?;
        Result::<_, NegotiationError<anyhow::Error>>::Ok((addr, service))
      };
      let fut = futures::future::try_join(client_future, server_future);
      let fut = timeout(Duration::from_secs(5), fut);
      let ((), (addr, _service)) = fut.await.expect("Must not time out").unwrap();
      assert_eq!(&addr.to_string(), TEST_ADDR);
    })
    .await;
  }
}