1use crate::association::Association;
2use crate::association::state::AssociationState;
3use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
4use crate::queue::reassembly_queue::{Chunks, ReassemblyQueue};
5use crate::{ErrorCauseCode, Event, Side};
6use shared::error::{Error, Result};
7
8use crate::util::{ByteSlice, BytesArray, BytesSource};
9use bytes::Bytes;
10use log::{debug, error, trace};
11use std::fmt;
12
13pub type StreamId = u16;
15
16#[derive(Debug, PartialEq, Eq)]
18pub enum StreamEvent {
19 Opened { id: StreamId },
21 Readable {
23 id: StreamId,
25 },
26 Writable {
30 id: StreamId,
32 },
33 Finished {
35 id: StreamId,
37 },
38 Stopped {
40 id: StreamId,
42 error_code: ErrorCauseCode,
44 },
45 Available,
47 BufferedAmountLow {
49 id: StreamId,
51 },
52 BufferedAmountHigh {
54 id: StreamId,
56 },
57}
58
59#[derive(Default, Debug, Copy, Clone, PartialEq)]
61pub enum ReliabilityType {
62 #[default]
64 Reliable = 0,
65 Rexmit = 1,
67 Timed = 2,
69}
70
71impl fmt::Display for ReliabilityType {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 let s = match *self {
74 ReliabilityType::Reliable => "Reliable",
75 ReliabilityType::Rexmit => "Rexmit",
76 ReliabilityType::Timed => "Timed",
77 };
78 write!(f, "{}", s)
79 }
80}
81
82impl From<u8> for ReliabilityType {
83 fn from(v: u8) -> ReliabilityType {
84 match v {
85 1 => ReliabilityType::Rexmit,
86 2 => ReliabilityType::Timed,
87 _ => ReliabilityType::Reliable,
88 }
89 }
90}
91
92pub struct Stream<'a> {
94 pub(crate) stream_identifier: StreamId,
95 pub(crate) association: &'a mut Association,
96}
97
98impl Stream<'_> {
99 pub fn read(&mut self) -> Result<Option<Chunks>> {
103 self.read_sctp()
104 }
105
106 pub fn read_sctp(&mut self) -> Result<Option<Chunks>> {
111 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier)
112 && (s.state == RecvSendState::ReadWritable || s.state == RecvSendState::Readable)
113 {
114 Ok(s.reassembly_queue.read())
115 } else {
116 Err(Error::ErrStreamClosed)
117 }
118 }
119
120 pub fn write_sctp(&mut self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
122 self.write_source(&mut ByteSlice::from_slice(p), ppi)
123 }
124
125 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
131 self.write_with_ppi(data, self.get_default_payload_type()?)
132 }
133
134 pub fn write_with_ppi(&mut self, data: &[u8], ppi: PayloadProtocolIdentifier) -> Result<usize> {
138 self.write_source(&mut ByteSlice::from_slice(data), ppi)
139 }
140
141 pub fn write_chunk(&mut self, p: &Bytes) -> Result<usize> {
143 self.write_source(
144 &mut ByteSlice::from_slice(p),
145 self.get_default_payload_type()?,
146 )
147 }
148
149 pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<usize> {
156 self.write_source(
157 &mut BytesArray::from_chunks(data),
158 self.get_default_payload_type()?,
159 )
160 }
161
162 fn write_source<B: BytesSource>(
164 &mut self,
165 source: &mut B,
166 ppi: PayloadProtocolIdentifier,
167 ) -> Result<usize> {
168 if !self.is_writable() {
169 return Err(Error::ErrStreamClosed);
170 }
171
172 if source.remaining() > self.association.max_message_size() as usize {
173 return Err(Error::ErrOutboundPacketTooLarge);
174 }
175
176 let state: AssociationState = self.association.state();
177 match state {
178 AssociationState::ShutdownSent
179 | AssociationState::ShutdownAckSent
180 | AssociationState::ShutdownPending
181 | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
182 _ => {}
183 };
184
185 let (p, _) = source.pop_chunk(self.association.max_message_size() as usize);
186
187 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
188 let (is_buffered_amount_high, chunks) = s.packetize(&p, ppi);
189
190 if is_buffered_amount_high {
191 trace!("StreamEvent::BufferedAmountHigh");
192 self.association
193 .events
194 .push_back(Event::Stream(StreamEvent::BufferedAmountHigh {
195 id: self.stream_identifier,
196 }))
197 }
198
199 self.association.send_payload_data(chunks)?;
200
201 Ok(p.len())
202 } else {
203 Err(Error::ErrStreamClosed)
204 }
205 }
206
207 pub fn is_readable(&self) -> bool {
208 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
209 s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable
210 } else {
211 false
212 }
213 }
214
215 pub fn is_writable(&self) -> bool {
216 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
217 s.state == RecvSendState::Writable || s.state == RecvSendState::ReadWritable
218 } else {
219 false
220 }
221 }
222
223 pub fn stop(&mut self) -> Result<()> {
226 let mut reset = false;
227 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
228 if s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable {
229 reset = true;
230 }
231 s.state = ((s.state as u8) & 0x2).into();
232 }
233
234 if reset {
235 self.association
238 .send_reset_request(self.stream_identifier)?;
239 }
240
241 Ok(())
242 }
243
244 pub fn finish(&mut self) -> Result<()> {
247 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
248 s.state = ((s.state as u8) & 0x1).into();
249 }
250 Ok(())
251 }
252
253 pub fn close(&mut self) -> Result<()> {
260 self.finish()?;
261 self.stop()
262 }
263
264 pub fn stream_identifier(&self) -> StreamId {
266 self.stream_identifier
267 }
268
269 pub fn set_default_payload_type(
271 &mut self,
272 default_payload_type: PayloadProtocolIdentifier,
273 ) -> Result<()> {
274 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
275 s.default_payload_type = default_payload_type;
276 Ok(())
277 } else {
278 Err(Error::ErrStreamClosed)
279 }
280 }
281
282 pub fn get_default_payload_type(&self) -> Result<PayloadProtocolIdentifier> {
284 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
285 Ok(s.default_payload_type)
286 } else {
287 Err(Error::ErrStreamClosed)
288 }
289 }
290
291 pub fn set_reliability_params(
293 &mut self,
294 unordered: bool,
295 rel_type: ReliabilityType,
296 rel_val: u32,
297 ) -> Result<()> {
298 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
299 debug!(
300 "[{}] reliability params: ordered={} type={} value={}",
301 s.side, !unordered, rel_type, rel_val
302 );
303 s.unordered = unordered;
304 s.reliability_type = rel_type;
305 s.reliability_value = rel_val;
306 Ok(())
307 } else {
308 Err(Error::ErrStreamClosed)
309 }
310 }
311
312 pub fn buffered_amount(&self) -> Result<usize> {
314 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
315 Ok(s.buffered_amount)
316 } else {
317 Err(Error::ErrStreamClosed)
318 }
319 }
320
321 pub fn buffered_amount_low_threshold(&self) -> Result<usize> {
324 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
325 Ok(s.buffered_amount_low)
326 } else {
327 Err(Error::ErrStreamClosed)
328 }
329 }
330
331 pub fn set_buffered_amount_low_threshold(&mut self, th: usize) -> Result<()> {
334 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
335 s.buffered_amount_low = th;
336 Ok(())
337 } else {
338 Err(Error::ErrStreamClosed)
339 }
340 }
341
342 pub fn buffered_amount_high_threshold(&self) -> Result<usize> {
345 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
346 Ok(s.buffered_amount_high)
347 } else {
348 Err(Error::ErrStreamClosed)
349 }
350 }
351
352 pub fn set_buffered_amount_high_threshold(&mut self, th: usize) -> Result<()> {
355 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
356 s.buffered_amount_high = th;
357 Ok(())
358 } else {
359 Err(Error::ErrStreamClosed)
360 }
361 }
362}
363
364#[derive(Default, Debug, Copy, Clone, Eq, PartialEq)]
365pub enum RecvSendState {
366 #[default]
367 Closed = 0,
368 Readable = 1,
369 Writable = 2,
370 ReadWritable = 3,
371}
372
373impl From<u8> for RecvSendState {
374 fn from(v: u8) -> Self {
375 match v {
376 1 => RecvSendState::Readable,
377 2 => RecvSendState::Writable,
378 3 => RecvSendState::ReadWritable,
379 _ => RecvSendState::Closed,
380 }
381 }
382}
383
384#[derive(Default, Debug)]
386pub struct StreamState {
387 pub(crate) side: Side,
388 pub(crate) max_payload_size: u32,
389 pub(crate) stream_identifier: StreamId,
390 pub(crate) default_payload_type: PayloadProtocolIdentifier,
391 pub(crate) reassembly_queue: ReassemblyQueue,
392 pub(crate) sequence_number: u16,
393 pub(crate) state: RecvSendState,
394 pub(crate) unordered: bool,
395 pub(crate) reliability_type: ReliabilityType,
396 pub(crate) reliability_value: u32,
397 pub(crate) buffered_amount: usize,
398 pub(crate) buffered_amount_low: usize,
399 pub(crate) buffered_amount_high: usize,
400}
401impl StreamState {
402 pub(crate) fn new(
403 side: Side,
404 stream_identifier: StreamId,
405 max_payload_size: u32,
406 default_payload_type: PayloadProtocolIdentifier,
407 ) -> Self {
408 StreamState {
409 side,
410 stream_identifier,
411 max_payload_size,
412 default_payload_type,
413 reassembly_queue: ReassemblyQueue::new(stream_identifier),
414 sequence_number: 0,
415 state: RecvSendState::ReadWritable,
416 unordered: false,
417 reliability_type: ReliabilityType::Reliable,
418 reliability_value: 0,
419 buffered_amount: 0,
420 buffered_amount_low: 0,
421 buffered_amount_high: u32::MAX as usize,
422 }
423 }
424
425 pub(crate) fn handle_data(&mut self, pd: &ChunkPayloadData) -> bool {
426 self.reassembly_queue.push(pd.clone())
427 }
428
429 pub(crate) fn handle_forward_tsn_for_ordered(&mut self, ssn: u16) {
430 if self.unordered {
431 return; }
433
434 self.reassembly_queue.forward_tsn_for_ordered(ssn);
437 }
438
439 pub(crate) fn handle_forward_tsn_for_unordered(&mut self, new_cumulative_tsn: u32) {
440 if !self.unordered {
441 return; }
443
444 self.reassembly_queue
447 .forward_tsn_for_unordered(new_cumulative_tsn);
448 }
449
450 fn packetize(
451 &mut self,
452 raw: &Bytes,
453 ppi: PayloadProtocolIdentifier,
454 ) -> (bool, Vec<ChunkPayloadData>) {
455 let mut i = 0;
456 let mut remaining = raw.len();
457
458 let unordered = ppi != PayloadProtocolIdentifier::Dcep && self.unordered;
462
463 let mut chunks = vec![];
464
465 let head_abandoned = false;
466 let head_all_inflight = false;
467 while remaining != 0 {
468 let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); let user_data = raw.slice(i..i + fragment_size);
473
474 let chunk = ChunkPayloadData {
475 stream_identifier: self.stream_identifier,
476 user_data,
477 unordered,
478 beginning_fragment: i == 0,
479 ending_fragment: remaining - fragment_size == 0,
480 immediate_sack: false,
481 payload_type: ppi,
482 stream_sequence_number: self.sequence_number,
483 abandoned: head_abandoned, all_inflight: head_all_inflight, ..Default::default()
486 };
487
488 chunks.push(chunk);
489
490 remaining -= fragment_size;
491 i += fragment_size;
492 }
493
494 if !unordered {
499 self.sequence_number = self.sequence_number.wrapping_add(1);
500 }
501
502 let old_amount = self.buffered_amount;
503 let n_bytes_added = raw.len();
504 self.buffered_amount += raw.len();
505 let new_amount = self.buffered_amount;
506
507 trace!(
508 "[{}] new_amount = {}, old_amount = {}, buffered_amount_high = {}, n_bytes_added = {}",
509 self.side, new_amount, old_amount, self.buffered_amount_high, n_bytes_added,
510 );
511
512 let is_buffered_amount_high =
513 old_amount < self.buffered_amount_high && new_amount >= self.buffered_amount_high;
514
515 (is_buffered_amount_high, chunks)
516 }
517
518 pub(crate) fn on_buffer_released(&mut self, n_bytes_released: i64) -> bool {
521 if n_bytes_released <= 0 {
522 return false;
523 }
524
525 let old_amount = self.buffered_amount;
526 let new_amount = if old_amount < n_bytes_released as usize {
527 self.buffered_amount = 0;
528 error!(
529 "[{}] released buffer size {} should be <= {}",
530 self.side, n_bytes_released, 0,
531 );
532 0
533 } else {
534 self.buffered_amount -= n_bytes_released as usize;
535
536 old_amount - n_bytes_released as usize
537 };
538
539 trace!(
540 "[{}] new_amount = {}, old_amount = {}, buffered_amount_low = {}, n_bytes_released = {}",
541 self.side, new_amount, old_amount, self.buffered_amount_low, n_bytes_released,
542 );
543
544 old_amount > self.buffered_amount_low && new_amount <= self.buffered_amount_low
545 }
546
547 pub(crate) fn get_num_bytes_in_reassembly_queue(&self) -> usize {
548 self.reassembly_queue.get_num_bytes()
550 }
551}