marlin/
scallop.rs

1// Scallop protocol
2//
3// Broad goals:
4// - Transport layer security (the concept not the protocol)
5// - Enclave native protocol (no Web PKI)
6// - Modern cryptography
7//
8// Handshake shape and security levels in the user stories are modelled on the Noise protocol specification.
9//
10// User story 1 - HTTP query against known server:
11//
12// Server is running inside an enclave.
13// Client has authenticated the attestation of the enclave and has the static key.
14//
15// Client wants to create a secure channel with the server to make a HTTP query.
16//
17// Client to server requires a security level of 0/5.
18// Server to client requires a security level of 2/1.
19// NK is the minimum viable handshake.
20// With 1RTT client delay and 0.5RTT server delay.
21//
22// Bonuses:
23// - Authentication refresh on expiry
24//   - Client can request a new attestation in the first message
25//   - Server can send the attestation in the second message
26//
27// User story 2 - HTTP query against unknown server:
28//
29// Server is running inside an enclave.
30// Client knows the expected PCRs of the server.
31//
32// Client wants to create a secure channel with the server to make a HTTP query.
33//
34// Client to server requires a security level of 0/5.
35// Server to client requires a security level of 2/1.
36// NX is the minimum viable handshake.
37// With 1RTT client delay and 0.5RTT server delay.
38// With additional handshake payloads
39//   - Client requests a new attestation in the first message
40//   - Server sends the attestation in the second message
41//
42// User story 3 - webhook trigger from a known client to a known server:
43//
44// Client is running inside an enclave.
45// Client has the static key of the server.
46// Server has previously authenticated the attestation of the client and has the static key.
47//
48// Client wants to create a secure channel with the server and trigger a webhook.
49//
50// Client to server requires a security level of 2/5.
51// Server to client requires a security level of 2/5.
52// KK is the minimum viable handshake.
53// With 1 RTT client delay and 0.5RTT server delay.
54//
55// User story 4 - webhook trigger from a known client to an unknown server:
56//
57// Client is running inside an enclave.
58// Server is running inside an enclave.
59// Client knows the expected PCRs of the server.
60// Server has previously authenticated the attestation of the client and has the static key.
61//
62// Client wants to create a secure channel with the server and trigger a webhook.
63//
64// Client to server requires a security level of 2/5.
65// Server to client requires a security level of 2/5.
66// KX is the minimum viable handshake.
67// With 1 RTT client delay and 0.5RTT server delay.
68// With additional handshake payloads
69//   - Client requests a new attestation in the first message
70//   - Server sends the attestation in the second message
71//
72// User story 5 - webhook trigger from an unknown client to a known server:
73//
74// Client is running inside an enclave.
75// Client has the static key of the server.
76// Server knows the expected PCRs of the client.
77//
78// Client wants to create a secure channel with the server and trigger a webhook.
79//
80// Client to server requires a security level of 2/5.
81// Server to client requires a security level of 2/5.
82// XK is the minimum viable handshake.
83// With 1 RTT client delay and 1.5RTT server delay.
84// With additional handshake payloads
85//   - Server requests a new attestation in the second message
86//   - Client sends the attestation in the third message
87//
88// User story 6 - webhook trigger from an unknown client to an unknown server:
89//
90// Client is running inside an enclave.
91// Server is running inside an enclave.
92// Client knows the expected PCRs of the server.
93// Server knows the expected PCRs of the client.
94//
95// Client wants to create a secure channel with the server and trigger a webhook.
96//
97// Client to server requires a security level of 2/5.
98// Server to client requires a security level of 2/5.
99// XX is the minimum viable handshake.
100// With 1 RTT client delay and 1.5RTT server delay.
101// With additional handshake payloads
102//   - Client requests a new attestation in the first message
103//   - Server sends the attestation in the second message
104//   - Server requests a new attestation in the second message
105//   - Client sends the attestation in the third message
106//
107// Attestation efficiency:
108//
109// How does the server know whether to request a new attestation or not?
110// The client sends the static key only in the third message.
111//
112// Either switch to I* handshakes or incur additional messages and RTT delays.
113//
114// How does the client know whether to request a new attestation or not?
115// The server sends the static key only in the second message.
116//
117// Nothing can really be done since it is the first message sent by the server.
118// TLS always sends certificates to work around this, but this seems very inefficient.
119//
120// Worst case here is 2 RTT client delay and 1.5 RTT server delay.
121//
122// Once cached on both sides, 1 RTT client delay and 1.5 RTT server delay.
123// (Server still has to wait for the client to request attestation or not)
124//
125// User story considerations:
126// - various handshake shapes
127// - various security levels
128// - handshake latency
129// - handshake efficiency
130//   - not having to send attestations unless requested by the other party
131//   - (questionable?) not having to send static keys unless requested by the other party
132//
133// General considerations:
134// - Different cipher suites
135// - Protocol evolution
136//
137// Conclusions:
138// Pick IX as the Noise Protocol
139// - Most flexible and covers wide variety of use cases
140// - At the cost of a higher server delay
141// - At the cost of a lower security level for handshake messages themselves
142// - At the cost of handshake messages being larger
143// - But allows the significantly larger attestations to be optional in both directions
144//
145// Pick NoiseSocket as the negotiation protocol
146//
147// TODOs:
148// - (desirable?) 0RTT
149//   - main concern is replay attacks
150
151// TODO: vectored reads/writes
152
153use snow::{Builder, TransportState};
154use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
155
156#[derive(Debug, thiserror::Error)]
157pub enum ScallopError {
158    #[error("failed to init builder")]
159    InitFailed(#[source] snow::Error),
160    #[error("transport error")]
161    TransportError(#[from] tokio::io::Error),
162    #[error("noise error")]
163    NoiseError(#[from] snow::Error),
164    #[error("protocol error")]
165    ProtocolError(String),
166    #[error("auth error")]
167    AuthError(String),
168}
169
170#[derive(Debug, PartialEq)]
171enum ReadMode {
172    Length,
173    Body,
174    Read,
175}
176
177pub type Key = [u8; 32];
178#[derive(PartialEq)]
179pub enum ContainsResponse<State: PartialEq> {
180    // the key was found and is approved
181    Approved(State),
182    // the key was not found
183    NotFound,
184    // the key is rejected
185    Rejected,
186}
187
188pub trait ScallopAuthStore {
189    type State: PartialEq;
190
191    // intended as a caching mechanism so attestations do not have to be
192    // requested every time, always returning NotFound is valid
193    fn contains(&mut self, _key: &Key) -> ContainsResponse<Self::State> {
194        ContainsResponse::NotFound
195    }
196    fn verify(&mut self, attestation: &[u8], key: Key) -> Option<Self::State>;
197}
198
199impl<T: ScallopAuthStore> ScallopAuthStore for &mut T {
200    type State = T::State;
201
202    fn contains(&mut self, key: &Key) -> ContainsResponse<Self::State> {
203        (**self).contains(key)
204    }
205
206    fn verify(&mut self, attestation: &[u8], key: Key) -> Option<Self::State> {
207        (**self).verify(attestation, key)
208    }
209}
210
211// to let callers pass in None with empty type
212impl ScallopAuthStore for () {
213    type State = ();
214
215    fn contains(&mut self, _key: &Key) -> ContainsResponse<Self::State> {
216        unimplemented!()
217    }
218
219    fn verify(&mut self, _attestation: &[u8], _key: Key) -> Option<Self::State> {
220        unimplemented!()
221    }
222}
223
224// Send bound not ideal
225// Investigate both Send and non-Send versions
226pub trait ScallopAuther: Send {
227    type Error: std::fmt::Debug;
228    fn new_auth(
229        &mut self,
230    ) -> impl std::future::Future<Output = Result<Box<[u8]>, Self::Error>> + Send;
231}
232
233impl<T: ScallopAuther> ScallopAuther for &mut T {
234    type Error = T::Error;
235    async fn new_auth(&mut self) -> Result<Box<[u8]>, T::Error> {
236        (**self).new_auth().await
237    }
238}
239
240// to let callers pass in None with empty type
241impl ScallopAuther for () {
242    type Error = ();
243    // was not able to implement in the impl Future form
244    // requires higher MSRV
245    async fn new_auth(&mut self) -> Result<Box<[u8]>, ()> {
246        unimplemented!();
247    }
248}
249
250#[derive(Debug)]
251pub struct ScallopStream<Stream: AsyncWrite + AsyncRead + Unpin, State = ()> {
252    noise: TransportState,
253    stream: Stream,
254
255    // read buffer
256    rbuf: Box<[u8]>,
257    pending: usize,
258    mode: ReadMode,
259    read_end: usize,
260    read_start: usize,
261
262    // write buffer
263    wbuf: Box<[u8]>,
264    write_start: usize,
265    write_end: usize,
266
267    // extra connection state from the AuthStore if any
268    pub state: Option<State>,
269}
270
271trait Noiser {
272    fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error>;
273    fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error>;
274}
275
276impl Noiser for snow::HandshakeState {
277    fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
278        self.read_message(payload, message)
279    }
280
281    fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
282        self.write_message(payload, message)
283    }
284}
285
286impl Noiser for snow::TransportState {
287    fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
288        snow::TransportState::read_message(self, payload, message)
289    }
290
291    fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, snow::Error> {
292        snow::TransportState::write_message(self, payload, message)
293    }
294}
295
296async fn noise_read(
297    noise: &mut impl Noiser,
298    stream: &mut (impl AsyncRead + Unpin),
299    src: &mut [u8],
300    dst: &mut [u8],
301) -> Result<usize, ScallopError> {
302    // read noise message length
303    let len = stream.read_u16().await? as usize;
304
305    // check if buffer is big enough
306    if len > src.len() {
307        return Err(ScallopError::ProtocolError("message too big".into()));
308    }
309
310    // read handshake message
311    stream.read_exact(&mut src[0..len]).await?;
312
313    // handle handshake message
314    let len = noise.read_message(&src[0..len], dst)?;
315
316    Ok(len)
317}
318
319async fn noise_write(
320    noise: &mut impl Noiser,
321    stream: &mut (impl AsyncWrite + Unpin),
322    src: &[u8],
323    dst: &mut [u8],
324    // in case dst has data encoded already
325    dst_offset: usize,
326) -> Result<(), ScallopError> {
327    // set noise message
328    let len = noise
329        .write_message(src, &mut dst[dst_offset + 2..])
330        .map_err(std::io::Error::other)?;
331
332    // set length
333    dst[dst_offset..dst_offset + 2].copy_from_slice(&(len as u16).to_be_bytes());
334
335    // send
336    stream.write_all(&dst[0..dst_offset + len + 2]).await?;
337    stream.flush().await?;
338
339    Ok(())
340}
341
342#[allow(non_snake_case)]
343pub async fn new_client_async_Noise_IX_25519_ChaChaPoly_BLAKE2b<
344    Base: AsyncWrite + AsyncRead + Unpin,
345    AS: ScallopAuthStore,
346>(
347    mut stream: Base,
348    secret: &[u8; 32],
349    // will not auth remote if None
350    mut auth_store: Option<AS>,
351    // will not respond to auth requests if None
352    auther: Option<impl ScallopAuther>,
353) -> Result<ScallopStream<Base, AS::State>, ScallopError> {
354    let mut buf = vec![0u8; 65000].into_boxed_slice();
355    let mut noise_buf = vec![0u8; 65000].into_boxed_slice();
356
357    let prologue = b"NoiseSocketInit1\x00\x00";
358
359    let mut noise = Builder::new(
360        "Noise_IX_25519_ChaChaPoly_BLAKE2b"
361            .parse()
362            .map_err(ScallopError::InitFailed)?,
363    )
364    .local_private_key(secret)
365    .prologue(prologue)
366    .build_initiator()
367    .map_err(ScallopError::InitFailed)?;
368
369    //---- -> e, s start ----//
370
371    // first two bytes are already zero, skip writing negotiation payload
372
373    // encode and send handshake message
374    noise_write(&mut noise, &mut stream, &[], &mut buf, 2).await?;
375
376    //---- -> e, s end ----//
377
378    //---- <- e, ee, se, s, es start ----//
379
380    // read negotiation length
381    let len = stream.read_u16().await?;
382
383    // length should be zero
384    if len != 0 {
385        return Err(ScallopError::ProtocolError(
386            "non zero second negotiation length".into(),
387        ));
388    }
389
390    // read and handle handshake message
391    let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
392
393    // handshake payload should contain auth request
394    if len != 3 || noise_buf[0] != 0 || noise_buf[1] != 1 {
395        return Err(ScallopError::ProtocolError(
396            "invalid second payload length".into(),
397        ));
398    }
399
400    // auth request should be 0 or 1
401    if noise_buf[2] > 1 {
402        return Err(ScallopError::ProtocolError(
403            "invalid auth request in second payload".into(),
404        ));
405    }
406
407    let should_send_auth = noise_buf[2] == 1;
408
409    //---- <- e, ee, se, s, es end ----//
410
411    // check if auth is possible
412    if should_send_auth && auther.is_none() {
413        // auth requested and no auther available
414        // error out
415        return Err(ScallopError::ProtocolError(
416            "auth requested but no auther available".into(),
417        ));
418    }
419
420    // safe to unwrap since IX should have key by now
421    let remote_static: [u8; 32] = noise.get_remote_static().unwrap().try_into().unwrap();
422
423    let contains = auth_store.as_mut().map(|x| x.contains(&remote_static));
424    // error out if key is considered rejected
425    if contains == Some(ContainsResponse::Rejected) {
426        return Err(ScallopError::ProtocolError(
427            "remote static key rejected".into(),
428        ));
429    }
430
431    // start tracking what state should go in the stream
432    // it is expected to be set if contains returns state
433    // or verify returns state
434    let mut state = None::<AS::State>;
435
436    let should_ask_auth = contains == Some(ContainsResponse::NotFound);
437
438    // set state immediately if key was approved previously
439    if let Some(ContainsResponse::Approved(_state)) = contains {
440        state = Some(_state);
441    }
442
443    // handshake is done, switch to transport mode
444    let mut noise = noise.into_transport_mode()?;
445
446    //---- -> CLIENTFIN start ----//
447    //
448    // not part of the noise protocol, needed for optional attestations
449    //
450    // first two bytes length
451    // 0x00 for no auth request, 0x01 for auth request
452    // two bytes payload size
453    // payload
454
455    async fn send_CLIENTFIN(
456        noise: &mut impl Noiser,
457        stream: &mut (impl AsyncWrite + Unpin),
458        buf: &mut [u8],
459        noise_buf: &mut [u8],
460        payload: &[u8],
461        should_ask_auth: bool,
462    ) -> Result<(), ScallopError> {
463        // assemble message for encryption
464        noise_buf[0] = if !should_ask_auth { 0 } else { 1 };
465        // safe to cast since range has been checked above
466        noise_buf[1..3].copy_from_slice(&(payload.len() as u16).to_be_bytes());
467        noise_buf[3..3 + payload.len()].copy_from_slice(payload);
468
469        // encode and send handshake message
470        noise_write(noise, stream, &noise_buf[0..payload.len() + 3], buf, 0).await?;
471
472        Ok(())
473    }
474
475    if should_send_auth {
476        // safe to unwrap since it has been checked above
477        let payload = auther
478            .unwrap()
479            .new_auth()
480            .await
481            .map_err(|e| ScallopError::AuthError(format!("{e:?}")))?;
482        // check if payload is not too big
483        if payload.len() > 60000 {
484            return Err(ScallopError::ProtocolError("auth payload too big".into()));
485        }
486
487        send_CLIENTFIN(
488            &mut noise,
489            &mut stream,
490            &mut buf,
491            &mut noise_buf,
492            &payload,
493            should_ask_auth,
494        )
495        .await?;
496    } else {
497        send_CLIENTFIN(
498            &mut noise,
499            &mut stream,
500            &mut buf,
501            &mut noise_buf,
502            &[],
503            should_ask_auth,
504        )
505        .await?;
506    }
507
508    //---- -> CLIENTFIN end ----//
509
510    //---- <- SERVERFIN start ----//
511    //
512    // not part of the noise protocol, needed for optional attestations
513    //
514    // first two bytes length
515    // two bytes payload size
516    // payload
517
518    if should_ask_auth {
519        // read and handle handshake message
520        let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
521
522        // should have at least 2 size
523        if len < 2 {
524            return Err(ScallopError::ProtocolError(
525                "invalid SERVERFIN length".into(),
526            ));
527        }
528
529        // payload size should match
530        if u16::from_be_bytes([noise_buf[0], noise_buf[1]]) as usize != len - 2 {
531            return Err(ScallopError::ProtocolError(
532                "invalid SERVERFIN payload length".into(),
533            ));
534        }
535
536        // verify
537        let Some(_state) = auth_store
538            .as_mut()
539            .unwrap()
540            .verify(&noise_buf[2..len], remote_static)
541        else {
542            return Err(ScallopError::ProtocolError("invalid attestation".into()));
543        };
544
545        // set state
546        state = Some(_state)
547    }
548
549    //---- <- SERVERFIN end ----//
550
551    Ok(ScallopStream {
552        noise,
553        stream,
554        // initialize with 2 sized buffer to read length
555        rbuf: vec![0u8; 2].into_boxed_slice(),
556        pending: 2,
557        mode: ReadMode::Length,
558        read_start: 0,
559        read_end: 0,
560        wbuf: vec![].into_boxed_slice(),
561        write_start: 0,
562        write_end: 0,
563        state,
564    })
565}
566
567#[allow(non_snake_case)]
568pub async fn new_server_async_Noise_IX_25519_ChaChaPoly_BLAKE2b<
569    Base: AsyncWrite + AsyncRead + Unpin,
570    AS: ScallopAuthStore,
571>(
572    mut stream: Base,
573    secret: &[u8; 32],
574    // will not auth remote if None
575    mut auth_store: Option<AS>,
576    // will not respond to auth requests if None
577    auther: Option<impl ScallopAuther>,
578) -> Result<ScallopStream<Base, AS::State>, ScallopError> {
579    let mut buf = vec![0u8; 65000].into_boxed_slice();
580    let mut noise_buf = vec![0u8; 65000].into_boxed_slice();
581
582    let prologue = b"NoiseSocketInit1\x00\x00";
583
584    let mut noise = Builder::new(
585        "Noise_IX_25519_ChaChaPoly_BLAKE2b"
586            .parse()
587            .map_err(ScallopError::InitFailed)?,
588    )
589    .local_private_key(secret)
590    .prologue(prologue)
591    .build_responder()
592    .map_err(ScallopError::InitFailed)?;
593
594    //---- -> e, s start ----//
595
596    // read negotiation length
597    let len = stream.read_u16().await?;
598
599    // length should be zero
600    if len != 0 {
601        return Err(ScallopError::ProtocolError(
602            "non zero first negotiation length".into(),
603        ));
604    }
605
606    // read and handle handshake message
607    let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
608
609    // handshake payload should be empty
610    if len != 0 {
611        return Err(ScallopError::ProtocolError(
612            "non zero first handshake payload".into(),
613        ));
614    }
615
616    //---- -> e, s end ----//
617
618    //---- <- e, ee, se, s, es start ----//
619
620    // negotiation length
621    buf[0..2].copy_from_slice(&0u16.to_be_bytes());
622
623    // request auth if auth_store is available
624    // and static key is not found in the auth store
625    let remote_static: [u8; 32] = noise
626        .get_remote_static()
627        .expect("handshake should have static key by now")
628        .try_into()
629        .expect("expected 32 byte key");
630
631    let contains = auth_store.as_mut().map(|x| x.contains(&remote_static));
632    // error out if key is considered rejected
633    if contains == Some(ContainsResponse::Rejected) {
634        return Err(ScallopError::ProtocolError(
635            "remote static key rejected".into(),
636        ));
637    }
638
639    // start tracking what state should go in the stream
640    // it is expected to be set if contains returns state
641    // or verify returns state
642    let mut state = None::<AS::State>;
643
644    let should_ask_auth = contains == Some(ContainsResponse::NotFound);
645
646    // set state immediately if key was approved previously
647    if let Some(ContainsResponse::Approved(_state)) = contains {
648        state = Some(_state);
649    }
650
651    let payload = &[0u8, 1u8, if !should_ask_auth { 0u8 } else { 1u8 }];
652
653    // encode and send handshake message
654    noise_write(&mut noise, &mut stream, payload, &mut buf, 2).await?;
655
656    //---- <- e, ee, se, s, es end ----//
657
658    // handshake is done, switch to transport mode
659    let mut noise = noise.into_transport_mode()?;
660
661    //---- -> CLIENTFIN start ----//
662    //
663    // not part of the noise protocol, needed for optional attestations
664    //
665    // first two bytes length
666    // 0x00 for no auth request, 0x01 for auth request
667    // two bytes payload size
668    // payload
669
670    // read and handle handshake message
671    let len = noise_read(&mut noise, &mut stream, &mut buf, &mut noise_buf).await?;
672
673    // should have at least 3 size
674    if len < 3 {
675        return Err(ScallopError::ProtocolError(
676            "invalid CLIENTFIN length".into(),
677        ));
678    }
679
680    // payload size should match
681    if u16::from_be_bytes([noise_buf[1], noise_buf[2]]) as usize != len - 3 {
682        return Err(ScallopError::ProtocolError(
683            "invalid CLIENTFIN payload length".into(),
684        ));
685    }
686
687    // verify auth if we asked for it
688    if should_ask_auth {
689        // verify
690        let Some(_state) = auth_store
691            .as_mut()
692            .unwrap()
693            .verify(&noise_buf[3..len], remote_static)
694        else {
695            return Err(ScallopError::ProtocolError("invalid attestation".into()));
696        };
697
698        // set state
699        state = Some(_state)
700    }
701
702    // auth request should be 0 or 1
703    if noise_buf[0] > 1 {
704        return Err(ScallopError::ProtocolError(
705            "invalid auth request in third payload".into(),
706        ));
707    }
708
709    let should_send_auth = noise_buf[0] == 1;
710
711    //---- -> CLIENTFIN end ----//
712
713    // check if auth is possible
714    if should_send_auth && auther.is_none() {
715        // auth requested and no auther available
716        // error out
717        return Err(ScallopError::ProtocolError(
718            "auth requested but no auther available".into(),
719        ));
720    }
721
722    //---- <- SERVERFIN start ----//
723    //
724    // not part of the noise protocol, needed for optional attestations
725    //
726    // first two bytes length
727    // two bytes payload size
728    // payload
729
730    if should_send_auth {
731        // safe to unwrap since it has been checked above
732        let payload = auther
733            .unwrap()
734            .new_auth()
735            .await
736            .map_err(|e| ScallopError::AuthError(format!("{e:?}")))?;
737        // check if payload is not too big
738        if payload.len() > 60000 {
739            return Err(ScallopError::ProtocolError("auth payload too big".into()));
740        }
741
742        // safe to cast since range has been checked above
743        noise_buf[0..2].copy_from_slice(&(payload.len() as u16).to_be_bytes());
744        noise_buf[2..2 + payload.len()].copy_from_slice(&payload);
745
746        // encode and send handshake message
747        noise_write(
748            &mut noise,
749            &mut stream,
750            &noise_buf[0..payload.len() + 2],
751            &mut buf,
752            0,
753        )
754        .await?;
755    }
756
757    //---- <- SERVERFIN end ----//
758
759    Ok(ScallopStream {
760        noise,
761        stream,
762        // initialize with 2 sized buffer to read length
763        rbuf: vec![0u8; 2].into_boxed_slice(),
764        pending: 2,
765        mode: ReadMode::Length,
766        read_start: 0,
767        read_end: 0,
768        wbuf: vec![].into_boxed_slice(),
769        write_start: 0,
770        write_end: 0,
771        state,
772    })
773}
774
775impl<Base: AsyncWrite + AsyncRead + Unpin, State> ScallopStream<Base, State> {
776    pub fn get_remote_static(&self) -> Option<[u8; 32]> {
777        self.noise
778            .get_remote_static()
779            .map(|x| x.try_into().expect("expected 32 byte key"))
780    }
781}
782
783impl<Base: AsyncWrite + AsyncRead + Unpin, State: Unpin> AsyncRead for ScallopStream<Base, State> {
784    // IMPORTANT: Return Pending only as a direct result of base returning Pending
785    // Ensures wakers are set up correctly
786    fn poll_read(
787        self: std::pin::Pin<&mut Self>,
788        cx: &mut std::task::Context<'_>,
789        buf: &mut tokio::io::ReadBuf<'_>,
790    ) -> std::task::Poll<std::io::Result<()>> {
791        let stream = self.get_mut();
792        loop {
793            while stream.pending != 0 {
794                let base = std::pin::pin!(&mut stream.stream);
795
796                // do not have enough data, try to read more
797                let len = stream.rbuf.len();
798                let mut buf = ReadBuf::new(&mut stream.rbuf[(len - stream.pending)..]);
799                std::task::ready!(base.poll_read(cx, &mut buf))?;
800
801                // check eof
802                if buf.filled().is_empty() {
803                    return std::task::Poll::Ready(Ok(()));
804                }
805                stream.pending -= buf.filled().len();
806            }
807
808            // pending should always be 0 after this point
809
810            if stream.mode == ReadMode::Length {
811                // we have read the length
812
813                // parse length
814                let record_length = u16::from_be_bytes(stream.rbuf[0..2].try_into().unwrap());
815
816                // set up to read record
817                stream.pending = record_length.into();
818                stream.mode = ReadMode::Body;
819                stream.rbuf = vec![0u8; stream.pending].into_boxed_slice();
820            } else if stream.mode == ReadMode::Body {
821                // we have the data
822
823                // process as noise message
824                let len = stream
825                    .noise
826                    .read_message(&stream.rbuf.clone(), &mut stream.rbuf)
827                    .map_err(std::io::Error::other)?;
828
829                // set up to send body upstream
830                stream.read_start = 0;
831                stream.read_end = len;
832                stream.mode = ReadMode::Read;
833            } else {
834                if buf.remaining() < stream.read_end - stream.read_start {
835                    // can transmit only partial
836                    let read_start = stream.read_start;
837                    stream.read_start += buf.remaining();
838                    let read_end = read_start + buf.remaining();
839                    buf.put_slice(&stream.rbuf[read_start..read_end]);
840                } else {
841                    // can transmit full
842                    buf.put_slice(&stream.rbuf[stream.read_start..stream.read_end]);
843
844                    stream.rbuf = vec![0u8; 2].into_boxed_slice();
845                    stream.pending = 2;
846                    stream.mode = ReadMode::Length;
847                }
848                return std::task::Poll::Ready(Ok(()));
849            }
850        }
851    }
852}
853
854impl<Base: AsyncWrite + AsyncRead + Unpin, State: Unpin> AsyncWrite for ScallopStream<Base, State> {
855    // IMPORTANT: Return Pending only as a direct result of base returning Pending
856    // Ensures wakers are set up correctly
857    fn poll_write(
858        mut self: std::pin::Pin<&mut Self>,
859        cx: &mut std::task::Context<'_>,
860        buf: &[u8],
861    ) -> std::task::Poll<Result<usize, std::io::Error>> {
862        // flush existing data first
863        std::task::ready!(self.as_mut().poll_flush(cx))?;
864
865        let mut stream = self.as_mut();
866
867        // construct new buf
868        // up to 64000 bytes at once
869        let len = std::cmp::min(buf.len(), 64000) as u16;
870        let mut new_buf = vec![0u8; len as usize + 1000].into_boxed_slice();
871
872        // set noise message
873        let noise_len = stream
874            .noise
875            .write_message(&buf[0..len as usize], &mut new_buf[2..])
876            .map_err(std::io::Error::other)?;
877
878        // set length
879        new_buf[0..2].copy_from_slice(&(noise_len as u16).to_be_bytes());
880
881        // queue up new buf
882        stream.wbuf = new_buf;
883        stream.write_start = 0;
884        stream.write_end = noise_len + 2;
885
886        // TODO: Should we flush here so it does not need to be called in the common case?
887        // How do we implement this?
888        //
889        // Not sure how the semantics will play out though.
890        //
891        // Happy path looks great.
892        // We make a call to poll_flush, it returns Ready and we return Ready with length.
893        //
894        // But what if it returns Pending?
895        // If we return Pending, the caller will assume nothing was sent.
896        // If we return Ready, polL_flush has potentially set up wakers.
897        // What happens on repeated calls? Unsure if it is supposed to be idempotent.
898
899        std::task::Poll::Ready(Ok(len as usize))
900    }
901
902    // IMPORTANT: Return Pending only as a direct result of base returning Pending
903    // Ensures wakers are set up correctly
904    fn poll_flush(
905        self: std::pin::Pin<&mut Self>,
906        cx: &mut std::task::Context<'_>,
907    ) -> std::task::Poll<Result<(), std::io::Error>> {
908        let stream = self.get_mut();
909
910        while stream.write_start != stream.write_end {
911            let base = std::pin::pin!(&mut stream.stream);
912
913            // try to send existing messages first
914            let size = std::task::ready!(
915                base.poll_write(cx, &stream.wbuf[stream.write_start..stream.write_end])
916            )?;
917            stream.write_start += size;
918        }
919
920        // flush data after write since base could be buffered
921        let base = std::pin::pin!(&mut stream.stream);
922        base.poll_flush(cx)
923    }
924
925    // IMPORTANT: Return Pending only as a direct result of base returning Pending
926    // Ensures wakers are set up correctly
927    //
928    // Shutdown is supposed to be graceful
929    //
930    // From the tokio docs:
931    // Invocation of a shutdown implies an invocation of flush.
932    // Once this method returns Ready it implies that a flush successfully happened
933    // before the shutdown happened. That is, callers don’t need to call flush before
934    // calling shutdown. They can rely that by calling shutdown any pending buffered
935    // data will be written out.
936    fn poll_shutdown(
937        mut self: std::pin::Pin<&mut Self>,
938        cx: &mut std::task::Context<'_>,
939    ) -> std::task::Poll<Result<(), std::io::Error>> {
940        // flush data for graceful shutdowns
941        std::task::ready!(self.as_mut().poll_flush(cx))?;
942
943        let stream = self.get_mut();
944        let base = std::pin::pin!(&mut stream.stream);
945
946        base.poll_shutdown(cx)
947    }
948}