use crate::{KernelError, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionInfo {
pub ip: String,
pub transport: String,
pub signature_scheme: String,
pub key: String,
pub shell_port: u16,
pub iopub_port: u16,
pub stdin_port: u16,
pub control_port: u16,
pub hb_port: u16,
}
impl Default for ConnectionInfo {
fn default() -> Self {
Self {
ip: "127.0.0.1".to_string(),
transport: "tcp".to_string(),
signature_scheme: "hmac-sha256".to_string(),
key: uuid::Uuid::new_v4().to_string(),
shell_port: 0, iopub_port: 0, stdin_port: 0, control_port: 0, hb_port: 0, }
}
}
impl ConnectionInfo {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path)
.map_err(|e| KernelError::Connection(format!("Failed to read connection file: {e}")))?;
Self::from_json(&content)
}
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| KernelError::Connection(format!("Invalid connection JSON: {e}")))
}
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self).map_err(KernelError::Json)
}
pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = self.to_json()?;
std::fs::write(path, json)
.map_err(|e| KernelError::Connection(format!("Failed to write connection file: {e}")))
}
pub fn socket_url(&self, port: u16) -> String {
format!("{}://{}:{}", self.transport, self.ip, port)
}
pub fn shell_url(&self) -> String {
self.socket_url(self.shell_port)
}
pub fn iopub_url(&self) -> String {
self.socket_url(self.iopub_port)
}
pub fn stdin_url(&self) -> String {
self.socket_url(self.stdin_port)
}
pub fn control_url(&self) -> String {
self.socket_url(self.control_port)
}
pub fn heartbeat_url(&self) -> String {
self.socket_url(self.hb_port)
}
pub fn validate(&self) -> Result<()> {
if self.ip.is_empty() {
return Err(KernelError::Connection(
"IP address cannot be empty".to_string(),
));
}
if self.transport.is_empty() {
return Err(KernelError::Connection(
"Transport cannot be empty".to_string(),
));
}
if self.key.is_empty() {
return Err(KernelError::Connection("Key cannot be empty".to_string()));
}
let ports = [
("shell", self.shell_port),
("iopub", self.iopub_port),
("stdin", self.stdin_port),
("control", self.control_port),
("hb", self.hb_port),
];
for (name, port) in ports {
if port == 0 {
return Err(KernelError::Connection(format!(
"{name} port must be assigned"
)));
}
}
Ok(())
}
pub fn assign_ports(&mut self) -> Result<()> {
use std::net::TcpListener;
fn find_available_port() -> Result<u16> {
let listener = TcpListener::bind("127.0.0.1:0").map_err(|e| {
KernelError::Connection(format!("Failed to find available port: {e}"))
})?;
Ok(listener
.local_addr()
.map_err(|e| KernelError::Connection(format!("Failed to get port: {e}")))?
.port())
}
self.shell_port = find_available_port()?;
self.iopub_port = find_available_port()?;
self.stdin_port = find_available_port()?;
self.control_port = find_available_port()?;
self.hb_port = find_available_port()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_default_connection() {
let conn = ConnectionInfo::default();
assert_eq!(conn.ip, "127.0.0.1");
assert_eq!(conn.transport, "tcp");
assert_eq!(conn.signature_scheme, "hmac-sha256");
assert!(!conn.key.is_empty());
}
#[test]
fn test_connection_json_roundtrip() {
let conn = ConnectionInfo {
shell_port: 12345,
iopub_port: 12346,
stdin_port: 12347,
control_port: 12348,
hb_port: 12349,
..Default::default()
};
let json = conn.to_json().unwrap();
let parsed = ConnectionInfo::from_json(&json).unwrap();
assert_eq!(conn.shell_port, parsed.shell_port);
assert_eq!(conn.iopub_port, parsed.iopub_port);
assert_eq!(conn.key, parsed.key);
}
#[test]
fn test_connection_file_io() {
let conn = ConnectionInfo {
shell_port: 12345,
iopub_port: 12346,
stdin_port: 12347,
control_port: 12348,
hb_port: 12349,
..Default::default()
};
let temp_file = NamedTempFile::new().unwrap();
conn.write_to_file(temp_file.path()).unwrap();
let loaded = ConnectionInfo::from_file(temp_file.path()).unwrap();
assert_eq!(conn.shell_port, loaded.shell_port);
assert_eq!(conn.key, loaded.key);
}
#[test]
fn test_socket_urls() {
let conn = ConnectionInfo {
shell_port: 12345,
iopub_port: 12346,
..Default::default()
};
assert_eq!(conn.shell_url(), "tcp://127.0.0.1:12345");
assert_eq!(conn.iopub_url(), "tcp://127.0.0.1:12346");
}
#[test]
fn test_port_assignment() {
let mut conn = ConnectionInfo::default();
match conn.assign_ports() {
Ok(()) => {}
Err(err) if err.to_string().contains("Operation not permitted") => {
eprintln!("skipping port assignment test: {err}");
return;
}
Err(err) => panic!("{err}"),
}
assert_ne!(conn.shell_port, 0);
assert_ne!(conn.iopub_port, 0);
assert_ne!(conn.stdin_port, 0);
assert_ne!(conn.control_port, 0);
assert_ne!(conn.hb_port, 0);
conn.validate().unwrap();
}
#[test]
fn test_validation() {
let mut conn = ConnectionInfo::default();
assert!(conn.validate().is_err());
match conn.assign_ports() {
Ok(()) => {}
Err(err) if err.to_string().contains("Operation not permitted") => {
eprintln!("skipping validation test: {err}");
return;
}
Err(err) => panic!("{err}"),
}
conn.validate().unwrap();
conn.key.clear();
assert!(conn.validate().is_err());
}
}