1use log::{info, trace, warn};
2use response::{Response, ResponseBuilder};
3use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
4use tokio::net::{TcpStream, ToSocketAddrs};
5use tokio::sync::broadcast::Sender;
6use tokio::sync::mpsc::Receiver;
7use tokio::sync::{broadcast, mpsc, oneshot};
8
9mod response;
10
11#[derive(Debug, Clone, PartialOrd, PartialEq)]
15pub struct Tag {
16 pub key: String,
17 pub value: String,
18}
19
20impl Tag {
21 pub fn of(key: String, value: String) -> Self {
22 Self { key, value }
23 }
24
25 pub fn from(key: &str, value: &str) -> Self {
26 Self {
27 key: key.to_string(),
28 value: value.to_string(),
29 }
30 }
31}
32
33pub type Packet = Vec<Tag>;
35
36pub type Responder<T> = oneshot::Sender<T>;
38
39#[derive(Debug)]
42struct Command {
43 packet: Packet,
44 resp: Responder<Vec<Packet>>,
45}
46
47pub struct AmiConnection {
48 cmd_tx: mpsc::Sender<Command>,
49 events_tx: broadcast::Sender<Option<Packet>>,
50}
51
52impl AmiConnection {
53 pub async fn connect<A: ToSocketAddrs + std::fmt::Debug>(
59 server: A,
60 ) -> Result<AmiConnection, std::io::Error> {
61 let reader = Self::connect_to_server(server).await?;
62
63 let (cmd_tx, cmd_rx) = mpsc::channel::<Command>(32);
64 let (events_tx, _) = broadcast::channel::<Option<Packet>>(32);
65
66 let events_tx2 = events_tx.clone();
67
68 tokio::spawn(async move {
69 Self::handle_server_connection(reader, cmd_rx, events_tx2).await;
70 });
71
72 Ok(AmiConnection { cmd_tx, events_tx })
73 }
74
75 async fn handle_server_connection(
76 mut server_connection: BufReader<TcpStream>,
77 mut command_channel_rx: Receiver<Command>,
78 event_channel_tx: Sender<Option<Packet>>,
79 ) {
80 let mut current_command: Option<Command> = None;
81 let mut response_builder = ResponseBuilder::new();
82 let mut line = String::new();
83 let mut maybe_response: Option<Response> = None;
84 loop {
85 if current_command.is_none() {
86 tokio::select! {
87 bytes_read = server_connection.read_line(&mut line) => {
88 match bytes_read {
89 Err(e) => {
90 warn!("Error reading from server connection: {:?}", e);
91 break;
92 }
93 Ok(0) => {
94 trace!("Server connection closed");
95 break;
96 }
97 Ok(_) => {
98 maybe_response = response_builder.add_line(line.trim());
99 }
100 }
101 }
102
103 cmd = command_channel_rx.recv() => {
104 if let Some(c) = cmd {
105 let chunk = format!("{}\r\n\r\n", packet_to_string(&c.packet));
106 current_command = Some(c);
107 if let Err(e) = server_connection.write_all(chunk.as_bytes()).await {
108 warn!("Error writing to server connection: {:?}", e);
109 break;
110 }
111 } else {
112 trace!("Channel has been closed");
113 break;
114 }
115 }
116 }
117 } else {
118 tokio::select! {
119 bytes_read = server_connection.read_line(&mut line) => {
120 match bytes_read {
121 Err(e) => {
122 warn!("Error reading from server connection: {:?}", e);
123 break;
124 }
125 Ok(0) => {
126 trace!("Server connection closed");
127 break;
128 }
129 Ok(_) => {
130 maybe_response = response_builder.add_line(line.trim());
131 }
132 }
133 }
134 }
135 }
136
137 if let Some(resp) = maybe_response {
138 match resp {
139 Response::Event(pkt) => {
140 if !Self::publish_event(&event_channel_tx, Some(pkt)) {
141 break;
142 }
143 }
144 Response::CommandResponse(cr) => {
145 if let Some(cmd) = current_command {
146 current_command = None;
147 if let Err(e) = cmd.resp.send(cr) {
148 warn!(
149 "Cannot send command response back: {:?}",
150 e
151 );
152 break;
153 }
154 }
155 }
156 }
157 }
158 maybe_response = None;
159 line.clear();
160 }
161
162 trace!("Packet passing loop ended! Publishing 'None' event");
163 Self::publish_event(&event_channel_tx, None);
164
165 trace!("Closing command channel");
166 command_channel_rx.close();
167 if let Some(cmd) = current_command {
168 info!("There was a running command on closed connection: {:?}", cmd);
169 if let Err(e) = cmd.resp.send(vec![]) {
170 warn!("Cannot terminate current command on close: {:?}", e);
171 }
172 }
173 }
174
175 fn publish_event(
176 event_channel_tx: &Sender<Option<Packet>>,
177 pkt: Option<Packet>,
178 ) -> bool {
179 if event_channel_tx.receiver_count() > 0 {
180 if let Err(e) = event_channel_tx.send(pkt) {
181 warn!("Could not send event to subscribers: {:?}", e);
182 return false;
183 }
184 }
185 true
186 }
187
188 async fn connect_to_server<A: ToSocketAddrs + std::fmt::Debug>(
189 server: A,
190 ) -> Result<BufReader<TcpStream>, std::io::Error> {
191 trace!("Connecting to {:?}", server);
192 let mut reader = BufReader::new(TcpStream::connect(server).await?);
193 Self::read_greeting(&mut reader).await?;
194 Ok(reader)
195 }
196
197 async fn read_greeting(
198 reader: &mut BufReader<TcpStream>,
199 ) -> Result<(), std::io::Error> {
200 let mut greeting = String::new();
201 reader.read_line(&mut greeting).await?;
202
203 Ok(())
204 }
205
206 pub async fn send(&self, pkt: Packet) -> Option<Vec<Packet>> {
217 let (tx, rx) = oneshot::channel();
218 self.cmd_tx
219 .send(Command {
220 packet: pkt,
221 resp: tx,
222 })
223 .await
224 .ok()?;
225 rx.await.ok()
226 }
227
228 pub fn events(&self) -> broadcast::Receiver<Option<Packet>> {
229 self.events_tx.subscribe()
230 }
231}
232
233pub fn find_tag<'a>(pkt: &'a Packet, key: &str) -> Option<&'a String> {
240 pkt.iter()
241 .find(|&tag| tag.key.eq_ignore_ascii_case(key))
242 .map(|t| &t.value)
243}
244
245fn packet_to_string(pkt: &Packet) -> String {
246 pkt.iter()
247 .map(|Tag { key, value }| format!("{}: {}", key, value))
248 .collect::<Vec<String>>()
249 .join("\r\n")
250}
251
252#[cfg(test)]
253mod tests {
254 #[test]
255 fn it_works() {
256 assert_eq!(2 + 2, 4);
257 }
258}