runmat_kernel/
transport.rs1use crate::protocol::{JupyterMessage, MessageHeader};
9use crate::{KernelError, Result};
10use hmac::{Hmac, Mac};
11use serde_json::Value as JsonValue;
12use sha2::Sha256;
13use std::env;
14
15const DELIM: &str = "<IDS|MSG>";
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum SignatureAlg {
20 None,
21 HmacSha256,
22}
23
24impl SignatureAlg {
25 pub fn from_scheme(scheme: &str) -> Self {
26 match scheme.to_ascii_lowercase().as_str() {
27 "hmac-sha256" => SignatureAlg::HmacSha256,
28 _ => SignatureAlg::None,
29 }
30 }
31}
32
33fn compute_signature(alg: SignatureAlg, key: &[u8], frames: &[Vec<u8>]) -> String {
35 match alg {
36 SignatureAlg::None => String::new(),
37 SignatureAlg::HmacSha256 => {
38 let mut mac = Hmac::<Sha256>::new_from_slice(key).unwrap();
39 for frame in frames {
40 mac.update(frame);
41 }
42 let bytes = mac.finalize().into_bytes();
43 hex::encode(bytes)
44 }
45 }
46}
47
48pub fn recv_jupyter_message(
51 socket: &zmq::Socket,
52 key: &str,
53 scheme: &str,
54) -> Result<(Vec<Vec<u8>>, JupyterMessage)> {
55 let trace = env::var("RUNMAT_KERNEL_ZMQ_TRACE").is_ok();
56 let frames = socket.recv_multipart(0).map_err(KernelError::Zmq)?;
58
59 let mut ids: Vec<Vec<u8>> = Vec::new();
61 let mut idx = 0usize;
62 while idx < frames.len() {
63 if frames[idx] == DELIM.as_bytes() {
64 idx += 1; break;
66 }
67 ids.push(frames[idx].clone());
68 idx += 1;
69 }
70
71 if idx >= frames.len() {
72 return Err(KernelError::Protocol(
73 "Missing <IDS|MSG> delimiter".to_string(),
74 ));
75 }
76
77 if frames.len() - idx < 5 {
79 return Err(KernelError::Protocol(
80 "Incomplete message (expected signature + 4 JSON frames)".to_string(),
81 ));
82 }
83
84 let signature = &frames[idx];
85 let header = &frames[idx + 1];
86 let parent_header = &frames[idx + 2];
87 let metadata = &frames[idx + 3];
88 let content = &frames[idx + 4];
89 let buffers: Vec<Vec<u8>> = frames[idx + 5..].to_vec();
90
91 let alg = if key.is_empty() {
93 SignatureAlg::None
94 } else {
95 SignatureAlg::from_scheme(scheme)
96 };
97
98 if !matches!(alg, SignatureAlg::None) {
99 let expected = compute_signature(
100 alg,
101 key.as_bytes(),
102 &[
103 header.clone(),
104 parent_header.clone(),
105 metadata.clone(),
106 content.clone(),
107 ],
108 );
109 let provided = String::from_utf8_lossy(signature).to_string();
110 if expected != provided {
111 if trace {
112 eprintln!(
113 "[ZMQ-TRACE] signature mismatch: expected {} provided {}",
114 expected, provided
115 );
116 }
117 return Err(KernelError::Protocol("Invalid HMAC signature".to_string()));
118 }
119 }
120
121 let header: MessageHeader = serde_json::from_slice(header)?;
123 let parent_val: JsonValue = serde_json::from_slice(parent_header)?;
125 let parent_header: Option<MessageHeader> = match parent_val {
126 JsonValue::Null => None,
127 JsonValue::Object(ref m) if m.is_empty() => None,
128 other => Some(serde_json::from_value(other).map_err(KernelError::Json)?),
129 };
130 let metadata_map: serde_json::Map<String, JsonValue> = serde_json::from_slice(metadata)?;
131 let metadata: std::collections::HashMap<String, JsonValue> = metadata_map.into_iter().collect();
132 let content: JsonValue = serde_json::from_slice(content)?;
133
134 let msg = JupyterMessage {
135 header,
136 parent_header,
137 metadata,
138 content,
139 buffers,
140 };
141
142 if trace {
143 eprintln!(
144 "[ZMQ-TRACE] RECV type={:?} session={}",
145 msg.header.msg_type, msg.header.session
146 );
147 }
148
149 Ok((ids, msg))
150}
151
152pub fn send_jupyter_message(
154 socket: &zmq::Socket,
155 ids: &[Vec<u8>],
156 key: &str,
157 scheme: &str,
158 msg: &JupyterMessage,
159) -> Result<()> {
160 let trace = env::var("RUNMAT_KERNEL_ZMQ_TRACE").is_ok();
161 let alg = if key.is_empty() {
162 SignatureAlg::None
163 } else {
164 SignatureAlg::from_scheme(scheme)
165 };
166
167 let header = serde_json::to_vec(&msg.header)?;
169 let parent_header = if let Some(ref p) = msg.parent_header {
170 serde_json::to_vec(p)?
171 } else {
172 serde_json::to_vec(&serde_json::json!({}))?
174 };
175 let metadata = serde_json::to_vec(&msg.metadata)?;
176 let content = serde_json::to_vec(&msg.content)?;
177
178 let signature = compute_signature(
179 alg,
180 key.as_bytes(),
181 &[
182 header.clone(),
183 parent_header.clone(),
184 metadata.clone(),
185 content.clone(),
186 ],
187 );
188
189 let mut frames: Vec<Vec<u8>> = Vec::new();
191 frames.extend_from_slice(ids);
192 frames.push(DELIM.as_bytes().to_vec());
193 frames.push(signature.into_bytes());
194 frames.push(header);
195 frames.push(parent_header);
196 frames.push(metadata);
197 frames.push(content);
198 frames.extend_from_slice(&msg.buffers);
199
200 socket.send_multipart(frames, 0).map_err(KernelError::Zmq)?;
201
202 if trace {
203 eprintln!(
204 "[ZMQ-TRACE] SEND type={:?} session={}",
205 msg.header.msg_type, msg.header.session
206 );
207 }
208
209 Ok(())
210}