leftwm_core/utils/
state_socket.rs1use crate::errors::{LeftError, Result};
2use crate::models::dto::ManagerState;
3use crate::models::Handle;
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::fs;
7use tokio::io::AsyncWriteExt;
8use tokio::net::{UnixListener, UnixStream};
9use tokio::sync::Mutex;
10
11#[derive(Debug, Default)]
12struct State {
13 peers: Vec<Option<UnixStream>>,
14 last_state: String,
15}
16
17#[derive(Debug, Default)]
18pub struct StateSocket {
19 state: Arc<Mutex<State>>,
20 listener: Option<tokio::task::JoinHandle<()>>,
21 socket_file: PathBuf,
22}
23
24impl Drop for StateSocket {
25 fn drop(&mut self) {
26 assert!(
27 std::thread::panicking() || self.listener.is_none(),
28 "StateSocket has to be shutdown explicitly before drop"
29 );
30 }
31}
32
33impl StateSocket {
34 pub async fn listen(&mut self, socket_file: PathBuf) -> Result<()> {
41 self.socket_file = socket_file;
42 let listener = self.build_listener().await?;
43 self.listener = Some(listener);
44 Ok(())
45 }
46
47 pub async fn shutdown(&mut self) {
49 if let Some(listener) = self.listener.take() {
50 listener.abort();
51 listener.await.ok();
52 fs::remove_file(self.socket_file.as_path()).await.ok();
53 }
54 }
55
56 pub async fn write_manager_state<H: Handle>(
60 &mut self,
61 raw_state: &crate::state::State<H>,
62 ) -> Result<()> {
63 if self.listener.is_some() {
64 let state: ManagerState = raw_state.into();
65 let mut json = serde_json::to_string(&state)?;
66 json.push('\n');
67 let mut state = self.state.lock().await;
68
69 let state_changed = json != state.last_state;
70 if state_changed {
71 state.peers.retain(std::option::Option::is_some);
72 for peer in &mut state.peers {
73 if peer
74 .as_mut()
75 .ok_or(LeftError::StreamError)?
76 .write_all(json.as_bytes())
77 .await
78 .is_err()
79 {
80 peer.take();
81 }
82 }
83 state.last_state = json;
84 }
85 }
86 Ok(())
87 }
88
89 async fn build_listener(&self) -> Result<tokio::task::JoinHandle<()>> {
90 let state = self.state.clone();
91 let listener = if let Ok(m) = UnixListener::bind(&self.socket_file) {
92 m
93 } else {
94 fs::remove_file(&self.socket_file).await?;
95 UnixListener::bind(&self.socket_file)?
96 };
97
98 Ok(tokio::spawn(async move {
99 loop {
100 match listener.accept().await {
101 Ok((mut peer, _)) => {
102 let mut state = state.lock().await;
103 if peer.write_all(state.last_state.as_bytes()).await.is_ok() {
104 state.peers.push(Some(peer));
105 }
106 }
107 Err(e) => tracing::error!("Accept failed = {:?}", e),
108 }
109 }
110 }))
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use super::*;
117 use crate::utils::helpers::test::temp_path;
118 use crate::Manager;
119 use tokio::io::{AsyncBufReadExt, BufReader};
120
121 #[tokio::test]
122 async fn multiple_peers() {
123 let manager = Manager::new_test(vec![]);
124 let state = &manager.state;
125
126 let socket_file = temp_path().await.unwrap();
127 let mut state_socket = StateSocket::default();
128 state_socket.listen(socket_file.clone()).await.unwrap();
129 state_socket.write_manager_state(state).await.unwrap();
130
131 assert_eq!(
132 serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
133 BufReader::new(UnixStream::connect(socket_file.clone()).await.unwrap())
134 .lines()
135 .next_line()
136 .await
137 .expect("Read next line")
138 .unwrap()
139 );
140
141 assert_eq!(
142 serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
143 BufReader::new(UnixStream::connect(socket_file.clone()).await.unwrap())
144 .lines()
145 .next_line()
146 .await
147 .expect("Read next line")
148 .unwrap()
149 );
150
151 assert_eq!(
152 serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
153 BufReader::new(UnixStream::connect(socket_file).await.unwrap())
154 .lines()
155 .next_line()
156 .await
157 .expect("Read next line")
158 .unwrap()
159 );
160
161 state_socket.shutdown().await;
162 }
163
164 #[tokio::test]
165 async fn get_update() {
166 let manager = Manager::new_test(vec![]);
167 let state = &manager.state;
168
169 let socket_file = temp_path().await.unwrap();
170 let mut state_socket = StateSocket::default();
171 state_socket.listen(socket_file.clone()).await.unwrap();
172 state_socket.write_manager_state(state).await.unwrap();
173
174 let mut lines = BufReader::new(UnixStream::connect(socket_file).await.unwrap()).lines();
175
176 assert_eq!(
177 serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
178 lines.next_line().await.expect("Read next line").unwrap()
179 );
180
181 state_socket.state.lock().await.last_state = String::default();
183 state_socket.write_manager_state(state).await.unwrap();
184
185 assert_eq!(
186 serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
187 lines.next_line().await.expect("Read next line").unwrap()
188 );
189
190 state_socket.shutdown().await;
191 }
192
193 #[tokio::test]
194 async fn socket_cleanup() {
195 let socket_file = temp_path().await.unwrap();
196 let mut state_socket = StateSocket::default();
197 state_socket.listen(socket_file.clone()).await.unwrap();
198 state_socket.shutdown().await;
199 assert!(!socket_file.exists());
200 }
201
202 #[tokio::test]
203 async fn socket_already_bound() {
204 let socket_file = temp_path().await.unwrap();
205 let mut old_socket = StateSocket::default();
206 old_socket.listen(socket_file.clone()).await.unwrap();
207 assert!(socket_file.exists());
208 let mut state_socket = StateSocket::default();
209 state_socket.listen(socket_file.clone()).await.unwrap();
210 state_socket.shutdown().await;
211 assert!(!socket_file.exists());
212 old_socket.shutdown().await;
213 }
214}