runmat_kernel/
transport.rs

1//! ZMQ transport and Jupyter v5 framing utilities
2//!
3//! Implements multipart message encoding/decoding with HMAC signatures
4//! according to the Jupyter messaging protocol. Used by the kernel server
5//! to read requests from the shell/control channels and publish results on
6//! the IOPub channel.
7
8use crate::protocol::{JupyterMessage, MessageHeader};
9use crate::{KernelError, Result};
10use hmac::{Hmac, Mac};
11use serde_json::Value as JsonValue;
12use sha2::Sha256;
13use std::env;
14
15const DELIM: &str = "<IDS|MSG>";
16
17/// Signature algorithm supported by the kernel
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum SignatureAlg {
20    None,
21    HmacSha256,
22}
23
24impl SignatureAlg {
25    pub fn from_scheme(scheme: &str) -> Self {
26        match scheme.to_ascii_lowercase().as_str() {
27            "hmac-sha256" => SignatureAlg::HmacSha256,
28            _ => SignatureAlg::None,
29        }
30    }
31}
32
33/// Compute Jupyter message signature for the 4 JSON frames
34fn compute_signature(alg: SignatureAlg, key: &[u8], frames: &[Vec<u8>]) -> String {
35    match alg {
36        SignatureAlg::None => String::new(),
37        SignatureAlg::HmacSha256 => {
38            let mut mac = Hmac::<Sha256>::new_from_slice(key).unwrap();
39            for frame in frames {
40                mac.update(frame);
41            }
42            let bytes = mac.finalize().into_bytes();
43            hex::encode(bytes)
44        }
45    }
46}
47
48/// Decode a multipart message received from a ZMQ socket into routing identities
49/// and a structured `JupyterMessage`. Verifies the HMAC signature if a key exists.
50pub fn recv_jupyter_message(
51    socket: &zmq::Socket,
52    key: &str,
53    scheme: &str,
54) -> Result<(Vec<Vec<u8>>, JupyterMessage)> {
55    let trace = env::var("RUNMAT_KERNEL_ZMQ_TRACE").is_ok();
56    // Receive all frames for one message
57    let frames = socket.recv_multipart(0).map_err(KernelError::Zmq)?;
58
59    // Split identities and content frames
60    let mut ids: Vec<Vec<u8>> = Vec::new();
61    let mut idx = 0usize;
62    while idx < frames.len() {
63        if frames[idx] == DELIM.as_bytes() {
64            idx += 1; // skip delim
65            break;
66        }
67        ids.push(frames[idx].clone());
68        idx += 1;
69    }
70
71    if idx >= frames.len() {
72        return Err(KernelError::Protocol(
73            "Missing <IDS|MSG> delimiter".to_string(),
74        ));
75    }
76
77    // We require at least signature + 4 JSON frames
78    if frames.len() - idx < 5 {
79        return Err(KernelError::Protocol(
80            "Incomplete message (expected signature + 4 JSON frames)".to_string(),
81        ));
82    }
83
84    let signature = &frames[idx];
85    let header = &frames[idx + 1];
86    let parent_header = &frames[idx + 2];
87    let metadata = &frames[idx + 3];
88    let content = &frames[idx + 4];
89    let buffers: Vec<Vec<u8>> = frames[idx + 5..].to_vec();
90
91    // Validate signature if key present
92    let alg = if key.is_empty() {
93        SignatureAlg::None
94    } else {
95        SignatureAlg::from_scheme(scheme)
96    };
97
98    if !matches!(alg, SignatureAlg::None) {
99        let expected = compute_signature(
100            alg,
101            key.as_bytes(),
102            &[
103                header.clone(),
104                parent_header.clone(),
105                metadata.clone(),
106                content.clone(),
107            ],
108        );
109        let provided = String::from_utf8_lossy(signature).to_string();
110        if expected != provided {
111            if trace {
112                eprintln!(
113                    "[ZMQ-TRACE] signature mismatch: expected {} provided {}",
114                    expected, provided
115                );
116            }
117            return Err(KernelError::Protocol("Invalid HMAC signature".to_string()));
118        }
119    }
120
121    // Build structured message
122    let header: MessageHeader = serde_json::from_slice(header)?;
123    // Parent header can be {} or null. Treat both as None.
124    let parent_val: JsonValue = serde_json::from_slice(parent_header)?;
125    let parent_header: Option<MessageHeader> = match parent_val {
126        JsonValue::Null => None,
127        JsonValue::Object(ref m) if m.is_empty() => None,
128        other => Some(serde_json::from_value(other).map_err(KernelError::Json)?),
129    };
130    let metadata_map: serde_json::Map<String, JsonValue> = serde_json::from_slice(metadata)?;
131    let metadata: std::collections::HashMap<String, JsonValue> = metadata_map.into_iter().collect();
132    let content: JsonValue = serde_json::from_slice(content)?;
133
134    let msg = JupyterMessage {
135        header,
136        parent_header,
137        metadata,
138        content,
139        buffers,
140    };
141
142    if trace {
143        eprintln!(
144            "[ZMQ-TRACE] RECV type={:?} session={}",
145            msg.header.msg_type, msg.header.session
146        );
147    }
148
149    Ok((ids, msg))
150}
151
152/// Encode and send a `JupyterMessage` with given routing identities on a ZMQ socket.
153pub fn send_jupyter_message(
154    socket: &zmq::Socket,
155    ids: &[Vec<u8>],
156    key: &str,
157    scheme: &str,
158    msg: &JupyterMessage,
159) -> Result<()> {
160    let trace = env::var("RUNMAT_KERNEL_ZMQ_TRACE").is_ok();
161    let alg = if key.is_empty() {
162        SignatureAlg::None
163    } else {
164        SignatureAlg::from_scheme(scheme)
165    };
166
167    // Serialize frames
168    let header = serde_json::to_vec(&msg.header)?;
169    let parent_header = if let Some(ref p) = msg.parent_header {
170        serde_json::to_vec(p)?
171    } else {
172        // Parent header can be an empty JSON object according to protocol
173        serde_json::to_vec(&serde_json::json!({}))?
174    };
175    let metadata = serde_json::to_vec(&msg.metadata)?;
176    let content = serde_json::to_vec(&msg.content)?;
177
178    let signature = compute_signature(
179        alg,
180        key.as_bytes(),
181        &[
182            header.clone(),
183            parent_header.clone(),
184            metadata.clone(),
185            content.clone(),
186        ],
187    );
188
189    // Assemble multipart frames
190    let mut frames: Vec<Vec<u8>> = Vec::new();
191    frames.extend_from_slice(ids);
192    frames.push(DELIM.as_bytes().to_vec());
193    frames.push(signature.into_bytes());
194    frames.push(header);
195    frames.push(parent_header);
196    frames.push(metadata);
197    frames.push(content);
198    frames.extend_from_slice(&msg.buffers);
199
200    socket.send_multipart(frames, 0).map_err(KernelError::Zmq)?;
201
202    if trace {
203        eprintln!(
204            "[ZMQ-TRACE] SEND type={:?} session={}",
205            msg.header.msg_type, msg.header.session
206        );
207    }
208
209    Ok(())
210}