1use std::{cell::Cell, cmp, fmt, future::poll_fn, mem, ops, rc::Rc, task::Context, task::Poll};
2
3use ntex_bytes::Bytes;
4use ntex_http::{header::CONTENT_LENGTH, HeaderMap, StatusCode};
5use ntex_util::task::LocalWaker;
6
7use crate::error::{OperationError, StreamError};
8use crate::frame::{
9 Data, Headers, PseudoHeaders, Reason, Reset, StreamId, WindowSize, WindowUpdate,
10};
11use crate::{connection::Connection, frame, message::Message, window::Window};
12
13pub struct Stream(StreamRef);
15
16#[derive(Debug)]
18pub struct Capacity {
19 size: Cell<u32>,
20 stream: Rc<StreamState>,
21}
22
23impl Capacity {
24 fn new(size: u32, stream: &Rc<StreamState>) -> Self {
25 stream.add_recv_capacity(size);
26
27 Self {
28 size: Cell::new(size),
29 stream: stream.clone(),
30 }
31 }
32
33 #[inline]
34 pub fn size(&self) -> usize {
36 self.size.get() as usize
37 }
38
39 pub fn consume(&self, sz: u32) {
43 let size = self.size.get();
44 if let Some(sz) = size.checked_sub(sz) {
45 log::trace!(
46 "{}: {:?} capacity consumed from {} to {}",
47 self.stream.tag(),
48 self.stream.id,
49 size,
50 sz
51 );
52 self.size.set(sz);
53 self.stream.consume_capacity(size - sz);
54 } else {
55 panic!("Capacity overflow");
56 }
57 }
58}
59
60impl ops::Add for Capacity {
62 type Output = Self;
63
64 fn add(self, other: Self) -> Self {
65 if Rc::ptr_eq(&self.stream, &other.stream) {
66 let size = Cell::new(self.size.get() + other.size.get());
67 self.size.set(0);
68 other.size.set(0);
69 Self {
70 size,
71 stream: self.stream.clone(),
72 }
73 } else {
74 panic!("Cannot add capacity from different streams");
75 }
76 }
77}
78
79impl ops::AddAssign for Capacity {
81 fn add_assign(&mut self, other: Self) {
82 if Rc::ptr_eq(&self.stream, &other.stream) {
83 let size = self.size.get() + other.size.get();
84 self.size.set(size);
85 other.size.set(0);
86 } else {
87 panic!("Cannot add capacity from different streams");
88 }
89 }
90}
91
92impl Drop for Capacity {
93 fn drop(&mut self) {
94 let size = self.size.get();
95 if size > 0 {
96 self.stream.consume_capacity(size);
97 }
98 }
99}
100
101#[derive(Debug, Copy, Clone, PartialEq, Eq)]
103pub enum ContentLength {
104 Omitted,
105 Head,
106 Remaining(u64),
107}
108
109#[derive(Clone, Debug)]
110pub struct StreamRef(pub(crate) Rc<StreamState>);
111
112bitflags::bitflags! {
113 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
114 struct StreamFlags: u8 {
115 const REMOTE = 0b0000_0001;
116 const FAILED = 0b0000_0010;
117 const DISCONNECT_ON_DROP = 0b0000_0100;
118 const WAIT_FOR_CAPACITY = 0b0000_1000;
119 }
120}
121
122pub(crate) struct StreamState {
123 id: StreamId,
125 flags: Cell<StreamFlags>,
126 content_length: Cell<ContentLength>,
127 recv: Cell<HalfState>,
129 recv_window: Cell<Window>,
130 recv_size: Cell<u32>,
131 send: Cell<HalfState>,
133 send_window: Cell<Window>,
134 send_cap: LocalWaker,
135 send_reset: LocalWaker,
136 pub(crate) con: Connection,
138 error: Cell<Option<OperationError>>,
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub(crate) enum HalfState {
144 Idle,
145 Payload,
146 Closed(Option<Reason>),
147}
148
149impl HalfState {
150 pub(crate) fn is_closed(&self) -> bool {
151 matches!(self, HalfState::Closed(_))
152 }
153}
154
155impl StreamState {
156 fn tag(&self) -> &'static str {
157 self.con.tag()
158 }
159
160 fn state_send_payload(&self) {
161 self.send.set(HalfState::Payload);
162 }
163
164 fn state_send_close(&self, reason: Option<Reason>) {
165 log::trace!(
166 "{}: {:?} send side is closed with reason {:?}",
167 self.tag(),
168 self.id,
169 reason
170 );
171 self.send.set(HalfState::Closed(reason));
172 self.send_cap.wake();
173 self.review_state();
174 }
175
176 fn state_recv_payload(&self) {
177 self.recv.set(HalfState::Payload);
178 }
179
180 fn state_recv_close(&self, reason: Option<Reason>) {
181 log::trace!("{}: {:?} receive side is closed", self.tag(), self.id);
182 self.recv.set(HalfState::Closed(reason));
183 self.review_state();
184 }
185
186 fn reset_stream(&self, reason: Option<Reason>) {
187 self.recv.set(HalfState::Closed(reason));
188 self.send.set(HalfState::Closed(None));
189 if let Some(reason) = reason {
190 self.error.set(Some(OperationError::LocalReset(reason)));
191 }
192 self.review_state();
193 }
194
195 fn remote_reset_stream(&self, reason: Reason) {
196 self.recv.set(HalfState::Closed(None));
197 self.send.set(HalfState::Closed(Some(reason)));
198 self.error.set(Some(OperationError::RemoteReset(reason)));
199 self.review_state();
200 }
201
202 fn failed(&self, err: OperationError) {
203 if !self.recv.get().is_closed() {
204 self.recv.set(HalfState::Closed(None));
205 }
206 if !self.send.get().is_closed() {
207 self.send.set(HalfState::Closed(None));
208 }
209 self.error.set(Some(err));
210 self.insert_flag(StreamFlags::FAILED);
211 self.review_state();
212 }
213
214 fn insert_flag(&self, f: StreamFlags) {
215 let mut flags = self.flags.get();
216 flags.insert(f);
217 self.flags.set(flags);
218 }
219
220 fn remove_flag(&self, f: StreamFlags) {
221 let mut flags = self.flags.get();
222 flags.remove(f);
223 self.flags.set(flags);
224 }
225
226 fn check_error(&self) -> Result<(), OperationError> {
227 if let Some(err) = self.error.take() {
228 self.error.set(Some(err.clone()));
229 Err(err)
230 } else {
231 Ok(())
232 }
233 }
234
235 fn review_state(&self) {
236 if self.recv.get().is_closed() {
237 self.send_reset.wake();
238
239 if let HalfState::Closed(reason) = self.send.get() {
240 if let Some(reason) = reason {
242 log::trace!(
243 "{}: {:?} is closed with remote reset {:?}, dropping stream",
244 self.tag(),
245 self.id,
246 reason
247 );
248 } else {
249 log::trace!(
250 "{}: {:?} both sides are closed, dropping stream",
251 self.tag(),
252 self.id
253 );
254 }
255 self.send_cap.wake();
256 self.con.drop_stream(self.id);
257 }
258 }
259 }
260
261 fn add_recv_capacity(&self, size: u32) {
263 let cap = self.recv_size.get();
264 self.recv_size.set(cap + size);
265 self.recv_window.set(self.recv_window.get().dec(size));
266 log::trace!(
267 "{}: {:?} capacity incresed from {} to {}",
268 self.tag(),
269 self.id,
270 cap,
271 cap + size
272 );
273
274 self.con.add_recv_capacity(size);
276 }
277
278 fn consume_capacity(&self, size: u32) {
280 let cap = self.recv_size.get();
281 let size = cap - size;
282 log::trace!(
283 "{}: {:?} capacity decresed from {} to {}",
284 self.tag(),
285 self.id,
286 cap,
287 size
288 );
289
290 self.recv_size.set(size);
291 let mut window = self.recv_window.get();
292 if let Some(val) = window.update(
293 size,
294 self.con.config().window_sz.get(),
295 self.con.config().window_sz_threshold.get(),
296 ) {
297 log::trace!(
298 "{}: {:?} capacity decresed below threshold {} increase by {} ({})",
299 self.tag(),
300 self.id,
301 self.con.config().window_sz_threshold.get(),
302 val,
303 self.con.config().window_sz.get(),
304 );
305 self.recv_window.set(window);
306 self.con.encode(WindowUpdate::new(self.id, val));
307 }
308 }
309}
310
311impl StreamRef {
312 pub(crate) fn new(id: StreamId, remote: bool, con: Connection) -> Self {
313 let recv_window = if con.settings_processed() {
316 Window::new(con.config().window_sz.get() as i32)
317 } else {
318 Window::new(frame::DEFAULT_INITIAL_WINDOW_SIZE as i32)
319 };
320 let send_window = Window::new(con.remote_window_size() as i32);
321
322 StreamRef(Rc::new(StreamState {
323 id,
324 con,
325 recv: Cell::new(HalfState::Idle),
326 recv_window: Cell::new(recv_window),
327 recv_size: Cell::new(0),
328 send: Cell::new(HalfState::Idle),
329 send_window: Cell::new(send_window),
330 send_cap: LocalWaker::new(),
331 send_reset: LocalWaker::new(),
332 error: Cell::new(None),
333 content_length: Cell::new(ContentLength::Omitted),
334 flags: Cell::new(if remote {
335 StreamFlags::REMOTE
336 } else {
337 StreamFlags::empty()
338 }),
339 }))
340 }
341
342 #[inline]
343 pub fn id(&self) -> StreamId {
344 self.0.id
345 }
346
347 #[inline]
348 pub fn tag(&self) -> &'static str {
349 self.0.con.tag()
350 }
351
352 #[inline]
354 pub fn is_remote(&self) -> bool {
355 self.0.flags.get().contains(StreamFlags::REMOTE)
356 }
357
358 #[inline]
360 pub fn is_failed(&self) -> bool {
361 self.0.flags.get().contains(StreamFlags::FAILED)
362 }
363
364 pub(crate) fn send_state(&self) -> HalfState {
365 self.0.send.get()
366 }
367
368 pub(crate) fn recv_state(&self) -> HalfState {
369 self.0.recv.get()
370 }
371
372 pub(crate) fn disconnect_on_drop(&self) {
373 self.0.insert_flag(StreamFlags::DISCONNECT_ON_DROP);
374 }
375
376 pub(crate) fn is_disconnect_on_drop(&self) -> bool {
377 self.0.flags.get().contains(StreamFlags::DISCONNECT_ON_DROP)
378 }
379
380 #[inline]
382 pub fn reset(&self, reason: Reason) {
383 if !self.0.recv.get().is_closed() || !self.0.send.get().is_closed() {
384 self.0.con.encode(Reset::new(self.0.id, reason));
385 self.0.reset_stream(Some(reason));
386 }
387 }
388
389 #[inline]
391 pub fn empty_capacity(&self) -> Capacity {
392 Capacity {
393 size: Cell::new(0),
394 stream: self.0.clone(),
395 }
396 }
397
398 #[inline]
399 pub(crate) fn into_stream(self) -> Stream {
400 Stream(self)
401 }
402
403 pub(crate) fn send_headers(&self, mut hdrs: Headers) {
404 hdrs.set_end_headers();
405 if hdrs.is_end_stream() {
406 self.0.state_send_close(None);
407 } else {
408 self.0.state_send_payload();
409 }
410 log::trace!(
411 "{}: send headers {:#?} eos: {:?}",
412 self.tag(),
413 hdrs,
414 hdrs.is_end_stream()
415 );
416
417 if hdrs
418 .pseudo()
419 .status
420 .is_some_and(|status| status.is_informational())
421 {
422 self.0.content_length.set(ContentLength::Head)
423 }
424 self.0.con.encode(hdrs);
425 }
426
427 pub(crate) fn set_failed(&self, reason: Option<Reason>) {
428 self.0.reset_stream(reason);
429 }
430
431 pub(crate) fn set_go_away(&self, reason: Reason) {
432 self.0.remote_reset_stream(reason)
433 }
434
435 pub(crate) fn set_failed_stream(&self, err: OperationError) {
436 self.0.failed(err);
437 }
438
439 pub(crate) fn recv_headers(&self, hdrs: Headers) -> Result<Option<Message>, StreamError> {
440 log::trace!(
441 "{}: processing HEADERS for {:?}:\n{:#?}\nrecv_state:{:?}, send_state: {:?}",
442 self.tag(),
443 self.0.id,
444 hdrs,
445 self.0.recv.get(),
446 self.0.send.get(),
447 );
448
449 match self.0.recv.get() {
450 HalfState::Idle => {
451 let eof = hdrs.is_end_stream();
452 if eof {
453 self.0.state_recv_close(None);
454 } else {
455 self.0.state_recv_payload();
456 }
457 let (pseudo, headers) = hdrs.into_parts();
458
459 if self.0.content_length.get() != ContentLength::Head {
460 if let Some(content_length) = headers.get(CONTENT_LENGTH) {
461 if let Some(v) = parse_u64(content_length.as_bytes()) {
462 self.0.content_length.set(ContentLength::Remaining(v));
463 } else {
464 proto_err!(stream: "could not parse content-length; stream={:?}", self.0.id);
465 return Err(StreamError::InvalidContentLength);
466 }
467 }
468 }
469 Ok(Some(Message::new(pseudo, headers, eof, self)))
470 }
471 HalfState::Payload => {
472 if !hdrs.is_end_stream() {
474 Err(StreamError::TrailersWithoutEos)
475 } else {
476 self.0.state_recv_close(None);
477 Ok(Some(Message::trailers(hdrs.into_fields(), self)))
478 }
479 }
480 HalfState::Closed(_) => Err(StreamError::Closed),
481 }
482 }
483
484 pub(crate) fn recv_data(&self, data: Data) -> Result<Option<Message>, StreamError> {
485 let cap = Capacity::new(data.payload().len() as u32, &self.0);
486 log::trace!(
487 "{}: processing DATA frame for {:?}, len: {:?}",
488 self.tag(),
489 self.0.id,
490 data.payload().len()
491 );
492
493 match self.0.recv.get() {
494 HalfState::Payload => {
495 let eof = data.is_end_stream();
496
497 match self.0.content_length.get() {
499 ContentLength::Remaining(rem) => {
500 match rem.checked_sub(data.payload().len() as u64) {
501 Some(val) => {
502 self.0.content_length.set(ContentLength::Remaining(val));
503 if eof && val != 0 {
504 return Err(StreamError::WrongPayloadLength);
505 }
506 }
507 None => return Err(StreamError::WrongPayloadLength),
508 }
509 }
510 ContentLength::Head => {
511 if !data.payload().is_empty() {
512 return Err(StreamError::NonEmptyPayload);
513 }
514 }
515 _ => (),
516 }
517
518 if eof {
519 self.0.state_recv_close(None);
520 Ok(Some(Message::eof_data(data.into_payload(), self)))
521 } else {
522 Ok(Some(Message::data(data.into_payload(), cap, self)))
523 }
524 }
525 HalfState::Idle => Err(StreamError::Idle("DATA framed received")),
526 HalfState::Closed(_) => Err(StreamError::Closed),
527 }
528 }
529
530 pub(crate) fn recv_rst_stream(&self, frm: &Reset) {
531 self.0.remote_reset_stream(frm.reason())
532 }
533
534 pub(crate) fn recv_window_update_connection(&self) {
535 if self.0.flags.get().contains(StreamFlags::WAIT_FOR_CAPACITY)
536 && self.0.send_window.get().window_size() > 0
537 {
538 self.0.send_cap.wake();
539 }
540 }
541
542 pub(crate) fn recv_window_update(&self, frm: WindowUpdate) -> Result<(), StreamError> {
543 if frm.size_increment() == 0 {
544 Err(StreamError::WindowZeroUpdateValue)
545 } else {
546 let window = self
547 .0
548 .send_window
549 .get()
550 .inc(frm.size_increment())
551 .map_err(|_| StreamError::WindowOverflowed)?;
552 self.0.send_window.set(window);
553
554 if window.window_size() > 0 {
555 self.0.send_cap.wake();
556 }
557 Ok(())
558 }
559 }
560
561 pub(crate) fn update_send_window(&self, upd: i32) -> Result<(), StreamError> {
562 let orig = self.0.send_window.get();
563 let window = match upd.cmp(&0) {
564 cmp::Ordering::Less => orig.dec(upd.unsigned_abs()), cmp::Ordering::Greater => orig
566 .inc(upd as u32)
567 .map_err(|_| StreamError::WindowOverflowed)?,
568 cmp::Ordering::Equal => return Ok(()),
569 };
570 log::trace!(
571 "{}: Updating send window size from {} to {}",
572 self.tag(),
573 orig.window_size,
574 window.window_size
575 );
576 self.0.send_window.set(window);
577 Ok(())
578 }
579
580 pub(crate) fn update_recv_window(&self, upd: i32) -> Result<Option<WindowSize>, StreamError> {
581 let mut window = match upd.cmp(&0) {
582 cmp::Ordering::Less => self.0.recv_window.get().dec(upd.unsigned_abs()), cmp::Ordering::Greater => self
584 .0
585 .recv_window
586 .get()
587 .inc(upd as u32)
588 .map_err(|_| StreamError::WindowOverflowed)?,
589 cmp::Ordering::Equal => return Ok(None),
590 };
591 if let Some(val) = window.update(
592 self.0.recv_size.get(),
593 self.0.con.config().window_sz.get(),
594 self.0.con.config().window_sz_threshold.get(),
595 ) {
596 self.0.recv_window.set(window);
597 Ok(Some(val))
598 } else {
599 self.0.recv_window.set(window);
600 Ok(None)
601 }
602 }
603
604 pub fn send_response(
606 &self,
607 status: StatusCode,
608 headers: HeaderMap,
609 eof: bool,
610 ) -> Result<(), OperationError> {
611 self.0.check_error()?;
612
613 match self.0.send.get() {
614 HalfState::Idle => {
615 let pseudo = PseudoHeaders::response(status);
616 let mut hdrs = Headers::new(self.0.id, pseudo, headers, eof);
617
618 if eof {
619 hdrs.set_end_stream();
620 self.0.state_send_close(None);
621 } else {
622 self.0.state_send_payload();
623 }
624 self.0.con.encode(hdrs);
625 Ok(())
626 }
627 HalfState::Payload => Err(OperationError::Payload),
628 HalfState::Closed(r) => Err(OperationError::Closed(r)),
629 }
630 }
631
632 pub async fn send_payload(&self, mut res: Bytes, eof: bool) -> Result<(), OperationError> {
634 match self.0.send.get() {
635 HalfState::Payload => {
636 self.0.check_error()?;
638
639 log::trace!(
640 "{}: {:?} sending {} bytes, eof: {}, send: {:?}",
641 self.0.tag(),
642 self.0.id,
643 res.len(),
644 eof,
645 self.0.send.get()
646 );
647
648 if eof && res.is_empty() {
650 let mut data = Data::new(self.0.id, Bytes::new());
651 data.set_end_stream();
652 self.0.state_send_close(None);
653
654 self.0.con.encode(data);
656 return Ok(());
657 }
658
659 loop {
660 let win = self.available_send_capacity() as usize;
662 if win > 0 {
663 let size =
664 cmp::min(win, cmp::min(res.len(), self.0.con.remote_frame_size()));
665 let mut data = if size >= res.len() {
666 Data::new(self.0.id, mem::replace(&mut res, Bytes::new()))
667 } else {
668 log::trace!(
669 "{}: {:?} sending {} out of {} bytes",
670 self.0.tag(),
671 self.0.id,
672 size,
673 res.len()
674 );
675 Data::new(self.0.id, res.split_to(size))
676 };
677 if eof && res.is_empty() {
678 data.set_end_stream();
679 self.0.state_send_close(None);
680 }
681
682 self.0
684 .send_window
685 .set(self.0.send_window.get().dec(size as u32));
686
687 self.0.con.consume_send_window(size as u32);
689
690 self.0.con.encode(data);
692 if res.is_empty() {
693 return Ok(());
694 }
695 } else {
696 log::trace!(
697 "{}: Not enough sending capacity for {:?} remaining {:?}",
698 self.0.tag(),
699 self.0.id,
700 res.len()
701 );
702 self.send_capacity().await?;
704 }
705 }
706 }
707 HalfState::Idle => Err(OperationError::Idle),
708 HalfState::Closed(reason) => Err(OperationError::Closed(reason)),
709 }
710 }
711
712 pub fn send_trailers(&self, map: HeaderMap) {
714 if self.0.send.get() == HalfState::Payload {
715 let mut hdrs = Headers::trailers(self.0.id, map);
716 hdrs.set_end_headers();
717 hdrs.set_end_stream();
718 self.0.con.encode(hdrs);
719 self.0.state_send_close(None);
720 }
721 }
722
723 pub fn available_send_capacity(&self) -> WindowSize {
724 cmp::min(
725 self.0.send_window.get().window_size(),
726 self.0.con.send_window_size(),
727 )
728 }
729
730 pub async fn send_capacity(&self) -> Result<WindowSize, OperationError> {
731 poll_fn(|cx| self.poll_send_capacity(cx)).await
732 }
733
734 pub fn poll_send_capacity(&self, cx: &Context<'_>) -> Poll<Result<WindowSize, OperationError>> {
736 self.0.check_error()?;
737 self.0.con.check_error()?;
738
739 let win = self.available_send_capacity();
740 if win > 0 {
741 self.0.remove_flag(StreamFlags::WAIT_FOR_CAPACITY);
742 Poll::Ready(Ok(win))
743 } else {
744 self.0.insert_flag(StreamFlags::WAIT_FOR_CAPACITY);
745 self.0.send_cap.register(cx.waker());
746 Poll::Pending
747 }
748 }
749
750 pub fn poll_send_reset(&self, cx: &Context<'_>) -> Poll<Result<(), OperationError>> {
752 if self.0.send.get().is_closed() {
753 Poll::Ready(Ok(()))
754 } else {
755 self.0.check_error()?;
756 self.0.con.check_error()?;
757 self.0.send_reset.register(cx.waker());
758 Poll::Pending
759 }
760 }
761}
762
763impl PartialEq for StreamRef {
764 fn eq(&self, other: &StreamRef) -> bool {
765 Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0)
766 }
767}
768
769impl ops::Deref for Stream {
770 type Target = StreamRef;
771
772 #[inline]
773 fn deref(&self) -> &Self::Target {
774 &self.0
775 }
776}
777
778impl Drop for Stream {
779 fn drop(&mut self) {
780 self.0.reset(Reason::CANCEL);
781 }
782}
783
784impl fmt::Debug for Stream {
785 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
786 let mut builder = f.debug_struct("Stream");
787 builder
788 .field("stream_id", &self.0 .0.id)
789 .field("recv_state", &self.0 .0.recv.get())
790 .field("send_state", &self.0 .0.send.get())
791 .finish()
792 }
793}
794
795impl fmt::Debug for StreamState {
796 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
797 let mut builder = f.debug_struct("StreamState");
798 builder
799 .field("id", &self.id)
800 .field("recv", &self.recv.get())
801 .field("recv_window", &self.recv_window.get())
802 .field("recv_size", &self.recv_size.get())
803 .field("send", &self.send.get())
804 .field("send_window", &self.send_window.get())
805 .field("flags", &self.flags.get())
806 .finish()
807 }
808}
809
810pub fn parse_u64(src: &[u8]) -> Option<u64> {
811 if src.len() > 19 {
812 None
814 } else {
815 let mut ret = 0;
816 for &d in src {
817 if !d.is_ascii_digit() {
818 return None;
819 }
820
821 ret *= 10;
822 ret += (d - b'0') as u64;
823 }
824
825 Some(ret)
826 }
827}