partition_sim/
supervisor.rs

1use openssh::Session;
2use std::collections::HashMap;
3use std::env::var;
4use std::process::Output;
5use tokio::sync::mpsc::Receiver;
6use uuid::Uuid;
7
8use crate::commands::{Command, Commands, SshCommands};
9use crate::peer::Peer;
10
11#[derive(Debug, Default)]
12pub struct Supervisor {
13    peers: HashMap<Uuid, Peer>,
14    peer_ids: Vec<Uuid>,
15    path_to_key: String,
16}
17
18pub type Message = (Uuid, Commands);
19pub type Request<I, O> = (I, tokio::sync::oneshot::Sender<O>);
20
21impl Supervisor {
22    pub fn get_peer_ids(&self) -> &[Uuid] {
23        &self.peer_ids
24    }
25    pub fn get_peer(&self, peer_id: Uuid) -> crate::Result<&Peer> {
26        self.peers
27            .get(&peer_id)
28            .ok_or(crate::Error::PeerNotFound(peer_id))
29    }
30    pub fn get_peer_mut(&mut self, peer_id: Uuid) -> crate::Result<&mut Peer> {
31        self.peers
32            .get_mut(&peer_id)
33            .ok_or(crate::Error::PeerNotFound(peer_id))
34    }
35
36    fn copy_id(&self, peer_id: Uuid) -> crate::Result<()> {
37        let peer = self.peers.get(&peer_id).unwrap();
38        let mut command = SshCommands::CopyId {
39            ip_addr: peer.ip_addr,
40            path_to_key: self.path_to_key.clone(),
41        }
42        .build();
43
44        let output = command.output()?;
45        tracing::info!(
46            "sshpass stdout: {}",
47            String::from_utf8(output.stdout).unwrap()
48        );
49
50        if !output.status.success() {
51            tracing::error!(
52                "sshpass stderr: {}",
53                String::from_utf8(output.stderr).unwrap()
54            );
55            return Err(crate::Error::SshCopyIdFailed);
56        }
57        Ok(())
58    }
59
60    pub fn set_up_ssh(&self) -> crate::Result<()> {
61        for peer in self.peers.values() {
62            self.copy_id(peer.id)?;
63        }
64        Ok(())
65    }
66
67    pub fn new(peers: Vec<Peer>) -> Self {
68        let peer_ids = peers.iter().map(|peer| peer.id).collect::<Vec<_>>();
69
70        let hmap = peers
71            .into_iter()
72            .map(|peer| (peer.id, peer))
73            .collect::<HashMap<_, _>>();
74
75        let home = var("HOME").unwrap_or_else(|_| "/root".into());
76        let path_to_key = format!("{}/.ssh/id_ed25519.pub", home);
77
78        Self {
79            peers: hmap,
80            peer_ids,
81            path_to_key,
82        }
83    }
84
85    pub fn with_key(mut self, path_to_key: &str) -> Self {
86        self.path_to_key = path_to_key.into();
87        self
88    }
89
90    pub async fn connect(&mut self, peer_id: Uuid) -> crate::Result<()> {
91        if let Some(peer) = self.peers.get_mut(&peer_id) {
92            peer.connect().await?;
93            Ok(())
94        } else {
95            Err(crate::errors::PartitionSimError::PeerNotFound(peer_id))
96        }
97    }
98
99    pub async fn connect_all(&mut self) -> crate::Result<()> {
100        for peer_id in self.peer_ids.clone().iter() {
101            self.connect(*peer_id).await?;
102        }
103        Ok(())
104    }
105
106    fn get_session(&self, peer_id: Uuid) -> crate::Result<&Session> {
107        self.peers
108            .get(&peer_id)
109            .and_then(|peer| peer.session.as_ref())
110            .ok_or(crate::errors::PartitionSimError::SessionUninitialized)
111    }
112
113    pub async fn execute(
114        &mut self,
115        peer_id: Uuid,
116        command: impl Into<Commands>,
117    ) -> crate::Result<Output> {
118        self.connect(peer_id).await?;
119        let session = self.get_session(peer_id)?;
120        Ok(command.into().build(session).output().await?)
121    }
122
123    pub async fn run(
124        mut self,
125        mut commands_rx: Receiver<Request<Message, Output>>,
126    ) -> crate::Result<()> {
127        while let Some((msg, result_tx)) = commands_rx.recv().await {
128            let (peer_id, command) = msg;
129            self.connect(peer_id).await?;
130            let session = self.get_session(peer_id)?;
131            let output = command.build(session).output().await?;
132            result_tx.send(output)?;
133        }
134        Ok(())
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    use std::env;
143    use tokio::sync::mpsc::channel;
144
145    #[tokio::test]
146    async fn test_supervisor_new() {
147        let peer1 = Peer::new(
148            "192.168.1.137".parse().unwrap(),
149            Some("pi"),
150            Some(env::var("SSH_KEYFILE").unwrap().as_str()),
151        );
152        let peer2 = Peer::new(
153            "192.168.1.137".parse().unwrap(),
154            Some("pi"),
155            Some(env::var("SSH_KEYFILE").unwrap().as_str()),
156        );
157
158        let peers = vec![peer1, peer2];
159        let mut supervisor = Supervisor::new(peers).with_key("/home/infinity/.ssh/id_ed25519.pub");
160        supervisor.connect_all().await.unwrap();
161
162        let (tx, rx) = channel(10);
163
164        let (request_tx, response_rx) = tokio::sync::oneshot::channel();
165
166        let peer_ids = supervisor.get_peer_ids().to_vec();
167
168        let t1 = tokio::spawn(async move {
169            tx.send((
170                (
171                    *peer_ids.get(0).unwrap(),
172                    Commands::IpTables(crate::commands::IpTablesCommands::Get),
173                ),
174                request_tx,
175            ))
176            .await
177            .unwrap();
178        });
179        let t2 = tokio::spawn(async move {
180            let resp = response_rx.await.unwrap();
181            println!("Response: {:?}", resp);
182        });
183        let t3 = supervisor.run(rx);
184        let _ = tokio::join!(t1, t2, t3);
185    }
186}