1use std::{
20 fmt::{Debug, Formatter, Result as FmtResult},
21 io::{Cursor, Error as IoError},
22 ops::RangeInclusive,
23 pin::Pin,
24 task::{Context, Poll},
25 time::Duration,
26};
27
28use bytes::{BufMut, Bytes, BytesMut};
29use futures_util::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
30use quinn::{
31 ClosedStream, Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError,
32 SendStream, VarInt,
33};
34use thiserror::Error;
35use uuid::Uuid;
36
37use asport::{
38 model::{
39 side::{Rx, Tx},
40 AssembleError, ClientHello as ClientHelloModel, Connect as ConnectModel,
41 Connection as ConnectionModel, KeyingMaterialExporter as KeyingMaterialExporterImpl,
42 Packet as PacketModel, ServerHello as ServerHelloModel,
43 },
44 Address, Flags, Header, ServerHello as ServerHelloHeader, UnmarshalError,
45};
46
47use self::side::Side;
48
49pub mod side {
50 #[derive(Clone, Debug)]
51 pub struct Client;
52
53 #[derive(Clone, Debug)]
54 pub struct Server;
55
56 #[derive(Debug)]
57 pub(super) enum Side<C, S> {
58 Client(C),
59 Server(S),
60 }
61}
62
63#[derive(Clone)]
67pub struct Connection<Side> {
68 conn: QuinnConnection,
69 model: ConnectionModel<Bytes>,
70 _marker: Side,
71}
72
73impl<Side> Connection<Side> {
74 pub fn packet_native(
76 &self,
77 pkt: impl AsRef<[u8]>,
78 addr: Address,
79 assoc_id: u16,
80 ) -> Result<(), Error> {
81 let Some(max_pkt_size) = self.conn.max_datagram_size() else {
82 return Err(Error::SendDatagram(SendDatagramError::Disabled));
83 };
84
85 let model = self.model.send_packet(assoc_id, addr, max_pkt_size);
86
87 for (header, frag) in model.into_fragments(pkt) {
88 let mut buf = BytesMut::with_capacity(header.len() + frag.len());
89 header.write(&mut buf);
90 buf.put_slice(frag);
91 self.conn.send_datagram(Bytes::from(buf))?;
92 }
93
94 Ok(())
95 }
96
97 pub async fn packet_quic(
99 &self,
100 pkt: impl AsRef<[u8]>,
101 addr: Address,
102 assoc_id: u16,
103 ) -> Result<(), Error> {
104 let model = self.model.send_packet(assoc_id, addr, u16::MAX as usize);
105
106 for (header, frag) in model.into_fragments(pkt) {
107 let mut send = self.conn.open_uni().await?;
108 header.async_marshal(&mut send).await?;
109 AsyncWriteExt::write_all(&mut send, frag).await?;
110 send.close().await?;
111 }
112
113 Ok(())
114 }
115
116 pub fn task_connect_count(&self) -> usize {
118 self.model.task_connect_count()
119 }
120
121 pub fn task_associate_count(&self) -> usize {
123 self.model.task_associate_count()
124 }
125
126 pub fn collect_garbage(&self, timeout: Duration) {
128 self.model.collect_garbage(timeout);
129 }
130
131 fn keying_material_exporter(&self) -> KeyingMaterialExporter {
132 KeyingMaterialExporter(self.conn.clone())
133 }
134}
135
136impl Connection<side::Client> {
137 pub fn new(conn: QuinnConnection) -> Self {
139 Self {
140 conn,
141 model: ConnectionModel::new(),
142 _marker: side::Client,
143 }
144 }
145
146 pub async fn client_hello(
148 &self,
149 uuid: Uuid,
150 password: impl AsRef<[u8]>,
151 flags: impl Into<Flags>,
152 expected_port_range: RangeInclusive<u16>,
153 ) -> Result<(), Error> {
154 let model = self.model.send_client_hello(
155 uuid,
156 password,
157 &self.keying_material_exporter(),
158 flags,
159 expected_port_range,
160 );
161
162 let mut send = self.conn.open_uni().await?;
163 model.header().async_marshal(&mut send).await?;
164 send.close().await?;
165 Ok(())
166 }
167
168 pub async fn heartbeat(&self) -> Result<(), Error> {
170 let model = self.model.send_heartbeat();
171 let mut buf = Vec::with_capacity(model.header().len());
172 model.header().async_marshal(&mut buf).await.unwrap();
173 self.conn.send_datagram(Bytes::from(buf))?;
174 Ok(())
175 }
176
177 pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result<Task, Error> {
181 let header = match Header::async_unmarshal(&mut recv).await {
182 Ok(header) => header,
183 Err(err) => return Err(Error::UnmarshalUniStream(err, recv)),
184 };
185
186 match header {
187 Header::ClientHello(_) => Err(Error::BadCommandUniStream("clienthello", recv)),
188 Header::ServerHello(server_hello) => {
189 let model = self.model.recv_server_hello(server_hello);
190 Ok(Task::ServerHello(ServerHello::new(model)))
191 }
192 Header::Packet(pkt) => {
193 let model = self.model.recv_packet_unrestricted(pkt);
194 Ok(Task::Packet(Packet::new(model, PacketSource::Quic(recv))))
195 }
196 Header::Dissociate(dissoc) => {
197 let model = self.model.recv_dissociate(dissoc);
198 Ok(Task::Dissociate(model.assoc_id()))
199 }
200 Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)),
201 Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)),
202 _ => unreachable!(),
203 }
204 }
205
206 pub async fn accept_bi_stream(
210 &self,
211 send: SendStream,
212 mut recv: RecvStream,
213 ) -> Result<Task, Error> {
214 let header = match Header::async_unmarshal(&mut recv).await {
215 Ok(header) => header,
216 Err(err) => return Err(Error::UnmarshalBiStream(err, send, recv)),
217 };
218
219 match header {
220 Header::ClientHello(_) => Err(Error::BadCommandBiStream("clienthello", send, recv)),
221 Header::ServerHello(_) => Err(Error::BadCommandBiStream("serverhello", send, recv)),
222 Header::Connect(connect) => {
223 let model = self.model.recv_connect(connect);
224 Ok(Task::Connect(Connect::new(Side::Client(model), send, recv)))
225 }
226 Header::Packet(_) => Err(Error::BadCommandBiStream("packet", send, recv)),
227 Header::Dissociate(_) => Err(Error::BadCommandBiStream("dissociate", send, recv)),
228 Header::Heartbeat(_) => Err(Error::BadCommandBiStream("heartbeat", send, recv)),
229 _ => unreachable!(),
230 }
231 }
232
233 pub fn accept_datagram(&self, dg: Bytes) -> Result<Task, Error> {
237 let mut dg = Cursor::new(dg);
238
239 let header = match Header::unmarshal(&mut dg) {
240 Ok(header) => header,
241 Err(err) => return Err(Error::UnmarshalDatagram(err, dg.into_inner())),
242 };
243
244 match header {
245 Header::ClientHello(_) => {
246 Err(Error::BadCommandDatagram("clienthello", dg.into_inner()))
247 }
248 Header::ServerHello(_) => {
249 Err(Error::BadCommandDatagram("serverhello", dg.into_inner()))
250 }
251 Header::Connect(_) => Err(Error::BadCommandDatagram("connect", dg.into_inner())),
252 Header::Packet(pkt) => {
253 let model = self.model.recv_packet_unrestricted(pkt);
254 let pos = dg.position() as usize;
255 let buf = dg.into_inner().slice(pos..pos + model.size() as usize);
256 Ok(Task::Packet(Packet::new(model, PacketSource::Native(buf))))
257 }
258 Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())),
259 Header::Heartbeat(_) => Err(Error::BadCommandDatagram("heartbeat", dg.into_inner())),
260 _ => unreachable!(),
261 }
262 }
263}
264
265impl Connection<side::Server> {
266 pub fn new(conn: QuinnConnection) -> Self {
268 Self {
269 conn,
270 model: ConnectionModel::new(),
271 _marker: side::Server,
272 }
273 }
274
275 pub async fn server_hello(&self, result: ServerHelloHeader) -> Result<(), Error> {
277 let model = self.model.send_server_hello(result);
278 let mut send = self.conn.open_uni().await?;
279 model.header().async_marshal(&mut send).await?;
280 send.close().await?;
281 Ok(())
282 }
283
284 pub async fn connect(&self, addr: Address) -> Result<Connect, Error> {
286 let model = self.model.send_connect(addr);
287 let (mut send, recv) = self.conn.open_bi().await?;
288 model.header().async_marshal(&mut send).await?;
289 Ok(Connect::new(Side::Server(model), send, recv))
290 }
291
292 pub async fn dissociate(&self, assoc_id: u16) -> Result<(), Error> {
294 let model = self.model.send_dissociate(assoc_id);
295 let mut send = self.conn.open_uni().await?;
296 model.header().async_marshal(&mut send).await?;
297 send.close().await?;
298 Ok(())
299 }
300
301 pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result<Task, Error> {
305 let header = match Header::async_unmarshal(&mut recv).await {
306 Ok(header) => header,
307 Err(err) => return Err(Error::UnmarshalUniStream(err, recv)),
308 };
309
310 match header {
311 Header::ClientHello(client_hello) => {
312 let model = self.model.recv_client_hello(client_hello);
313 Ok(Task::ClientHello(ClientHello::new(
314 model,
315 self.keying_material_exporter(),
316 )))
317 }
318 Header::ServerHello(_) => Err(Error::BadCommandUniStream("serverhello", recv)),
319 Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)),
320 Header::Packet(pkt) => {
321 let assoc_id = pkt.assoc_id();
322 let pkt_id = pkt.pkt_id();
323 self.model
324 .recv_packet(pkt)
325 .map_or(Err(Error::InvalidUdpSession(assoc_id, pkt_id)), |pkt| {
326 Ok(Task::Packet(Packet::new(pkt, PacketSource::Quic(recv))))
327 })
328 }
329 Header::Dissociate(_) => Err(Error::BadCommandUniStream("dissociate", recv)),
330 Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)),
331 _ => unreachable!(),
332 }
333 }
334
335 pub async fn accept_bi_stream(
339 &self,
340 send: SendStream,
341 mut recv: RecvStream,
342 ) -> Result<Task, Error> {
343 let header = match Header::async_unmarshal(&mut recv).await {
344 Ok(header) => header,
345 Err(err) => return Err(Error::UnmarshalBiStream(err, send, recv)),
346 };
347
348 match header {
349 Header::ClientHello(_) => Err(Error::BadCommandUniStream("clienthello", recv)),
350 Header::ServerHello(_) => Err(Error::BadCommandBiStream("serverhello", send, recv)),
351 Header::Connect(_) => Err(Error::BadCommandBiStream("connect", send, recv)),
352 Header::Packet(_) => Err(Error::BadCommandBiStream("packet", send, recv)),
353 Header::Dissociate(_) => Err(Error::BadCommandBiStream("dissociate", send, recv)),
354 Header::Heartbeat(_) => Err(Error::BadCommandBiStream("heartbeat", send, recv)),
355 _ => unreachable!(),
356 }
357 }
358
359 pub fn accept_datagram(&self, dg: Bytes) -> Result<Task, Error> {
363 let mut dg = Cursor::new(dg);
364
365 let header = match Header::unmarshal(&mut dg) {
366 Ok(header) => header,
367 Err(err) => return Err(Error::UnmarshalDatagram(err, dg.into_inner())),
368 };
369
370 match header {
371 Header::ClientHello(_) => {
372 Err(Error::BadCommandDatagram("clienthello", dg.into_inner()))
373 }
374 Header::ServerHello(_) => {
375 Err(Error::BadCommandDatagram("serverhello", dg.into_inner()))
376 }
377 Header::Connect(_) => Err(Error::BadCommandDatagram("connect", dg.into_inner())),
378 Header::Packet(pkt) => {
379 let assoc_id = pkt.assoc_id();
380 let pkt_id = pkt.pkt_id();
381 if let Some(pkt) = self.model.recv_packet(pkt) {
382 let pos = dg.position() as usize;
383 let mut buf = dg.into_inner();
384 if (pos + pkt.size() as usize) <= buf.len() {
385 buf = buf.slice(pos..pos + pkt.size() as usize);
386 Ok(Task::Packet(Packet::new(pkt, PacketSource::Native(buf))))
387 } else {
388 Err(Error::PayloadLength(pkt.size() as usize, buf.len() - pos))
389 }
390 } else {
391 Err(Error::InvalidUdpSession(assoc_id, pkt_id))
392 }
393 }
394 Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())),
395 Header::Heartbeat(hb) => {
396 let _ = self.model.recv_heartbeat(hb);
397 Ok(Task::Heartbeat)
398 }
399 _ => unreachable!(),
400 }
401 }
402}
403
404#[derive(Debug)]
406pub struct ClientHello {
407 model: ClientHelloModel<Rx>,
408 exporter: KeyingMaterialExporter,
409}
410
411impl ClientHello {
412 fn new(model: ClientHelloModel<Rx>, exporter: KeyingMaterialExporter) -> Self {
413 Self { model, exporter }
414 }
415
416 pub fn uuid(&self) -> Uuid {
418 self.model.uuid()
419 }
420
421 pub fn token(&self) -> [u8; 32] {
423 self.model.token()
424 }
425
426 pub fn flags(&self) -> Flags {
427 self.model.flags()
428 }
429
430 pub fn expected_port_range(&self) -> RangeInclusive<u16> {
431 self.model.expected_port_range()
432 }
433
434 pub fn validate(&self, password: impl AsRef<[u8]>) -> bool {
436 self.model.is_valid(password, &self.exporter)
437 }
438}
439
440#[derive(Debug)]
442pub struct ServerHello {
443 model: ServerHelloModel<Rx>,
444}
445
446impl ServerHello {
447 fn new(model: ServerHelloModel<Rx>) -> Self {
448 Self { model }
449 }
450
451 pub fn handshake_code(&self) -> u8 {
453 self.model.handshake_code()
454 }
455
456 pub fn port(&self) -> Option<u16> {
458 self.model.port()
459 }
460}
461
462pub struct Connect {
464 model: Side<ConnectModel<Rx>, ConnectModel<Tx>>,
465 send: SendStream,
466 recv: RecvStream,
467}
468
469impl Connect {
470 fn new(
471 model: Side<ConnectModel<Rx>, ConnectModel<Tx>>,
472 send: SendStream,
473 recv: RecvStream,
474 ) -> Self {
475 Self { model, send, recv }
476 }
477
478 pub fn addr(&self) -> &Address {
480 match &self.model {
481 Side::Server(model) => {
482 let Header::Connect(conn) = model.header() else {
483 unreachable!()
484 };
485 conn.addr()
486 }
487 Side::Client(model) => model.addr(),
488 }
489 }
490
491 pub fn reset(
493 &mut self,
494 error_code: VarInt,
495 ) -> (Result<(), ClosedStream>, Result<(), ClosedStream>) {
496 let send_res = self.send.reset(error_code);
497 let recv_res = self.recv.stop(error_code);
498 (send_res, recv_res)
499 }
500}
501
502impl AsyncRead for Connect {
503 fn poll_read(
504 self: Pin<&mut Self>,
505 cx: &mut Context<'_>,
506 buf: &mut [u8],
507 ) -> Poll<Result<usize, IoError>> {
508 AsyncRead::poll_read(Pin::new(&mut self.get_mut().recv), cx, buf)
509 }
510}
511
512impl AsyncWrite for Connect {
513 fn poll_write(
514 self: Pin<&mut Self>,
515 cx: &mut Context<'_>,
516 buf: &[u8],
517 ) -> Poll<Result<usize, IoError>> {
518 AsyncWrite::poll_write(Pin::new(&mut self.get_mut().send), cx, buf)
519 }
520
521 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
522 AsyncWrite::poll_flush(Pin::new(&mut self.get_mut().send), cx)
523 }
524
525 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
526 AsyncWrite::poll_close(Pin::new(&mut self.get_mut().send), cx)
527 }
528}
529
530impl Debug for Connect {
531 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
532 let model = match &self.model {
533 Side::Client(model) => model as &dyn Debug,
534 Side::Server(model) => model as &dyn Debug,
535 };
536
537 f.debug_struct("Connect")
538 .field("model", model)
539 .field("send", &self.send)
540 .field("recv", &self.recv)
541 .finish()
542 }
543}
544
545#[derive(Debug)]
547pub struct Packet {
548 model: PacketModel<Rx, Bytes>,
549 src: PacketSource,
550}
551
552#[derive(Debug)]
553enum PacketSource {
554 Quic(RecvStream),
555 Native(Bytes),
556}
557
558impl Packet {
559 fn new(model: PacketModel<Rx, Bytes>, src: PacketSource) -> Self {
560 Self { src, model }
561 }
562
563 pub fn assoc_id(&self) -> u16 {
565 self.model.assoc_id()
566 }
567
568 pub fn pkt_id(&self) -> u16 {
570 self.model.pkt_id()
571 }
572
573 pub fn frag_id(&self) -> u8 {
575 self.model.frag_id()
576 }
577
578 pub fn frag_total(&self) -> u8 {
580 self.model.frag_total()
581 }
582
583 pub fn is_from_quic(&self) -> bool {
585 matches!(self.src, PacketSource::Quic(_))
586 }
587
588 pub fn is_from_native(&self) -> bool {
590 matches!(self.src, PacketSource::Native(_))
591 }
592
593 pub async fn accept(self) -> Result<Option<(Bytes, Address, u16)>, Error> {
595 let pkt = match self.src {
596 PacketSource::Quic(mut recv) => {
597 let mut buf = vec![0; self.model.size() as usize];
598 AsyncReadExt::read_exact(&mut recv, &mut buf).await?;
599 Bytes::from(buf)
600 }
601 PacketSource::Native(pkt) => pkt,
602 };
603
604 let mut asm = Vec::new();
605
606 Ok(self
607 .model
608 .assemble(pkt)?
609 .map(|pkt| pkt.assemble(&mut asm))
610 .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id)))
611 }
612}
613
614#[non_exhaustive]
615#[derive(Debug)]
616pub enum Task {
617 ClientHello(ClientHello),
618 ServerHello(ServerHello),
619 Connect(Connect),
620 Packet(Packet),
621 Dissociate(u16),
622 Heartbeat,
623}
624
625#[derive(Debug)]
626struct KeyingMaterialExporter(QuinnConnection);
627
628impl KeyingMaterialExporterImpl for KeyingMaterialExporter {
629 fn export_keying_material(&self, label: &[u8], context: &[u8]) -> [u8; 32] {
630 let mut buf = [0; 32];
631 match self.0.export_keying_material(&mut buf, label, context) {
632 Ok(_) => {}
633 Err(_) => {
634 let info = "asport key derivation";
638 let derived_key = blake3::derive_key(&info, context);
639
640 let mac = blake3::keyed_hash(&derived_key, label.as_ref());
641 buf.copy_from_slice(mac.as_bytes());
642 }
643 }
644 buf
645 }
646}
647
648#[derive(Debug, Error)]
650pub enum Error {
651 #[error(transparent)]
652 Io(#[from] IoError),
653 #[error(transparent)]
654 Connection(#[from] ConnectionError),
655 #[error(transparent)]
656 SendDatagram(#[from] SendDatagramError),
657 #[error("expecting payload length {0} but got {1}")]
658 PayloadLength(usize, usize),
659 #[error("packet {1:#06x} on invalid udp session {0:#06x}")]
660 InvalidUdpSession(u16, u16),
661 #[error(transparent)]
662 Assemble(#[from] AssembleError),
663 #[error("error unmarshalling uni_stream: {0}")]
664 UnmarshalUniStream(UnmarshalError, RecvStream),
665 #[error("error unmarshalling bi_stream: {0}")]
666 UnmarshalBiStream(UnmarshalError, SendStream, RecvStream),
667 #[error("error unmarshalling datagram: {0}")]
668 UnmarshalDatagram(UnmarshalError, Bytes),
669 #[error("bad command `{0}` from uni_stream")]
670 BadCommandUniStream(&'static str, RecvStream),
671 #[error("bad command `{0}` from bi_stream")]
672 BadCommandBiStream(&'static str, SendStream, RecvStream),
673 #[error("bad command `{0}` from datagram")]
674 BadCommandDatagram(&'static str, Bytes),
675}