runmat_kernel/
connection.rs

1//! Jupyter kernel connection management
2//!
3//! Handles connection file parsing and ZMQ socket configuration compatible
4//! with the Jupyter protocol.
5
6use crate::{KernelError, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10/// Connection information for Jupyter kernel communication
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ConnectionInfo {
13    /// IP address to bind to (usually 127.0.0.1)
14    pub ip: String,
15    /// Transport protocol (usually "tcp")
16    pub transport: String,
17    /// Signature scheme for message authentication (usually "hmac-sha256")
18    pub signature_scheme: String,
19    /// HMAC key for message signing
20    pub key: String,
21    /// Shell socket port (handles execute requests)
22    pub shell_port: u16,
23    /// IOPub socket port (publishes execution results)
24    pub iopub_port: u16,
25    /// Stdin socket port (handles input requests)
26    pub stdin_port: u16,
27    /// Control socket port (handles kernel control)
28    pub control_port: u16,
29    /// Heartbeat socket port (kernel liveness check)
30    pub hb_port: u16,
31}
32
33impl Default for ConnectionInfo {
34    fn default() -> Self {
35        Self {
36            ip: "127.0.0.1".to_string(),
37            transport: "tcp".to_string(),
38            signature_scheme: "hmac-sha256".to_string(),
39            key: uuid::Uuid::new_v4().to_string(),
40            shell_port: 0,   // Let OS assign
41            iopub_port: 0,   // Let OS assign
42            stdin_port: 0,   // Let OS assign
43            control_port: 0, // Let OS assign
44            hb_port: 0,      // Let OS assign
45        }
46    }
47}
48
49impl ConnectionInfo {
50    /// Create connection info from a Jupyter connection file
51    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
52        let content = std::fs::read_to_string(path)
53            .map_err(|e| KernelError::Connection(format!("Failed to read connection file: {e}")))?;
54
55        Self::from_json(&content)
56    }
57
58    /// Parse connection info from JSON string
59    pub fn from_json(json: &str) -> Result<Self> {
60        serde_json::from_str(json)
61            .map_err(|e| KernelError::Connection(format!("Invalid connection JSON: {e}")))
62    }
63
64    /// Serialize connection info to JSON
65    pub fn to_json(&self) -> Result<String> {
66        serde_json::to_string_pretty(self).map_err(KernelError::Json)
67    }
68
69    /// Write connection info to a file
70    pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
71        let json = self.to_json()?;
72        std::fs::write(path, json)
73            .map_err(|e| KernelError::Connection(format!("Failed to write connection file: {e}")))
74    }
75
76    /// Generate a connection URL for a given socket type
77    pub fn socket_url(&self, port: u16) -> String {
78        format!("{}://{}:{}", self.transport, self.ip, port)
79    }
80
81    /// Get shell socket URL
82    pub fn shell_url(&self) -> String {
83        self.socket_url(self.shell_port)
84    }
85
86    /// Get IOPub socket URL
87    pub fn iopub_url(&self) -> String {
88        self.socket_url(self.iopub_port)
89    }
90
91    /// Get stdin socket URL
92    pub fn stdin_url(&self) -> String {
93        self.socket_url(self.stdin_port)
94    }
95
96    /// Get control socket URL
97    pub fn control_url(&self) -> String {
98        self.socket_url(self.control_port)
99    }
100
101    /// Get heartbeat socket URL
102    pub fn heartbeat_url(&self) -> String {
103        self.socket_url(self.hb_port)
104    }
105
106    /// Validate that all required fields are present and valid
107    pub fn validate(&self) -> Result<()> {
108        if self.ip.is_empty() {
109            return Err(KernelError::Connection(
110                "IP address cannot be empty".to_string(),
111            ));
112        }
113
114        if self.transport.is_empty() {
115            return Err(KernelError::Connection(
116                "Transport cannot be empty".to_string(),
117            ));
118        }
119
120        if self.key.is_empty() {
121            return Err(KernelError::Connection("Key cannot be empty".to_string()));
122        }
123
124        // Validate ports are non-zero (indicating they've been assigned)
125        let ports = [
126            ("shell", self.shell_port),
127            ("iopub", self.iopub_port),
128            ("stdin", self.stdin_port),
129            ("control", self.control_port),
130            ("hb", self.hb_port),
131        ];
132
133        for (name, port) in ports {
134            if port == 0 {
135                return Err(KernelError::Connection(format!(
136                    "{name} port must be assigned"
137                )));
138            }
139        }
140
141        Ok(())
142    }
143
144    /// Assign random available ports to all sockets
145    pub fn assign_ports(&mut self) -> Result<()> {
146        use std::net::TcpListener;
147
148        // Helper to find an available port
149        fn find_available_port() -> Result<u16> {
150            let listener = TcpListener::bind("127.0.0.1:0").map_err(|e| {
151                KernelError::Connection(format!("Failed to find available port: {e}"))
152            })?;
153            Ok(listener
154                .local_addr()
155                .map_err(|e| KernelError::Connection(format!("Failed to get port: {e}")))?
156                .port())
157        }
158
159        self.shell_port = find_available_port()?;
160        self.iopub_port = find_available_port()?;
161        self.stdin_port = find_available_port()?;
162        self.control_port = find_available_port()?;
163        self.hb_port = find_available_port()?;
164
165        Ok(())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use tempfile::NamedTempFile;
173
174    #[test]
175    fn test_default_connection() {
176        let conn = ConnectionInfo::default();
177        assert_eq!(conn.ip, "127.0.0.1");
178        assert_eq!(conn.transport, "tcp");
179        assert_eq!(conn.signature_scheme, "hmac-sha256");
180        assert!(!conn.key.is_empty());
181    }
182
183    #[test]
184    fn test_connection_json_roundtrip() {
185        let conn = ConnectionInfo {
186            shell_port: 12345,
187            iopub_port: 12346,
188            stdin_port: 12347,
189            control_port: 12348,
190            hb_port: 12349,
191            ..Default::default()
192        };
193
194        let json = conn.to_json().unwrap();
195        let parsed = ConnectionInfo::from_json(&json).unwrap();
196
197        assert_eq!(conn.shell_port, parsed.shell_port);
198        assert_eq!(conn.iopub_port, parsed.iopub_port);
199        assert_eq!(conn.key, parsed.key);
200    }
201
202    #[test]
203    fn test_connection_file_io() {
204        let conn = ConnectionInfo {
205            shell_port: 12345,
206            iopub_port: 12346,
207            stdin_port: 12347,
208            control_port: 12348,
209            hb_port: 12349,
210            ..Default::default()
211        };
212
213        let temp_file = NamedTempFile::new().unwrap();
214        conn.write_to_file(temp_file.path()).unwrap();
215
216        let loaded = ConnectionInfo::from_file(temp_file.path()).unwrap();
217        assert_eq!(conn.shell_port, loaded.shell_port);
218        assert_eq!(conn.key, loaded.key);
219    }
220
221    #[test]
222    fn test_socket_urls() {
223        let conn = ConnectionInfo {
224            shell_port: 12345,
225            iopub_port: 12346,
226            ..Default::default()
227        };
228
229        assert_eq!(conn.shell_url(), "tcp://127.0.0.1:12345");
230        assert_eq!(conn.iopub_url(), "tcp://127.0.0.1:12346");
231    }
232
233    #[test]
234    fn test_port_assignment() {
235        let mut conn = ConnectionInfo::default();
236        match conn.assign_ports() {
237            Ok(()) => {}
238            Err(err) if err.to_string().contains("Operation not permitted") => {
239                eprintln!("skipping port assignment test: {err}");
240                return;
241            }
242            Err(err) => panic!("{err}"),
243        }
244
245        assert_ne!(conn.shell_port, 0);
246        assert_ne!(conn.iopub_port, 0);
247        assert_ne!(conn.stdin_port, 0);
248        assert_ne!(conn.control_port, 0);
249        assert_ne!(conn.hb_port, 0);
250
251        conn.validate().unwrap();
252    }
253
254    #[test]
255    fn test_validation() {
256        let mut conn = ConnectionInfo::default();
257
258        // Should fail with unassigned ports
259        assert!(conn.validate().is_err());
260
261        // Should pass after port assignment
262        match conn.assign_ports() {
263            Ok(()) => {}
264            Err(err) if err.to_string().contains("Operation not permitted") => {
265                eprintln!("skipping validation test: {err}");
266                return;
267            }
268            Err(err) => panic!("{err}"),
269        }
270        conn.validate().unwrap();
271
272        // Should fail with empty key
273        conn.key.clear();
274        assert!(conn.validate().is_err());
275    }
276}