1use std::{cell::Cell, cmp, fmt, future::poll_fn, mem, ops, rc::Rc, task::Context, task::Poll};
2
3use ntex_bytes::Bytes;
4use ntex_http::{HeaderMap, StatusCode, header::CONTENT_LENGTH};
5use ntex_util::task::LocalWaker;
6
7use crate::error::{OperationError, StreamError};
8use crate::frame::{
9 Data, Headers, PseudoHeaders, Reason, Reset, StreamId, WindowSize, WindowUpdate,
10};
11use crate::{connection::Connection, frame, message::Message, window::Window};
12
13pub struct Stream(StreamRef);
15
16#[derive(Debug)]
18pub struct Capacity {
19 size: Cell<u32>,
20 stream: Rc<StreamState>,
21}
22
23impl Capacity {
24 fn new(size: u32, stream: &Rc<StreamState>) -> Self {
25 stream.add_recv_capacity(size);
26
27 Self {
28 size: Cell::new(size),
29 stream: stream.clone(),
30 }
31 }
32
33 #[inline]
34 pub fn size(&self) -> usize {
36 self.size.get() as usize
37 }
38
39 pub fn consume(&self, sz: u32) {
45 let size = self.size.get();
46 if let Some(sz) = size.checked_sub(sz) {
47 log::trace!(
48 "{}: {:?} capacity consumed from {} to {}",
49 self.stream.tag(),
50 self.stream.id,
51 size,
52 sz
53 );
54 self.size.set(sz);
55 self.stream.consume_capacity(size - sz);
56 } else {
57 panic!("Capacity overflow");
58 }
59 }
60}
61
62impl ops::Add for Capacity {
64 type Output = Self;
65
66 fn add(self, other: Self) -> Self {
67 if Rc::ptr_eq(&self.stream, &other.stream) {
68 let size = Cell::new(self.size.get() + other.size.get());
69 self.size.set(0);
70 other.size.set(0);
71 Self {
72 size,
73 stream: self.stream.clone(),
74 }
75 } else {
76 panic!("Cannot add capacity from different streams");
77 }
78 }
79}
80
81impl ops::AddAssign for Capacity {
83 fn add_assign(&mut self, other: Self) {
84 if Rc::ptr_eq(&self.stream, &other.stream) {
85 let size = self.size.get() + other.size.get();
86 self.size.set(size);
87 other.size.set(0);
88 } else {
89 panic!("Cannot add capacity from different streams");
90 }
91 }
92}
93
94impl Drop for Capacity {
95 fn drop(&mut self) {
96 let size = self.size.get();
97 if size > 0 {
98 self.stream.consume_capacity(size);
99 }
100 }
101}
102
103#[derive(Debug, Copy, Clone, PartialEq, Eq)]
105pub(super) enum ContentLength {
106 Omitted,
107 Head,
108 Remaining(u64),
109}
110
111#[derive(Clone, Debug)]
112pub struct StreamRef(pub(crate) Rc<StreamState>);
113
114bitflags::bitflags! {
115 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
116 struct StreamFlags: u8 {
117 const REMOTE = 0b0000_0001;
118 const FAILED = 0b0000_0010;
119 const DISCONNECT_ON_DROP = 0b0000_0100;
120 const WAIT_FOR_CAPACITY = 0b0000_1000;
121 }
122}
123
124pub(crate) struct StreamState {
125 id: StreamId,
127 flags: Cell<StreamFlags>,
128 content_length: Cell<ContentLength>,
129 recv: Cell<HalfState>,
131 recv_window: Cell<Window>,
132 recv_size: Cell<u32>,
133 send: Cell<HalfState>,
135 send_window: Cell<Window>,
136 send_cap: LocalWaker,
137 send_reset: LocalWaker,
138 pub(crate) con: Connection,
140 error: Cell<Option<OperationError>>,
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub(crate) enum HalfState {
146 Idle,
147 Payload,
148 Closed(Option<Reason>),
149}
150
151impl HalfState {
152 pub(crate) fn is_closed(self) -> bool {
153 matches!(self, HalfState::Closed(_))
154 }
155}
156
157impl StreamState {
158 fn tag(&self) -> &'static str {
159 self.con.tag()
160 }
161
162 fn state_send_payload(&self) {
163 self.send.set(HalfState::Payload);
164 }
165
166 fn state_send_close(&self, reason: Option<Reason>) {
167 log::trace!(
168 "{}: {:?} send side is closed with reason {:?}",
169 self.tag(),
170 self.id,
171 reason
172 );
173 self.send.set(HalfState::Closed(reason));
174 self.send_cap.wake();
175 self.review_state();
176 }
177
178 fn state_recv_payload(&self) {
179 self.recv.set(HalfState::Payload);
180 }
181
182 fn state_recv_close(&self, reason: Option<Reason>) {
183 log::trace!("{}: {:?} receive side is closed", self.tag(), self.id);
184 self.recv.set(HalfState::Closed(reason));
185 self.review_state();
186 }
187
188 fn reset_stream(&self, reason: Option<Reason>) {
189 self.recv.set(HalfState::Closed(reason));
190 self.send.set(HalfState::Closed(None));
191 if let Some(reason) = reason {
192 self.error.set(Some(OperationError::LocalReset(reason)));
193 }
194 self.review_state();
195 }
196
197 fn remote_reset_stream(&self, reason: Reason) {
198 self.recv.set(HalfState::Closed(None));
199 self.send.set(HalfState::Closed(Some(reason)));
200 self.error.set(Some(OperationError::RemoteReset(reason)));
201 self.review_state();
202 }
203
204 fn failed(&self, err: OperationError) {
205 if !self.recv.get().is_closed() {
206 self.recv.set(HalfState::Closed(None));
207 }
208 if !self.send.get().is_closed() {
209 self.send.set(HalfState::Closed(None));
210 }
211 self.error.set(Some(err));
212 self.insert_flag(StreamFlags::FAILED);
213 self.review_state();
214 }
215
216 fn insert_flag(&self, f: StreamFlags) {
217 let mut flags = self.flags.get();
218 flags.insert(f);
219 self.flags.set(flags);
220 }
221
222 fn remove_flag(&self, f: StreamFlags) {
223 let mut flags = self.flags.get();
224 flags.remove(f);
225 self.flags.set(flags);
226 }
227
228 fn check_error(&self) -> Result<(), OperationError> {
229 if let Some(err) = self.error.take() {
230 self.error.set(Some(err.clone()));
231 Err(err)
232 } else {
233 Ok(())
234 }
235 }
236
237 fn review_state(&self) {
238 if self.recv.get().is_closed() {
239 self.send_reset.wake();
240
241 if let HalfState::Closed(reason) = self.send.get() {
242 if let Some(reason) = reason {
244 log::trace!(
245 "{}: {:?} is closed with remote reset {:?}, dropping stream",
246 self.tag(),
247 self.id,
248 reason
249 );
250 } else {
251 log::trace!(
252 "{}: {:?} both sides are closed, dropping stream",
253 self.tag(),
254 self.id
255 );
256 }
257 self.send_cap.wake();
258 self.con.drop_stream(self.id);
259 }
260 }
261 }
262
263 fn add_recv_capacity(&self, size: u32) {
265 let cap = self.recv_size.get();
266 self.recv_size.set(cap + size);
267 self.recv_window.set(self.recv_window.get().dec(size));
268 log::trace!(
269 "{}: {:?} capacity incresed from {} to {}",
270 self.tag(),
271 self.id,
272 cap,
273 cap + size
274 );
275
276 self.con.add_recv_capacity(size);
278 }
279
280 fn consume_capacity(&self, size: u32) {
282 let cap = self.recv_size.get();
283 let size = cap - size;
284 log::trace!(
285 "{}: {:?} capacity decresed from {} to {}",
286 self.tag(),
287 self.id,
288 cap,
289 size
290 );
291
292 self.recv_size.set(size);
293 let mut window = self.recv_window.get();
294 if let Some(val) = window.update(
295 size,
296 self.con.config().window_sz,
297 self.con.config().window_sz_threshold,
298 ) {
299 log::trace!(
300 "{}: {:?} capacity decresed below threshold {} increase by {} ({})",
301 self.tag(),
302 self.id,
303 self.con.config().window_sz_threshold,
304 val,
305 self.con.config().window_sz,
306 );
307 self.recv_window.set(window);
308 self.con.encode(WindowUpdate::new(self.id, val));
309 }
310 }
311}
312
313impl StreamRef {
314 pub(crate) fn new(id: StreamId, remote: bool, con: Connection) -> Self {
315 let recv_window = if con.settings_processed() {
318 Window::new(con.config().window_sz)
319 } else {
320 Window::new(frame::DEFAULT_INITIAL_WINDOW_SIZE)
321 };
322 let send_window = Window::new(con.remote_window_size());
323
324 StreamRef(Rc::new(StreamState {
325 id,
326 con,
327 recv: Cell::new(HalfState::Idle),
328 recv_window: Cell::new(recv_window),
329 recv_size: Cell::new(0),
330 send: Cell::new(HalfState::Idle),
331 send_window: Cell::new(send_window),
332 send_cap: LocalWaker::new(),
333 send_reset: LocalWaker::new(),
334 error: Cell::new(None),
335 content_length: Cell::new(ContentLength::Omitted),
336 flags: Cell::new(if remote {
337 StreamFlags::REMOTE
338 } else {
339 StreamFlags::empty()
340 }),
341 }))
342 }
343
344 #[inline]
345 pub fn id(&self) -> StreamId {
346 self.0.id
347 }
348
349 #[inline]
350 pub fn tag(&self) -> &'static str {
351 self.0.con.tag()
352 }
353
354 #[inline]
356 pub fn is_remote(&self) -> bool {
357 self.0.flags.get().contains(StreamFlags::REMOTE)
358 }
359
360 #[inline]
362 pub fn is_failed(&self) -> bool {
363 self.0.flags.get().contains(StreamFlags::FAILED)
364 }
365
366 pub(crate) fn send_state(&self) -> HalfState {
367 self.0.send.get()
368 }
369
370 pub(crate) fn recv_state(&self) -> HalfState {
371 self.0.recv.get()
372 }
373
374 pub(crate) fn disconnect_on_drop(&self) {
375 self.0.insert_flag(StreamFlags::DISCONNECT_ON_DROP);
376 }
377
378 pub(crate) fn is_disconnect_on_drop(&self) -> bool {
379 self.0.flags.get().contains(StreamFlags::DISCONNECT_ON_DROP)
380 }
381
382 #[inline]
387 pub fn reset(&self, reason: Reason) -> bool {
388 if !self.0.recv.get().is_closed() || !self.0.send.get().is_closed() {
389 self.0.con.encode(Reset::new(self.0.id, reason));
390 self.0.reset_stream(Some(reason));
391 true
392 } else {
393 false
394 }
395 }
396
397 #[inline]
399 pub fn empty_capacity(&self) -> Capacity {
400 Capacity {
401 size: Cell::new(0),
402 stream: self.0.clone(),
403 }
404 }
405
406 #[inline]
407 pub(crate) fn into_stream(self) -> Stream {
408 Stream(self)
409 }
410
411 pub(crate) fn send_headers(&self, mut hdrs: Headers) {
412 hdrs.set_end_headers();
413 if hdrs.is_end_stream() {
414 self.0.state_send_close(None);
415 } else {
416 self.0.state_send_payload();
417 }
418 log::trace!(
419 "{}: send headers {:#?} eos: {:?}",
420 self.tag(),
421 hdrs,
422 hdrs.is_end_stream()
423 );
424
425 if hdrs
426 .pseudo()
427 .status
428 .is_some_and(|status| status.is_informational())
429 {
430 self.0.content_length.set(ContentLength::Head);
431 }
432 self.0.con.encode(hdrs);
433 }
434
435 pub(crate) fn set_go_away(&self, reason: Reason) {
436 self.0.remote_reset_stream(reason);
437 }
438
439 pub(crate) fn set_failed_stream(&self, err: OperationError) {
440 self.0.failed(err);
441 }
442
443 pub(crate) fn recv_headers(&self, hdrs: Headers) -> Result<Option<Message>, StreamError> {
444 log::trace!(
445 "{}: processing HEADERS for {:?}:\n{:#?}\nrecv_state:{:?}, send_state: {:?}",
446 self.tag(),
447 self.0.id,
448 hdrs,
449 self.0.recv.get(),
450 self.0.send.get(),
451 );
452
453 match self.0.recv.get() {
454 HalfState::Idle => {
455 let eof = hdrs.is_end_stream();
456 if eof {
457 self.0.state_recv_close(None);
458 } else {
459 self.0.state_recv_payload();
460 }
461 let (pseudo, headers) = hdrs.into_parts();
462
463 if self.0.content_length.get() != ContentLength::Head
464 && let Some(content_length) = headers.get(CONTENT_LENGTH)
465 {
466 if let Some(v) = parse_u64(content_length.as_bytes()) {
467 self.0.content_length.set(ContentLength::Remaining(v));
468 } else {
469 proto_err!(stream: "could not parse content-length; stream={:?}", self.0.id);
470 return Err(StreamError::InvalidContentLength);
471 }
472 }
473 Ok(Some(Message::new(pseudo, headers, eof, self)))
474 }
475 HalfState::Payload => {
476 if hdrs.is_end_stream() {
478 self.0.state_recv_close(None);
479 Ok(Some(Message::trailers(hdrs.into_fields(), self)))
480 } else {
481 Err(StreamError::TrailersWithoutEos)
482 }
483 }
484 HalfState::Closed(_) => Err(StreamError::Closed),
485 }
486 }
487
488 pub(crate) fn recv_data(&self, data: Data) -> Result<Option<Message>, StreamError> {
489 let cap = Capacity::new(data.payload().len() as u32, &self.0);
490 log::trace!(
491 "{}: processing DATA frame for {:?}, len: {:?}",
492 self.tag(),
493 self.0.id,
494 data.payload().len()
495 );
496
497 match self.0.recv.get() {
498 HalfState::Payload => {
499 let eof = data.is_end_stream();
500
501 match self.0.content_length.get() {
503 ContentLength::Remaining(rem) => {
504 match rem.checked_sub(data.payload().len() as u64) {
505 Some(val) => {
506 self.0.content_length.set(ContentLength::Remaining(val));
507 if eof && val != 0 {
508 return Err(StreamError::WrongPayloadLength);
509 }
510 }
511 None => return Err(StreamError::WrongPayloadLength),
512 }
513 }
514 ContentLength::Head => {
515 if !data.payload().is_empty() {
516 return Err(StreamError::NonEmptyPayload);
517 }
518 }
519 ContentLength::Omitted => (),
520 }
521
522 if eof {
523 self.0.state_recv_close(None);
524 Ok(Some(Message::eof_data(data.into_payload(), self)))
525 } else {
526 Ok(Some(Message::data(data.into_payload(), cap, self)))
527 }
528 }
529 HalfState::Idle => Err(StreamError::Idle("DATA framed received")),
530 HalfState::Closed(_) => Err(StreamError::Closed),
531 }
532 }
533
534 pub(crate) fn recv_rst_stream(&self, frm: Reset) {
535 self.0.remote_reset_stream(frm.reason());
536 }
537
538 pub(crate) fn recv_window_update_connection(&self) {
539 if self.0.flags.get().contains(StreamFlags::WAIT_FOR_CAPACITY)
540 && self.0.send_window.get().window_size() > 0
541 {
542 self.0.send_cap.wake();
543 }
544 }
545
546 pub(crate) fn recv_window_update(&self, frm: WindowUpdate) -> Result<(), StreamError> {
547 if frm.size_increment() == 0 {
548 Err(StreamError::WindowZeroUpdateValue)
549 } else {
550 let window = self
551 .0
552 .send_window
553 .get()
554 .inc(frm.size_increment())
555 .map_err(|()| StreamError::WindowOverflowed)?;
556 self.0.send_window.set(window);
557
558 if window.window_size() > 0 {
559 self.0.send_cap.wake();
560 }
561 Ok(())
562 }
563 }
564
565 pub(crate) fn update_send_window(&self, upd: i32) -> Result<(), StreamError> {
566 let orig = self.0.send_window.get();
567 let window = match upd.cmp(&0) {
568 cmp::Ordering::Less => orig.dec(upd.unsigned_abs()), cmp::Ordering::Greater => orig.inc(upd).map_err(|()| StreamError::WindowOverflowed)?,
570 cmp::Ordering::Equal => return Ok(()),
571 };
572 log::trace!(
573 "{}: Updating send window size from {} to {}",
574 self.tag(),
575 orig.window_size,
576 window.window_size
577 );
578 self.0.send_window.set(window);
579 Ok(())
580 }
581
582 pub(crate) fn update_recv_window(&self, upd: i32) -> Result<Option<WindowSize>, StreamError> {
583 let mut window = match upd.cmp(&0) {
584 cmp::Ordering::Less => self.0.recv_window.get().dec(upd.unsigned_abs()), cmp::Ordering::Greater => self
586 .0
587 .recv_window
588 .get()
589 .inc(upd)
590 .map_err(|()| StreamError::WindowOverflowed)?,
591 cmp::Ordering::Equal => return Ok(None),
592 };
593 if let Some(val) = window.update(
594 self.0.recv_size.get(),
595 self.0.con.config().window_sz,
596 self.0.con.config().window_sz_threshold,
597 ) {
598 self.0.recv_window.set(window);
599 Ok(Some(val))
600 } else {
601 self.0.recv_window.set(window);
602 Ok(None)
603 }
604 }
605
606 pub fn send_response(
608 &self,
609 status: StatusCode,
610 headers: HeaderMap,
611 eof: bool,
612 ) -> Result<(), OperationError> {
613 self.0.check_error()?;
614
615 match self.0.send.get() {
616 HalfState::Idle => {
617 let pseudo = PseudoHeaders::response(status);
618 let mut hdrs = Headers::new(self.0.id, pseudo, headers, eof);
619
620 if eof {
621 hdrs.set_end_stream();
622 self.0.state_send_close(None);
623 } else {
624 self.0.state_send_payload();
625 }
626 self.0.con.encode(hdrs);
627 Ok(())
628 }
629 HalfState::Payload => Err(OperationError::Payload),
630 HalfState::Closed(r) => Err(OperationError::Closed(r)),
631 }
632 }
633
634 pub async fn send_payload(&self, mut res: Bytes, eof: bool) -> Result<(), OperationError> {
636 match self.0.send.get() {
637 HalfState::Payload => {
638 self.0.check_error()?;
640
641 log::trace!(
642 "{}: {:?} sending {} bytes, eof: {}, send: {:?}",
643 self.0.tag(),
644 self.0.id,
645 res.len(),
646 eof,
647 self.0.send.get()
648 );
649
650 if eof && res.is_empty() {
652 let mut data = Data::new(self.0.id, Bytes::new());
653 data.set_end_stream();
654 self.0.state_send_close(None);
655
656 self.0.con.encode(data);
658 return Ok(());
659 }
660
661 loop {
662 let win = self.available_send_capacity() as usize;
664 if win > 0 {
665 let size =
666 cmp::min(win, cmp::min(res.len(), self.0.con.remote_frame_size()));
667 let mut data = if size >= res.len() {
668 Data::new(self.0.id, mem::replace(&mut res, Bytes::new()))
669 } else {
670 log::trace!(
671 "{}: {:?} sending {} out of {} bytes",
672 self.0.tag(),
673 self.0.id,
674 size,
675 res.len()
676 );
677 Data::new(self.0.id, res.split_to(size))
678 };
679 if eof && res.is_empty() {
680 data.set_end_stream();
681 self.0.state_send_close(None);
682 }
683
684 self.0
686 .send_window
687 .set(self.0.send_window.get().dec(size as u32));
688
689 self.0.con.consume_send_window(size as u32);
691
692 self.0.con.encode(data);
694 if res.is_empty() {
695 return Ok(());
696 }
697 } else {
698 log::trace!(
699 "{}: Not enough sending capacity for {:?} remaining {:?}",
700 self.0.tag(),
701 self.0.id,
702 res.len()
703 );
704 self.send_capacity().await?;
706 }
707 }
708 }
709 HalfState::Idle => Err(OperationError::Idle),
710 HalfState::Closed(reason) => Err(OperationError::Closed(reason)),
711 }
712 }
713
714 pub fn send_trailers(&self, map: HeaderMap) {
716 if self.0.send.get() == HalfState::Payload {
717 let mut hdrs = Headers::trailers(self.0.id, map);
718 hdrs.set_end_headers();
719 hdrs.set_end_stream();
720 self.0.con.encode(hdrs);
721 self.0.state_send_close(None);
722 }
723 }
724
725 pub fn available_send_capacity(&self) -> WindowSize {
726 cmp::min(
727 self.0.send_window.get().window_size(),
728 self.0.con.send_window_size(),
729 )
730 }
731
732 pub async fn send_capacity(&self) -> Result<WindowSize, OperationError> {
733 poll_fn(|cx| self.poll_send_capacity(cx)).await
734 }
735
736 pub fn poll_send_capacity(&self, cx: &Context<'_>) -> Poll<Result<WindowSize, OperationError>> {
738 self.0.check_error()?;
739 self.0.con.check_error()?;
740
741 let win = self.available_send_capacity();
742 if win > 0 {
743 self.0.remove_flag(StreamFlags::WAIT_FOR_CAPACITY);
744 Poll::Ready(Ok(win))
745 } else {
746 self.0.insert_flag(StreamFlags::WAIT_FOR_CAPACITY);
747 self.0.send_cap.register(cx.waker());
748 Poll::Pending
749 }
750 }
751
752 pub fn poll_send_reset(&self, cx: &Context<'_>) -> Poll<Result<(), OperationError>> {
754 if self.0.send.get().is_closed() {
755 Poll::Ready(Ok(()))
756 } else {
757 self.0.check_error()?;
758 self.0.con.check_error()?;
759 self.0.send_reset.register(cx.waker());
760 Poll::Pending
761 }
762 }
763}
764
765impl PartialEq for StreamRef {
766 fn eq(&self, other: &StreamRef) -> bool {
767 Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0)
768 }
769}
770
771impl ops::Deref for Stream {
772 type Target = StreamRef;
773
774 #[inline]
775 fn deref(&self) -> &Self::Target {
776 &self.0
777 }
778}
779
780impl Drop for Stream {
781 fn drop(&mut self) {
782 self.0.reset(Reason::CANCEL);
783 }
784}
785
786impl fmt::Debug for Stream {
787 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
788 let mut builder = f.debug_struct("Stream");
789 builder
790 .field("stream_id", &self.0.0.id)
791 .field("recv_state", &self.0.0.recv.get())
792 .field("send_state", &self.0.0.send.get())
793 .finish()
794 }
795}
796
797impl fmt::Debug for StreamState {
798 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
799 let mut builder = f.debug_struct("StreamState");
800 builder
801 .field("id", &self.id)
802 .field("recv", &self.recv.get())
803 .field("recv_window", &self.recv_window.get())
804 .field("recv_size", &self.recv_size.get())
805 .field("send", &self.send.get())
806 .field("send_window", &self.send_window.get())
807 .field("flags", &self.flags.get())
808 .finish()
809 }
810}
811
812pub(super) fn parse_u64(src: &[u8]) -> Option<u64> {
813 if src.len() > 19 {
814 None
816 } else {
817 let mut ret = 0;
818 for &d in src {
819 if !d.is_ascii_digit() {
820 return None;
821 }
822
823 ret *= 10;
824 ret += u64::from(d - b'0');
825 }
826
827 Some(ret)
828 }
829}