marlin/scallop.rs
1// Scallop protocol
2//
3// Broad goals:
4// - Transport layer security (the concept not the protocol)
5// - Enclave native protocol (no Web PKI)
6// - Modern cryptography
7//
8// Handshake shape and security levels in the user stories are modelled on the Noise protocol specification.
9//
10// User story 1 - HTTP query against known server:
11//
12// Server is running inside an enclave.
13// Client has authenticated the attestation of the enclave and has the static key.
14//
15// Client wants to create a secure channel with the server to make a HTTP query.
16//
17// Client to server requires a security level of 0/5.
18// Server to client requires a security level of 2/1.
19// NK is the minimum viable handshake.
20// With 1RTT client delay and 0.5RTT server delay.
21//
22// Bonuses:
23// - Authentication refresh on expiry
24// - Client can request a new attestation in the first message
25// - Server can send the attestation in the second message
26//
27// User story 2 - HTTP query against unknown server:
28//
29// Server is running inside an enclave.
30// Client knows the expected PCRs of the server.
31//
32// Client wants to create a secure channel with the server to make a HTTP query.
33//
34// Client to server requires a security level of 0/5.
35// Server to client requires a security level of 2/1.
36// NX is the minimum viable handshake.
37// With 1RTT client delay and 0.5RTT server delay.
38// With additional handshake payloads
39// - Client requests a new attestation in the first message
40// - Server sends the attestation in the second message
41//
42// User story 3 - webhook trigger from a known client to a known server:
43//
44// Client is running inside an enclave.
45// Client has the static key of the server.
46// Server has previously authenticated the attestation of the client and has the static key.
47//
48// Client wants to create a secure channel with the server and trigger a webhook.
49//
50// Client to server requires a security level of 2/5.
51// Server to client requires a security level of 2/5.
52// KK is the minimum viable handshake.
53// With 1 RTT client delay and 0.5RTT server delay.
54//
55// User story 4 - webhook trigger from a known client to an unknown server:
56//
57// Client is running inside an enclave.
58// Server is running inside an enclave.
59// Client knows the expected PCRs of the server.
60// Server has previously authenticated the attestation of the client and has the static key.
61//
62// Client wants to create a secure channel with the server and trigger a webhook.
63//
64// Client to server requires a security level of 2/5.
65// Server to client requires a security level of 2/5.
66// KX is the minimum viable handshake.
67// With 1 RTT client delay and 0.5RTT server delay.
68// With additional handshake payloads
69// - Client requests a new attestation in the first message
70// - Server sends the attestation in the second message
71//
72// User story 5 - webhook trigger from an unknown client to a known server:
73//
74// Client is running inside an enclave.
75// Client has the static key of the server.
76// Server knows the expected PCRs of the client.
77//
78// Client wants to create a secure channel with the server and trigger a webhook.
79//
80// Client to server requires a security level of 2/5.
81// Server to client requires a security level of 2/5.
82// XK is the minimum viable handshake.
83// With 1 RTT client delay and 1.5RTT server delay.
84// With additional handshake payloads
85// - Server requests a new attestation in the second message
86// - Client sends the attestation in the third message
87//
88// User story 6 - webhook trigger from an unknown client to an unknown server:
89//
90// Client is running inside an enclave.
91// Server is running inside an enclave.
92// Client knows the expected PCRs of the server.
93// Server knows the expected PCRs of the client.
94//
95// Client wants to create a secure channel with the server and trigger a webhook.
96//
97// Client to server requires a security level of 2/5.
98// Server to client requires a security level of 2/5.
99// XX is the minimum viable handshake.
100// With 1 RTT client delay and 1.5RTT server delay.
101// With additional handshake payloads
102// - Client requests a new attestation in the first message
103// - Server sends the attestation in the second message
104// - Server requests a new attestation in the second message
105// - Client sends the attestation in the third message
106//
107// Attestation efficiency:
108//
109// How does the server know whether to request a new attestation or not?
110// The client sends the static key only in the third message.
111//
112// Either switch to I* handshakes or incur additional messages and RTT delays.
113//
114// How does the client know whether to request a new attestation or not?
115// The server sends the static key only in the second message.
116//
117// Nothing can really be done since it is the first message sent by the server.
118// TLS always sends certificates to work around this, but this seems very inefficient.
119//
120// Worst case here is 2 RTT client delay and 1.5 RTT server delay.
121//
122// Once cached on both sides, 1 RTT client delay and 1.5 RTT server delay.
123// (Server still has to wait for the client to request attestation or not)
124//
125// User story considerations:
126// - various handshake shapes
127// - various security levels
128// - handshake latency
129// - handshake efficiency
130// - not having to send attestations unless requested by the other party
131// - (questionable?) not having to send static keys unless requested by the other party
132//
133// General considerations:
134// - Different cipher suites
135// - Protocol evolution
136//
137// Conclusions:
138// Pick IX as the Noise Protocol
139// - Most flexible and covers wide variety of use cases
140// - At the cost of a higher server delay
141// - At the cost of a lower security level for handshake messages themselves
142// - At the cost of handshake messages being larger
143// - But allows the significantly larger attestations to be optional in both directions
144//
145// Pick NoiseSocket as the negotiation protocol
146//
147// TODOs:
148// - (desirable?) 0RTT
149// - main concern is replay attacks
150
151// TODO: vectored reads/writes
152
153use snow::{Builder, TransportState};
154use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
155
156#[derive(Debug, thiserror::Error)]
157pub enum ScallopError {
158 #[error("failed to init builder")]
159 InitFailed(#[source] snow::Error),
160 #[error("transport error")]
161 TransportError(#[from] tokio::io::Error),
162 #[error("noise error")]
163 NoiseError(#[from] snow::Error),
164 #[error("protocol error")]
165 ProtocolError(String),
166 #[error("auth error")]
167 AuthError(String),
168}
169
170#[derive(Debug, PartialEq)]
171enum ReadMode {
172 Length,
173 Body,
174 Read,
175}
176
177pub type Key = [u8; 32];
178#[derive(PartialEq)]
179pub enum ContainsResponse<State: PartialEq> {
180 // the key was found and is approved
181 Approved(State),
182 // the key was not found
183 NotFound,
184 // the key is rejected
185 Rejected,
186}
187
188pub trait ScallopAuthStore {
189 type State: PartialEq;
190
191 // intended as a caching mechanism so attestations do not have to be
192 // requested every time, always returning NotFound is valid
193 fn contains(&mut self, _key: &Key) -> ContainsResponse<Self::State> {
194 ContainsResponse::NotFound
195 }
196 fn verify(&mut self, attestation: &[u8], key: Key) -> Option<Self::State>;
197}
198
199impl<T: ScallopAuthStore> ScallopAuthStore for &mut T {
200 type State = T::State;
201
202 fn contains(&mut self, key: &Key) -> ContainsResponse<Self::State> {
203 (**self).contains(key)
204 }
205
206 fn verify(&mut self, attestation: &[u8], key: Key) -> Option<Self::State> {
207 (**self).verify(attestation, key)
208 }
209}
210
211// to let callers pass in None with empty type
212impl ScallopAuthStore for () {
213 type State = ();
214
215 fn contains(&mut self, _key: &Key) -> ContainsResponse<Self::State> {
216 unimplemented!()
217 }
218
219 fn verify(&mut self, _attestation: &[u8], _key: Key) -> Option<Self::State> {
220 unimplemented!()
221 }
222}
223
224// Send bound not ideal
225// Investigate both Send and non-Send versions
226pub trait ScallopAuther: Send {
227 type Error: std::fmt::Debug;
228 fn new_auth(
229 &mut self,
230 ) -> impl std::future::Future<Output = Result<Box<[u8]>, Self::Error>> + Send;
231}
232
233impl<T: ScallopAuther> ScallopAuther for &mut T {
234 type Error = T::Error;
235 async fn new_auth(&mut self) -> Result<Box<[u8]>, T::Error> {
236 (**self).new_auth().await
237 }
238}
239
240// to let callers pass in None with empty type
241impl ScallopAuther for () {
242 type Error = ();
243 // was not able to implement in the impl Future form
244 // requires higher MSRV
245 async fn new_auth(&mut self) -> Result<Box<[u8]>, ()> {
246 unimplemented!();
247 }
248}
249
250#[derive(Debug)]
251pub struct ScallopStream<Stream: AsyncWrite + AsyncRead + Unpin, State = ()> {
252 noise: TransportState,
253 stream: Stream,
254
255 // read buffer
256 rbuf: Box<[u8]>,
257 pending: usize,
258 mode: ReadMode,
259 read_end: usize,
260 read_start: usize,
261
262 // write buffer
263 wbuf: Box<[u8]>,
264 write_start: usize,
265 write_end: usize,
266
267 // extra connection state from the AuthStore if any
268 pub state: Option<State>,
269}
270
271trait Noiser {
272 fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error>;
273 fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error>;
274}
275
276impl Noiser for snow::HandshakeState {
277 fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
278 self.read_message(payload, message)
279 }
280
281 fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
282 self.write_message(payload, message)
283 }
284}
285
286impl Noiser for snow::TransportState {
287 fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
288 snow::TransportState::read_message(self, payload, message)
289 }
290
291 fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
292 snow::TransportState::write_message(self, payload, message)
293 }
294}
295
296async fn noise_read(
297 noise: &mut impl Noiser,
298 stream: &mut (impl AsyncRead + Unpin),
299 src: &mut [u8],
300 dst: &mut [u8],
301) -> Result<usize, ScallopError> {
302 // read noise message length
303 let len = stream.read_u16().await? as usize;
304
305 // check if buffer is big enough
306 if len > src.len() {
307 return Err(ScallopError::ProtocolError("message too big".into()));
308 }
309
310 // read handshake message
311 stream.read_exact(&mut src[0..len]).await?;
312
313 // handle handshake message
314 let len = noise.read_message(&src[0..len], dst)?;
315
316 Ok(len)
317}
318
319async fn noise_write(
320 noise: &mut impl Noiser,
321 stream: &mut (impl AsyncWrite + Unpin),
322 src: &[u8],
323 dst: &mut [u8],
324 // in case dst has data encoded already
325 dst_offset: usize,
326) -> Result<(), ScallopError> {
327 // set noise message
328 let len = noise
329 .write_message(src, &mut dst[dst_offset + 2..])
330 .map_err(std::io::Error::other)?;
331
332 // set length
333 dst[dst_offset..dst_offset + 2].copy_from_slice(&(len as u16).to_be_bytes());
334
335 // send
336 stream.write_all(&dst[0..dst_offset + len + 2]).await?;
337 stream.flush().await?;
338
339 Ok(())
340}
341
342#[allow(non_snake_case)]
343pub async fn new_client_async_Noise_IX_25519_ChaChaPoly_BLAKE2b<
344 Base: AsyncWrite + AsyncRead + Unpin,
345 AS: ScallopAuthStore,
346>(
347 mut stream: Base,
348 secret: &[u8; 32],
349 // will not auth remote if None
350 mut auth_store: Option<AS>,
351 // will not respond to auth requests if None
352 auther: Option<impl ScallopAuther>,
353) -> Result<ScallopStream<Base, AS::State>, ScallopError> {
354 let mut buf = vec![0u8; 65000].into_boxed_slice();
355 let mut noise_buf = vec![0u8; 65000].into_boxed_slice();
356
357 let prologue = b"NoiseSocketInit1\x00\x00";
358
359 let mut noise = Builder::new(
360 "Noise_IX_25519_ChaChaPoly_BLAKE2b"
361 .parse()
362 .map_err(ScallopError::InitFailed)?,
363 )
364 .local_private_key(secret)
365 .prologue(prologue)
366 .build_initiator()
367 .map_err(ScallopError::InitFailed)?;
368
369 //---- -> e, s start ----//
370
371 // first two bytes are already zero, skip writing negotiation payload
372
373 // encode and send handshake message
374 noise_write(&mut noise, &mut stream, &[], &mut buf, 2).await?;
375
376 //---- -> e, s end ----//
377
378 //---- <- e, ee, se, s, es start ----//
379
380 // read negotiation length
381 let len = stream.read_u16().await?;
382
383 // length should be zero
384 if len != 0 {
385 return Err(ScallopError::ProtocolError(
386 "non zero second negotiation length".into(),
387 ));
388 }
389
390 // read and handle handshake message
391 let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
392
393 // handshake payload should contain auth request
394 if len != 3 || noise_buf[0] != 0 || noise_buf[1] != 1 {
395 return Err(ScallopError::ProtocolError(
396 "invalid second payload length".into(),
397 ));
398 }
399
400 // auth request should be 0 or 1
401 if noise_buf[2] > 1 {
402 return Err(ScallopError::ProtocolError(
403 "invalid auth request in second payload".into(),
404 ));
405 }
406
407 let should_send_auth = noise_buf[2] == 1;
408
409 //---- <- e, ee, se, s, es end ----//
410
411 // check if auth is possible
412 if should_send_auth && auther.is_none() {
413 // auth requested and no auther available
414 // error out
415 return Err(ScallopError::ProtocolError(
416 "auth requested but no auther available".into(),
417 ));
418 }
419
420 // safe to unwrap since IX should have key by now
421 let remote_static: [u8; 32] = noise.get_remote_static().unwrap().try_into().unwrap();
422
423 let contains = auth_store.as_mut().map(|x| x.contains(&remote_static));
424 // error out if key is considered rejected
425 if contains == Some(ContainsResponse::Rejected) {
426 return Err(ScallopError::ProtocolError(
427 "remote static key rejected".into(),
428 ));
429 }
430
431 // start tracking what state should go in the stream
432 // it is expected to be set if contains returns state
433 // or verify returns state
434 let mut state = None::<AS::State>;
435
436 let should_ask_auth = contains == Some(ContainsResponse::NotFound);
437
438 // set state immediately if key was approved previously
439 if let Some(ContainsResponse::Approved(_state)) = contains {
440 state = Some(_state);
441 }
442
443 // handshake is done, switch to transport mode
444 let mut noise = noise.into_transport_mode()?;
445
446 //---- -> CLIENTFIN start ----//
447 //
448 // not part of the noise protocol, needed for optional attestations
449 //
450 // first two bytes length
451 // 0x00 for no auth request, 0x01 for auth request
452 // two bytes payload size
453 // payload
454
455 async fn send_CLIENTFIN(
456 noise: &mut impl Noiser,
457 stream: &mut (impl AsyncWrite + Unpin),
458 buf: &mut [u8],
459 noise_buf: &mut [u8],
460 payload: &[u8],
461 should_ask_auth: bool,
462 ) -> Result<(), ScallopError> {
463 // assemble message for encryption
464 noise_buf[0] = if !should_ask_auth { 0 } else { 1 };
465 // safe to cast since range has been checked above
466 noise_buf[1..3].copy_from_slice(&(payload.len() as u16).to_be_bytes());
467 noise_buf[3..3 + payload.len()].copy_from_slice(payload);
468
469 // encode and send handshake message
470 noise_write(noise, stream, &noise_buf[0..payload.len() + 3], buf, 0).await?;
471
472 Ok(())
473 }
474
475 if should_send_auth {
476 // safe to unwrap since it has been checked above
477 let payload = auther
478 .unwrap()
479 .new_auth()
480 .await
481 .map_err(|e| ScallopError::AuthError(format!("{e:?}")))?;
482 // check if payload is not too big
483 if payload.len() > 60000 {
484 return Err(ScallopError::ProtocolError("auth payload too big".into()));
485 }
486
487 send_CLIENTFIN(
488 &mut noise,
489 &mut stream,
490 &mut buf,
491 &mut noise_buf,
492 &payload,
493 should_ask_auth,
494 )
495 .await?;
496 } else {
497 send_CLIENTFIN(
498 &mut noise,
499 &mut stream,
500 &mut buf,
501 &mut noise_buf,
502 &[],
503 should_ask_auth,
504 )
505 .await?;
506 }
507
508 //---- -> CLIENTFIN end ----//
509
510 //---- <- SERVERFIN start ----//
511 //
512 // not part of the noise protocol, needed for optional attestations
513 //
514 // first two bytes length
515 // two bytes payload size
516 // payload
517
518 if should_ask_auth {
519 // read and handle handshake message
520 let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
521
522 // should have at least 2 size
523 if len < 2 {
524 return Err(ScallopError::ProtocolError(
525 "invalid SERVERFIN length".into(),
526 ));
527 }
528
529 // payload size should match
530 if u16::from_be_bytes([noise_buf[0], noise_buf[1]]) as usize != len - 2 {
531 return Err(ScallopError::ProtocolError(
532 "invalid SERVERFIN payload length".into(),
533 ));
534 }
535
536 // verify
537 let Some(_state) = auth_store
538 .as_mut()
539 .unwrap()
540 .verify(&noise_buf[2..len], remote_static)
541 else {
542 return Err(ScallopError::ProtocolError("invalid attestation".into()));
543 };
544
545 // set state
546 state = Some(_state)
547 }
548
549 //---- <- SERVERFIN end ----//
550
551 Ok(ScallopStream {
552 noise,
553 stream,
554 // initialize with 2 sized buffer to read length
555 rbuf: vec![0u8; 2].into_boxed_slice(),
556 pending: 2,
557 mode: ReadMode::Length,
558 read_start: 0,
559 read_end: 0,
560 wbuf: vec![].into_boxed_slice(),
561 write_start: 0,
562 write_end: 0,
563 state,
564 })
565}
566
567#[allow(non_snake_case)]
568pub async fn new_server_async_Noise_IX_25519_ChaChaPoly_BLAKE2b<
569 Base: AsyncWrite + AsyncRead + Unpin,
570 AS: ScallopAuthStore,
571>(
572 mut stream: Base,
573 secret: &[u8; 32],
574 // will not auth remote if None
575 mut auth_store: Option<AS>,
576 // will not respond to auth requests if None
577 auther: Option<impl ScallopAuther>,
578) -> Result<ScallopStream<Base, AS::State>, ScallopError> {
579 let mut buf = vec![0u8; 65000].into_boxed_slice();
580 let mut noise_buf = vec![0u8; 65000].into_boxed_slice();
581
582 let prologue = b"NoiseSocketInit1\x00\x00";
583
584 let mut noise = Builder::new(
585 "Noise_IX_25519_ChaChaPoly_BLAKE2b"
586 .parse()
587 .map_err(ScallopError::InitFailed)?,
588 )
589 .local_private_key(secret)
590 .prologue(prologue)
591 .build_responder()
592 .map_err(ScallopError::InitFailed)?;
593
594 //---- -> e, s start ----//
595
596 // read negotiation length
597 let len = stream.read_u16().await?;
598
599 // length should be zero
600 if len != 0 {
601 return Err(ScallopError::ProtocolError(
602 "non zero first negotiation length".into(),
603 ));
604 }
605
606 // read and handle handshake message
607 let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
608
609 // handshake payload should be empty
610 if len != 0 {
611 return Err(ScallopError::ProtocolError(
612 "non zero first handshake payload".into(),
613 ));
614 }
615
616 //---- -> e, s end ----//
617
618 //---- <- e, ee, se, s, es start ----//
619
620 // negotiation length
621 buf[0..2].copy_from_slice(&0u16.to_be_bytes());
622
623 // request auth if auth_store is available
624 // and static key is not found in the auth store
625 let remote_static: [u8; 32] = noise
626 .get_remote_static()
627 .expect("handshake should have static key by now")
628 .try_into()
629 .expect("expected 32 byte key");
630
631 let contains = auth_store.as_mut().map(|x| x.contains(&remote_static));
632 // error out if key is considered rejected
633 if contains == Some(ContainsResponse::Rejected) {
634 return Err(ScallopError::ProtocolError(
635 "remote static key rejected".into(),
636 ));
637 }
638
639 // start tracking what state should go in the stream
640 // it is expected to be set if contains returns state
641 // or verify returns state
642 let mut state = None::<AS::State>;
643
644 let should_ask_auth = contains == Some(ContainsResponse::NotFound);
645
646 // set state immediately if key was approved previously
647 if let Some(ContainsResponse::Approved(_state)) = contains {
648 state = Some(_state);
649 }
650
651 let payload = &[0u8, 1u8, if !should_ask_auth { 0u8 } else { 1u8 }];
652
653 // encode and send handshake message
654 noise_write(&mut noise, &mut stream, payload, &mut buf, 2).await?;
655
656 //---- <- e, ee, se, s, es end ----//
657
658 // handshake is done, switch to transport mode
659 let mut noise = noise.into_transport_mode()?;
660
661 //---- -> CLIENTFIN start ----//
662 //
663 // not part of the noise protocol, needed for optional attestations
664 //
665 // first two bytes length
666 // 0x00 for no auth request, 0x01 for auth request
667 // two bytes payload size
668 // payload
669
670 // read and handle handshake message
671 let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
672
673 // should have at least 3 size
674 if len < 3 {
675 return Err(ScallopError::ProtocolError(
676 "invalid CLIENTFIN length".into(),
677 ));
678 }
679
680 // payload size should match
681 if u16::from_be_bytes([noise_buf[1], noise_buf[2]]) as usize != len - 3 {
682 return Err(ScallopError::ProtocolError(
683 "invalid CLIENTFIN payload length".into(),
684 ));
685 }
686
687 // verify auth if we asked for it
688 if should_ask_auth {
689 // verify
690 let Some(_state) = auth_store
691 .as_mut()
692 .unwrap()
693 .verify(&noise_buf[3..len], remote_static)
694 else {
695 return Err(ScallopError::ProtocolError("invalid attestation".into()));
696 };
697
698 // set state
699 state = Some(_state)
700 }
701
702 // auth request should be 0 or 1
703 if noise_buf[0] > 1 {
704 return Err(ScallopError::ProtocolError(
705 "invalid auth request in third payload".into(),
706 ));
707 }
708
709 let should_send_auth = noise_buf[0] == 1;
710
711 //---- -> CLIENTFIN end ----//
712
713 // check if auth is possible
714 if should_send_auth && auther.is_none() {
715 // auth requested and no auther available
716 // error out
717 return Err(ScallopError::ProtocolError(
718 "auth requested but no auther available".into(),
719 ));
720 }
721
722 //---- <- SERVERFIN start ----//
723 //
724 // not part of the noise protocol, needed for optional attestations
725 //
726 // first two bytes length
727 // two bytes payload size
728 // payload
729
730 if should_send_auth {
731 // safe to unwrap since it has been checked above
732 let payload = auther
733 .unwrap()
734 .new_auth()
735 .await
736 .map_err(|e| ScallopError::AuthError(format!("{e:?}")))?;
737 // check if payload is not too big
738 if payload.len() > 60000 {
739 return Err(ScallopError::ProtocolError("auth payload too big".into()));
740 }
741
742 // safe to cast since range has been checked above
743 noise_buf[0..2].copy_from_slice(&(payload.len() as u16).to_be_bytes());
744 noise_buf[2..2 + payload.len()].copy_from_slice(&payload);
745
746 // encode and send handshake message
747 noise_write(
748 &mut noise,
749 &mut stream,
750 &noise_buf[0..payload.len() + 2],
751 &mut buf,
752 0,
753 )
754 .await?;
755 }
756
757 //---- <- SERVERFIN end ----//
758
759 Ok(ScallopStream {
760 noise,
761 stream,
762 // initialize with 2 sized buffer to read length
763 rbuf: vec![0u8; 2].into_boxed_slice(),
764 pending: 2,
765 mode: ReadMode::Length,
766 read_start: 0,
767 read_end: 0,
768 wbuf: vec![].into_boxed_slice(),
769 write_start: 0,
770 write_end: 0,
771 state,
772 })
773}
774
775impl<Base: AsyncWrite + AsyncRead + Unpin, State> ScallopStream<Base, State> {
776 pub fn get_remote_static(&self) -> Option<[u8; 32]> {
777 self.noise
778 .get_remote_static()
779 .map(|x| x.try_into().expect("expected 32 byte key"))
780 }
781}
782
783impl<Base: AsyncWrite + AsyncRead + Unpin, State: Unpin> AsyncRead for ScallopStream<Base, State> {
784 // IMPORTANT: Return Pending only as a direct result of base returning Pending
785 // Ensures wakers are set up correctly
786 fn poll_read(
787 self: std::pin::Pin<&mut Self>,
788 cx: &mut std::task::Context<'_>,
789 buf: &mut tokio::io::ReadBuf<'_>,
790 ) -> std::task::Poll<std::io::Result<()>> {
791 let stream = self.get_mut();
792 loop {
793 while stream.pending != 0 {
794 let base = std::pin::pin!(&mut stream.stream);
795
796 // do not have enough data, try to read more
797 let len = stream.rbuf.len();
798 let mut buf = ReadBuf::new(&mut stream.rbuf[(len - stream.pending)..]);
799 std::task::ready!(base.poll_read(cx, &mut buf))?;
800
801 // check eof
802 if buf.filled().is_empty() {
803 return std::task::Poll::Ready(Ok(()));
804 }
805 stream.pending -= buf.filled().len();
806 }
807
808 // pending should always be 0 after this point
809
810 if stream.mode == ReadMode::Length {
811 // we have read the length
812
813 // parse length
814 let record_length = u16::from_be_bytes(stream.rbuf[0..2].try_into().unwrap());
815
816 // set up to read record
817 stream.pending = record_length.into();
818 stream.mode = ReadMode::Body;
819 stream.rbuf = vec![0u8; stream.pending].into_boxed_slice();
820 } else if stream.mode == ReadMode::Body {
821 // we have the data
822
823 // process as noise message
824 let len = stream
825 .noise
826 .read_message(&stream.rbuf.clone(), &mut stream.rbuf)
827 .map_err(std::io::Error::other)?;
828
829 // set up to send body upstream
830 stream.read_start = 0;
831 stream.read_end = len;
832 stream.mode = ReadMode::Read;
833 } else {
834 if buf.remaining() < stream.read_end - stream.read_start {
835 // can transmit only partial
836 let read_start = stream.read_start;
837 stream.read_start += buf.remaining();
838 let read_end = read_start + buf.remaining();
839 buf.put_slice(&stream.rbuf[read_start..read_end]);
840 } else {
841 // can transmit full
842 buf.put_slice(&stream.rbuf[stream.read_start..stream.read_end]);
843
844 stream.rbuf = vec![0u8; 2].into_boxed_slice();
845 stream.pending = 2;
846 stream.mode = ReadMode::Length;
847 }
848 return std::task::Poll::Ready(Ok(()));
849 }
850 }
851 }
852}
853
854impl<Base: AsyncWrite + AsyncRead + Unpin, State: Unpin> AsyncWrite for ScallopStream<Base, State> {
855 // IMPORTANT: Return Pending only as a direct result of base returning Pending
856 // Ensures wakers are set up correctly
857 fn poll_write(
858 mut self: std::pin::Pin<&mut Self>,
859 cx: &mut std::task::Context<'_>,
860 buf: &[u8],
861 ) -> std::task::Poll<Result<usize, std::io::Error>> {
862 // flush existing data first
863 std::task::ready!(self.as_mut().poll_flush(cx))?;
864
865 let mut stream = self.as_mut();
866
867 // construct new buf
868 // up to 64000 bytes at once
869 let len = std::cmp::min(buf.len(), 64000) as u16;
870 let mut new_buf = vec![0u8; len as usize + 1000].into_boxed_slice();
871
872 // set noise message
873 let noise_len = stream
874 .noise
875 .write_message(&buf[0..len as usize], &mut new_buf[2..])
876 .map_err(std::io::Error::other)?;
877
878 // set length
879 new_buf[0..2].copy_from_slice(&(noise_len as u16).to_be_bytes());
880
881 // queue up new buf
882 stream.wbuf = new_buf;
883 stream.write_start = 0;
884 stream.write_end = noise_len + 2;
885
886 // TODO: Should we flush here so it does not need to be called in the common case?
887 // How do we implement this?
888 //
889 // Not sure how the semantics will play out though.
890 //
891 // Happy path looks great.
892 // We make a call to poll_flush, it returns Ready and we return Ready with length.
893 //
894 // But what if it returns Pending?
895 // If we return Pending, the caller will assume nothing was sent.
896 // If we return Ready, polL_flush has potentially set up wakers.
897 // What happens on repeated calls? Unsure if it is supposed to be idempotent.
898
899 std::task::Poll::Ready(Ok(len as usize))
900 }
901
902 // IMPORTANT: Return Pending only as a direct result of base returning Pending
903 // Ensures wakers are set up correctly
904 fn poll_flush(
905 self: std::pin::Pin<&mut Self>,
906 cx: &mut std::task::Context<'_>,
907 ) -> std::task::Poll<Result<(), std::io::Error>> {
908 let stream = self.get_mut();
909
910 while stream.write_start != stream.write_end {
911 let base = std::pin::pin!(&mut stream.stream);
912
913 // try to send existing messages first
914 let size = std::task::ready!(
915 base.poll_write(cx, &stream.wbuf[stream.write_start..stream.write_end])
916 )?;
917 stream.write_start += size;
918 }
919
920 // flush data after write since base could be buffered
921 let base = std::pin::pin!(&mut stream.stream);
922 base.poll_flush(cx)
923 }
924
925 // IMPORTANT: Return Pending only as a direct result of base returning Pending
926 // Ensures wakers are set up correctly
927 //
928 // Shutdown is supposed to be graceful
929 //
930 // From the tokio docs:
931 // Invocation of a shutdown implies an invocation of flush.
932 // Once this method returns Ready it implies that a flush successfully happened
933 // before the shutdown happened. That is, callers don’t need to call flush before
934 // calling shutdown. They can rely that by calling shutdown any pending buffered
935 // data will be written out.
936 fn poll_shutdown(
937 mut self: std::pin::Pin<&mut Self>,
938 cx: &mut std::task::Context<'_>,
939 ) -> std::task::Poll<Result<(), std::io::Error>> {
940 // flush data for graceful shutdowns
941 std::task::ready!(self.as_mut().poll_flush(cx))?;
942
943 let stream = self.get_mut();
944 let base = std::pin::pin!(&mut stream.stream);
945
946 base.poll_shutdown(cx)
947 }
948}