1use std::time::Duration;
8
9use anyhow::{Context as _, Result};
10use async_trait::async_trait;
11use quinn::{ConnectionStats, Endpoint};
12use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _, Stdin, Stdout};
13use tokio::time::timeout;
14use tracing::{Instrument as _, debug, error, info, trace, warn};
15
16use crate::client::Parameters;
17use crate::config::{Configuration, Configuration_Optional, Manager};
18use crate::control::create_endpoint;
19use crate::protocol::FindTag as _;
20use crate::protocol::common::{ProtocolMessage, ReceivingStream, SendReceivePair, SendingStream};
21use crate::protocol::compat::Feature;
22use crate::protocol::control::{
23 BANNER, ClientGreeting, ClientMessage, ClientMessage2Attributes, ClientMessageV2,
24 ClosedownReport, ClosedownReportV1, Compatibility, CongestionController, ConnectionType,
25 Direction, OLD_BANNER, OUR_COMPATIBILITY_LEVEL, OUR_COMPATIBILITY_NUMERIC, ServerFailure,
26 ServerGreeting, ServerMessage, ServerMessage2Attributes, ServerMessageV2,
27};
28use crate::transport::combine_bandwidth_configurations;
29use crate::util::{Credentials, TimeFormat, TracingSetupFn};
30
31#[cfg(test)]
32use mockall::{automock, predicate::*};
33
34#[cfg_attr(test, automock)]
36#[async_trait]
37pub(crate) trait ControlChannelServerInterface<
38 S: SendingStream + 'static,
39 R: ReceivingStream + 'static,
40>
41{
42 async fn run_server(
43 &mut self,
44 remote_ip: Option<String>,
45 manager: &mut Manager,
46 setup_tracing: TracingSetupFn,
47 colours: bool,
48 force_compat: Option<Compatibility>,
49 ) -> anyhow::Result<ServerResult>;
50
51 async fn run_server_inner(&mut self, manager: &mut Manager) -> anyhow::Result<ServerResult>;
52
53 async fn send_closedown_report(&mut self, stats: &ConnectionStats) -> Result<()>;
54
55 fn compat(&self) -> Compatibility;
56}
57
58#[derive(Debug)]
60pub struct ControlChannel<S: SendingStream, R: ReceivingStream> {
61 stream: SendReceivePair<S, R>,
62 pub selected_compat: Compatibility,
64}
65
66impl SendingStream for Stdout {}
67impl ReceivingStream for Stdin {}
68
69pub(crate) fn stdio_channel() -> ControlChannel<Stdout, Stdin> {
74 ControlChannel::new((tokio::io::stdout(), tokio::io::stdin()).into())
75}
76
77#[derive(Debug)]
79pub(crate) struct ServerResult {
80 pub(crate) config: Configuration,
82 pub(crate) endpoint: Endpoint,
84}
85
86impl<S: SendingStream, R: ReceivingStream> ControlChannel<S, R> {
87 pub(crate) fn new(stream: SendReceivePair<S, R>) -> Self {
88 Self {
89 stream,
90 selected_compat: Compatibility::Unknown,
91 }
92 }
93
94 async fn send<T: ProtocolMessage>(&mut self, message: T, context: &str) -> Result<()> {
95 let send = &mut self.stream.send;
96 message
97 .to_writer_async_framed(send)
98 .await
99 .with_context(|| format!("sending {context}"))?;
100 send.flush().await?;
101 Ok(())
102 }
103
104 async fn send_error(&mut self, failure: ServerFailure) -> Result<()> {
105 self.send(ServerMessage::Failure(failure), "error").await?;
106 Ok(())
107 }
108
109 async fn recv<T: ProtocolMessage>(&mut self, context: &str) -> Result<T> {
110 T::from_reader_async_framed(&mut self.stream.recv)
111 .await
112 .with_context(|| format!("receiving {context}"))
113 }
114
115 async fn flush(&mut self) -> Result<()> {
116 self.stream.send.flush().await?;
117 Ok(())
118 }
119
120 fn choose_compatibility_level(ours: u16, theirs: u16) -> Compatibility {
122 use std::cmp::Ordering::{Equal, Greater, Less};
124 let (d, result) = match theirs.cmp(&ours) {
125 Less => (Some("older"), theirs),
126 Equal => (None, ours),
127 Greater => (Some("newer"), ours),
128 };
129 if let Some(d) = d {
130 debug!("Remote compatibility level {theirs} is {d} than ours {ours}");
131 }
132 debug!("Selected compatibility level {result}");
133 result.into()
134 }
135
136 fn process_compatibility_levels(&mut self, theirs: u16) {
137 self.selected_compat = Self::choose_compatibility_level(OUR_COMPATIBILITY_NUMERIC, theirs);
139 }
140
141 async fn client_exchange_greetings(
145 &mut self,
146 remote_debug: bool,
147 force_compat: Option<Compatibility>,
148 ) -> Result<ServerGreeting> {
149 self.send(
150 ClientGreeting {
151 compatibility: force_compat.unwrap_or(OUR_COMPATIBILITY_LEVEL).into(),
152 debug: remote_debug,
153 extension: 0,
154 },
155 "client greeting",
156 )
157 .await?;
158
159 let reply = self.recv::<ServerGreeting>("server greeting").await?;
160 self.process_compatibility_levels(reply.compatibility);
161 Ok(reply)
162 }
163
164 async fn client_send_message(
165 &mut self,
166 credentials: &Credentials,
167 connection_type: ConnectionType,
168 parameters: &Parameters,
169 config: &Configuration_Optional,
170 direction: Direction,
171 ) -> Result<()> {
172 let congestion = config
173 .congestion
174 .unwrap_or(Configuration::system_default().congestion);
175 if congestion == CongestionController::NewReno {
176 anyhow::ensure!(
177 self.selected_compat.supports(Feature::NEW_RENO),
178 "Remote host does not support NewReno"
179 );
180 }
181
182 let tagged_creds =
183 credentials.to_tagged_data(self.selected_compat, config.tls_auth_type)?;
184 let mut message = ClientMessage::new(
185 self.selected_compat,
186 tagged_creds,
187 connection_type,
188 parameters.remote_config,
189 config,
190 );
191 message.set_direction(direction);
192 debug!("Our client message: {{ {message} }}");
193 self.send(message, "client message").await
194 }
195
196 async fn client_read_server_message(&mut self) -> Result<ServerMessageV2> {
197 let message = self.recv::<ServerMessage>("server message").await?;
198 debug!("Received server message: {{ {message} }}");
199 Ok(match message {
200 ServerMessage::V1(m) => m.into(),
201 ServerMessage::V2(m) => m,
202 ServerMessage::Failure(f) => {
203 anyhow::bail!("server sent failure message: {f}");
204 }
205 ServerMessage::ToFollow => {
206 anyhow::bail!("remote or logic error: unpacked unexpected ServerMessage::ToFollow")
207 }
208 })
209 }
210
211 pub(crate) async fn run_client(
215 &mut self,
216 credentials: &Credentials,
217 connection_type: ConnectionType,
218 manager: &mut Manager,
219 parameters: &Parameters,
220 direction: Direction,
221 force_compat: Option<Compatibility>,
222 ) -> Result<ServerMessageV2> {
223 trace!("opening control channel");
224
225 self.wait_for_banner().await?;
227
228 let remote_greeting = self
230 .client_exchange_greetings(parameters.remote_debug, force_compat)
231 .await?;
232 debug!("Received server greeting: {remote_greeting:?}");
233
234 let working = manager.get::<Configuration_Optional>().unwrap_or_default();
236 self.client_send_message(
237 credentials,
238 connection_type,
239 parameters,
240 &working,
241 direction,
242 )
243 .await?;
244
245 trace!("waiting for server message");
246 let message = self.client_read_server_message().await?;
247
248 manager.merge_provider(&message);
249 manager.apply_system_default(); for attr in &message.attributes {
251 if attr.tag() == Some(ServerMessage2Attributes::WarningMessage) {
252 warn!(
253 "Remote endpoint warning: {}",
254 attr.data.as_str().unwrap_or("<invalid string>")
255 );
256 }
257 }
258 Ok(message)
259 }
260
261 pub(super) async fn wait_for_banner(&mut self) -> Result<()> {
262 let mut buf = [0u8; BANNER.len()];
263 let recv = &mut self.stream.recv;
264 let mut reader = recv.take(buf.len() as u64);
265
266 let n = reader
271 .read_exact(&mut buf[0..1])
272 .await
273 .context("failed to connect control channel")?;
274 anyhow::ensure!(n == 1, "control channel closed unexpectedly");
275
276 let _ = timeout(Duration::from_secs(1), reader.read_exact(&mut buf[1..]))
279 .await
280 .context("timed out reading server banner")?
282 .context("error reading control channel")?;
284
285 let read_banner = std::str::from_utf8(&buf).context("garbage server banner")?;
286 match read_banner {
287 BANNER => (),
288 OLD_BANNER => {
289 anyhow::bail!("unsupported protocol version (upgrade server to qcp 0.3.0 or later)")
290 }
291 b => anyhow::bail!(
292 "unsupported protocol version (unrecognised server banner `{}'; may be too new for me?)",
293 &b[0..b.len() - 1]
294 ),
295 }
296 Ok(())
297 }
298
299 pub(crate) async fn read_closedown_report(&mut self) -> Result<ClosedownReportV1> {
301 let stats = self.recv::<ClosedownReport>("closedown report").await?;
302 let ClosedownReport::V1(stats) = stats else {
304 anyhow::bail!("server sent unknown ClosedownReport message type");
305 };
306
307 debug!("remote reported stats: {:?}", stats);
308
309 Ok(stats)
310 }
311
312 async fn server_exchange_greetings(
316 &mut self,
317 force_compat: Option<Compatibility>,
318 ) -> Result<ClientGreeting> {
319 let compat = force_compat.unwrap_or(OUR_COMPATIBILITY_LEVEL);
320 self.send(
321 ServerGreeting {
322 compatibility: compat.into(),
323 extension: 0,
324 },
325 "server greeting",
326 )
327 .await?;
328
329 let reply = self.recv::<ClientGreeting>("client greeting").await?;
330 self.process_compatibility_levels(reply.compatibility);
331 Ok(reply)
332 }
333
334 async fn server_read_client_message(&mut self) -> Result<ClientMessageV2> {
335 let client_message = match self.recv::<ClientMessage>("client message").await {
336 Ok(cm) => cm,
337 Err(e) => {
338 self.send_error(ServerFailure::Malformed).await?;
339 error!("{e}");
341 anyhow::bail!(
342 "In server mode, this program expects to receive a binary data packet on stdin"
343 );
344 }
345 };
346
347 trace!("waiting for client message");
348 let message = match client_message {
349 ClientMessage::ToFollow => {
350 self.send_error(ServerFailure::Malformed).await?;
351 anyhow::bail!("remote or logic error: unpacked unexpected ClientMessage::ToFollow")
352 }
353 ClientMessage::V1(m) => m.into(),
354 ClientMessage::V2(m) => m,
355 };
356 Ok(message)
357 }
358
359 async fn server_send_message(
360 &mut self,
361 port: u16,
362 credentials: &Credentials,
363 config: &Configuration,
364 warning: String,
365 ) -> Result<()> {
366 let tagged_creds =
367 credentials.to_tagged_data(self.selected_compat, Some(config.tls_auth_type))?;
368
369 let message = ServerMessage::new(
370 self.selected_compat,
371 config,
372 port,
373 tagged_creds,
374 credentials.hostname.clone(),
375 warning,
376 );
377 debug!("sending server message: {message:?}");
378 self.send(message, "server message").await?;
379 self.flush().await?;
380 Ok(())
381 }
382
383 fn server_trace_level(debug: bool) -> &'static str {
384 if debug { "debug" } else { "info" }
385 }
386}
387
388#[async_trait]
389impl<S: SendingStream + 'static, R: ReceivingStream + 'static> ControlChannelServerInterface<S, R>
390 for ControlChannel<S, R>
391{
392 async fn run_server(
393 &mut self,
394 remote_ip: Option<String>,
395 manager: &mut Manager,
396 setup_tracing: TracingSetupFn,
397 colours: bool,
398 force_compat: Option<Compatibility>,
399 ) -> anyhow::Result<ServerResult> {
400 self.stream.send.write_all(BANNER.as_bytes()).await?;
402
403 let remote_greeting = self.server_exchange_greetings(force_compat).await?;
405 let time_format = manager.get_config_field::<TimeFormat>(
407 "time_format",
408 Some(Configuration::system_default().time_format),
409 )?;
410
411 let level = Self::server_trace_level(remote_greeting.debug);
413 setup_tracing(
414 level,
415 crate::util::ConsoleTraceType::Standard,
416 None,
417 time_format,
418 colours,
419 )?;
420 debug!(
422 "client IP is {}",
423 remote_ip.as_deref().map_or("none", |v| v)
424 );
425 debug!("Received client greeting: {remote_greeting:?}");
426
427 self.run_server_inner(manager)
428 .instrument(tracing::error_span!("[Server]").or_current())
429 .await
430 }
431
432 async fn run_server_inner(&mut self, manager: &mut Manager) -> anyhow::Result<ServerResult> {
433 let message2 = self.server_read_client_message().await?;
436
437 debug!("using {:?}", message2.connection_type,);
439 debug!("Received client message: {message2}");
440 let show_config = message2
441 .attributes
442 .find_tag(crate::protocol::control::ClientMessage2Attributes::OutputConfig)
443 .is_some();
444 if show_config {
445 info!(
446 "Static configuration:\n{}",
447 manager.to_display_adapter::<Configuration>()
448 );
449 }
450
451 let config = match combine_bandwidth_configurations(manager, &message2) {
452 Ok(cfg) => cfg,
453 Err(e) => {
454 self.send_error(ServerFailure::NegotiationFailed(format!("{e}")))
455 .await?;
456 anyhow::bail!("Config negotiation failed: {e}");
457 }
458 };
459
460 if show_config {
461 info!(
462 "Final configuration:\n{}",
463 manager.to_display_adapter::<Configuration>()
464 );
465 }
466
467 let credentials = Credentials::generate()?;
469 let direction = Direction::from(
470 message2
471 .attributes
472 .find_tag(ClientMessage2Attributes::DirectionOfTravel),
473 );
474 trace!("Direction of travel: {direction}");
475
476 let (endpoint, warning) = match create_endpoint(
477 &credentials,
478 &message2.credentials,
479 message2.connection_type,
480 &config,
481 direction.server_mode(),
482 true,
483 self.selected_compat,
484 ) {
485 Ok(t) => t,
486 Err(e) => {
487 self.send_error(ServerFailure::EndpointFailed(format!("{e}")))
488 .await?;
489 anyhow::bail!("failed to create server endpoint: {e}");
490 }
491 };
492 let local_addr = endpoint.local_addr()?;
493 debug!("Local endpoint address is {local_addr}");
494
495 self.server_send_message(
497 local_addr.port(),
498 &credentials,
499 &config,
500 warning.unwrap_or_default(),
501 )
502 .await?;
503
504 Ok(ServerResult { config, endpoint })
505 }
506
507 async fn send_closedown_report(&mut self, stats: &ConnectionStats) -> Result<()> {
508 self.send(
510 ClosedownReport::V1(ClosedownReportV1::from(stats)),
511 "closedown report",
512 )
513 .await?;
514 Ok(())
515 }
516
517 fn compat(&self) -> Compatibility {
518 self.selected_compat
519 }
520}
521
522#[cfg(test)]
523#[cfg_attr(coverage_nightly, coverage(off))]
524mod test {
525 use crate::{
526 client::Parameters,
527 config::{Configuration_Optional, Manager},
528 control::{ControlChannel, ControlChannelServerInterface as _},
529 protocol::{
530 common::{
531 MessageHeader, ProtocolMessage as _, ReceivingStream, SendReceivePair,
532 SendingStream,
533 },
534 control::{
535 ClosedownReportV1, Compatibility, CongestionController, ConnectionType, OLD_BANNER,
536 ServerMessageV2,
537 },
538 test_helpers::new_test_plumbing,
539 },
540 util::{Credentials, PortRange, TimeFormat},
541 };
542 use anyhow::Result;
543 use pretty_assertions::assert_eq;
544 use quinn::ConnectionStats;
545 use tokio::io::AsyncWriteExt;
546
547 #[allow(clippy::unnecessary_wraps)]
548 fn setup_tracing_stub(
549 _trace_level: &str,
550 _display: crate::util::ConsoleTraceType,
551 _filename: Option<&String>,
552 _time_format: TimeFormat,
553 _colour: bool,
554 ) -> anyhow::Result<()> {
555 Ok(())
556 }
557
558 struct TestClient<S: SendingStream, R: ReceivingStream> {
559 creds: Credentials,
560 manager: Manager,
561 params: Parameters,
562 client: ControlChannel<S, R>,
563 compat: Compatibility,
564 }
565 impl<S: SendingStream, R: ReceivingStream> TestClient<S, R> {
566 fn new(pipe: SendReceivePair<S, R>, compat: Compatibility) -> TestClient<S, R> {
567 Self {
568 creds: Credentials::generate().unwrap(),
569 manager: Manager::without_files(None),
570 params: Parameters::default(),
571 client: ControlChannel::new(pipe),
572 compat,
573 }
574 }
575 fn with_prefs<F: FnOnce(&mut Manager)>(
577 pipe: SendReceivePair<S, R>,
578 f: F,
579 compat: Compatibility,
580 ) -> TestClient<S, R> {
581 let mut rv = Self::new(pipe, compat);
582 f(&mut rv.manager);
583 rv
584 }
585 fn run(&mut self) -> impl Future<Output = Result<ServerMessageV2>> {
587 self.client.run_client(
588 &self.creds,
589 ConnectionType::Ipv4,
590 &mut self.manager,
591 &self.params,
592 crate::protocol::control::Direction::Both,
593 Some(self.compat),
594 )
595 }
596 }
597
598 async fn happy_path(compat: Compatibility) {
602 let (pipe1, pipe2) = new_test_plumbing();
603 let mut cli = TestClient::new(pipe1, compat);
604 cli.params.remote_config = true;
605 let cli_fut = cli.run();
606
607 let mut server = ControlChannel::new(pipe2);
608 let mut manager = Manager::without_files(None);
609 let ser_fut =
610 server.run_server(None, &mut manager, setup_tracing_stub, false, Some(compat));
611
612 let (cli_res, ser_res) = tokio::join!(cli_fut, ser_fut);
613 eprintln!("Client: {cli_res:?}\nServer: {ser_res:?}");
614 assert!(cli_res.is_ok());
615 assert!(ser_res.is_ok());
616
617 let stats = ConnectionStats::default();
618 let expected = ClosedownReportV1::from(&stats);
619 let _ = server.send_closedown_report(&stats).await;
620 let got = cli.client.read_closedown_report().await.unwrap();
621 assert_eq!(expected, got);
622 }
623
624 #[cfg_attr(cross_target_mingw, ignore)] #[tokio::test]
626 async fn happy_path_compat_1() {
627 happy_path(Compatibility::Level(1)).await;
628 }
629
630 #[cfg_attr(cross_target_mingw, ignore)] #[tokio::test]
632 async fn happy_path_compat_3() {
633 happy_path(Compatibility::Level(3)).await;
634 }
635
636 #[tokio::test]
637 async fn old_banner() {
638 let (pipe1, mut pipe2) = new_test_plumbing();
639 let mut cli = TestClient::new(pipe1, Compatibility::Level(1));
640 let cli_fut = cli.run();
641 pipe2.send.write_all(OLD_BANNER.as_bytes()).await.unwrap();
642 let res = cli_fut.await;
643 assert!(res.is_err_and(|e| {
644 e.to_string()
645 .contains("unsupported protocol version (upgrade")
646 }));
647 }
648
649 #[tokio::test]
650 async fn banner_junk() {
651 let (pipe1, mut pipe2) = new_test_plumbing();
652 let mut cli = TestClient::new(pipe1, Compatibility::Level(1));
653 let cli_fut = cli.run();
654 pipe2
655 .send
656 .write_all("qqqqqqqqqqqqqqqqq\n".as_bytes())
657 .await
658 .unwrap();
659 let res = cli_fut.await;
660 assert!(res.is_err_and(|e| e.to_string().contains("unrecognised server banner")));
661 }
662
663 fn fake_cli_with_port(begin: u16, end: u16) -> Configuration_Optional {
664 Configuration_Optional {
665 port: Some(PortRange { begin, end }),
666 remote_port: Some(PortRange { begin, end }),
667 ..Default::default()
668 }
669 }
670
671 #[tokio::test]
672 async fn negotiation_fails() {
673 let (pipe1, pipe2) = new_test_plumbing();
674
675 let mut cli = TestClient::with_prefs(
676 pipe1,
677 |mgr| {
678 mgr.merge_provider(fake_cli_with_port(11111, 11111));
679 },
680 Compatibility::Level(1),
681 );
682 let cli_fut = cli.run();
683
684 let mut server = ControlChannel::new(pipe2);
685 let mut manager = Manager::without_files(None);
686 manager.merge_provider(fake_cli_with_port(22222, 22222));
688 let ser_fut = server.run_server(
689 None,
690 &mut manager,
691 setup_tracing_stub,
692 false,
693 Some(Compatibility::Level(1)),
694 );
695
696 let (cli_res, ser_res) = tokio::join!(cli_fut, ser_fut);
697 assert!(cli_res.is_err());
698 assert!(cli_res.is_err_and(|e| e.to_string().contains("Negotiation Failed")));
699 assert!(ser_res.is_err());
700 assert!(ser_res.is_err_and(|e| e.to_string().contains("negotiation failed")));
701 }
702
703 #[tokio::test]
704 async fn client_message_junk() {
705 let (mut pipe1, pipe2) = new_test_plumbing();
706
707 let mut server = ControlChannel::new(pipe2);
708 let fut = server.server_read_client_message();
709 let write_fut = pipe1.send.write_all(&[255u8; 1024]);
710
711 let (ser_res, write_res) = tokio::join!(fut, write_fut);
712 assert!(write_res.is_ok());
713 assert!(ser_res.is_err_and(|e| {
714 e.to_string()
715 .contains("this program expects to receive a binary data packet")
716 }));
717 }
718
719 #[tokio::test]
720 async fn client_message_illegal() {
721 let (mut pipe1, pipe2) = new_test_plumbing();
722
723 let mut server = ControlChannel::new(pipe2);
724 let fut = server.server_read_client_message();
725 let mut body = vec![0u8];
727 let mut packet = MessageHeader { size: 1 }.to_vec().unwrap();
728 packet.append(&mut body);
729 let fut2 = pipe1.send.write_all(&packet);
730
731 let (res1, res2) = tokio::join!(fut, fut2);
732 assert!(res2.is_ok());
733 assert!(res1.is_err_and(|e| e.to_string().contains("unexpected ClientMessage::ToFollow")));
734 }
735
736 #[test]
737 fn compatibility_level_comparison() {
738 type Uut = ControlChannel<tokio::io::Stdout, tokio::io::Stdin>;
739 let cases = &[(1u16, 1u16, 1u16), (1, 2, 1), (2, 1, 1), (65535, 1, 1)];
740 for (a, b, result) in cases {
741 assert_eq!(
742 Uut::choose_compatibility_level(*a, *b),
743 (*result).into(),
744 "case: {a} {b} -> {result}"
745 );
746 }
747 }
748
749 #[tokio::test]
750 async fn compat_check_newreno() {
751 let (pipe1, pipe2) = new_test_plumbing();
752 let mut cli = TestClient::new(pipe1, Compatibility::Level(3));
754 let cfg = Configuration_Optional {
756 congestion: Some(CongestionController::NewReno),
757 ..Default::default()
758 };
759 cli.manager.merge_provider(cfg);
760 let cli_fut = cli.run();
761
762 let mut server = ControlChannel::new(pipe2);
763 let mut manager = Manager::without_files(None);
764 let ser_fut = server.run_server(
766 None,
767 &mut manager,
768 setup_tracing_stub,
769 false,
770 Some(Compatibility::Level(1)),
771 );
772
773 let res = tokio::try_join!(cli_fut, ser_fut).unwrap_err();
774 assert!(res.to_string().contains("does not support NewReno"));
775 }
776}