Skip to main content

soe_protocol/
session.rs

1//! The session handler: an I/O-agnostic state machine driving a single SOE session.
2//!
3//! This ports the reference `SoeProtocolHandler`, restructured as a pure state
4//! machine. Rather than owning a socket, the handler accepts incoming datagrams via
5//! [`SoeSession::process_incoming`], surfaces datagrams to be sent via
6//! [`SoeSession::take_outgoing`], and surfaces received application data via
7//! [`SoeSession::take_received`]. Time is supplied by the caller as [`Instant`].
8//!
9//! The handler negotiates a session (contextless [`SessionRequest`]/
10//! [`SessionResponse`] exchange), then dispatches contextual packets: routing
11//! reliable data to the input channel, acknowledgements to the output channel, and
12//! handling heartbeats and disconnects.
13
14use std::time::{Duration, Instant};
15
16use bytes::Bytes;
17
18use crate::channel::{
19    InputConfig, OutputConfig, ReliableDataInputChannel, ReliableDataOutputChannel,
20};
21use crate::constants::{
22    CRC_LENGTH, DEFAULT_SESSION_HEARTBEAT_AFTER, DEFAULT_SESSION_INACTIVITY_TIMEOUT,
23    DEFAULT_UDP_LENGTH, SOE_PROTOCOL_VERSION,
24};
25use crate::crc32::Crc32;
26use crate::io::{BinaryReader, BinaryWriter};
27use crate::packet_utils::{ValidationResult, append_crc, read_op_code, validate_packet};
28use crate::packets::{Acknowledge, AcknowledgeAll, Disconnect, SessionRequest, SessionResponse};
29use crate::protocol::{DisconnectReason, OpCode};
30use crate::rc4::Rc4KeyState;
31use crate::varint::multi_packet;
32use crate::zlib;
33
34const OP_CODE_SIZE: usize = 2;
35/// The default ACK wait used by the output channel.
36const DEFAULT_ACK_WAIT: Duration = Duration::from_millis(500);
37
38/// The mode a [`SoeSession`] operates in.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum SessionMode {
41    /// The handler initiates the session (sends the [`SessionRequest`]).
42    Client,
43    /// The handler accepts a session (responds to a [`SessionRequest`]).
44    Server,
45}
46
47/// The lifecycle state of a [`SoeSession`].
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum SessionState {
50    /// The session is being negotiated.
51    Negotiating,
52    /// The session is established and exchanging data.
53    Running,
54    /// The session has terminated.
55    Terminated,
56}
57
58/// An event surfaced by a [`SoeSession`].
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum SessionEvent {
61    /// The session has been established and is ready to exchange data.
62    Opened,
63    /// The session has terminated for the given reason.
64    Closed(DisconnectReason),
65}
66
67/// Parameters controlling a session. Mutated during negotiation as the two parties
68/// agree on connection details.
69#[derive(Debug, Clone)]
70pub struct SessionParameters {
71    /// The application protocol being proxied (must match between the two parties).
72    pub application_protocol: String,
73    /// The maximum UDP payload length this party can receive.
74    pub udp_length: u32,
75    /// The maximum UDP payload length the remote party can receive.
76    pub remote_udp_length: u32,
77    /// The seed used to compute packet CRCs (agreed during negotiation).
78    pub crc_seed: u32,
79    /// The number of bytes used to store a packet CRC (0..=4).
80    pub crc_length: u8,
81    /// Whether contextual packets may be compressed.
82    pub is_compression_enabled: bool,
83    /// The maximum number of incoming reliable data packets that may be queued.
84    pub max_queued_incoming_reliable: u16,
85    /// The maximum number of outgoing reliable data packets in flight at once.
86    pub max_queued_outgoing_reliable: u16,
87    /// The acknowledgement window used by the input channel.
88    pub data_ack_window: u16,
89    /// The interval after which to send a heartbeat (client only). `ZERO` disables.
90    pub heartbeat_after: Duration,
91    /// The interval after which to terminate an inactive session. `ZERO` disables.
92    pub inactivity_timeout: Duration,
93    /// Whether every incoming reliable data packet is acknowledged individually.
94    pub acknowledge_all_data: bool,
95    /// The maximum delay before acknowledging incoming reliable data sequences.
96    pub max_ack_delay: Duration,
97}
98
99impl Default for SessionParameters {
100    fn default() -> Self {
101        Self {
102            application_protocol: String::new(),
103            udp_length: DEFAULT_UDP_LENGTH,
104            remote_udp_length: DEFAULT_UDP_LENGTH,
105            crc_seed: 0,
106            crc_length: CRC_LENGTH,
107            is_compression_enabled: false,
108            max_queued_incoming_reliable: 256,
109            max_queued_outgoing_reliable: 196,
110            data_ack_window: 32,
111            heartbeat_after: DEFAULT_SESSION_HEARTBEAT_AFTER,
112            inactivity_timeout: DEFAULT_SESSION_INACTIVITY_TIMEOUT,
113            acknowledge_all_data: false,
114            max_ack_delay: Duration::from_millis(2),
115        }
116    }
117}
118
119/// Application-level parameters: the optional encryption key state.
120#[derive(Debug, Clone, Default)]
121pub struct ApplicationParameters {
122    /// The RC4 key state used to (en/de)crypt application data, if encryption is
123    /// enabled.
124    pub encryption_key_state: Option<Rc4KeyState>,
125}
126
127/// A small linear-congruential generator used to produce session IDs and CRC seeds.
128#[derive(Debug)]
129struct Lcg {
130    state: u64,
131}
132
133impl Lcg {
134    fn new(seed: u64) -> Self {
135        Self {
136            state: seed ^ 0x9E37_79B9_7F4A_7C15,
137        }
138    }
139
140    fn next_u32(&mut self) -> u32 {
141        self.state = self
142            .state
143            .wrapping_mul(6_364_136_223_846_793_005)
144            .wrapping_add(1_442_695_040_888_963_407);
145        (self.state >> 32) as u32
146    }
147}
148
149/// an I/O-agnostic handler for a single SOE protocol session.
150#[derive(Debug)]
151pub struct SoeSession {
152    mode: SessionMode,
153    state: SessionState,
154    params: SessionParameters,
155
156    input: ReliableDataInputChannel,
157    output: ReliableDataOutputChannel,
158
159    session_id: u32,
160    termination_reason: DisconnectReason,
161    terminated_by_remote: bool,
162    open_session_on_next_packet: bool,
163    last_received: Instant,
164
165    rng: Lcg,
166
167    outgoing: Vec<Bytes>,
168    received: Vec<Bytes>,
169    events: Vec<SessionEvent>,
170}
171
172impl SoeSession {
173    /// Creates a new session handler in the [`SessionState::Negotiating`] state.
174    ///
175    /// `rng_seed` seeds the generator used for the session ID (client) and CRC seed
176    /// (server); pass a fixed value for deterministic behaviour, or entropy for real
177    /// sessions.
178    pub fn new(
179        mode: SessionMode,
180        params: SessionParameters,
181        app: ApplicationParameters,
182        rng_seed: u64,
183        now: Instant,
184    ) -> Self {
185        let input = ReliableDataInputChannel::new(
186            InputConfig {
187                max_queued_incoming: params.max_queued_incoming_reliable,
188                acknowledge_all_data: params.acknowledge_all_data,
189                data_ack_window: params.data_ack_window,
190                max_ack_delay: params.max_ack_delay,
191            },
192            app.encryption_key_state.clone(),
193            now,
194        );
195
196        let mut output = ReliableDataOutputChannel::new(
197            OutputConfig {
198                max_data_length: Self::max_data_length(&params),
199                max_queued_outgoing: params.max_queued_outgoing_reliable as usize,
200                ack_wait: DEFAULT_ACK_WAIT,
201            },
202            app.encryption_key_state.clone(),
203            now,
204        );
205        output.set_max_data_length(Self::max_data_length(&params));
206
207        Self {
208            mode,
209            state: SessionState::Negotiating,
210            params,
211            input,
212            output,
213            session_id: 0,
214            termination_reason: DisconnectReason::None,
215            terminated_by_remote: false,
216            open_session_on_next_packet: false,
217            last_received: now,
218            rng: Lcg::new(rng_seed),
219            outgoing: Vec::new(),
220            received: Vec::new(),
221            events: Vec::new(),
222        }
223    }
224
225    /// Returns the current session state.
226    pub fn state(&self) -> SessionState {
227        self.state
228    }
229
230    /// Returns the session mode.
231    pub fn mode(&self) -> SessionMode {
232        self.mode
233    }
234
235    /// Returns the negotiated session ID.
236    pub fn session_id(&self) -> u32 {
237        self.session_id
238    }
239
240    /// Returns the negotiated CRC seed (meaningful once running).
241    pub fn crc_seed(&self) -> u32 {
242        self.params.crc_seed
243    }
244
245    /// Returns the reason the session terminated (meaningful once terminated).
246    pub fn termination_reason(&self) -> DisconnectReason {
247        self.termination_reason
248    }
249
250    /// Returns whether the termination was initiated by the remote party.
251    pub fn terminated_by_remote(&self) -> bool {
252        self.terminated_by_remote
253    }
254
255    /// Drains datagrams that the caller should send to the remote.
256    pub fn take_outgoing(&mut self) -> Vec<Bytes> {
257        std::mem::take(&mut self.outgoing)
258    }
259
260    /// Drains application data received from the remote.
261    pub fn take_received(&mut self) -> Vec<Bytes> {
262        std::mem::take(&mut self.received)
263    }
264
265    /// Drains session lifecycle events.
266    pub fn take_events(&mut self) -> Vec<SessionEvent> {
267        std::mem::take(&mut self.events)
268    }
269
270    fn max_data_length(params: &SessionParameters) -> usize {
271        params.udp_length as usize
272            - OP_CODE_SIZE
273            - params.is_compression_enabled as usize
274            - params.crc_length as usize
275    }
276
277    /// Sends a [`SessionRequest`] to begin negotiation. Only valid in client mode
278    /// while negotiating.
279    pub fn send_session_request(&mut self) {
280        if self.state != SessionState::Negotiating || self.mode != SessionMode::Client {
281            return;
282        }
283
284        let id = self.rng.next_u32();
285        self.session_id = id;
286        let request = SessionRequest {
287            soe_protocol_version: SOE_PROTOCOL_VERSION,
288            session_id: id,
289            udp_length: self.params.udp_length,
290            application_protocol: self.params.application_protocol.clone(),
291        };
292
293        let mut buf = vec![0u8; request.size()];
294        let n = request.serialize(&mut buf).expect("session request buffer");
295        buf.truncate(n);
296        self.outgoing.push(Bytes::from(buf));
297    }
298
299    /// Enqueues application data to be sent reliably. Returns `false` if the session
300    /// is not running.
301    #[must_use = "a false return means the data was dropped because the session is not running"]
302    pub fn enqueue_data(&mut self, data: &[u8]) -> bool {
303        if self.state != SessionState::Running {
304            return false;
305        }
306        self.output.enqueue_data(data);
307        true
308    }
309
310    /// Terminates the session, optionally notifying the remote.
311    pub fn terminate(&mut self, reason: DisconnectReason, notify_remote: bool, now: Instant) {
312        self.terminate_inner(reason, notify_remote, false, now);
313    }
314
315    /// Processes a single incoming datagram from the remote.
316    pub fn process_incoming(&mut self, datagram: Bytes, now: Instant) {
317        if self.state == SessionState::Terminated {
318            return;
319        }
320
321        let crc = Crc32::new(self.params.crc_seed);
322        let (result, op) = validate_packet(
323            &datagram,
324            &crc,
325            self.params.crc_length,
326            self.params.is_compression_enabled,
327        );
328
329        if result != ValidationResult::Valid {
330            self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
331            return;
332        }
333        let op = op.expect("valid packet has an op code");
334
335        if self.open_session_on_next_packet {
336            self.events.push(SessionEvent::Opened);
337            self.open_session_on_next_packet = false;
338        }
339
340        // Set after validation, as a primitive guard against a stream of corrupt
341        // packets keeping a session alive.
342        self.last_received = now;
343
344        let body = datagram.slice(OP_CODE_SIZE..);
345        if op.is_contextless() {
346            self.handle_contextless(op, &body, now);
347        } else {
348            let crc_length = self.params.crc_length as usize;
349            let body = body.slice(..body.len() - crc_length);
350            self.handle_contextual(op, body, now);
351        }
352
353        self.flush_channels(now);
354    }
355
356    /// Runs a single tick of the session: heartbeats, inactivity timeout, and the
357    /// reliable data channels.
358    pub fn run_tick(&mut self, now: Instant) {
359        if self.state == SessionState::Terminated {
360            return;
361        }
362
363        self.send_heartbeat_if_required(now);
364
365        if !self.params.inactivity_timeout.is_zero()
366            && now.duration_since(self.last_received) > self.params.inactivity_timeout
367        {
368            self.terminate_inner(DisconnectReason::Timeout, false, false, now);
369            return;
370        }
371
372        self.input.run_tick(now);
373        self.output.run_tick(now);
374        self.flush_channels(now);
375    }
376
377    fn handle_contextless(&mut self, op: OpCode, body: &[u8], now: Instant) {
378        match op {
379            OpCode::SessionRequest => self.handle_session_request(body, now),
380            OpCode::SessionResponse => self.handle_session_response(body, now),
381            OpCode::UnknownSender => {
382                self.terminate_inner(DisconnectReason::UnreachableConnection, false, false, now);
383            }
384            // Remap requests are the responsibility of a connection manager (Phase 7).
385            _ => {}
386        }
387    }
388
389    fn handle_session_request(&mut self, body: &[u8], now: Instant) {
390        if self.mode == SessionMode::Client {
391            self.terminate_inner(DisconnectReason::ConnectingToSelf, false, false, now);
392            return;
393        }
394
395        let request = match SessionRequest::deserialize(body, false) {
396            Ok(r) => r,
397            Err(_) => {
398                self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
399                return;
400            }
401        };
402
403        self.params.remote_udp_length = request.udp_length;
404        self.session_id = request.session_id;
405
406        if self.state != SessionState::Negotiating {
407            self.terminate_inner(DisconnectReason::ConnectError, true, false, now);
408            return;
409        }
410
411        let protocols_match = request.soe_protocol_version == SOE_PROTOCOL_VERSION
412            && request.application_protocol == self.params.application_protocol;
413        if !protocols_match {
414            self.terminate_inner(DisconnectReason::ProtocolMismatch, true, false, now);
415            return;
416        }
417
418        self.params.crc_length = CRC_LENGTH;
419        self.params.crc_seed = self.rng.next_u32();
420        self.output
421            .set_max_data_length(Self::max_data_length(&self.params));
422
423        let response = SessionResponse {
424            session_id: self.session_id,
425            crc_seed: self.params.crc_seed,
426            crc_length: self.params.crc_length,
427            is_compression_enabled: self.params.is_compression_enabled,
428            unknown_value_1: 0,
429            udp_length: self.params.udp_length,
430            soe_protocol_version: SOE_PROTOCOL_VERSION,
431        };
432
433        let mut buf = [0u8; SessionResponse::SIZE];
434        let n = response
435            .serialize(&mut buf)
436            .expect("session response buffer");
437        self.outgoing.push(Bytes::copy_from_slice(&buf[..n]));
438
439        self.state = SessionState::Running;
440        self.open_session_on_next_packet = true;
441    }
442
443    fn handle_session_response(&mut self, body: &[u8], now: Instant) {
444        if self.mode == SessionMode::Server {
445            self.terminate_inner(DisconnectReason::ConnectingToSelf, false, false, now);
446            return;
447        }
448
449        let response = match SessionResponse::deserialize(body, false) {
450            Ok(r) => r,
451            Err(_) => {
452                self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
453                return;
454            }
455        };
456
457        if self.state != SessionState::Negotiating {
458            self.terminate_inner(DisconnectReason::ConnectError, true, false, now);
459            return;
460        }
461
462        if response.soe_protocol_version != SOE_PROTOCOL_VERSION {
463            self.terminate_inner(DisconnectReason::ProtocolMismatch, true, false, now);
464            return;
465        }
466
467        self.params.remote_udp_length = response.udp_length;
468        self.params.crc_length = response.crc_length;
469        self.params.crc_seed = response.crc_seed;
470        self.params.is_compression_enabled = response.is_compression_enabled;
471        self.session_id = response.session_id;
472        self.output
473            .set_max_data_length(Self::max_data_length(&self.params));
474
475        self.state = SessionState::Running;
476        self.events.push(SessionEvent::Opened);
477    }
478
479    fn handle_contextual(&mut self, op: OpCode, body: Bytes, now: Instant) {
480        let body = if self.params.is_compression_enabled {
481            if body.is_empty() {
482                return;
483            }
484            let is_compressed = body[0] > 0;
485            let rest = body.slice(1..);
486            if is_compressed {
487                match zlib::inflate(&rest) {
488                    Ok(d) => Bytes::from(d),
489                    Err(_) => {
490                        self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
491                        return;
492                    }
493                }
494            } else {
495                rest
496            }
497        } else {
498            body
499        };
500
501        self.handle_contextual_inner(op, body, now);
502    }
503
504    fn handle_contextual_inner(&mut self, op: OpCode, body: Bytes, now: Instant) {
505        match op {
506            OpCode::MultiPacket => {
507                let mut offset = 0;
508                while offset < body.len() {
509                    let mut reader = BinaryReader::new(&body[offset..]);
510                    let len = match multi_packet::read(&mut reader) {
511                        Ok(l) => l as usize,
512                        Err(_) => {
513                            self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
514                            return;
515                        }
516                    };
517                    // Advance past the length varint by however many bytes it used.
518                    offset += reader.offset();
519
520                    if len < OP_CODE_SIZE || len > body.len() - offset {
521                        self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
522                        return;
523                    }
524
525                    let sub = body.slice(offset..offset + len);
526                    let sub_op = match read_op_code(&sub) {
527                        Some(o) => o,
528                        None => {
529                            self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
530                            return;
531                        }
532                    };
533                    self.handle_contextual_inner(sub_op, sub.slice(OP_CODE_SIZE..), now);
534                    offset += len;
535
536                    // A sub-packet may have terminated the session (e.g. a corrupt
537                    // fragment or an embedded Disconnect). Stop draining the bundle
538                    // rather than processing data on a dead session.
539                    if self.state == SessionState::Terminated {
540                        return;
541                    }
542                }
543            }
544            OpCode::Disconnect => {
545                if let Ok(disconnect) = Disconnect::deserialize(&body) {
546                    self.terminate_inner(disconnect.reason, false, true, now);
547                }
548            }
549            OpCode::Heartbeat if self.mode == SessionMode::Server => {
550                let dg = self.frame_contextual(OpCode::Heartbeat, &[]);
551                self.outgoing.push(dg);
552            }
553            OpCode::ReliableData => {
554                let outcome = self.input.handle_reliable_data(body, now);
555                if outcome.is_err() {
556                    self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
557                }
558            }
559            OpCode::ReliableDataFragment => {
560                let outcome = self.input.handle_reliable_data_fragment(body, now);
561                if outcome.is_err() {
562                    self.terminate_inner(DisconnectReason::CorruptPacket, true, false, now);
563                }
564            }
565            OpCode::Acknowledge => {
566                if let Ok(ack) = Acknowledge::deserialize(&body) {
567                    self.output.notify_of_acknowledge(ack.sequence, now);
568                }
569            }
570            OpCode::AcknowledgeAll => {
571                if let Ok(ack) = AcknowledgeAll::deserialize(&body) {
572                    self.output.notify_of_acknowledge_all(ack.sequence, now);
573                }
574            }
575            _ => {}
576        }
577    }
578
579    fn send_heartbeat_if_required(&mut self, now: Instant) {
580        let may_send = self.mode == SessionMode::Client
581            && self.state == SessionState::Running
582            && !self.params.heartbeat_after.is_zero()
583            && now.duration_since(self.last_received) > self.params.heartbeat_after;
584
585        if may_send {
586            let dg = self.frame_contextual(OpCode::Heartbeat, &[]);
587            self.outgoing.push(dg);
588        }
589    }
590
591    fn flush_channels(&mut self, _now: Instant) {
592        for ack in self.input.take_outgoing() {
593            let payload = ack.sequence.to_be_bytes();
594            let dg = self.frame_contextual(ack.op_code, &payload);
595            self.outgoing.push(dg);
596        }
597
598        for packet in self.output.take_outgoing() {
599            let dg = self.frame_contextual(packet.op_code, &packet.payload);
600            self.outgoing.push(dg);
601        }
602
603        for data in self.input.take_app_data() {
604            self.received.push(data);
605        }
606    }
607
608    /// Frames a contextual packet: OP code, optional compression flag, payload, and
609    /// CRC.
610    fn frame_contextual(&self, op: OpCode, payload: &[u8]) -> Bytes {
611        let compression = self.params.is_compression_enabled as usize;
612        let crc_length = self.params.crc_length as usize;
613        let capacity = OP_CODE_SIZE + compression + payload.len() + crc_length;
614
615        let mut buf = vec![0u8; capacity];
616        let written = {
617            let mut w = BinaryWriter::new(&mut buf);
618            w.write_u16(op.as_u16()).expect("op code");
619            if self.params.is_compression_enabled {
620                w.write_bool(false).expect("compression flag");
621            }
622            w.write_bytes(payload).expect("payload");
623            w.offset()
624        };
625
626        let crc = Crc32::new(self.params.crc_seed);
627        let total = append_crc(&mut buf, written, &crc, self.params.crc_length).expect("crc");
628        buf.truncate(total);
629        Bytes::from(buf)
630    }
631
632    fn terminate_inner(
633        &mut self,
634        reason: DisconnectReason,
635        notify_remote: bool,
636        terminated_by_remote: bool,
637        now: Instant,
638    ) {
639        if self.state == SessionState::Terminated {
640            return;
641        }
642
643        // Naive flush of the output channel.
644        self.output.run_tick(now);
645        self.flush_channels(now);
646        self.termination_reason = reason;
647
648        if notify_remote && self.state == SessionState::Running {
649            let disconnect = Disconnect::new(self.session_id, reason);
650            let mut payload = [0u8; Disconnect::SIZE];
651            let n = disconnect
652                .serialize(&mut payload)
653                .expect("disconnect buffer");
654            let dg = self.frame_contextual(OpCode::Disconnect, &payload[..n]);
655            self.outgoing.push(dg);
656        }
657
658        self.state = SessionState::Terminated;
659        self.terminated_by_remote = terminated_by_remote;
660        self.events.push(SessionEvent::Closed(reason));
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    fn params(protocol: &str) -> SessionParameters {
669        SessionParameters {
670            application_protocol: protocol.to_owned(),
671            // Keep the window small so fragmentation/windowing is exercised.
672            max_queued_incoming_reliable: 32,
673            max_queued_outgoing_reliable: 32,
674            // Disable heartbeats/timeouts for deterministic tests.
675            heartbeat_after: Duration::ZERO,
676            inactivity_timeout: Duration::ZERO,
677            ..SessionParameters::default()
678        }
679    }
680
681    /// Drives a full negotiation handshake, returning the two running sessions.
682    fn negotiate(now: Instant) -> (SoeSession, SoeSession) {
683        let mut client = SoeSession::new(
684            SessionMode::Client,
685            params("TestProtocol"),
686            ApplicationParameters::default(),
687            1,
688            now,
689        );
690        let mut server = SoeSession::new(
691            SessionMode::Server,
692            params("TestProtocol"),
693            ApplicationParameters::default(),
694            2,
695            now,
696        );
697
698        client.send_session_request();
699        pump(&mut client, &mut server, now);
700
701        (client, server)
702    }
703
704    /// Moves all queued datagrams between the two sessions until neither has any
705    /// more to send.
706    fn pump(a: &mut SoeSession, b: &mut SoeSession, now: Instant) {
707        loop {
708            let from_a = a.take_outgoing();
709            let from_b = b.take_outgoing();
710            if from_a.is_empty() && from_b.is_empty() {
711                break;
712            }
713            for dg in from_a {
714                b.process_incoming(dg, now);
715            }
716            for dg in from_b {
717                a.process_incoming(dg, now);
718            }
719        }
720    }
721
722    fn generate(size: usize) -> Vec<u8> {
723        let mut state: u32 = 0x1234_5678 ^ size as u32;
724        (0..size)
725            .map(|_| {
726                state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
727                (state >> 24) as u8
728            })
729            .collect()
730    }
731
732    #[test]
733    fn negotiation_establishes_running_session() {
734        let now = Instant::now();
735        let (mut client, mut server) = negotiate(now);
736
737        assert_eq!(client.state(), SessionState::Running);
738        assert_eq!(server.state(), SessionState::Running);
739        assert_eq!(client.session_id(), server.session_id());
740        // Both parties agreed on the server's CRC seed.
741        assert_ne!(server.params.crc_seed, 0);
742        assert_eq!(client.params.crc_seed, server.params.crc_seed);
743
744        assert!(client.take_events().contains(&SessionEvent::Opened));
745        // The server only opens the session once it receives its first packet after
746        // sending the response (matching the C# reference). Drive one more packet.
747        assert!(client.enqueue_data(b"hi"));
748        client.run_tick(now);
749        pump(&mut client, &mut server, now);
750        assert!(server.take_events().contains(&SessionEvent::Opened));
751    }
752
753    #[test]
754    fn protocol_mismatch_terminates() {
755        let now = Instant::now();
756        let mut client = SoeSession::new(
757            SessionMode::Client,
758            params("ClientProtocol"),
759            ApplicationParameters::default(),
760            1,
761            now,
762        );
763        let mut server = SoeSession::new(
764            SessionMode::Server,
765            params("ServerProtocol"),
766            ApplicationParameters::default(),
767            2,
768            now,
769        );
770
771        client.send_session_request();
772        pump(&mut client, &mut server, now);
773
774        assert_eq!(server.state(), SessionState::Terminated);
775        assert_eq!(
776            server.termination_reason(),
777            DisconnectReason::ProtocolMismatch
778        );
779        // The server rejects before a CRC seed is agreed, so it cannot send a valid
780        // contextual Disconnect; the client stays in negotiation and would later time
781        // out (matching the C# reference, which only notifies the remote when Running).
782        assert_eq!(client.state(), SessionState::Negotiating);
783    }
784
785    #[test]
786    fn round_trips_small_and_large_data() {
787        let now = Instant::now();
788        let (mut client, mut server) = negotiate(now);
789
790        let small = generate(5);
791        let large = generate(2000); // forces fragmentation
792
793        assert!(client.enqueue_data(&small));
794        assert!(client.enqueue_data(&large));
795
796        client.run_tick(now);
797        pump(&mut client, &mut server, now);
798
799        let received = server.take_received();
800        assert_eq!(received.len(), 2);
801        assert_eq!(&received[0][..], &small[..]);
802        assert_eq!(&received[1][..], &large[..]);
803    }
804
805    #[test]
806    fn round_trips_data_both_directions() {
807        let now = Instant::now();
808        let (mut client, mut server) = negotiate(now);
809
810        let to_server = generate(1500);
811        let to_client = generate(800);
812
813        assert!(client.enqueue_data(&to_server));
814        assert!(server.enqueue_data(&to_client));
815        client.run_tick(now);
816        server.run_tick(now);
817        pump(&mut client, &mut server, now);
818
819        assert_eq!(&server.take_received()[0][..], &to_server[..]);
820        assert_eq!(&client.take_received()[0][..], &to_client[..]);
821    }
822
823    /// A `MultiPacket` bundle whose first sub-packet corrupts the session must not
824    /// have its remaining sub-packets processed: once a sub-packet terminates the
825    /// session, the bundle loop short-circuits rather than delivering data on a dead
826    /// session.
827    #[test]
828    fn multi_packet_stops_after_sub_packet_terminates() {
829        let now = Instant::now();
830        let (_client, mut server) = negotiate(now);
831        assert_eq!(server.state(), SessionState::Running);
832
833        // Build a MultiPacket body with two sub-packets:
834        //   1. a corrupt master ReliableDataFragment (only 2 of the required 4
835        //      total-length bytes) -> terminates the session as CorruptPacket;
836        //   2. an otherwise-valid ReliableData carrying "hi".
837        // Each sub-packet is `[length][op-code (2 BE)][sub-payload]`; lengths < 256
838        // encode as a single byte.
839        let mut body = Vec::new();
840
841        // Sub-packet 1: ReliableDataFragment, sequence 0, truncated length prefix.
842        let sub1 = [0x00, 0x0D, 0x00, 0x00, 0xAB, 0xCD];
843        body.push(sub1.len() as u8);
844        body.extend_from_slice(&sub1);
845
846        // Sub-packet 2: ReliableData, sequence 0, payload "hi".
847        let sub2 = [0x00, 0x09, 0x00, 0x00, b'h', b'i'];
848        body.push(sub2.len() as u8);
849        body.extend_from_slice(&sub2);
850
851        server.handle_contextual_inner(OpCode::MultiPacket, Bytes::from(body), now);
852
853        assert_eq!(server.state(), SessionState::Terminated);
854        assert_eq!(server.termination_reason(), DisconnectReason::CorruptPacket);
855        // The second sub-packet must never have reached the input channel.
856        assert!(
857            server.input.take_app_data().is_empty(),
858            "data after a terminating sub-packet was processed"
859        );
860    }
861
862    #[test]
863    fn disconnect_notifies_remote() {
864        let now = Instant::now();
865        let (mut client, mut server) = negotiate(now);
866
867        client.terminate(DisconnectReason::Application, true, now);
868        assert_eq!(client.state(), SessionState::Terminated);
869
870        pump(&mut client, &mut server, now);
871        assert_eq!(server.state(), SessionState::Terminated);
872        assert_eq!(server.termination_reason(), DisconnectReason::Application);
873        assert!(server.terminated_by_remote());
874    }
875
876    #[test]
877    fn encrypted_data_round_trips() {
878        let now = Instant::now();
879        let key = Rc4KeyState::new(&[1, 2, 3, 4, 5]);
880        let app = ApplicationParameters {
881            encryption_key_state: Some(key),
882        };
883
884        let mut client = SoeSession::new(
885            SessionMode::Client,
886            params("TestProtocol"),
887            app.clone(),
888            1,
889            now,
890        );
891        let mut server = SoeSession::new(SessionMode::Server, params("TestProtocol"), app, 2, now);
892
893        client.send_session_request();
894        pump(&mut client, &mut server, now);
895
896        let data = generate(1200);
897        assert!(client.enqueue_data(&data));
898        client.run_tick(now);
899        pump(&mut client, &mut server, now);
900
901        assert_eq!(&server.take_received()[0][..], &data[..]);
902    }
903}