1use std::{cell::Cell, cmp, fmt, future::poll_fn, mem, ops, rc::Rc, task::Context, task::Poll};
2
3use ntex_bytes::Bytes;
4use ntex_http::{HeaderMap, StatusCode, header::CONTENT_LENGTH};
5use ntex_util::task::LocalWaker;
6
7use crate::error::{OperationError, StreamError};
8use crate::frame::{
9 Data, Headers, PseudoHeaders, Reason, Reset, StreamId, WindowSize, WindowUpdate,
10};
11use crate::{connection::Connection, frame, message::Message, window::Window};
12
13pub struct Stream(StreamRef);
15
16#[derive(Debug)]
18pub struct Capacity {
19 size: Cell<u32>,
20 stream: Rc<StreamState>,
21}
22
23impl Capacity {
24 fn new(size: u32, stream: &Rc<StreamState>) -> Self {
25 stream.add_recv_capacity(size);
26
27 Self {
28 size: Cell::new(size),
29 stream: stream.clone(),
30 }
31 }
32
33 #[inline]
34 pub fn size(&self) -> usize {
36 self.size.get() as usize
37 }
38
39 pub fn consume(&self, sz: u32) {
43 let size = self.size.get();
44 if let Some(sz) = size.checked_sub(sz) {
45 log::trace!(
46 "{}: {:?} capacity consumed from {} to {}",
47 self.stream.tag(),
48 self.stream.id,
49 size,
50 sz
51 );
52 self.size.set(sz);
53 self.stream.consume_capacity(size - sz);
54 } else {
55 panic!("Capacity overflow");
56 }
57 }
58}
59
60impl ops::Add for Capacity {
62 type Output = Self;
63
64 fn add(self, other: Self) -> Self {
65 if Rc::ptr_eq(&self.stream, &other.stream) {
66 let size = Cell::new(self.size.get() + other.size.get());
67 self.size.set(0);
68 other.size.set(0);
69 Self {
70 size,
71 stream: self.stream.clone(),
72 }
73 } else {
74 panic!("Cannot add capacity from different streams");
75 }
76 }
77}
78
79impl ops::AddAssign for Capacity {
81 fn add_assign(&mut self, other: Self) {
82 if Rc::ptr_eq(&self.stream, &other.stream) {
83 let size = self.size.get() + other.size.get();
84 self.size.set(size);
85 other.size.set(0);
86 } else {
87 panic!("Cannot add capacity from different streams");
88 }
89 }
90}
91
92impl Drop for Capacity {
93 fn drop(&mut self) {
94 let size = self.size.get();
95 if size > 0 {
96 self.stream.consume_capacity(size);
97 }
98 }
99}
100
101#[derive(Debug, Copy, Clone, PartialEq, Eq)]
103pub enum ContentLength {
104 Omitted,
105 Head,
106 Remaining(u64),
107}
108
109#[derive(Clone, Debug)]
110pub struct StreamRef(pub(crate) Rc<StreamState>);
111
112bitflags::bitflags! {
113 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
114 struct StreamFlags: u8 {
115 const REMOTE = 0b0000_0001;
116 const FAILED = 0b0000_0010;
117 const DISCONNECT_ON_DROP = 0b0000_0100;
118 const WAIT_FOR_CAPACITY = 0b0000_1000;
119 }
120}
121
122pub(crate) struct StreamState {
123 id: StreamId,
125 flags: Cell<StreamFlags>,
126 content_length: Cell<ContentLength>,
127 recv: Cell<HalfState>,
129 recv_window: Cell<Window>,
130 recv_size: Cell<u32>,
131 send: Cell<HalfState>,
133 send_window: Cell<Window>,
134 send_cap: LocalWaker,
135 send_reset: LocalWaker,
136 pub(crate) con: Connection,
138 error: Cell<Option<OperationError>>,
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub(crate) enum HalfState {
144 Idle,
145 Payload,
146 Closed(Option<Reason>),
147}
148
149impl HalfState {
150 pub(crate) fn is_closed(&self) -> bool {
151 matches!(self, HalfState::Closed(_))
152 }
153}
154
155impl StreamState {
156 fn tag(&self) -> &'static str {
157 self.con.tag()
158 }
159
160 fn state_send_payload(&self) {
161 self.send.set(HalfState::Payload);
162 }
163
164 fn state_send_close(&self, reason: Option<Reason>) {
165 log::trace!(
166 "{}: {:?} send side is closed with reason {:?}",
167 self.tag(),
168 self.id,
169 reason
170 );
171 self.send.set(HalfState::Closed(reason));
172 self.send_cap.wake();
173 self.review_state();
174 }
175
176 fn state_recv_payload(&self) {
177 self.recv.set(HalfState::Payload);
178 }
179
180 fn state_recv_close(&self, reason: Option<Reason>) {
181 log::trace!("{}: {:?} receive side is closed", self.tag(), self.id);
182 self.recv.set(HalfState::Closed(reason));
183 self.review_state();
184 }
185
186 fn reset_stream(&self, reason: Option<Reason>) {
187 self.recv.set(HalfState::Closed(reason));
188 self.send.set(HalfState::Closed(None));
189 if let Some(reason) = reason {
190 self.error.set(Some(OperationError::LocalReset(reason)));
191 }
192 self.review_state();
193 }
194
195 fn remote_reset_stream(&self, reason: Reason) {
196 self.recv.set(HalfState::Closed(None));
197 self.send.set(HalfState::Closed(Some(reason)));
198 self.error.set(Some(OperationError::RemoteReset(reason)));
199 self.review_state();
200 }
201
202 fn failed(&self, err: OperationError) {
203 if !self.recv.get().is_closed() {
204 self.recv.set(HalfState::Closed(None));
205 }
206 if !self.send.get().is_closed() {
207 self.send.set(HalfState::Closed(None));
208 }
209 self.error.set(Some(err));
210 self.insert_flag(StreamFlags::FAILED);
211 self.review_state();
212 }
213
214 fn insert_flag(&self, f: StreamFlags) {
215 let mut flags = self.flags.get();
216 flags.insert(f);
217 self.flags.set(flags);
218 }
219
220 fn remove_flag(&self, f: StreamFlags) {
221 let mut flags = self.flags.get();
222 flags.remove(f);
223 self.flags.set(flags);
224 }
225
226 fn check_error(&self) -> Result<(), OperationError> {
227 if let Some(err) = self.error.take() {
228 self.error.set(Some(err.clone()));
229 Err(err)
230 } else {
231 Ok(())
232 }
233 }
234
235 fn review_state(&self) {
236 if self.recv.get().is_closed() {
237 self.send_reset.wake();
238
239 if let HalfState::Closed(reason) = self.send.get() {
240 if let Some(reason) = reason {
242 log::trace!(
243 "{}: {:?} is closed with remote reset {:?}, dropping stream",
244 self.tag(),
245 self.id,
246 reason
247 );
248 } else {
249 log::trace!(
250 "{}: {:?} both sides are closed, dropping stream",
251 self.tag(),
252 self.id
253 );
254 }
255 self.send_cap.wake();
256 self.con.drop_stream(self.id);
257 }
258 }
259 }
260
261 fn add_recv_capacity(&self, size: u32) {
263 let cap = self.recv_size.get();
264 self.recv_size.set(cap + size);
265 self.recv_window.set(self.recv_window.get().dec(size));
266 log::trace!(
267 "{}: {:?} capacity incresed from {} to {}",
268 self.tag(),
269 self.id,
270 cap,
271 cap + size
272 );
273
274 self.con.add_recv_capacity(size);
276 }
277
278 fn consume_capacity(&self, size: u32) {
280 let cap = self.recv_size.get();
281 let size = cap - size;
282 log::trace!(
283 "{}: {:?} capacity decresed from {} to {}",
284 self.tag(),
285 self.id,
286 cap,
287 size
288 );
289
290 self.recv_size.set(size);
291 let mut window = self.recv_window.get();
292 if let Some(val) = window.update(
293 size,
294 self.con.config().window_sz,
295 self.con.config().window_sz_threshold,
296 ) {
297 log::trace!(
298 "{}: {:?} capacity decresed below threshold {} increase by {} ({})",
299 self.tag(),
300 self.id,
301 self.con.config().window_sz_threshold,
302 val,
303 self.con.config().window_sz,
304 );
305 self.recv_window.set(window);
306 self.con.encode(WindowUpdate::new(self.id, val));
307 }
308 }
309}
310
311impl StreamRef {
312 pub(crate) fn new(id: StreamId, remote: bool, con: Connection) -> Self {
313 let recv_window = if con.settings_processed() {
316 Window::new(con.config().window_sz as i32)
317 } else {
318 Window::new(frame::DEFAULT_INITIAL_WINDOW_SIZE as i32)
319 };
320 let send_window = Window::new(con.remote_window_size() as i32);
321
322 StreamRef(Rc::new(StreamState {
323 id,
324 con,
325 recv: Cell::new(HalfState::Idle),
326 recv_window: Cell::new(recv_window),
327 recv_size: Cell::new(0),
328 send: Cell::new(HalfState::Idle),
329 send_window: Cell::new(send_window),
330 send_cap: LocalWaker::new(),
331 send_reset: LocalWaker::new(),
332 error: Cell::new(None),
333 content_length: Cell::new(ContentLength::Omitted),
334 flags: Cell::new(if remote {
335 StreamFlags::REMOTE
336 } else {
337 StreamFlags::empty()
338 }),
339 }))
340 }
341
342 #[inline]
343 pub fn id(&self) -> StreamId {
344 self.0.id
345 }
346
347 #[inline]
348 pub fn tag(&self) -> &'static str {
349 self.0.con.tag()
350 }
351
352 #[inline]
354 pub fn is_remote(&self) -> bool {
355 self.0.flags.get().contains(StreamFlags::REMOTE)
356 }
357
358 #[inline]
360 pub fn is_failed(&self) -> bool {
361 self.0.flags.get().contains(StreamFlags::FAILED)
362 }
363
364 pub(crate) fn send_state(&self) -> HalfState {
365 self.0.send.get()
366 }
367
368 pub(crate) fn recv_state(&self) -> HalfState {
369 self.0.recv.get()
370 }
371
372 pub(crate) fn disconnect_on_drop(&self) {
373 self.0.insert_flag(StreamFlags::DISCONNECT_ON_DROP);
374 }
375
376 pub(crate) fn is_disconnect_on_drop(&self) -> bool {
377 self.0.flags.get().contains(StreamFlags::DISCONNECT_ON_DROP)
378 }
379
380 #[inline]
385 pub fn reset(&self, reason: Reason) -> bool {
386 if !self.0.recv.get().is_closed() || !self.0.send.get().is_closed() {
387 self.0.con.encode(Reset::new(self.0.id, reason));
388 self.0.reset_stream(Some(reason));
389 true
390 } else {
391 false
392 }
393 }
394
395 #[inline]
397 pub fn empty_capacity(&self) -> Capacity {
398 Capacity {
399 size: Cell::new(0),
400 stream: self.0.clone(),
401 }
402 }
403
404 #[inline]
405 pub(crate) fn into_stream(self) -> Stream {
406 Stream(self)
407 }
408
409 pub(crate) fn send_headers(&self, mut hdrs: Headers) {
410 hdrs.set_end_headers();
411 if hdrs.is_end_stream() {
412 self.0.state_send_close(None);
413 } else {
414 self.0.state_send_payload();
415 }
416 log::trace!(
417 "{}: send headers {:#?} eos: {:?}",
418 self.tag(),
419 hdrs,
420 hdrs.is_end_stream()
421 );
422
423 if hdrs
424 .pseudo()
425 .status
426 .is_some_and(|status| status.is_informational())
427 {
428 self.0.content_length.set(ContentLength::Head)
429 }
430 self.0.con.encode(hdrs);
431 }
432
433 pub(crate) fn set_failed(&self, reason: Option<Reason>) {
434 self.0.reset_stream(reason);
435 }
436
437 pub(crate) fn set_go_away(&self, reason: Reason) {
438 self.0.remote_reset_stream(reason)
439 }
440
441 pub(crate) fn set_failed_stream(&self, err: OperationError) {
442 self.0.failed(err);
443 }
444
445 pub(crate) fn recv_headers(&self, hdrs: Headers) -> Result<Option<Message>, StreamError> {
446 log::trace!(
447 "{}: processing HEADERS for {:?}:\n{:#?}\nrecv_state:{:?}, send_state: {:?}",
448 self.tag(),
449 self.0.id,
450 hdrs,
451 self.0.recv.get(),
452 self.0.send.get(),
453 );
454
455 match self.0.recv.get() {
456 HalfState::Idle => {
457 let eof = hdrs.is_end_stream();
458 if eof {
459 self.0.state_recv_close(None);
460 } else {
461 self.0.state_recv_payload();
462 }
463 let (pseudo, headers) = hdrs.into_parts();
464
465 if self.0.content_length.get() != ContentLength::Head {
466 if let Some(content_length) = headers.get(CONTENT_LENGTH) {
467 if let Some(v) = parse_u64(content_length.as_bytes()) {
468 self.0.content_length.set(ContentLength::Remaining(v));
469 } else {
470 proto_err!(stream: "could not parse content-length; stream={:?}", self.0.id);
471 return Err(StreamError::InvalidContentLength);
472 }
473 }
474 }
475 Ok(Some(Message::new(pseudo, headers, eof, self)))
476 }
477 HalfState::Payload => {
478 if !hdrs.is_end_stream() {
480 Err(StreamError::TrailersWithoutEos)
481 } else {
482 self.0.state_recv_close(None);
483 Ok(Some(Message::trailers(hdrs.into_fields(), self)))
484 }
485 }
486 HalfState::Closed(_) => Err(StreamError::Closed),
487 }
488 }
489
490 pub(crate) fn recv_data(&self, data: Data) -> Result<Option<Message>, StreamError> {
491 let cap = Capacity::new(data.payload().len() as u32, &self.0);
492 log::trace!(
493 "{}: processing DATA frame for {:?}, len: {:?}",
494 self.tag(),
495 self.0.id,
496 data.payload().len()
497 );
498
499 match self.0.recv.get() {
500 HalfState::Payload => {
501 let eof = data.is_end_stream();
502
503 match self.0.content_length.get() {
505 ContentLength::Remaining(rem) => {
506 match rem.checked_sub(data.payload().len() as u64) {
507 Some(val) => {
508 self.0.content_length.set(ContentLength::Remaining(val));
509 if eof && val != 0 {
510 return Err(StreamError::WrongPayloadLength);
511 }
512 }
513 None => return Err(StreamError::WrongPayloadLength),
514 }
515 }
516 ContentLength::Head => {
517 if !data.payload().is_empty() {
518 return Err(StreamError::NonEmptyPayload);
519 }
520 }
521 _ => (),
522 }
523
524 if eof {
525 self.0.state_recv_close(None);
526 Ok(Some(Message::eof_data(data.into_payload(), self)))
527 } else {
528 Ok(Some(Message::data(data.into_payload(), cap, self)))
529 }
530 }
531 HalfState::Idle => Err(StreamError::Idle("DATA framed received")),
532 HalfState::Closed(_) => Err(StreamError::Closed),
533 }
534 }
535
536 pub(crate) fn recv_rst_stream(&self, frm: &Reset) {
537 self.0.remote_reset_stream(frm.reason())
538 }
539
540 pub(crate) fn recv_window_update_connection(&self) {
541 if self.0.flags.get().contains(StreamFlags::WAIT_FOR_CAPACITY)
542 && self.0.send_window.get().window_size() > 0
543 {
544 self.0.send_cap.wake();
545 }
546 }
547
548 pub(crate) fn recv_window_update(&self, frm: WindowUpdate) -> Result<(), StreamError> {
549 if frm.size_increment() == 0 {
550 Err(StreamError::WindowZeroUpdateValue)
551 } else {
552 let window = self
553 .0
554 .send_window
555 .get()
556 .inc(frm.size_increment())
557 .map_err(|_| StreamError::WindowOverflowed)?;
558 self.0.send_window.set(window);
559
560 if window.window_size() > 0 {
561 self.0.send_cap.wake();
562 }
563 Ok(())
564 }
565 }
566
567 pub(crate) fn update_send_window(&self, upd: i32) -> Result<(), StreamError> {
568 let orig = self.0.send_window.get();
569 let window = match upd.cmp(&0) {
570 cmp::Ordering::Less => orig.dec(upd.unsigned_abs()), cmp::Ordering::Greater => orig
572 .inc(upd as u32)
573 .map_err(|_| StreamError::WindowOverflowed)?,
574 cmp::Ordering::Equal => return Ok(()),
575 };
576 log::trace!(
577 "{}: Updating send window size from {} to {}",
578 self.tag(),
579 orig.window_size,
580 window.window_size
581 );
582 self.0.send_window.set(window);
583 Ok(())
584 }
585
586 pub(crate) fn update_recv_window(&self, upd: i32) -> Result<Option<WindowSize>, StreamError> {
587 let mut window = match upd.cmp(&0) {
588 cmp::Ordering::Less => self.0.recv_window.get().dec(upd.unsigned_abs()), cmp::Ordering::Greater => self
590 .0
591 .recv_window
592 .get()
593 .inc(upd as u32)
594 .map_err(|_| StreamError::WindowOverflowed)?,
595 cmp::Ordering::Equal => return Ok(None),
596 };
597 if let Some(val) = window.update(
598 self.0.recv_size.get(),
599 self.0.con.config().window_sz,
600 self.0.con.config().window_sz_threshold,
601 ) {
602 self.0.recv_window.set(window);
603 Ok(Some(val))
604 } else {
605 self.0.recv_window.set(window);
606 Ok(None)
607 }
608 }
609
610 pub fn send_response(
612 &self,
613 status: StatusCode,
614 headers: HeaderMap,
615 eof: bool,
616 ) -> Result<(), OperationError> {
617 self.0.check_error()?;
618
619 match self.0.send.get() {
620 HalfState::Idle => {
621 let pseudo = PseudoHeaders::response(status);
622 let mut hdrs = Headers::new(self.0.id, pseudo, headers, eof);
623
624 if eof {
625 hdrs.set_end_stream();
626 self.0.state_send_close(None);
627 } else {
628 self.0.state_send_payload();
629 }
630 self.0.con.encode(hdrs);
631 Ok(())
632 }
633 HalfState::Payload => Err(OperationError::Payload),
634 HalfState::Closed(r) => Err(OperationError::Closed(r)),
635 }
636 }
637
638 pub async fn send_payload(&self, mut res: Bytes, eof: bool) -> Result<(), OperationError> {
640 match self.0.send.get() {
641 HalfState::Payload => {
642 self.0.check_error()?;
644
645 log::trace!(
646 "{}: {:?} sending {} bytes, eof: {}, send: {:?}",
647 self.0.tag(),
648 self.0.id,
649 res.len(),
650 eof,
651 self.0.send.get()
652 );
653
654 if eof && res.is_empty() {
656 let mut data = Data::new(self.0.id, Bytes::new());
657 data.set_end_stream();
658 self.0.state_send_close(None);
659
660 self.0.con.encode(data);
662 return Ok(());
663 }
664
665 loop {
666 let win = self.available_send_capacity() as usize;
668 if win > 0 {
669 let size =
670 cmp::min(win, cmp::min(res.len(), self.0.con.remote_frame_size()));
671 let mut data = if size >= res.len() {
672 Data::new(self.0.id, mem::replace(&mut res, Bytes::new()))
673 } else {
674 log::trace!(
675 "{}: {:?} sending {} out of {} bytes",
676 self.0.tag(),
677 self.0.id,
678 size,
679 res.len()
680 );
681 Data::new(self.0.id, res.split_to(size))
682 };
683 if eof && res.is_empty() {
684 data.set_end_stream();
685 self.0.state_send_close(None);
686 }
687
688 self.0
690 .send_window
691 .set(self.0.send_window.get().dec(size as u32));
692
693 self.0.con.consume_send_window(size as u32);
695
696 self.0.con.encode(data);
698 if res.is_empty() {
699 return Ok(());
700 }
701 } else {
702 log::trace!(
703 "{}: Not enough sending capacity for {:?} remaining {:?}",
704 self.0.tag(),
705 self.0.id,
706 res.len()
707 );
708 self.send_capacity().await?;
710 }
711 }
712 }
713 HalfState::Idle => Err(OperationError::Idle),
714 HalfState::Closed(reason) => Err(OperationError::Closed(reason)),
715 }
716 }
717
718 pub fn send_trailers(&self, map: HeaderMap) {
720 if self.0.send.get() == HalfState::Payload {
721 let mut hdrs = Headers::trailers(self.0.id, map);
722 hdrs.set_end_headers();
723 hdrs.set_end_stream();
724 self.0.con.encode(hdrs);
725 self.0.state_send_close(None);
726 }
727 }
728
729 pub fn available_send_capacity(&self) -> WindowSize {
730 cmp::min(
731 self.0.send_window.get().window_size(),
732 self.0.con.send_window_size(),
733 )
734 }
735
736 pub async fn send_capacity(&self) -> Result<WindowSize, OperationError> {
737 poll_fn(|cx| self.poll_send_capacity(cx)).await
738 }
739
740 pub fn poll_send_capacity(&self, cx: &Context<'_>) -> Poll<Result<WindowSize, OperationError>> {
742 self.0.check_error()?;
743 self.0.con.check_error()?;
744
745 let win = self.available_send_capacity();
746 if win > 0 {
747 self.0.remove_flag(StreamFlags::WAIT_FOR_CAPACITY);
748 Poll::Ready(Ok(win))
749 } else {
750 self.0.insert_flag(StreamFlags::WAIT_FOR_CAPACITY);
751 self.0.send_cap.register(cx.waker());
752 Poll::Pending
753 }
754 }
755
756 pub fn poll_send_reset(&self, cx: &Context<'_>) -> Poll<Result<(), OperationError>> {
758 if self.0.send.get().is_closed() {
759 Poll::Ready(Ok(()))
760 } else {
761 self.0.check_error()?;
762 self.0.con.check_error()?;
763 self.0.send_reset.register(cx.waker());
764 Poll::Pending
765 }
766 }
767}
768
769impl PartialEq for StreamRef {
770 fn eq(&self, other: &StreamRef) -> bool {
771 Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0)
772 }
773}
774
775impl ops::Deref for Stream {
776 type Target = StreamRef;
777
778 #[inline]
779 fn deref(&self) -> &Self::Target {
780 &self.0
781 }
782}
783
784impl Drop for Stream {
785 fn drop(&mut self) {
786 self.0.reset(Reason::CANCEL);
787 }
788}
789
790impl fmt::Debug for Stream {
791 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
792 let mut builder = f.debug_struct("Stream");
793 builder
794 .field("stream_id", &self.0.0.id)
795 .field("recv_state", &self.0.0.recv.get())
796 .field("send_state", &self.0.0.send.get())
797 .finish()
798 }
799}
800
801impl fmt::Debug for StreamState {
802 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
803 let mut builder = f.debug_struct("StreamState");
804 builder
805 .field("id", &self.id)
806 .field("recv", &self.recv.get())
807 .field("recv_window", &self.recv_window.get())
808 .field("recv_size", &self.recv_size.get())
809 .field("send", &self.send.get())
810 .field("send_window", &self.send_window.get())
811 .field("flags", &self.flags.get())
812 .finish()
813 }
814}
815
816pub fn parse_u64(src: &[u8]) -> Option<u64> {
817 if src.len() > 19 {
818 None
820 } else {
821 let mut ret = 0;
822 for &d in src {
823 if !d.is_ascii_digit() {
824 return None;
825 }
826
827 ret *= 10;
828 ret += (d - b'0') as u64;
829 }
830
831 Some(ret)
832 }
833}