Skip to main content

irtt_client/
session.rs

1use std::time::Instant;
2
3use irtt_proto::{Params, PROTOCOL_VERSION};
4
5use crate::{
6    config::{NegotiationPolicy, MAX_DSCP_CODEPOINT},
7    error::ClientError,
8    probe::{CompletedSet, PendingMap, TimedOutMap},
9};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct NegotiatedParams {
13    pub params: Params,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub(crate) enum ClientPhase {
18    Connected,
19    Open { token: u64 },
20    NoTestCompleted,
21    Closed,
22}
23
24#[derive(Debug)]
25pub(crate) struct ActiveSession {
26    pub next_wire_seq: u32,
27    pub next_logical_seq: u64,
28    pub highest_received_seq: Option<u32>,
29    pub packets_sent: u64,
30    pub start_mono: Instant,
31    pub end_mono: Option<Instant>,
32    pub next_send_at: Instant,
33    pub pending: PendingMap,
34    pub timed_out: TimedOutMap,
35    pub completed: CompletedSet,
36    pub sending_done: bool,
37}
38
39pub(crate) fn validate_negotiated_params(
40    requested: &Params,
41    returned: &Params,
42    policy: NegotiationPolicy,
43) -> Result<(), ClientError> {
44    if returned.protocol_version != PROTOCOL_VERSION {
45        return Err(ClientError::ProtocolVersionMismatch {
46            requested: PROTOCOL_VERSION,
47            received: returned.protocol_version,
48        });
49    }
50    validate_duration_restriction(requested.duration_ns, returned.duration_ns)?;
51    if returned.length < 0 {
52        return Err(ClientError::NegotiationRejected {
53            reason: "length must be non-negative".to_owned(),
54        });
55    }
56    if returned.length > requested.length {
57        return Err(ClientError::NegotiationRejected {
58            reason: "length increased".to_owned(),
59        });
60    }
61    if returned.interval_ns <= 0 {
62        return Err(ClientError::NegotiationRejected {
63            reason: "interval must be positive".to_owned(),
64        });
65    }
66    validate_dscp_restriction(returned.dscp)?;
67
68    if policy == NegotiationPolicy::Strict && returned != requested {
69        return Err(ClientError::NegotiationRejected {
70            reason: "returned params differ from requested params".to_owned(),
71        });
72    }
73    Ok(())
74}
75
76fn validate_duration_restriction(requested: i64, returned: i64) -> Result<(), ClientError> {
77    if returned < 0 {
78        return Err(ClientError::NegotiationRejected {
79            reason: "duration must be non-negative".to_owned(),
80        });
81    }
82
83    if requested > 0 && returned == 0 {
84        return Err(ClientError::NegotiationRejected {
85            reason: "server returned continuous duration for finite request".to_owned(),
86        });
87    }
88
89    if requested > 0 && returned > requested {
90        return Err(ClientError::NegotiationRejected {
91            reason: "duration increased".to_owned(),
92        });
93    }
94
95    Ok(())
96}
97
98fn validate_dscp_restriction(returned: i64) -> Result<(), ClientError> {
99    if !(0..=i64::from(MAX_DSCP_CODEPOINT)).contains(&returned) {
100        return Err(ClientError::NegotiationRejected {
101            reason: format!("dscp must be in range 0..={MAX_DSCP_CODEPOINT}"),
102        });
103    }
104
105    Ok(())
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use irtt_proto::{Clock, ReceivedStats, ServerFill, StampAt};
112
113    fn default_params() -> Params {
114        Params {
115            protocol_version: PROTOCOL_VERSION,
116            duration_ns: 3_000_000_000,
117            interval_ns: 1_000_000_000,
118            length: 256,
119            received_stats: ReceivedStats::Both,
120            stamp_at: StampAt::Both,
121            clock: Clock::Both,
122            dscp: 46,
123            server_fill: Some(ServerFill {
124                value: "rand".to_owned(),
125            }),
126        }
127    }
128
129    fn assert_rejected(requested: &Params, returned: &Params, policy: NegotiationPolicy) {
130        assert!(matches!(
131            validate_negotiated_params(requested, returned, policy),
132            Err(ClientError::NegotiationRejected { .. })
133        ));
134    }
135
136    fn rejection_reason(
137        requested: &Params,
138        returned: &Params,
139        policy: NegotiationPolicy,
140    ) -> String {
141        match validate_negotiated_params(requested, returned, policy) {
142            Err(ClientError::NegotiationRejected { reason }) => reason,
143            other => panic!("expected negotiation rejection, got {other:?}"),
144        }
145    }
146
147    #[test]
148    fn strict_rejects_changed_negotiated_fields() {
149        let requested = default_params();
150
151        let mut returned = requested.clone();
152        returned.length = 128;
153        assert_rejected(&requested, &returned, NegotiationPolicy::Strict);
154
155        let mut returned = requested.clone();
156        returned.dscp = 8;
157        assert_rejected(&requested, &returned, NegotiationPolicy::Strict);
158
159        let mut returned = requested.clone();
160        returned.received_stats = ReceivedStats::Count;
161        assert_rejected(&requested, &returned, NegotiationPolicy::Strict);
162
163        let mut returned = requested.clone();
164        returned.stamp_at = StampAt::Midpoint;
165        assert_rejected(&requested, &returned, NegotiationPolicy::Strict);
166
167        let mut returned = requested.clone();
168        returned.clock = Clock::Wall;
169        assert_rejected(&requested, &returned, NegotiationPolicy::Strict);
170
171        let mut returned = requested.clone();
172        returned.server_fill = None;
173        assert_rejected(&requested, &returned, NegotiationPolicy::Strict);
174    }
175
176    #[test]
177    fn loose_duration_negotiation_uses_run_duration_semantics() {
178        let requested = default_params();
179
180        let mut returned = requested.clone();
181        returned.duration_ns = requested.duration_ns / 2;
182        assert!(
183            validate_negotiated_params(&requested, &returned, NegotiationPolicy::Loose).is_ok()
184        );
185
186        let mut returned = requested.clone();
187        returned.duration_ns = requested.duration_ns + 1;
188        assert_eq!(
189            rejection_reason(&requested, &returned, NegotiationPolicy::Loose),
190            "duration increased"
191        );
192
193        let mut returned = requested.clone();
194        returned.duration_ns = 0;
195        assert_eq!(
196            rejection_reason(&requested, &returned, NegotiationPolicy::Loose),
197            "server returned continuous duration for finite request"
198        );
199
200        let mut continuous_requested = requested.clone();
201        continuous_requested.duration_ns = 0;
202        let mut finite_returned = continuous_requested.clone();
203        finite_returned.duration_ns = 1_000_000_000;
204        assert!(validate_negotiated_params(
205            &continuous_requested,
206            &finite_returned,
207            NegotiationPolicy::Loose
208        )
209        .is_ok());
210
211        assert_rejected(
212            &continuous_requested,
213            &finite_returned,
214            NegotiationPolicy::Strict,
215        );
216
217        assert!(validate_negotiated_params(
218            &continuous_requested,
219            &continuous_requested,
220            NegotiationPolicy::Strict
221        )
222        .is_ok());
223    }
224}