kernel_sidecar/jupyter/
connection_file.rs

1use rand::distributions::Alphanumeric;
2use rand::Rng;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5use std::fs::{self, File};
6use std::io::{self, Write};
7use std::net::TcpListener;
8use std::path::Path;
9
10#[derive(Serialize, Deserialize, Debug, Clone)]
11pub struct ConnectionInfo {
12    ip: String,
13    transport: String,
14    shell_port: u16,
15    iopub_port: u16,
16    stdin_port: u16,
17    control_port: u16,
18    hb_port: u16,
19    signature_scheme: String,
20    pub key: String,
21    kernel_name: Option<String>,
22}
23
24fn find_open_port() -> Result<u16, std::io::Error> {
25    TcpListener::bind("127.0.0.1:0")
26        .and_then(|listener| listener.local_addr())
27        .map(|addr| addr.port())
28}
29
30fn generate_hmac_key() -> String {
31    rand::thread_rng()
32        .sample_iter(&Alphanumeric)
33        .take(32)
34        .map(char::from)
35        .collect()
36}
37
38impl ConnectionInfo {
39    pub fn new(kernel_name: Option<String>) -> Result<Self, io::Error> {
40        let mut ports = HashSet::new();
41        while ports.len() < 5 {
42            let port = find_open_port()?;
43            ports.insert(port);
44        }
45
46        let mut port_iter = ports.into_iter();
47        Ok(Self {
48            ip: "127.0.0.1".to_string(),
49            transport: "tcp".to_string(),
50            shell_port: port_iter.next().unwrap(),
51            iopub_port: port_iter.next().unwrap(),
52            stdin_port: port_iter.next().unwrap(),
53            control_port: port_iter.next().unwrap(),
54            hb_port: port_iter.next().unwrap(),
55            signature_scheme: "hmac-sha256".to_string(),
56            key: generate_hmac_key(),
57            kernel_name,
58        })
59    }
60
61    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
62        let file_contents = fs::read_to_string(path)?;
63        serde_json::from_str(&file_contents).map_err(io::Error::from)
64    }
65
66    pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
67        let json = serde_json::to_string_pretty(self)?;
68        let mut file = File::create(path)?;
69        file.write_all(json.as_bytes())?;
70        Ok(())
71    }
72
73    pub fn to_temp_file(&self) -> Result<std::path::PathBuf, io::Error> {
74        let mut file_path = std::env::temp_dir();
75        if self.kernel_name.is_some() {
76            file_path.push(format!(
77                "kernel-sidecar-{}-{}.json",
78                self.kernel_name.as_ref().unwrap(),
79                uuid::Uuid::new_v4()
80            ));
81        } else {
82            file_path.push(format!("kernel-sidecar-{}.json", uuid::Uuid::new_v4()));
83        }
84        self.to_file(&file_path)?;
85        Ok(file_path)
86    }
87
88    pub fn iopub_address(&self) -> String {
89        format!("{}://{}:{}", self.transport, self.ip, self.iopub_port)
90    }
91
92    pub fn shell_address(&self) -> String {
93        format!("{}://{}:{}", self.transport, self.ip, self.shell_port)
94    }
95
96    pub fn heartbeat_address(&self) -> String {
97        format!("{}://{}:{}", self.transport, self.ip, self.hb_port)
98    }
99}