Skip to main content

prt_core/core/
ssh_tunnel.rs

1//! SSH tunnel specification — a serializable description of one tunnel.
2//!
3//! This module is process-agnostic: it only describes *what* tunnel to spawn.
4//! Actual `ssh` subprocess handling lives in the binary crate (`prt::forward`).
5
6use serde::{Deserialize, Serialize};
7
8/// Type of SSH tunnel.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum TunnelKind {
12    /// `ssh -L LOCAL:remote_host:REMOTE host` — bring a remote service to localhost.
13    Local,
14    /// `ssh -D LOCAL host` — SOCKS5 proxy on localhost.
15    Dynamic,
16}
17
18impl TunnelKind {
19    pub fn label(self) -> &'static str {
20        match self {
21            Self::Local => "local",
22            Self::Dynamic => "dynamic",
23        }
24    }
25}
26
27/// Description of one SSH tunnel.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct SshTunnelSpec {
30    /// Optional human-friendly name.
31    pub name: Option<String>,
32    pub kind: TunnelKind,
33    /// Local port that ssh will bind on `localhost`.
34    pub local_port: u16,
35    /// Remote target host (only used for `Local`). For `Dynamic`, ignored.
36    pub remote_host: Option<String>,
37    /// Remote target port (only used for `Local`).
38    pub remote_port: Option<u16>,
39    /// SSH host argument (alias from `~/.ssh/config` or `user@host`).
40    pub host_alias: String,
41}
42
43/// Resolved SSH connection settings used to expand a spec into concrete
44/// command-line flags. Lets `[[ssh_hosts]]` aliases that don't appear in
45/// `~/.ssh/config` actually resolve at the OS level.
46#[derive(Debug, Clone, Default)]
47pub struct ResolvedHost<'a> {
48    pub hostname: Option<&'a str>,
49    pub user: Option<&'a str>,
50    pub port: Option<u16>,
51    pub identity_file: Option<&'a str>,
52}
53
54impl SshTunnelSpec {
55    /// `-N -L LOCAL:host:PORT` / `-N -D LOCAL` — without the trailing host arg.
56    fn forward_args(&self) -> Vec<String> {
57        match self.kind {
58            TunnelKind::Local => {
59                let host = self.remote_host.as_deref().unwrap_or("localhost");
60                let port = self.remote_port.unwrap_or(0);
61                vec![
62                    "-N".into(),
63                    "-L".into(),
64                    format!("{}:{}:{}", self.local_port, host, port),
65                ]
66            }
67            TunnelKind::Dynamic => {
68                vec!["-N".into(), "-D".into(), self.local_port.to_string()]
69            }
70        }
71    }
72
73    /// Build the argument list passed to `ssh`.
74    /// Always includes `-N` (no remote command). Uses only `host_alias`
75    /// — relies on `~/.ssh/config` (or DNS) to resolve it.
76    pub fn ssh_args(&self) -> Vec<String> {
77        let mut args = self.forward_args();
78        args.push(self.host_alias.clone());
79        args
80    }
81
82    /// Like [`ssh_args`] but injects `-l user`, `-p port`, `-i identity_file`
83    /// from a resolved host, and uses `hostname` (when provided) as the
84    /// positional target so prt-config-only aliases resolve correctly.
85    pub fn ssh_args_with(&self, host: &ResolvedHost<'_>) -> Vec<String> {
86        let mut args = self.forward_args();
87        if let Some(u) = host.user {
88            args.push("-l".into());
89            args.push(u.into());
90        }
91        if let Some(p) = host.port {
92            args.push("-p".into());
93            args.push(p.to_string());
94        }
95        if let Some(id) = host.identity_file {
96            args.push("-i".into());
97            args.push(id.into());
98        }
99        let target = host.hostname.unwrap_or(self.host_alias.as_str());
100        args.push(target.into());
101        args
102    }
103
104    /// Human-readable one-line summary.
105    pub fn summary(&self) -> String {
106        match self.kind {
107            TunnelKind::Local => {
108                let host = self.remote_host.as_deref().unwrap_or("?");
109                let port = self
110                    .remote_port
111                    .map(|p| p.to_string())
112                    .unwrap_or_else(|| "?".into());
113                format!(
114                    "L localhost:{} \u{2192} {}:{}:{}",
115                    self.local_port, self.host_alias, host, port
116                )
117            }
118            TunnelKind::Dynamic => format!(
119                "D socks5://localhost:{} \u{2192} {}",
120                self.local_port, self.host_alias
121            ),
122        }
123    }
124
125    /// Validate that the spec is internally consistent.
126    pub fn validate(&self) -> Result<(), String> {
127        if self.host_alias.trim().is_empty() {
128            return Err("host_alias is empty".into());
129        }
130        if self.local_port == 0 {
131            return Err("local_port must be > 0".into());
132        }
133        if self.kind == TunnelKind::Local {
134            if self
135                .remote_host
136                .as_deref()
137                .map(str::is_empty)
138                .unwrap_or(true)
139            {
140                return Err("remote_host required for Local tunnel".into());
141            }
142            match self.remote_port {
143                Some(p) if p > 0 => {}
144                _ => return Err("remote_port required for Local tunnel".into()),
145            }
146        }
147        Ok(())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    fn local_spec() -> SshTunnelSpec {
156        SshTunnelSpec {
157            name: Some("pg".into()),
158            kind: TunnelKind::Local,
159            local_port: 5433,
160            remote_host: Some("127.0.0.1".into()),
161            remote_port: Some(5432),
162            host_alias: "prod".into(),
163        }
164    }
165
166    fn dynamic_spec() -> SshTunnelSpec {
167        SshTunnelSpec {
168            name: None,
169            kind: TunnelKind::Dynamic,
170            local_port: 1080,
171            remote_host: None,
172            remote_port: None,
173            host_alias: "prod".into(),
174        }
175    }
176
177    #[test]
178    fn local_args() {
179        let args = local_spec().ssh_args();
180        assert_eq!(
181            args,
182            vec!["-N", "-L", "5433:127.0.0.1:5432", "prod"]
183                .into_iter()
184                .map(String::from)
185                .collect::<Vec<_>>()
186        );
187    }
188
189    #[test]
190    fn dynamic_args() {
191        let args = dynamic_spec().ssh_args();
192        assert_eq!(
193            args,
194            vec!["-N", "-D", "1080", "prod"]
195                .into_iter()
196                .map(String::from)
197                .collect::<Vec<_>>()
198        );
199    }
200
201    #[test]
202    fn summary_local_contains_endpoints() {
203        let s = local_spec().summary();
204        assert!(s.contains("5433"));
205        assert!(s.contains("prod"));
206        assert!(s.contains("127.0.0.1"));
207        assert!(s.contains("5432"));
208    }
209
210    #[test]
211    fn summary_dynamic_mentions_socks() {
212        let s = dynamic_spec().summary();
213        assert!(s.contains("1080"));
214        assert!(s.to_lowercase().contains("socks"));
215        assert!(s.contains("prod"));
216    }
217
218    #[test]
219    fn validate_local_ok_and_errors() {
220        assert!(local_spec().validate().is_ok());
221
222        let mut bad = local_spec();
223        bad.host_alias = "".into();
224        assert!(bad.validate().is_err());
225
226        let mut bad = local_spec();
227        bad.local_port = 0;
228        assert!(bad.validate().is_err());
229
230        let mut bad = local_spec();
231        bad.remote_host = None;
232        assert!(bad.validate().is_err());
233
234        let mut bad = local_spec();
235        bad.remote_port = None;
236        assert!(bad.validate().is_err());
237    }
238
239    #[test]
240    fn validate_dynamic_ok_with_no_remote() {
241        assert!(dynamic_spec().validate().is_ok());
242    }
243
244    #[test]
245    fn ssh_args_with_resolved_host_local() {
246        let spec = local_spec();
247        let host = ResolvedHost {
248            hostname: Some("real.example.com"),
249            user: Some("deploy"),
250            port: Some(2222),
251            identity_file: Some("/home/u/.ssh/id"),
252        };
253        let args = spec.ssh_args_with(&host);
254        assert_eq!(
255            args,
256            vec![
257                "-N",
258                "-L",
259                "5433:127.0.0.1:5432",
260                "-l",
261                "deploy",
262                "-p",
263                "2222",
264                "-i",
265                "/home/u/.ssh/id",
266                "real.example.com",
267            ]
268            .into_iter()
269            .map(String::from)
270            .collect::<Vec<_>>()
271        );
272    }
273
274    #[test]
275    fn ssh_args_with_empty_host_falls_back_to_alias() {
276        let spec = local_spec();
277        let host = ResolvedHost::default();
278        let args = spec.ssh_args_with(&host);
279        // No -l/-p/-i and the alias is the positional target.
280        assert_eq!(args.last().map(String::as_str), Some("prod"));
281        assert!(!args.contains(&"-l".to_string()));
282        assert!(!args.contains(&"-p".to_string()));
283    }
284
285    #[test]
286    fn kind_serde_lowercase() {
287        let s: TunnelKind = serde_json::from_str("\"local\"").unwrap();
288        assert_eq!(s, TunnelKind::Local);
289        let s: TunnelKind = serde_json::from_str("\"dynamic\"").unwrap();
290        assert_eq!(s, TunnelKind::Dynamic);
291    }
292}