prt_core/core/
ssh_tunnel.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum TunnelKind {
12 Local,
14 Dynamic,
16}
17
18impl TunnelKind {
19 pub fn label(self) -> &'static str {
20 match self {
21 Self::Local => "local",
22 Self::Dynamic => "dynamic",
23 }
24 }
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct SshTunnelSpec {
30 pub name: Option<String>,
32 pub kind: TunnelKind,
33 pub local_port: u16,
35 pub remote_host: Option<String>,
37 pub remote_port: Option<u16>,
39 pub host_alias: String,
41}
42
43#[derive(Debug, Clone, Default)]
47pub struct ResolvedHost<'a> {
48 pub hostname: Option<&'a str>,
49 pub user: Option<&'a str>,
50 pub port: Option<u16>,
51 pub identity_file: Option<&'a str>,
52}
53
54impl SshTunnelSpec {
55 fn forward_args(&self) -> Vec<String> {
57 match self.kind {
58 TunnelKind::Local => {
59 let host = self.remote_host.as_deref().unwrap_or("localhost");
60 let port = self.remote_port.unwrap_or(0);
61 vec![
62 "-N".into(),
63 "-L".into(),
64 format!("{}:{}:{}", self.local_port, host, port),
65 ]
66 }
67 TunnelKind::Dynamic => {
68 vec!["-N".into(), "-D".into(), self.local_port.to_string()]
69 }
70 }
71 }
72
73 pub fn ssh_args(&self) -> Vec<String> {
77 let mut args = self.forward_args();
78 args.push(self.host_alias.clone());
79 args
80 }
81
82 pub fn ssh_args_with(&self, host: &ResolvedHost<'_>) -> Vec<String> {
86 let mut args = self.forward_args();
87 if let Some(u) = host.user {
88 args.push("-l".into());
89 args.push(u.into());
90 }
91 if let Some(p) = host.port {
92 args.push("-p".into());
93 args.push(p.to_string());
94 }
95 if let Some(id) = host.identity_file {
96 args.push("-i".into());
97 args.push(id.into());
98 }
99 let target = host.hostname.unwrap_or(self.host_alias.as_str());
100 args.push(target.into());
101 args
102 }
103
104 pub fn summary(&self) -> String {
106 match self.kind {
107 TunnelKind::Local => {
108 let host = self.remote_host.as_deref().unwrap_or("?");
109 let port = self
110 .remote_port
111 .map(|p| p.to_string())
112 .unwrap_or_else(|| "?".into());
113 format!(
114 "L localhost:{} \u{2192} {}:{}:{}",
115 self.local_port, self.host_alias, host, port
116 )
117 }
118 TunnelKind::Dynamic => format!(
119 "D socks5://localhost:{} \u{2192} {}",
120 self.local_port, self.host_alias
121 ),
122 }
123 }
124
125 pub fn validate(&self) -> Result<(), String> {
127 if self.host_alias.trim().is_empty() {
128 return Err("host_alias is empty".into());
129 }
130 if self.local_port == 0 {
131 return Err("local_port must be > 0".into());
132 }
133 if self.kind == TunnelKind::Local {
134 if self
135 .remote_host
136 .as_deref()
137 .map(str::is_empty)
138 .unwrap_or(true)
139 {
140 return Err("remote_host required for Local tunnel".into());
141 }
142 match self.remote_port {
143 Some(p) if p > 0 => {}
144 _ => return Err("remote_port required for Local tunnel".into()),
145 }
146 }
147 Ok(())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 fn local_spec() -> SshTunnelSpec {
156 SshTunnelSpec {
157 name: Some("pg".into()),
158 kind: TunnelKind::Local,
159 local_port: 5433,
160 remote_host: Some("127.0.0.1".into()),
161 remote_port: Some(5432),
162 host_alias: "prod".into(),
163 }
164 }
165
166 fn dynamic_spec() -> SshTunnelSpec {
167 SshTunnelSpec {
168 name: None,
169 kind: TunnelKind::Dynamic,
170 local_port: 1080,
171 remote_host: None,
172 remote_port: None,
173 host_alias: "prod".into(),
174 }
175 }
176
177 #[test]
178 fn local_args() {
179 let args = local_spec().ssh_args();
180 assert_eq!(
181 args,
182 vec!["-N", "-L", "5433:127.0.0.1:5432", "prod"]
183 .into_iter()
184 .map(String::from)
185 .collect::<Vec<_>>()
186 );
187 }
188
189 #[test]
190 fn dynamic_args() {
191 let args = dynamic_spec().ssh_args();
192 assert_eq!(
193 args,
194 vec!["-N", "-D", "1080", "prod"]
195 .into_iter()
196 .map(String::from)
197 .collect::<Vec<_>>()
198 );
199 }
200
201 #[test]
202 fn summary_local_contains_endpoints() {
203 let s = local_spec().summary();
204 assert!(s.contains("5433"));
205 assert!(s.contains("prod"));
206 assert!(s.contains("127.0.0.1"));
207 assert!(s.contains("5432"));
208 }
209
210 #[test]
211 fn summary_dynamic_mentions_socks() {
212 let s = dynamic_spec().summary();
213 assert!(s.contains("1080"));
214 assert!(s.to_lowercase().contains("socks"));
215 assert!(s.contains("prod"));
216 }
217
218 #[test]
219 fn validate_local_ok_and_errors() {
220 assert!(local_spec().validate().is_ok());
221
222 let mut bad = local_spec();
223 bad.host_alias = "".into();
224 assert!(bad.validate().is_err());
225
226 let mut bad = local_spec();
227 bad.local_port = 0;
228 assert!(bad.validate().is_err());
229
230 let mut bad = local_spec();
231 bad.remote_host = None;
232 assert!(bad.validate().is_err());
233
234 let mut bad = local_spec();
235 bad.remote_port = None;
236 assert!(bad.validate().is_err());
237 }
238
239 #[test]
240 fn validate_dynamic_ok_with_no_remote() {
241 assert!(dynamic_spec().validate().is_ok());
242 }
243
244 #[test]
245 fn ssh_args_with_resolved_host_local() {
246 let spec = local_spec();
247 let host = ResolvedHost {
248 hostname: Some("real.example.com"),
249 user: Some("deploy"),
250 port: Some(2222),
251 identity_file: Some("/home/u/.ssh/id"),
252 };
253 let args = spec.ssh_args_with(&host);
254 assert_eq!(
255 args,
256 vec![
257 "-N",
258 "-L",
259 "5433:127.0.0.1:5432",
260 "-l",
261 "deploy",
262 "-p",
263 "2222",
264 "-i",
265 "/home/u/.ssh/id",
266 "real.example.com",
267 ]
268 .into_iter()
269 .map(String::from)
270 .collect::<Vec<_>>()
271 );
272 }
273
274 #[test]
275 fn ssh_args_with_empty_host_falls_back_to_alias() {
276 let spec = local_spec();
277 let host = ResolvedHost::default();
278 let args = spec.ssh_args_with(&host);
279 assert_eq!(args.last().map(String::as_str), Some("prod"));
281 assert!(!args.contains(&"-l".to_string()));
282 assert!(!args.contains(&"-p".to_string()));
283 }
284
285 #[test]
286 fn kind_serde_lowercase() {
287 let s: TunnelKind = serde_json::from_str("\"local\"").unwrap();
288 assert_eq!(s, TunnelKind::Local);
289 let s: TunnelKind = serde_json::from_str("\"dynamic\"").unwrap();
290 assert_eq!(s, TunnelKind::Dynamic);
291 }
292}