runmat_kernel/
connection.rs1use crate::{KernelError, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ConnectionInfo {
13 pub ip: String,
15 pub transport: String,
17 pub signature_scheme: String,
19 pub key: String,
21 pub shell_port: u16,
23 pub iopub_port: u16,
25 pub stdin_port: u16,
27 pub control_port: u16,
29 pub hb_port: u16,
31}
32
33impl Default for ConnectionInfo {
34 fn default() -> Self {
35 Self {
36 ip: "127.0.0.1".to_string(),
37 transport: "tcp".to_string(),
38 signature_scheme: "hmac-sha256".to_string(),
39 key: uuid::Uuid::new_v4().to_string(),
40 shell_port: 0, iopub_port: 0, stdin_port: 0, control_port: 0, hb_port: 0, }
46 }
47}
48
49impl ConnectionInfo {
50 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
52 let content = std::fs::read_to_string(path)
53 .map_err(|e| KernelError::Connection(format!("Failed to read connection file: {e}")))?;
54
55 Self::from_json(&content)
56 }
57
58 pub fn from_json(json: &str) -> Result<Self> {
60 serde_json::from_str(json)
61 .map_err(|e| KernelError::Connection(format!("Invalid connection JSON: {e}")))
62 }
63
64 pub fn to_json(&self) -> Result<String> {
66 serde_json::to_string_pretty(self).map_err(KernelError::Json)
67 }
68
69 pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
71 let json = self.to_json()?;
72 std::fs::write(path, json)
73 .map_err(|e| KernelError::Connection(format!("Failed to write connection file: {e}")))
74 }
75
76 pub fn socket_url(&self, port: u16) -> String {
78 format!("{}://{}:{}", self.transport, self.ip, port)
79 }
80
81 pub fn shell_url(&self) -> String {
83 self.socket_url(self.shell_port)
84 }
85
86 pub fn iopub_url(&self) -> String {
88 self.socket_url(self.iopub_port)
89 }
90
91 pub fn stdin_url(&self) -> String {
93 self.socket_url(self.stdin_port)
94 }
95
96 pub fn control_url(&self) -> String {
98 self.socket_url(self.control_port)
99 }
100
101 pub fn heartbeat_url(&self) -> String {
103 self.socket_url(self.hb_port)
104 }
105
106 pub fn validate(&self) -> Result<()> {
108 if self.ip.is_empty() {
109 return Err(KernelError::Connection(
110 "IP address cannot be empty".to_string(),
111 ));
112 }
113
114 if self.transport.is_empty() {
115 return Err(KernelError::Connection(
116 "Transport cannot be empty".to_string(),
117 ));
118 }
119
120 if self.key.is_empty() {
121 return Err(KernelError::Connection("Key cannot be empty".to_string()));
122 }
123
124 let ports = [
126 ("shell", self.shell_port),
127 ("iopub", self.iopub_port),
128 ("stdin", self.stdin_port),
129 ("control", self.control_port),
130 ("hb", self.hb_port),
131 ];
132
133 for (name, port) in ports {
134 if port == 0 {
135 return Err(KernelError::Connection(format!(
136 "{name} port must be assigned"
137 )));
138 }
139 }
140
141 Ok(())
142 }
143
144 pub fn assign_ports(&mut self) -> Result<()> {
146 use std::net::TcpListener;
147
148 fn find_available_port() -> Result<u16> {
150 let listener = TcpListener::bind("127.0.0.1:0").map_err(|e| {
151 KernelError::Connection(format!("Failed to find available port: {e}"))
152 })?;
153 Ok(listener
154 .local_addr()
155 .map_err(|e| KernelError::Connection(format!("Failed to get port: {e}")))?
156 .port())
157 }
158
159 self.shell_port = find_available_port()?;
160 self.iopub_port = find_available_port()?;
161 self.stdin_port = find_available_port()?;
162 self.control_port = find_available_port()?;
163 self.hb_port = find_available_port()?;
164
165 Ok(())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use tempfile::NamedTempFile;
173
174 #[test]
175 fn test_default_connection() {
176 let conn = ConnectionInfo::default();
177 assert_eq!(conn.ip, "127.0.0.1");
178 assert_eq!(conn.transport, "tcp");
179 assert_eq!(conn.signature_scheme, "hmac-sha256");
180 assert!(!conn.key.is_empty());
181 }
182
183 #[test]
184 fn test_connection_json_roundtrip() {
185 let conn = ConnectionInfo {
186 shell_port: 12345,
187 iopub_port: 12346,
188 stdin_port: 12347,
189 control_port: 12348,
190 hb_port: 12349,
191 ..Default::default()
192 };
193
194 let json = conn.to_json().unwrap();
195 let parsed = ConnectionInfo::from_json(&json).unwrap();
196
197 assert_eq!(conn.shell_port, parsed.shell_port);
198 assert_eq!(conn.iopub_port, parsed.iopub_port);
199 assert_eq!(conn.key, parsed.key);
200 }
201
202 #[test]
203 fn test_connection_file_io() {
204 let conn = ConnectionInfo {
205 shell_port: 12345,
206 iopub_port: 12346,
207 stdin_port: 12347,
208 control_port: 12348,
209 hb_port: 12349,
210 ..Default::default()
211 };
212
213 let temp_file = NamedTempFile::new().unwrap();
214 conn.write_to_file(temp_file.path()).unwrap();
215
216 let loaded = ConnectionInfo::from_file(temp_file.path()).unwrap();
217 assert_eq!(conn.shell_port, loaded.shell_port);
218 assert_eq!(conn.key, loaded.key);
219 }
220
221 #[test]
222 fn test_socket_urls() {
223 let conn = ConnectionInfo {
224 shell_port: 12345,
225 iopub_port: 12346,
226 ..Default::default()
227 };
228
229 assert_eq!(conn.shell_url(), "tcp://127.0.0.1:12345");
230 assert_eq!(conn.iopub_url(), "tcp://127.0.0.1:12346");
231 }
232
233 #[test]
234 fn test_port_assignment() {
235 let mut conn = ConnectionInfo::default();
236 match conn.assign_ports() {
237 Ok(()) => {}
238 Err(err) if err.to_string().contains("Operation not permitted") => {
239 eprintln!("skipping port assignment test: {err}");
240 return;
241 }
242 Err(err) => panic!("{err}"),
243 }
244
245 assert_ne!(conn.shell_port, 0);
246 assert_ne!(conn.iopub_port, 0);
247 assert_ne!(conn.stdin_port, 0);
248 assert_ne!(conn.control_port, 0);
249 assert_ne!(conn.hb_port, 0);
250
251 conn.validate().unwrap();
252 }
253
254 #[test]
255 fn test_validation() {
256 let mut conn = ConnectionInfo::default();
257
258 assert!(conn.validate().is_err());
260
261 match conn.assign_ports() {
263 Ok(()) => {}
264 Err(err) if err.to_string().contains("Operation not permitted") => {
265 eprintln!("skipping validation test: {err}");
266 return;
267 }
268 Err(err) => panic!("{err}"),
269 }
270 conn.validate().unwrap();
271
272 conn.key.clear();
274 assert!(conn.validate().is_err());
275 }
276}