1use amq_protocol::protocol::AMQPClass;
2use log::debug;
3use parking_lot::Mutex;
4
5use std::{
6 collections::HashMap,
7 sync::Arc,
8};
9
10use crate::{
11 BasicProperties, Channel, ChannelState, Error, ErrorKind,
12 connection::Connection,
13 id_sequence::IdSequence,
14};
15
16#[derive(Clone, Debug, Default)]
17pub(crate) struct Channels {
18 inner: Arc<Mutex<Inner>>,
19}
20
21impl Channels {
22 pub(crate) fn create(&self, connection: Connection) -> Result<Channel, Error> {
23 self.inner.lock().create(connection)
24 }
25
26 pub(crate) fn create_zero(&self, connection: Connection) {
27 self.inner.lock().create_channel(0, connection).set_state(ChannelState::Connected);
28 }
29
30 pub(crate) fn get(&self, id: u16) -> Option<Channel> {
31 self.inner.lock().channels.get(&id).cloned()
32 }
33
34 pub(crate) fn remove(&self, id: u16) -> Result<(), Error> {
35 if self.inner.lock().channels.remove(&id).is_some() {
36 Ok(())
37 } else {
38 Err(ErrorKind::InvalidChannel(id).into())
39 }
40 }
41
42 pub(crate) fn receive_method(&self, id: u16, method: AMQPClass) -> Result<(), Error> {
43 if let Some(channel) = self.get(id) {
44 channel.receive_method(method)
45 } else {
46 Err(ErrorKind::InvalidChannel(id).into())
47 }
48 }
49
50 pub(crate) fn handle_content_header_frame(&self, id: u16, size: u64, properties: BasicProperties) -> Result<(), Error> {
51 if let Some(channel) = self.get(id) {
52 channel.handle_content_header_frame(size, properties)
53 } else {
54 Err(ErrorKind::InvalidChannel(id).into())
55 }
56 }
57
58 pub(crate) fn handle_body_frame(&self, id: u16, payload: Vec<u8>) -> Result<(), Error> {
59 if let Some(channel) = self.get(id) {
60 channel.handle_body_frame(payload)
61 } else {
62 Err(ErrorKind::InvalidChannel(id).into())
63 }
64 }
65
66 pub(crate) fn set_closing(&self) {
67 for channel in self.inner.lock().channels.values() {
68 channel.set_closing();
69 }
70 }
71
72 pub(crate) fn set_closed(&self) -> Result<(), Error> {
73 for (_, channel) in self.inner.lock().channels.drain() {
74 channel.set_state(ChannelState::Closed);
75 }
76 Ok(())
77 }
78
79 pub(crate) fn set_error(&self) -> Result<(), Error> {
80 for channel in self.inner.lock().channels.values() {
81 channel.set_error()?;
82 }
83 Ok(())
84 }
85
86 pub(crate) fn flow(&self) -> bool {
87 self.inner.lock().channels.values().all(|c| c.status().flow())
88 }
89}
90
91#[derive(Debug)]
92struct Inner {
93 channels: HashMap<u16, Channel>,
94 channel_id: IdSequence<u16>,
95}
96
97impl Default for Inner {
98 fn default() -> Self {
99 Self {
100 channels: HashMap::default(),
101 channel_id: IdSequence::new(false),
102 }
103 }
104}
105
106impl Inner {
107 fn create_channel(&mut self, id: u16, connection: Connection) -> Channel {
108 debug!("create channel with id {}", id);
109 let channel = Channel::new(id, connection);
110 self.channels.insert(id, channel.clone());
111 channel
112 }
113
114 fn create(&mut self, connection: Connection) -> Result<Channel, Error> {
115 debug!("create channel");
116 self.channel_id.set_max(connection.configuration().channel_max());
117 let first_id = self.channel_id.next();
118 let mut looped = false;
119 let mut id = first_id;
120 while !looped || id < first_id {
121 if id == 1 {
122 looped = true;
123 }
124 if !self.channels.contains_key(&id) {
125 return Ok(self.create_channel(id, connection))
126 }
127 id = self.channel_id.next();
128 }
129 Err(ErrorKind::ChannelLimitReached.into())
130 }
131}