1use crate::{Builder, Inner, IrohSsh, api::ClientOptions};
2use std::process::Stdio;
3
4use anyhow::bail;
5use ed25519_dalek::SECRET_KEY_LENGTH;
6use homedir::my_home;
7use iroh::{
8 Endpoint, NodeId, SecretKey, Watcher,
9 endpoint::Connection,
10 protocol::{ProtocolHandler, Router},
11};
12use tokio::{
13 net::{TcpListener, TcpStream},
14 process::{Child, Command},
15};
16
17impl Builder {
18 pub fn new() -> Self {
19 Self {
20 secret_key: SecretKey::generate(rand::rngs::OsRng).to_bytes(),
21 accept_incoming: false,
22 accept_port: None,
23 }
24 }
25
26 pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
27 self.accept_incoming = accept_incoming;
28 self
29 }
30
31 pub fn accept_port(mut self, accept_port: u16) -> Self {
32 self.accept_port = Some(accept_port);
33 self
34 }
35
36 pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
37 self.secret_key = *secret_key;
38 self
39 }
40
41 pub fn dot_ssh_integration(mut self, persist: bool, service: bool) -> Self {
42 if let Ok(secret_key) = dot_ssh(&SecretKey::from_bytes(&self.secret_key), persist, service)
43 {
44 self.secret_key = secret_key.to_bytes();
45 }
46 self
47 }
48
49 pub async fn build(&mut self) -> anyhow::Result<IrohSsh> {
50 let secret_key = SecretKey::from_bytes(&self.secret_key);
52 let endpoint = Endpoint::builder()
53 .secret_key(secret_key)
54 .discovery_n0()
55 .bind()
56 .await?;
57
58 let _ = endpoint.home_relay().initialized().await?;
59
60 let mut iroh_ssh = IrohSsh {
61 public_key: *endpoint.node_id().as_bytes(),
62 secret_key: self.secret_key,
63 inner: None,
64 ssh_port: self.accept_port.unwrap_or(22),
65 };
66
67 let router = if self.accept_incoming {
68 Router::builder(endpoint.clone()).accept(IrohSsh::ALPN(), iroh_ssh.clone())
69 } else {
70 Router::builder(endpoint.clone())
71 }
72 .spawn();
73
74 iroh_ssh.add_inner(endpoint, router);
75
76 Ok(iroh_ssh)
77 }
78}
79
80impl Default for Builder {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl IrohSsh {
87 pub fn builder() -> Builder {
88 Builder::new()
89 }
90
91 #[allow(non_snake_case)]
92 pub fn ALPN() -> Vec<u8> {
93 b"/iroh/ssh".to_vec()
94 }
95
96 fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
97 self.inner = Some(Inner { endpoint, router });
98 }
99
100 pub async fn connect(
101 &self,
102 ssh_user: &str,
103 node_id: NodeId,
104 client_options: ClientOptions,
105 execute_command: Vec<String>,
106 ) -> anyhow::Result<Child> {
107 let inner = self.inner.as_ref().expect("inner not set");
108 let conn = inner.endpoint.connect(node_id, &IrohSsh::ALPN()).await?;
109 let listener = TcpListener::bind("127.0.0.1:0").await?;
110 let port = listener.local_addr()?.port();
111
112 tokio::spawn(async move {
113 loop {
114 let _ = handle_next(&listener, &conn).await;
115 }
116
117 async fn handle_next(
118 listener: &TcpListener,
119 conn: &Connection,
120 ) -> Result<(), std::io::Error> {
121 let (mut stream, _) = listener.accept().await?;
122 let (mut iroh_send, mut iroh_recv) = conn.open_bi().await?;
123 tokio::spawn(async move {
124 let (mut local_read, mut local_write) = stream.split();
125 let a_to_b =
126 async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
127 let b_to_a =
128 async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
129
130 tokio::select! {
131 result = a_to_b => {
132 let _ = result;
133 },
134 result = b_to_a => {
135 let _ = result;
136 },
137 };
138 });
139 Ok(())
140 }
141 });
142 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
143 let mut cmd = &mut Command::new("ssh");
144 cmd = cmd
145 .arg("-tt") .arg(format!("{ssh_user}@127.0.0.1"))
147 .arg("-p")
148 .arg(port.to_string())
149 .arg("-o")
150 .arg("StrictHostKeyChecking=no")
151 .arg("-o")
152 .arg("UserKnownHostsFile=/dev/null")
153 .arg("-o")
154 .arg("LogLevel=ERROR"); if let Some(identity_file) = client_options.identity_file {
157 cmd.arg("-i").arg(identity_file);
158 }
159 if let Some(local_forward) = client_options.local_forward {
160 cmd.arg("-L").arg(local_forward);
161 }
162 if let Some(remote_forward) = client_options.remote_forward {
163 cmd.arg("-R").arg(remote_forward);
164 }
165 cmd.arg(execute_command.join(" "));
166
167 let ssh_process = cmd
168 .stdin(Stdio::inherit())
169 .stdout(Stdio::inherit())
170 .stderr(Stdio::inherit())
171 .spawn()?;
172
173 Ok(ssh_process)
174 }
175
176 pub fn node_id(&self) -> NodeId {
177 self.inner
178 .as_ref()
179 .expect("inner not set")
180 .endpoint
181 .node_id()
182 }
183}
184
185impl ProtocolHandler for IrohSsh {
186 async fn accept(&self, connection: Connection) -> Result<(), iroh::protocol::AcceptError> {
187 let alpn = connection
188 .alpn()
189 .ok_or_else(|| iroh::protocol::AcceptError::NotAllowed {})?;
190 if alpn != IrohSsh::ALPN() {
191 return Err(iroh::protocol::AcceptError::NotAllowed {});
192 }
193
194 let node_id = connection.remote_node_id()?;
195 println!("{}: {node_id} connected", String::from_utf8_lossy(&alpn));
196
197 match connection.accept_bi().await {
198 Ok((mut iroh_send, mut iroh_recv)) => {
199 println!("Accepted bidirectional stream from {node_id}");
200
201 match TcpStream::connect(format!("127.0.0.1:{}", self.ssh_port)).await {
202 Ok(mut ssh_stream) => {
203 println!("Connected to local SSH server on port {}", self.ssh_port);
204
205 let (mut local_read, mut local_write) = ssh_stream.split();
206
207 let a_to_b =
208 async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
209 let b_to_a =
210 async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
211
212 tokio::select! {
213 result = a_to_b => {
214 println!("SSH->Iroh stream ended: {result:?}");
215 },
216 result = b_to_a => {
217 println!("Iroh->SSH stream ended: {result:?}");
218 },
219 };
220 }
221 Err(e) => {
222 println!("Failed to connect to SSH server: {e}");
223 }
224 }
225 }
226 Err(e) => {
227 println!("Failed to accept bidirectional stream: {e}");
228 }
229 }
230
231 Ok(())
232 }
233}
234
235pub fn dot_ssh(
236 default_secret_key: &SecretKey,
237 persist: bool,
238 _service: bool,
239) -> anyhow::Result<SecretKey> {
240 let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
241 #[allow(unused_mut)]
242 let mut ssh_dir = distro_home.join(".ssh");
243
244 #[cfg(target_os = "linux")]
247 if _service {
248 ssh_dir = std::path::PathBuf::from("/root/.ssh");
249 }
250
251 #[cfg(target_os = "windows")]
254 if _service {
255 ssh_dir = std::path::PathBuf::from(r#"C:\WINDOWS\system32\config\systemprofile\.ssh"#);
256 }
257
258 let pub_key = ssh_dir.join("irohssh_ed25519.pub");
259 let priv_key = ssh_dir.join("irohssh_ed25519");
260
261 match (ssh_dir.exists(), persist) {
262 (false, false) => {
263 bail!(
264 "no .ssh folder found in {}, use --persist flag to create it",
265 distro_home.display()
266 )
267 }
268 (false, true) => {
269 std::fs::create_dir_all(&ssh_dir)?;
270 println!("[INFO] created .ssh folder: {}", ssh_dir.display());
271 dot_ssh(default_secret_key, persist, _service)
272 }
273 (true, true) => {
274 if pub_key.exists() && priv_key.exists() {
276 if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
278 let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
279 sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
280 Ok(SecretKey::from_bytes(&sk_bytes))
281 } else {
282 bail!("failed to read secret key from {}", priv_key.display())
283 }
284 } else {
285 let key = default_secret_key.clone();
286 let secret_key = key.secret();
287 let public_key = key.public();
288
289 std::fs::write(pub_key, z32::encode(public_key.as_bytes()))?;
290 std::fs::write(priv_key, z32::encode(secret_key.as_bytes()))?;
291 Ok(key)
292 }
293 }
294 (true, false) => {
295 if pub_key.exists() && priv_key.exists() {
297 if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
299 let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
300 sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
301 return Ok(SecretKey::from_bytes(&sk_bytes));
302 }
303 }
304 bail!(
305 "no iroh-ssh keys found in {}, use --persist flag to create it",
306 ssh_dir.display()
307 )
308 }
309 }
310}