use crate::protocol::{JupyterMessage, MessageHeader};
use crate::{KernelError, Result};
use hmac::{Hmac, Mac};
use serde_json::Value as JsonValue;
use sha2::Sha256;
use std::env;
const DELIM: &str = "<IDS|MSG>";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SignatureAlg {
None,
HmacSha256,
}
impl SignatureAlg {
pub fn from_scheme(scheme: &str) -> Self {
match scheme.to_ascii_lowercase().as_str() {
"hmac-sha256" => SignatureAlg::HmacSha256,
_ => SignatureAlg::None,
}
}
}
fn compute_signature(alg: SignatureAlg, key: &[u8], frames: &[Vec<u8>]) -> String {
match alg {
SignatureAlg::None => String::new(),
SignatureAlg::HmacSha256 => {
let mut mac = Hmac::<Sha256>::new_from_slice(key).unwrap();
for frame in frames {
mac.update(frame);
}
let bytes = mac.finalize().into_bytes();
hex::encode(bytes)
}
}
}
pub fn recv_jupyter_message(
socket: &zmq::Socket,
key: &str,
scheme: &str,
) -> Result<(Vec<Vec<u8>>, JupyterMessage)> {
let trace = env::var("RUNMAT_KERNEL_ZMQ_TRACE").is_ok();
let frames = socket.recv_multipart(0).map_err(KernelError::Zmq)?;
let mut ids: Vec<Vec<u8>> = Vec::new();
let mut idx = 0usize;
while idx < frames.len() {
if frames[idx] == DELIM.as_bytes() {
idx += 1; break;
}
ids.push(frames[idx].clone());
idx += 1;
}
if idx >= frames.len() {
return Err(KernelError::Protocol(
"Missing <IDS|MSG> delimiter".to_string(),
));
}
if frames.len() - idx < 5 {
return Err(KernelError::Protocol(
"Incomplete message (expected signature + 4 JSON frames)".to_string(),
));
}
let signature = &frames[idx];
let header = &frames[idx + 1];
let parent_header = &frames[idx + 2];
let metadata = &frames[idx + 3];
let content = &frames[idx + 4];
let buffers: Vec<Vec<u8>> = frames[idx + 5..].to_vec();
let alg = if key.is_empty() {
SignatureAlg::None
} else {
SignatureAlg::from_scheme(scheme)
};
if !matches!(alg, SignatureAlg::None) {
let expected = compute_signature(
alg,
key.as_bytes(),
&[
header.clone(),
parent_header.clone(),
metadata.clone(),
content.clone(),
],
);
let provided = String::from_utf8_lossy(signature).to_string();
if expected != provided {
if trace {
eprintln!(
"[ZMQ-TRACE] signature mismatch: expected {} provided {}",
expected, provided
);
}
return Err(KernelError::Protocol("Invalid HMAC signature".to_string()));
}
}
let header: MessageHeader = serde_json::from_slice(header)?;
let parent_val: JsonValue = serde_json::from_slice(parent_header)?;
let parent_header: Option<MessageHeader> = match parent_val {
JsonValue::Null => None,
JsonValue::Object(ref m) if m.is_empty() => None,
other => Some(serde_json::from_value(other).map_err(KernelError::Json)?),
};
let metadata_map: serde_json::Map<String, JsonValue> = serde_json::from_slice(metadata)?;
let metadata: std::collections::HashMap<String, JsonValue> = metadata_map.into_iter().collect();
let content: JsonValue = serde_json::from_slice(content)?;
let msg = JupyterMessage {
header,
parent_header,
metadata,
content,
buffers,
};
if trace {
eprintln!(
"[ZMQ-TRACE] RECV type={:?} session={}",
msg.header.msg_type, msg.header.session
);
}
Ok((ids, msg))
}
pub fn send_jupyter_message(
socket: &zmq::Socket,
ids: &[Vec<u8>],
key: &str,
scheme: &str,
msg: &JupyterMessage,
) -> Result<()> {
let trace = env::var("RUNMAT_KERNEL_ZMQ_TRACE").is_ok();
let alg = if key.is_empty() {
SignatureAlg::None
} else {
SignatureAlg::from_scheme(scheme)
};
let header = serde_json::to_vec(&msg.header)?;
let parent_header = if let Some(ref p) = msg.parent_header {
serde_json::to_vec(p)?
} else {
serde_json::to_vec(&serde_json::json!({}))?
};
let metadata = serde_json::to_vec(&msg.metadata)?;
let content = serde_json::to_vec(&msg.content)?;
let signature = compute_signature(
alg,
key.as_bytes(),
&[
header.clone(),
parent_header.clone(),
metadata.clone(),
content.clone(),
],
);
let mut frames: Vec<Vec<u8>> = Vec::new();
frames.extend_from_slice(ids);
frames.push(DELIM.as_bytes().to_vec());
frames.push(signature.into_bytes());
frames.push(header);
frames.push(parent_header);
frames.push(metadata);
frames.push(content);
frames.extend_from_slice(&msg.buffers);
socket.send_multipart(frames, 0).map_err(KernelError::Zmq)?;
if trace {
eprintln!(
"[ZMQ-TRACE] SEND type={:?} session={}",
msg.header.msg_type, msg.header.session
);
}
Ok(())
}