partition_sim/
supervisor.rs1use openssh::Session;
2use std::collections::HashMap;
3use std::env::var;
4use std::process::Output;
5use tokio::sync::mpsc::Receiver;
6use uuid::Uuid;
7
8use crate::commands::{Command, Commands, SshCommands};
9use crate::peer::Peer;
10
11#[derive(Debug, Default)]
12pub struct Supervisor {
13 peers: HashMap<Uuid, Peer>,
14 peer_ids: Vec<Uuid>,
15 path_to_key: String,
16}
17
18pub type Message = (Uuid, Commands);
19pub type Request<I, O> = (I, tokio::sync::oneshot::Sender<O>);
20
21impl Supervisor {
22 pub fn get_peer_ids(&self) -> &[Uuid] {
23 &self.peer_ids
24 }
25 pub fn get_peer(&self, peer_id: Uuid) -> crate::Result<&Peer> {
26 self.peers
27 .get(&peer_id)
28 .ok_or(crate::Error::PeerNotFound(peer_id))
29 }
30 pub fn get_peer_mut(&mut self, peer_id: Uuid) -> crate::Result<&mut Peer> {
31 self.peers
32 .get_mut(&peer_id)
33 .ok_or(crate::Error::PeerNotFound(peer_id))
34 }
35
36 fn copy_id(&self, peer_id: Uuid) -> crate::Result<()> {
37 let peer = self.peers.get(&peer_id).unwrap();
38 let mut command = SshCommands::CopyId {
39 ip_addr: peer.ip_addr,
40 path_to_key: self.path_to_key.clone(),
41 }
42 .build();
43
44 let output = command.output()?;
45 tracing::info!(
46 "sshpass stdout: {}",
47 String::from_utf8(output.stdout).unwrap()
48 );
49
50 if !output.status.success() {
51 tracing::error!(
52 "sshpass stderr: {}",
53 String::from_utf8(output.stderr).unwrap()
54 );
55 return Err(crate::Error::SshCopyIdFailed);
56 }
57 Ok(())
58 }
59
60 pub fn set_up_ssh(&self) -> crate::Result<()> {
61 for peer in self.peers.values() {
62 self.copy_id(peer.id)?;
63 }
64 Ok(())
65 }
66
67 pub fn new(peers: Vec<Peer>) -> Self {
68 let peer_ids = peers.iter().map(|peer| peer.id).collect::<Vec<_>>();
69
70 let hmap = peers
71 .into_iter()
72 .map(|peer| (peer.id, peer))
73 .collect::<HashMap<_, _>>();
74
75 let home = var("HOME").unwrap_or_else(|_| "/root".into());
76 let path_to_key = format!("{}/.ssh/id_ed25519.pub", home);
77
78 Self {
79 peers: hmap,
80 peer_ids,
81 path_to_key,
82 }
83 }
84
85 pub fn with_key(mut self, path_to_key: &str) -> Self {
86 self.path_to_key = path_to_key.into();
87 self
88 }
89
90 pub async fn connect(&mut self, peer_id: Uuid) -> crate::Result<()> {
91 if let Some(peer) = self.peers.get_mut(&peer_id) {
92 peer.connect().await?;
93 Ok(())
94 } else {
95 Err(crate::errors::PartitionSimError::PeerNotFound(peer_id))
96 }
97 }
98
99 pub async fn connect_all(&mut self) -> crate::Result<()> {
100 for peer_id in self.peer_ids.clone().iter() {
101 self.connect(*peer_id).await?;
102 }
103 Ok(())
104 }
105
106 fn get_session(&self, peer_id: Uuid) -> crate::Result<&Session> {
107 self.peers
108 .get(&peer_id)
109 .and_then(|peer| peer.session.as_ref())
110 .ok_or(crate::errors::PartitionSimError::SessionUninitialized)
111 }
112
113 pub async fn execute(
114 &mut self,
115 peer_id: Uuid,
116 command: impl Into<Commands>,
117 ) -> crate::Result<Output> {
118 self.connect(peer_id).await?;
119 let session = self.get_session(peer_id)?;
120 Ok(command.into().build(session).output().await?)
121 }
122
123 pub async fn run(
124 mut self,
125 mut commands_rx: Receiver<Request<Message, Output>>,
126 ) -> crate::Result<()> {
127 while let Some((msg, result_tx)) = commands_rx.recv().await {
128 let (peer_id, command) = msg;
129 self.connect(peer_id).await?;
130 let session = self.get_session(peer_id)?;
131 let output = command.build(session).output().await?;
132 result_tx.send(output)?;
133 }
134 Ok(())
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 use std::env;
143 use tokio::sync::mpsc::channel;
144
145 #[tokio::test]
146 async fn test_supervisor_new() {
147 let peer1 = Peer::new(
148 "192.168.1.137".parse().unwrap(),
149 Some("pi"),
150 Some(env::var("SSH_KEYFILE").unwrap().as_str()),
151 );
152 let peer2 = Peer::new(
153 "192.168.1.137".parse().unwrap(),
154 Some("pi"),
155 Some(env::var("SSH_KEYFILE").unwrap().as_str()),
156 );
157
158 let peers = vec![peer1, peer2];
159 let mut supervisor = Supervisor::new(peers).with_key("/home/infinity/.ssh/id_ed25519.pub");
160 supervisor.connect_all().await.unwrap();
161
162 let (tx, rx) = channel(10);
163
164 let (request_tx, response_rx) = tokio::sync::oneshot::channel();
165
166 let peer_ids = supervisor.get_peer_ids().to_vec();
167
168 let t1 = tokio::spawn(async move {
169 tx.send((
170 (
171 *peer_ids.get(0).unwrap(),
172 Commands::IpTables(crate::commands::IpTablesCommands::Get),
173 ),
174 request_tx,
175 ))
176 .await
177 .unwrap();
178 });
179 let t2 = tokio::spawn(async move {
180 let resp = response_rx.await.unwrap();
181 println!("Response: {:?}", resp);
182 });
183 let t3 = supervisor.run(rx);
184 let _ = tokio::join!(t1, t2, t3);
185 }
186}