leftwm_core/utils/
state_socket.rs

1use crate::errors::{LeftError, Result};
2use crate::models::dto::ManagerState;
3use crate::models::Handle;
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::fs;
7use tokio::io::AsyncWriteExt;
8use tokio::net::{UnixListener, UnixStream};
9use tokio::sync::Mutex;
10
11#[derive(Debug, Default)]
12struct State {
13    peers: Vec<Option<UnixStream>>,
14    last_state: String,
15}
16
17#[derive(Debug, Default)]
18pub struct StateSocket {
19    state: Arc<Mutex<State>>,
20    listener: Option<tokio::task::JoinHandle<()>>,
21    socket_file: PathBuf,
22}
23
24impl Drop for StateSocket {
25    fn drop(&mut self) {
26        assert!(
27            std::thread::panicking() || self.listener.is_none(),
28            "StateSocket has to be shutdown explicitly before drop"
29        );
30    }
31}
32
33impl StateSocket {
34    /// Bind to Unix socket and listen.
35    /// # Errors
36    ///
37    /// Will error if `build_listener()` cannot be unwrapped or awaited.
38    /// As in `build_listener()`, this is likely a filesystem issue,
39    /// such as incorrect permissions or a non-existant file.
40    pub async fn listen(&mut self, socket_file: PathBuf) -> Result<()> {
41        self.socket_file = socket_file;
42        let listener = self.build_listener().await?;
43        self.listener = Some(listener);
44        Ok(())
45    }
46
47    /// Explicitly shutdown `StateSocket` to perform cleanup.
48    pub async fn shutdown(&mut self) {
49        if let Some(listener) = self.listener.take() {
50            listener.abort();
51            listener.await.ok();
52            fs::remove_file(self.socket_file.as_path()).await.ok();
53        }
54    }
55
56    /// # Errors
57    /// Will return Err if a mut ref to the peer is unavailable.
58    /// Will return error if state cannot be serialized
59    pub async fn write_manager_state<H: Handle>(
60        &mut self,
61        raw_state: &crate::state::State<H>,
62    ) -> Result<()> {
63        if self.listener.is_some() {
64            let state: ManagerState = raw_state.into();
65            let mut json = serde_json::to_string(&state)?;
66            json.push('\n');
67            let mut state = self.state.lock().await;
68
69            let state_changed = json != state.last_state;
70            if state_changed {
71                state.peers.retain(std::option::Option::is_some);
72                for peer in &mut state.peers {
73                    if peer
74                        .as_mut()
75                        .ok_or(LeftError::StreamError)?
76                        .write_all(json.as_bytes())
77                        .await
78                        .is_err()
79                    {
80                        peer.take();
81                    }
82                }
83                state.last_state = json;
84            }
85        }
86        Ok(())
87    }
88
89    async fn build_listener(&self) -> Result<tokio::task::JoinHandle<()>> {
90        let state = self.state.clone();
91        let listener = if let Ok(m) = UnixListener::bind(&self.socket_file) {
92            m
93        } else {
94            fs::remove_file(&self.socket_file).await?;
95            UnixListener::bind(&self.socket_file)?
96        };
97
98        Ok(tokio::spawn(async move {
99            loop {
100                match listener.accept().await {
101                    Ok((mut peer, _)) => {
102                        let mut state = state.lock().await;
103                        if peer.write_all(state.last_state.as_bytes()).await.is_ok() {
104                            state.peers.push(Some(peer));
105                        }
106                    }
107                    Err(e) => tracing::error!("Accept failed = {:?}", e),
108                }
109            }
110        }))
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use super::*;
117    use crate::utils::helpers::test::temp_path;
118    use crate::Manager;
119    use tokio::io::{AsyncBufReadExt, BufReader};
120
121    #[tokio::test]
122    async fn multiple_peers() {
123        let manager = Manager::new_test(vec![]);
124        let state = &manager.state;
125
126        let socket_file = temp_path().await.unwrap();
127        let mut state_socket = StateSocket::default();
128        state_socket.listen(socket_file.clone()).await.unwrap();
129        state_socket.write_manager_state(state).await.unwrap();
130
131        assert_eq!(
132            serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
133            BufReader::new(UnixStream::connect(socket_file.clone()).await.unwrap())
134                .lines()
135                .next_line()
136                .await
137                .expect("Read next line")
138                .unwrap()
139        );
140
141        assert_eq!(
142            serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
143            BufReader::new(UnixStream::connect(socket_file.clone()).await.unwrap())
144                .lines()
145                .next_line()
146                .await
147                .expect("Read next line")
148                .unwrap()
149        );
150
151        assert_eq!(
152            serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
153            BufReader::new(UnixStream::connect(socket_file).await.unwrap())
154                .lines()
155                .next_line()
156                .await
157                .expect("Read next line")
158                .unwrap()
159        );
160
161        state_socket.shutdown().await;
162    }
163
164    #[tokio::test]
165    async fn get_update() {
166        let manager = Manager::new_test(vec![]);
167        let state = &manager.state;
168
169        let socket_file = temp_path().await.unwrap();
170        let mut state_socket = StateSocket::default();
171        state_socket.listen(socket_file.clone()).await.unwrap();
172        state_socket.write_manager_state(state).await.unwrap();
173
174        let mut lines = BufReader::new(UnixStream::connect(socket_file).await.unwrap()).lines();
175
176        assert_eq!(
177            serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
178            lines.next_line().await.expect("Read next line").unwrap()
179        );
180
181        // Fake state update.
182        state_socket.state.lock().await.last_state = String::default();
183        state_socket.write_manager_state(state).await.unwrap();
184
185        assert_eq!(
186            serde_json::to_string(&Into::<ManagerState>::into(state)).unwrap(),
187            lines.next_line().await.expect("Read next line").unwrap()
188        );
189
190        state_socket.shutdown().await;
191    }
192
193    #[tokio::test]
194    async fn socket_cleanup() {
195        let socket_file = temp_path().await.unwrap();
196        let mut state_socket = StateSocket::default();
197        state_socket.listen(socket_file.clone()).await.unwrap();
198        state_socket.shutdown().await;
199        assert!(!socket_file.exists());
200    }
201
202    #[tokio::test]
203    async fn socket_already_bound() {
204        let socket_file = temp_path().await.unwrap();
205        let mut old_socket = StateSocket::default();
206        old_socket.listen(socket_file.clone()).await.unwrap();
207        assert!(socket_file.exists());
208        let mut state_socket = StateSocket::default();
209        state_socket.listen(socket_file.clone()).await.unwrap();
210        state_socket.shutdown().await;
211        assert!(!socket_file.exists());
212        old_socket.shutdown().await;
213    }
214}