1use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::mem::replace;
19use std::num::Wrapping;
20
21use byteorder::{BigEndian, ByteOrder};
22use log::{debug, trace};
23use ssh_encoding::Encode;
24use tokio::sync::oneshot;
25
26use crate::cipher::OpeningKey;
27use crate::client::GexParams;
28use crate::kex::dh::groups::DhGroup;
29use crate::kex::{KexAlgorithm, KexAlgorithmImplementor};
30use crate::sshbuffer::PacketWriter;
31use crate::{
32 ChannelId, ChannelParams, CryptoVec, Disconnect, Limits, auth, cipher, mac, msg, negotiation,
33};
34
35#[derive(Debug)]
36pub(crate) struct Encrypted {
37 pub state: EncryptedState,
38
39 pub exchange: Option<Exchange>,
41 pub kex: KexAlgorithm,
42 pub key: usize,
43 pub client_mac: mac::Name,
44 pub server_mac: mac::Name,
45 pub session_id: CryptoVec,
46 pub channels: HashMap<ChannelId, ChannelParams>,
47 pub last_channel_id: Wrapping<u32>,
48 pub write: Vec<u8>,
52 pub write_cursor: usize,
53 pub last_rekey: russh_util::time::Instant,
54 pub server_compression: crate::compression::Compression,
55 pub client_compression: crate::compression::Compression,
56 pub decompress: crate::compression::Decompress,
57 pub rekey_wanted: bool,
58 pub received_extensions: Vec<String>,
59 pub extension_info_awaiters: HashMap<String, Vec<oneshot::Sender<()>>>,
60}
61
62pub(crate) struct CommonSession<Config> {
63 pub auth_user: String,
64 pub remote_sshid: Vec<u8>,
65 pub config: Config,
66 pub encrypted: Option<Encrypted>,
67 pub auth_method: Option<auth::Method>,
68 #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
69 pub(crate) auth_attempts: usize,
70 pub packet_writer: PacketWriter,
71 pub remote_to_local: Box<dyn OpeningKey + Send>,
72 pub wants_reply: bool,
73 pub disconnected: bool,
74 pub buffer: Vec<u8>,
76 pub strict_kex: bool,
77 pub alive_timeouts: usize,
78 pub received_data: bool,
79}
80
81impl<C> Debug for CommonSession<C> {
82 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("CommonSession")
84 .field("auth_user", &self.auth_user)
85 .field("remote_sshid", &self.remote_sshid)
86 .field("encrypted", &self.encrypted)
87 .field("auth_method", &self.auth_method)
88 .field("auth_attempts", &self.auth_attempts)
89 .field("packet_writer", &self.packet_writer)
90 .field("wants_reply", &self.wants_reply)
91 .field("disconnected", &self.disconnected)
92 .field("buffer", &self.buffer)
93 .field("strict_kex", &self.strict_kex)
94 .field("alive_timeouts", &self.alive_timeouts)
95 .field("received_data", &self.received_data)
96 .finish()
97 }
98}
99
100#[must_use]
101#[derive(Debug, Clone, Copy)]
102pub(crate) enum ChannelFlushResult {
103 Incomplete {
104 wrote: usize,
105 },
106 Complete {
107 wrote: usize,
108 pending_eof: bool,
109 pending_close: bool,
110 },
111}
112impl ChannelFlushResult {
113 pub(crate) fn wrote(&self) -> usize {
114 match self {
115 ChannelFlushResult::Incomplete { wrote } => *wrote,
116 ChannelFlushResult::Complete { wrote, .. } => *wrote,
117 }
118 }
119 pub(crate) fn complete(wrote: usize, channel: &mut ChannelParams) -> Self {
120 let (pending_eof, pending_close) = channel.take_pending_controls();
121 ChannelFlushResult::Complete {
122 wrote,
123 pending_eof,
124 pending_close,
125 }
126 }
127}
128
129impl<C> CommonSession<C> {
130 pub fn newkeys(&mut self, newkeys: NewKeys) {
131 if let Some(ref mut enc) = self.encrypted {
132 enc.exchange = Some(newkeys.exchange);
133 enc.kex = newkeys.kex;
134 enc.key = newkeys.key;
135 enc.client_mac = newkeys.names.client_mac;
136 enc.server_mac = newkeys.names.server_mac;
137 self.remote_to_local = newkeys.cipher.remote_to_local;
138 self.packet_writer
139 .set_cipher(newkeys.cipher.local_to_remote);
140 self.strict_kex = self.strict_kex || newkeys.names.strict_kex();
141
142 enc.client_compression
144 .init_compress(self.packet_writer.compress());
145 enc.server_compression.init_decompress(&mut enc.decompress);
146 }
147 }
148
149 pub fn encrypted(&mut self, state: EncryptedState, newkeys: NewKeys) {
150 let strict_kex = newkeys.names.strict_kex();
151 self.encrypted = Some(Encrypted {
152 exchange: Some(newkeys.exchange),
153 kex: newkeys.kex,
154 key: newkeys.key,
155 client_mac: newkeys.names.client_mac,
156 server_mac: newkeys.names.server_mac,
157 session_id: newkeys.session_id,
158 state,
159 channels: HashMap::new(),
160 last_channel_id: Wrapping(1),
161 write: Vec::new(),
162 write_cursor: 0,
163 last_rekey: russh_util::time::Instant::now(),
164 server_compression: newkeys.names.server_compression,
165 client_compression: newkeys.names.client_compression,
166 decompress: crate::compression::Decompress::None,
167 rekey_wanted: false,
168 received_extensions: Vec::new(),
169 extension_info_awaiters: HashMap::new(),
170 });
171 self.remote_to_local = newkeys.cipher.remote_to_local;
172 self.packet_writer
173 .set_cipher(newkeys.cipher.local_to_remote);
174 self.strict_kex = strict_kex;
175
176 if let Some(ref mut enc) = self.encrypted {
180 if !enc.client_compression.is_deferred() {
181 enc.client_compression
182 .init_compress(self.packet_writer.compress());
183 }
184 if !enc.server_compression.is_deferred() {
185 enc.server_compression
186 .init_decompress(&mut enc.decompress);
187 }
188 }
189 }
190
191 pub fn disconnect(
193 &mut self,
194 reason: Disconnect,
195 description: &str,
196 language_tag: &str,
197 ) -> Result<(), crate::Error> {
198 let disconnect = |buf: &mut Vec<u8>| {
199 push_packet!(buf, {
200 msg::DISCONNECT.encode(buf)?;
201 (reason as u32).encode(buf)?;
202 description.encode(buf)?;
203 language_tag.encode(buf)?;
204 });
205 Ok(())
206 };
207 if !self.disconnected {
208 self.disconnected = true;
209 return if let Some(ref mut enc) = self.encrypted {
210 disconnect(&mut enc.write)
211 } else {
212 disconnect(&mut self.packet_writer.buffer().buffer)
213 };
214 }
215 Ok(())
216 }
217
218 pub fn debug(
220 &mut self,
221 always_display: bool,
222 message: &str,
223 language_tag: &str,
224 ) -> Result<(), crate::Error> {
225 let debug = |buf: &mut Vec<u8>| {
226 push_packet!(buf, {
227 msg::DEBUG.encode(buf)?;
228 (always_display as u8).encode(buf)?;
229 message.encode(buf)?;
230 language_tag.encode(buf)?;
231 });
232 Ok(())
233 };
234 if let Some(ref mut enc) = self.encrypted {
235 debug(&mut enc.write)
236 } else {
237 debug(&mut self.packet_writer.buffer().buffer)
238 }
239 }
240
241 pub(crate) fn reset_seqn(&mut self) {
242 self.packet_writer.reset_seqn();
243 }
244}
245
246impl Encrypted {
247 pub fn byte(&mut self, channel: ChannelId, msg: u8) -> Result<(), crate::Error> {
248 if let Some(channel) = self.channels.get(&channel) {
249 push_packet!(self.write, {
250 self.write.push(msg);
251 channel.recipient_channel.encode(&mut self.write)?;
252 });
253 }
254 Ok(())
255 }
256
257 pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> {
258 if let Some(channel) = self.has_pending_data_mut(channel) {
259 channel.pending_eof = true;
260 } else {
261 self.byte(channel, msg::CHANNEL_EOF)?;
262 }
263 Ok(())
264 }
265
266 pub fn close(&mut self, channel: ChannelId) -> Result<(), crate::Error> {
267 if let Some(channel) = self.has_pending_data_mut(channel) {
268 channel.pending_close = true;
269 } else {
270 self.byte(channel, msg::CHANNEL_CLOSE)?;
271 self.channels.remove(&channel);
272 }
273 Ok(())
274 }
275
276 pub fn sender_window_size(&self, channel: ChannelId) -> usize {
277 if let Some(channel) = self.channels.get(&channel) {
278 channel.sender_window_size as usize
279 } else {
280 0
281 }
282 }
283
284 pub fn adjust_window_size(
285 &mut self,
286 channel: ChannelId,
287 data: &[u8],
288 target: u32,
289 ) -> Result<bool, crate::Error> {
290 if let Some(channel) = self.channels.get_mut(&channel) {
291 trace!(
292 "adjust_window_size, channel = {}, size = {},",
293 channel.sender_channel, target
294 );
295 if data.len() as u32 <= channel.sender_window_size {
298 channel.sender_window_size -= data.len() as u32;
299 }
300 if channel.sender_window_size < target / 2 {
301 debug!(
302 "sender_window_size {:?}, target {:?}",
303 channel.sender_window_size, target
304 );
305 push_packet!(self.write, {
306 self.write.push(msg::CHANNEL_WINDOW_ADJUST);
307 channel.recipient_channel.encode(&mut self.write)?;
308 (target - channel.sender_window_size).encode(&mut self.write)?;
309 });
310 channel.sender_window_size = target;
311 return Ok(true);
312 }
313 }
314 Ok(false)
315 }
316
317 fn flush_channel(
318 write: &mut Vec<u8>,
319 channel: &mut ChannelParams,
320 ) -> Result<ChannelFlushResult, crate::Error> {
321 let mut pending_size = 0;
322 while let Some((buf, a, from)) = channel.pending_data.pop_front() {
323 let size = Self::data_noqueue(write, channel, &buf, a, from)?;
324 pending_size += size;
325 if from + size < buf.len() {
326 channel.pending_data.push_front((buf, a, from + size));
327 return Ok(ChannelFlushResult::Incomplete {
328 wrote: pending_size,
329 });
330 }
331 }
332 Ok(ChannelFlushResult::complete(pending_size, channel))
333 }
334
335 fn handle_flushed_channel(
336 &mut self,
337 channel: ChannelId,
338 flush_result: ChannelFlushResult,
339 ) -> Result<(), crate::Error> {
340 if let ChannelFlushResult::Complete {
341 wrote: _,
342 pending_eof,
343 pending_close,
344 } = flush_result
345 {
346 if pending_eof {
347 self.eof(channel)?;
348 }
349 if pending_close {
350 self.close(channel)?;
351 }
352 }
353 Ok(())
354 }
355
356 pub fn flush_pending(&mut self, channel: ChannelId) -> Result<usize, crate::Error> {
357 let flush_result = match self.channels.get_mut(&channel) {
358 Some(ch) => Self::flush_channel(&mut self.write, ch)?,
359 None => return Ok(0),
360 };
361 let wrote = flush_result.wrote();
362 self.handle_flushed_channel(channel, flush_result)?;
363 Ok(wrote)
364 }
365
366 pub fn flush_all_pending(&mut self) -> Result<(), crate::Error> {
367 let channel_ids: Vec<ChannelId> = self.channels.keys().copied().collect();
368 for channel_id in channel_ids {
369 self.flush_pending(channel_id)?;
370 }
371 Ok(())
372 }
373
374 fn has_pending_data_mut(&mut self, channel: ChannelId) -> Option<&mut ChannelParams> {
375 self.channels
376 .get_mut(&channel)
377 .filter(|c| !c.pending_data.is_empty())
378 }
379
380 pub fn has_pending_data(&self, channel: ChannelId) -> bool {
381 if let Some(channel) = self.channels.get(&channel) {
382 !channel.pending_data.is_empty()
383 } else {
384 false
385 }
386 }
387
388 fn data_noqueue(
392 write: &mut Vec<u8>,
393 channel: &mut ChannelParams,
394 buf0: &[u8],
395 a: Option<u32>,
396 from: usize,
397 ) -> Result<usize, crate::Error> {
398 if from >= buf0.len() {
399 return Ok(0);
400 }
401 let mut buf = if buf0.len() as u32 > from as u32 + channel.recipient_window_size {
402 #[allow(clippy::indexing_slicing)] &buf0[from..from + channel.recipient_window_size as usize]
404 } else {
405 #[allow(clippy::indexing_slicing)] &buf0[from..]
407 };
408 let buf_len = buf.len();
409
410 while !buf.is_empty() {
411 let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize);
413 match a {
414 None => push_packet!(write, {
415 write.push(msg::CHANNEL_DATA);
416 channel.recipient_channel.encode(write)?;
417 #[allow(clippy::indexing_slicing)] buf[..off].encode(write)?;
419 }),
420 Some(ext) => push_packet!(write, {
421 write.push(msg::CHANNEL_EXTENDED_DATA);
422 channel.recipient_channel.encode(write)?;
423 ext.encode(write)?;
424 #[allow(clippy::indexing_slicing)] buf[..off].encode(write)?;
426 }),
427 }
428 trace!(
429 "buffer: {:?} {:?}",
430 write.len(),
431 channel.recipient_window_size
432 );
433 channel.recipient_window_size -= off as u32;
434 #[allow(clippy::indexing_slicing)] {
436 buf = &buf[off..]
437 }
438 }
439 trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len);
440 Ok(buf_len)
441 }
442
443 pub fn data(
444 &mut self,
445 channel: ChannelId,
446 buf0: impl Into<bytes::Bytes>,
447 is_rekeying: bool,
448 ) -> Result<(), crate::Error> {
449 let buf0 = buf0.into();
450 if let Some(channel) = self.channels.get_mut(&channel) {
451 assert!(channel.confirmed);
452 if !channel.pending_data.is_empty() && is_rekeying {
453 channel.pending_data.push_back((buf0, None, 0));
454 return Ok(());
455 }
456 let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0)?;
457 if buf_len < buf0.len() {
458 channel.pending_data.push_back((buf0, None, buf_len))
459 }
460 } else {
461 debug!("{channel:?} not saved for this session");
462 }
463 Ok(())
464 }
465
466 pub fn extended_data(
467 &mut self,
468 channel: ChannelId,
469 ext: u32,
470 buf0: impl Into<bytes::Bytes>,
471 is_rekeying: bool,
472 ) -> Result<(), crate::Error> {
473 let buf0 = buf0.into();
474 if let Some(channel) = self.channels.get_mut(&channel) {
475 assert!(channel.confirmed);
476 if !channel.pending_data.is_empty() && is_rekeying {
477 channel.pending_data.push_back((buf0, Some(ext), 0));
478 return Ok(());
479 }
480 let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0)?;
481 if buf_len < buf0.len() {
482 channel.pending_data.push_back((buf0, Some(ext), buf_len))
483 }
484 }
485 Ok(())
486 }
487
488 pub fn flush(
489 &mut self,
490 limits: &Limits,
491 writer: &mut PacketWriter,
492 ) -> Result<bool, crate::Error> {
493 {
495 while self.write_cursor < self.write.len() {
496 #[allow(clippy::indexing_slicing)] let len = BigEndian::read_u32(&self.write[self.write_cursor..]) as usize;
499 #[allow(clippy::indexing_slicing)]
500 let to_write = &self.write[(self.write_cursor + 4)..(self.write_cursor + 4 + len)];
501 trace!("session_write_encrypted, buf = {to_write:?}");
502
503 writer.packet_raw(to_write)?;
504 self.write_cursor += 4 + len
505 }
506 }
507 if self.write_cursor >= self.write.len() {
508 self.write_cursor = 0;
510 self.write.clear();
511 }
512
513 if self.kex.skip_exchange() {
514 return Ok(false);
515 }
516
517 let now = russh_util::time::Instant::now();
518 let dur = now.duration_since(self.last_rekey);
519 Ok(replace(&mut self.rekey_wanted, false)
520 || writer.buffer().bytes >= limits.rekey_write_limit
521 || dur >= limits.rekey_time_limit)
522 }
523
524 pub fn new_channel_id(&mut self) -> ChannelId {
525 self.last_channel_id += Wrapping(1);
526 while self
527 .channels
528 .contains_key(&ChannelId(self.last_channel_id.0))
529 {
530 self.last_channel_id += Wrapping(1)
531 }
532 ChannelId(self.last_channel_id.0)
533 }
534 pub fn new_channel(&mut self, window_size: u32, maxpacket: u32) -> ChannelId {
535 loop {
536 self.last_channel_id += Wrapping(1);
537 if let std::collections::hash_map::Entry::Vacant(vacant_entry) =
538 self.channels.entry(ChannelId(self.last_channel_id.0))
539 {
540 vacant_entry.insert(ChannelParams {
541 recipient_channel: 0,
542 sender_channel: ChannelId(self.last_channel_id.0),
543 sender_window_size: window_size,
544 recipient_window_size: 0,
545 sender_maximum_packet_size: maxpacket,
546 recipient_maximum_packet_size: 0,
547 confirmed: false,
548 wants_reply: false,
549 pending_data: std::collections::VecDeque::new(),
550 pending_eof: false,
551 pending_close: false,
552 });
553 return ChannelId(self.last_channel_id.0);
554 }
555 }
556 }
557}
558
559#[derive(Debug)]
560pub enum EncryptedState {
561 WaitingAuthServiceRequest { sent: bool, accepted: bool },
562 WaitingAuthRequest(auth::AuthRequest),
563 InitCompression,
564 Authenticated,
565}
566
567#[derive(Debug, Default, Clone)]
568pub struct Exchange {
569 pub client_id: Vec<u8>,
573 pub server_id: Vec<u8>,
574 pub client_kex_init: Vec<u8>,
575 pub server_kex_init: Vec<u8>,
576 pub client_ephemeral: Vec<u8>,
577 pub server_ephemeral: Vec<u8>,
578 pub gex: Option<(GexParams, DhGroup)>,
579}
580
581impl Exchange {
582 pub fn new(client_id: &[u8], server_id: &[u8]) -> Self {
583 Exchange {
584 client_id: client_id.into(),
585 server_id: server_id.into(),
586 ..Default::default()
587 }
588 }
589}
590
591#[derive(Debug)]
592pub(crate) struct NewKeys {
593 pub exchange: Exchange,
594 pub names: negotiation::Names,
595 pub kex: KexAlgorithm,
596 pub key: usize,
597 pub cipher: cipher::CipherPair,
598 pub session_id: CryptoVec,
599}
600
601#[derive(Debug)]
602pub(crate) enum GlobalRequestResponse {
603 Keepalive,
605 Ping(oneshot::Sender<()>),
607 NoMoreSessions,
609 TcpIpForward(oneshot::Sender<Option<u32>>),
611 CancelTcpIpForward(oneshot::Sender<bool>),
613 StreamLocalForward(oneshot::Sender<bool>),
615 CancelStreamLocalForward(oneshot::Sender<bool>),
616}
617
618#[cfg(test)]
619mod tests {
620 use std::collections::{HashMap, VecDeque};
621 use std::num::Wrapping;
622
623 use byteorder::{BigEndian, ByteOrder};
624 use bytes::Bytes;
625
626 use super::{Encrypted, EncryptedState, Exchange};
627 use crate::compression::{Compression, Decompress};
628 use crate::kex::{KEXES, NONE};
629 use crate::{ChannelId, ChannelParams, CryptoVec, mac, msg};
630
631 fn test_encrypted() -> Encrypted {
632 Encrypted {
633 state: EncryptedState::Authenticated,
634 exchange: Some(Exchange::default()),
635 kex: KEXES.get(&NONE).unwrap().make(),
636 key: 0,
637 client_mac: mac::NONE,
638 server_mac: mac::NONE,
639 session_id: CryptoVec::new(),
640 channels: HashMap::new(),
641 last_channel_id: Wrapping(0),
642 write: Vec::new(),
643 write_cursor: 0,
644 last_rekey: russh_util::time::Instant::now(),
645 server_compression: Compression::None,
646 client_compression: Compression::None,
647 decompress: Decompress::None,
648 rekey_wanted: false,
649 received_extensions: Vec::new(),
650 extension_info_awaiters: HashMap::new(),
651 }
652 }
653
654 fn test_channel(
655 sender_channel: ChannelId,
656 recipient_channel: u32,
657 pending_eof: bool,
658 pending_close: bool,
659 ) -> ChannelParams {
660 ChannelParams {
661 recipient_channel,
662 sender_channel,
663 recipient_window_size: 1024,
664 sender_window_size: 1024,
665 recipient_maximum_packet_size: 1024,
666 sender_maximum_packet_size: 1024,
667 confirmed: true,
668 wants_reply: false,
669 pending_data: VecDeque::from([(Bytes::from_static(b"hello"), None, 0)]),
670 pending_eof,
671 pending_close,
672 }
673 }
674
675 fn packet_types(buf: &[u8]) -> Vec<u8> {
676 let mut packet_types = Vec::new();
677 let mut cursor = 0;
678
679 while cursor < buf.len() {
680 let packet_len = BigEndian::read_u32(&buf[cursor..cursor + 4]) as usize;
681 packet_types.push(buf[cursor + 4]);
682 cursor += 4 + packet_len;
683 }
684
685 packet_types
686 }
687
688 fn test_channel_windowed(
689 sender_channel: ChannelId,
690 recipient_channel: u32,
691 window_size: u32,
692 pending_eof: bool,
693 pending_close: bool,
694 ) -> ChannelParams {
695 ChannelParams {
696 recipient_channel,
697 sender_channel,
698 recipient_window_size: window_size,
699 sender_window_size: 1024,
700 recipient_maximum_packet_size: 1024,
701 sender_maximum_packet_size: 1024,
702 confirmed: true,
703 wants_reply: false,
704 pending_data: VecDeque::from([(Bytes::from_static(b"hello"), None, 0)]),
705 pending_eof,
706 pending_close,
707 }
708 }
709
710 #[test]
713 fn flush_pending_replays_deferred_eof_once() {
714 let channel_id = ChannelId(10);
715 let mut encrypted = test_encrypted();
716 encrypted
717 .channels
718 .insert(channel_id, test_channel(channel_id, 42, true, false));
719
720 encrypted.flush_pending(channel_id).unwrap();
721 assert_eq!(
722 packet_types(&encrypted.write),
723 vec![msg::CHANNEL_DATA, msg::CHANNEL_EOF]
724 );
725 assert!(!encrypted.channels[&channel_id].pending_eof);
726
727 encrypted.flush_pending(channel_id).unwrap();
729 assert_eq!(
730 packet_types(&encrypted.write),
731 vec![msg::CHANNEL_DATA, msg::CHANNEL_EOF]
732 );
733 }
734
735 #[test]
736 fn flush_pending_replays_deferred_close_and_removes_channel() {
737 let channel_id = ChannelId(11);
738 let mut encrypted = test_encrypted();
739 encrypted
740 .channels
741 .insert(channel_id, test_channel(channel_id, 43, true, true));
742
743 encrypted.flush_pending(channel_id).unwrap();
744 assert_eq!(
745 packet_types(&encrypted.write),
746 vec![msg::CHANNEL_DATA, msg::CHANNEL_EOF, msg::CHANNEL_CLOSE]
747 );
748 assert!(!encrypted.channels.contains_key(&channel_id));
749 }
750
751 #[test]
752 fn flush_pending_no_controls_when_incomplete() {
753 let channel_id = ChannelId(12);
755 let mut encrypted = test_encrypted();
756 encrypted.channels.insert(
757 channel_id,
758 test_channel_windowed(channel_id, 44, 3, true, true),
759 );
760
761 encrypted.flush_pending(channel_id).unwrap();
762 assert_eq!(packet_types(&encrypted.write), vec![msg::CHANNEL_DATA]);
764 assert!(encrypted.channels.contains_key(&channel_id));
765 assert!(encrypted.channels[&channel_id].pending_eof);
766 assert!(encrypted.channels[&channel_id].pending_close);
767 }
768
769 #[test]
772 fn flush_all_pending_replays_deferred_eof_once() {
773 let channel_id = ChannelId(1);
774 let mut encrypted = test_encrypted();
775 encrypted
776 .channels
777 .insert(channel_id, test_channel(channel_id, 42, true, false));
778
779 encrypted.flush_all_pending().unwrap();
780 assert_eq!(
781 packet_types(&encrypted.write),
782 vec![msg::CHANNEL_DATA, msg::CHANNEL_EOF]
783 );
784 assert!(!encrypted.channels[&channel_id].pending_eof);
785
786 encrypted.flush_all_pending().unwrap();
787 assert_eq!(
788 packet_types(&encrypted.write),
789 vec![msg::CHANNEL_DATA, msg::CHANNEL_EOF]
790 );
791 }
792
793 #[test]
794 fn flush_all_pending_replays_deferred_close_and_removes_channel() {
795 let channel_id = ChannelId(2);
796 let mut encrypted = test_encrypted();
797 encrypted
798 .channels
799 .insert(channel_id, test_channel(channel_id, 43, true, true));
800
801 encrypted.flush_all_pending().unwrap();
802 assert_eq!(
803 packet_types(&encrypted.write),
804 vec![msg::CHANNEL_DATA, msg::CHANNEL_EOF, msg::CHANNEL_CLOSE]
805 );
806 assert!(!encrypted.channels.contains_key(&channel_id));
807 }
808
809 #[test]
810 fn flush_all_pending_handles_multiple_channels_independently() {
811 let eof_only = ChannelId(3);
812 let close_too = ChannelId(4);
813 let mut encrypted = test_encrypted();
814 encrypted
815 .channels
816 .insert(eof_only, test_channel(eof_only, 50, true, false));
817 encrypted
818 .channels
819 .insert(close_too, test_channel(close_too, 51, true, true));
820
821 encrypted.flush_all_pending().unwrap();
822
823 assert!(encrypted.channels.contains_key(&eof_only));
825 assert!(!encrypted.channels[&eof_only].pending_eof);
826
827 assert!(!encrypted.channels.contains_key(&close_too));
829
830 let types = packet_types(&encrypted.write);
832 assert_eq!(types.iter().filter(|&&t| t == msg::CHANNEL_DATA).count(), 2);
833 assert_eq!(types.iter().filter(|&&t| t == msg::CHANNEL_EOF).count(), 2);
834 assert_eq!(
835 types.iter().filter(|&&t| t == msg::CHANNEL_CLOSE).count(),
836 1
837 );
838 }
839}