1use crate::{
7 execution::{ExecutionStats, ExecutionStatus},
8 protocol::{
9 ErrorContent, ExecuteReply, ExecuteRequest, ExecuteResult, ExecutionState, JupyterMessage,
10 MessageType, Status,
11 },
12 transport::{recv_jupyter_message, send_jupyter_message},
13 ConnectionInfo, ExecutionEngine, KernelConfig, KernelError, KernelInfo, Result,
14};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::{broadcast, mpsc};
18
19pub struct KernelServer {
21 config: KernelConfig,
23 ctx: Option<zmq::Context>,
25 engine: Arc<tokio::sync::Mutex<ExecutionEngine>>,
27 status_tx: broadcast::Sender<ExecutionState>,
29 shutdown_tx: mpsc::Sender<()>,
31 tasks: Vec<std::thread::JoinHandle<Result<()>>>,
33 router: Option<Arc<MessageRouter>>,
35}
36
37struct MessageRouter {
39 engine: Arc<tokio::sync::Mutex<ExecutionEngine>>,
40 session_id: String,
41 status_tx: broadcast::Sender<ExecutionState>,
42}
43
44impl KernelServer {
45 pub fn new(config: KernelConfig) -> Self {
47 let engine = Arc::new(tokio::sync::Mutex::new(ExecutionEngine::new()));
48 let (status_tx, _) = broadcast::channel(16);
49 let (shutdown_tx, _) = mpsc::channel(1);
50
51 Self {
52 config,
53 ctx: None,
54 engine,
55 status_tx,
56 shutdown_tx,
57 tasks: Vec::new(),
58 router: None,
59 }
60 }
61
62 pub async fn start(&mut self) -> Result<()> {
64 log::info!("Starting RunMat kernel server");
65
66 self.config.connection.validate()?;
68
69 let ctx = zmq::Context::new();
71 self.ctx = Some(ctx.clone());
72
73 let shell_url = self.config.connection.shell_url();
75 let iopub_url = self.config.connection.iopub_url();
76 let stdin_url = self.config.connection.stdin_url();
77 let control_url = self.config.connection.control_url();
78 let heartbeat_url = self.config.connection.heartbeat_url();
79
80 log::info!(
81 "Kernel bound to ports: shell={}, iopub={}, stdin={}, control={}, hb={}",
82 self.config.connection.shell_port,
83 self.config.connection.iopub_port,
84 self.config.connection.stdin_port,
85 self.config.connection.control_port,
86 self.config.connection.hb_port
87 );
88
89 let router = Arc::new(MessageRouter::new(
91 Arc::clone(&self.engine),
92 self.config.session_id.clone(),
93 self.status_tx.clone(),
94 ));
95
96 log::info!(
98 "Message router initialized for session: {}",
99 router.session_id()
100 );
101
102 self.router = Some(Arc::clone(&router));
104
105 let (iopub_tx, mut iopub_rx) = tokio::sync::mpsc::unbounded_channel::<JupyterMessage>();
107
108 let session_for_iopub = self.config.session_id.clone();
110 let key_for_iopub = self.config.connection.key.clone();
111 let scheme_for_iopub = self.config.connection.signature_scheme.clone();
112 let ctx_iopub = ctx.clone();
113 let iopub_task = std::thread::spawn(move || -> Result<()> {
114 let socket = ctx_iopub.socket(zmq::PUB).map_err(KernelError::Zmq)?;
115 socket.bind(&iopub_url).map_err(KernelError::Zmq)?;
116 while let Some(mut msg) = iopub_rx.blocking_recv() {
117 msg.header.session = session_for_iopub.clone();
118 if msg.parent_header.is_none() {
119 msg.parent_header = Some(crate::protocol::MessageHeader::new(
120 MessageType::Status,
121 &session_for_iopub,
122 ));
123 }
124 if let Err(e) =
125 send_jupyter_message(&socket, &[], &key_for_iopub, &scheme_for_iopub, &msg)
126 {
127 log::error!("Failed to publish IOPub message: {e}");
128 }
129 }
130 Ok(())
131 });
132 self.tasks.push(iopub_task);
133
134 let ctx_hb = ctx.clone();
136 let hb_task = std::thread::spawn(move || -> Result<()> {
137 let socket = ctx_hb.socket(zmq::REP).map_err(KernelError::Zmq)?;
138 socket.bind(&heartbeat_url).map_err(KernelError::Zmq)?;
139 loop {
140 let msg = socket.recv_multipart(0).map_err(KernelError::Zmq)?;
141 socket.send_multipart(msg, 0).map_err(KernelError::Zmq)?;
142 }
143 });
144 self.tasks.push(hb_task);
145
146 let engine_shell = Arc::clone(&self.engine);
148 let router_shell = Arc::clone(&router);
149 let session_id_shell = self.config.session_id.clone();
150 let key_shell = self.config.connection.key.clone();
151 let scheme_shell = self.config.connection.signature_scheme.clone();
152 let iopub_tx_shell = iopub_tx.clone();
153 let ctx_shell = ctx.clone();
154 let shell_task = std::thread::spawn(move || -> Result<()> {
155 let shell_socket = ctx_shell.socket(zmq::ROUTER).map_err(KernelError::Zmq)?;
156 shell_socket.bind(&shell_url).map_err(KernelError::Zmq)?;
157 loop {
158 let (ids, msg) = recv_jupyter_message(&shell_socket, &key_shell, &scheme_shell)?;
159
160 match msg.header.msg_type.clone() {
161 MessageType::KernelInfoRequest => {
162 let status_busy = Status {
164 execution_state: ExecutionState::Busy,
165 };
166 let mut status_msg = JupyterMessage::reply(
167 &msg,
168 MessageType::Status,
169 serde_json::to_value(status_busy)?,
170 );
171 status_msg.header.session = session_id_shell.clone();
172 let _ = iopub_tx_shell.send(status_msg);
173
174 let mut reply = futures::executor::block_on(
175 router_shell.handle_kernel_info_request(&msg),
176 )?;
177 reply.header.session = session_id_shell.clone();
178 send_jupyter_message(
179 &shell_socket,
180 &ids,
181 &key_shell,
182 &scheme_shell,
183 &reply,
184 )?;
185
186 let status_idle = Status {
188 execution_state: ExecutionState::Idle,
189 };
190 let mut status_msg = JupyterMessage::reply(
191 &msg,
192 MessageType::Status,
193 serde_json::to_value(status_idle)?,
194 );
195 status_msg.header.session = session_id_shell.clone();
196 let _ = iopub_tx_shell.send(status_msg);
197 }
198 MessageType::ExecuteRequest => {
199 let exec_req: ExecuteRequest = serde_json::from_value(msg.content.clone())?;
200
201 let mut status_msg = JupyterMessage::reply(
203 &msg,
204 MessageType::Status,
205 serde_json::to_value(Status {
206 execution_state: ExecutionState::Busy,
207 })?,
208 );
209 status_msg.header.session = session_id_shell.clone();
210 let _ = iopub_tx_shell.send(status_msg);
211
212 let predicted = {
214 let eng = futures::executor::block_on(engine_shell.lock());
215 eng.execution_count() + 1
216 };
217 let mut input_msg = JupyterMessage::reply(
218 &msg,
219 MessageType::ExecuteInput,
220 serde_json::json!({"code": exec_req.code, "execution_count": predicted}),
221 );
222 input_msg.header.session = session_id_shell.clone();
223 let _ = iopub_tx_shell.send(input_msg);
224
225 let exec_result = {
226 let mut eng = futures::executor::block_on(engine_shell.lock());
227 let req_again: ExecuteRequest =
228 serde_json::from_value(msg.content.clone())?;
229 eng.execute(&req_again.code)
230 .map_err(|e| KernelError::Execution(e.to_string()))?
231 };
232
233 let status = match exec_result.status {
234 ExecutionStatus::Success => crate::protocol::ExecutionStatus::Ok,
235 ExecutionStatus::Error => crate::protocol::ExecutionStatus::Error,
236 ExecutionStatus::Interrupted | ExecutionStatus::Timeout => {
237 crate::protocol::ExecutionStatus::Abort
238 }
239 };
240
241 let exec_count = {
242 let eng = futures::executor::block_on(engine_shell.lock());
243 eng.execution_count()
244 };
245
246 match exec_result.status {
247 ExecutionStatus::Success => {
248 if let Some(val) = exec_result.result {
249 let mut data = std::collections::HashMap::new();
250 data.insert(
251 "text/plain".to_string(),
252 serde_json::json!(val.to_string()),
253 );
254 let res = ExecuteResult {
255 execution_count: exec_count,
256 data,
257 metadata: std::collections::HashMap::new(),
258 };
259 let mut res_msg = JupyterMessage::reply(
260 &msg,
261 MessageType::ExecuteResult,
262 serde_json::to_value(res)?,
263 );
264 res_msg.header.session = session_id_shell.clone();
265 let _ = iopub_tx_shell.send(res_msg);
266 }
267 }
268 ExecutionStatus::Error => {
269 if let Some(err) = exec_result.error {
270 let ec = ErrorContent {
271 ename: err.error_type,
272 evalue: err.message,
273 traceback: err.traceback,
274 };
275 let mut err_msg = JupyterMessage::reply(
276 &msg,
277 MessageType::Error,
278 serde_json::to_value(ec)?,
279 );
280 err_msg.header.session = session_id_shell.clone();
281 let _ = iopub_tx_shell.send(err_msg);
282 }
283 }
284 _ => {}
285 }
286
287 let reply = ExecuteReply {
288 status,
289 execution_count: exec_count,
290 user_expressions: HashMap::new(),
291 payload: Vec::new(),
292 };
293 let mut reply_msg = JupyterMessage::reply(
294 &msg,
295 MessageType::ExecuteReply,
296 serde_json::to_value(reply)?,
297 );
298 reply_msg.header.session = session_id_shell.clone();
299 send_jupyter_message(
300 &shell_socket,
301 &ids,
302 &key_shell,
303 &scheme_shell,
304 &reply_msg,
305 )?;
306
307 let mut status_msg = JupyterMessage::reply(
309 &msg,
310 MessageType::Status,
311 serde_json::to_value(Status {
312 execution_state: ExecutionState::Idle,
313 })?,
314 );
315 status_msg.header.session = session_id_shell.clone();
316 let _ = iopub_tx_shell.send(status_msg);
317 }
318 other => {
319 log::warn!("Unhandled shell message: {:?}", other);
320 if let Ok(Some(reply)) = futures::executor::block_on(async {
321 router_shell.route_message(&msg).await
322 }) {
323 send_jupyter_message(
324 &shell_socket,
325 &ids,
326 &key_shell,
327 &scheme_shell,
328 &reply,
329 )?;
330 }
331 }
332 }
333 }
334 });
335 self.tasks.push(shell_task);
336
337 let router_ctrl = Arc::clone(&router);
339 let key_ctrl = self.config.connection.key.clone();
340 let scheme_ctrl = self.config.connection.signature_scheme.clone();
341 let session_ctrl = self.config.session_id.clone();
342 let iopub_tx_ctrl = iopub_tx.clone();
343 let ctx_ctrl = ctx.clone();
344 let control_task = std::thread::spawn(move || -> Result<()> {
345 let control_socket = ctx_ctrl.socket(zmq::ROUTER).map_err(KernelError::Zmq)?;
346 control_socket
347 .bind(&control_url)
348 .map_err(KernelError::Zmq)?;
349 loop {
350 let (ids, msg) = recv_jupyter_message(&control_socket, &key_ctrl, &scheme_ctrl)?;
351 match msg.header.msg_type.clone() {
352 MessageType::ShutdownRequest | MessageType::InterruptRequest => {
353 let mut status_msg = JupyterMessage::reply(
354 &msg,
355 MessageType::Status,
356 serde_json::to_value(Status {
357 execution_state: ExecutionState::Busy,
358 })?,
359 );
360 status_msg.header.session = session_ctrl.clone();
361 let _ = iopub_tx_ctrl.send(status_msg);
362
363 let mut reply =
364 futures::executor::block_on(router_ctrl.route_message(&msg))?
365 .unwrap_or_else(|| {
366 JupyterMessage::reply(
367 &msg,
368 MessageType::InterruptReply,
369 serde_json::json!({"status":"ok"}),
370 )
371 });
372 reply.header.session = session_ctrl.clone();
373 send_jupyter_message(
374 &control_socket,
375 &ids,
376 &key_ctrl,
377 &scheme_ctrl,
378 &reply,
379 )?;
380
381 let mut status_msg = JupyterMessage::reply(
382 &msg,
383 MessageType::Status,
384 serde_json::to_value(Status {
385 execution_state: ExecutionState::Idle,
386 })?,
387 );
388 status_msg.header.session = session_ctrl.clone();
389 let _ = iopub_tx_ctrl.send(status_msg);
390 }
391 _ => {}
392 }
393 }
394 });
395 self.tasks.push(control_task);
396
397 let key_stdin = self.config.connection.key.clone();
399 let scheme_stdin = self.config.connection.signature_scheme.clone();
400 let session_stdin = self.config.session_id.clone();
401 let ctx_stdin = ctx.clone();
402 let stdin_task = std::thread::spawn(move || -> Result<()> {
403 let stdin_socket = ctx_stdin.socket(zmq::ROUTER).map_err(KernelError::Zmq)?;
404 stdin_socket.bind(&stdin_url).map_err(KernelError::Zmq)?;
405 loop {
406 let (ids, msg) = recv_jupyter_message(&stdin_socket, &key_stdin, &scheme_stdin)?;
407 if matches!(msg.header.msg_type, MessageType::InputRequest) {
408 let mut reply = JupyterMessage::reply(
409 &msg,
410 MessageType::InputReply,
411 serde_json::json!({"value": ""}),
412 );
413 reply.header.session = session_stdin.clone();
414 send_jupyter_message(&stdin_socket, &ids, &key_stdin, &scheme_stdin, &reply)?;
415 }
416 }
417 });
418 self.tasks.push(stdin_task);
419
420 let mut start_msg = JupyterMessage::new(
422 MessageType::Status,
423 &self.config.session_id,
424 serde_json::to_value(Status {
425 execution_state: ExecutionState::Starting,
426 })?,
427 );
428 start_msg.parent_header = None;
429 let _ = iopub_tx.send(start_msg);
430
431 let mut idle_msg = JupyterMessage::new(
432 MessageType::Status,
433 &self.config.session_id,
434 serde_json::to_value(Status {
435 execution_state: ExecutionState::Idle,
436 })?,
437 );
438 idle_msg.parent_header = None;
439 let _ = iopub_tx.send(idle_msg);
440
441 log::info!("RunMat kernel is ready for connections");
442
443 Ok(())
444 }
445
446 pub async fn stop(&mut self) -> Result<()> {
448 log::info!("Stopping kernel server");
449
450 if (self.shutdown_tx.send(()).await).is_err() {
452 log::warn!("Failed to send shutdown signal");
453 }
454
455 for task in self.tasks.drain(..) {
457 match task.join() {
458 Ok(Ok(())) => {}
459 Ok(Err(e)) => log::error!("Task failed during shutdown: {e:?}"),
460 Err(e) => log::error!("Task panicked: {e:?}"),
461 }
462 }
463
464 log::info!("Kernel server stopped");
465 Ok(())
466 }
467
468 pub fn kernel_info(&self) -> KernelInfo {
470 KernelInfo::default()
471 }
472
473 pub fn connection_info(&self) -> &ConnectionInfo {
475 &self.config.connection
476 }
477
478 pub async fn stats(&self) -> Result<ExecutionStats> {
480 let engine = self.engine.lock().await;
481 Ok(engine.stats())
482 }
483
484 pub async fn handle_message(&self, message: &JupyterMessage) -> Result<Option<JupyterMessage>> {
486 if let Some(ref router) = self.router {
487 router.route_message(message).await
488 } else {
489 Err(crate::KernelError::Internal(
490 "Message router not initialized".to_string(),
491 ))
492 }
493 }
494
495 pub fn session_id(&self) -> Option<&str> {
497 self.router.as_ref().map(|r| r.session_id())
498 }
499
500 pub async fn send_status(&self, status: ExecutionState) -> Result<()> {
502 if let Some(ref router) = self.router {
503 router.send_status(status).await
504 } else {
505 self.status_tx
506 .send(status)
507 .map_err(|e| crate::KernelError::Internal(format!("Failed to send status: {e}")))?;
508 Ok(())
509 }
510 }
511}
512
513impl MessageRouter {
514 pub fn new(
516 engine: Arc<tokio::sync::Mutex<ExecutionEngine>>,
517 session_id: String,
518 status_tx: broadcast::Sender<ExecutionState>,
519 ) -> Self {
520 Self {
521 engine,
522 session_id,
523 status_tx,
524 }
525 }
526
527 pub fn session_id(&self) -> &str {
529 &self.session_id
530 }
531
532 pub async fn send_status(&self, status: ExecutionState) -> Result<()> {
534 self.status_tx
535 .send(status)
536 .map_err(|e| KernelError::Internal(format!("Failed to send status: {e}")))?;
537 Ok(())
538 }
539
540 pub async fn route_message(&self, msg: &JupyterMessage) -> Result<Option<JupyterMessage>> {
542 let _ = self.send_status(ExecutionState::Busy).await;
544
545 let result = match msg.header.msg_type {
546 MessageType::KernelInfoRequest => Ok(Some(self.handle_kernel_info_request(msg).await?)),
547 MessageType::ExecuteRequest => Ok(Some(self.handle_execute_request(msg).await?)),
548 MessageType::ShutdownRequest => Ok(Some(self.handle_shutdown_request(msg).await?)),
549 MessageType::InterruptRequest => Ok(Some(self.handle_interrupt_request(msg).await?)),
550 _ => {
551 log::warn!("Unhandled message type: {:?}", msg.header.msg_type);
552 Ok(None)
553 }
554 };
555
556 let _ = self.send_status(ExecutionState::Idle).await;
558
559 result
560 }
561
562 async fn handle_kernel_info_request(&self, msg: &JupyterMessage) -> Result<JupyterMessage> {
564 let kernel_info = KernelInfo::default();
565 let content = serde_json::to_value(&kernel_info)?;
566 Ok(JupyterMessage::reply(
567 msg,
568 MessageType::KernelInfoReply,
569 content,
570 ))
571 }
572
573 async fn handle_execute_request(&self, msg: &JupyterMessage) -> Result<JupyterMessage> {
575 let execute_req: ExecuteRequest = serde_json::from_value(msg.content.clone())?;
577
578 let _ = self.status_tx.send(ExecutionState::Busy);
580
581 let mut engine = self.engine.lock().await;
583 let exec_result = engine.execute(&execute_req.code)?;
584
585 let status = match exec_result.status {
587 ExecutionStatus::Success => crate::protocol::ExecutionStatus::Ok,
588 ExecutionStatus::Error => crate::protocol::ExecutionStatus::Error,
589 ExecutionStatus::Interrupted => crate::protocol::ExecutionStatus::Abort,
590 ExecutionStatus::Timeout => crate::protocol::ExecutionStatus::Abort,
591 };
592
593 let reply = ExecuteReply {
594 status,
595 execution_count: engine.execution_count(),
596 user_expressions: HashMap::new(),
597 payload: Vec::new(),
598 };
599
600 let _ = self.status_tx.send(ExecutionState::Idle);
602
603 let content = serde_json::to_value(&reply)?;
604 Ok(JupyterMessage::reply(
605 msg,
606 MessageType::ExecuteReply,
607 content,
608 ))
609 }
610
611 async fn handle_shutdown_request(&self, msg: &JupyterMessage) -> Result<JupyterMessage> {
613 let shutdown_reply = serde_json::json!({
614 "restart": false
615 });
616 Ok(JupyterMessage::reply(
617 msg,
618 MessageType::ShutdownReply,
619 shutdown_reply,
620 ))
621 }
622
623 async fn handle_interrupt_request(&self, msg: &JupyterMessage) -> Result<JupyterMessage> {
625 let interrupt_reply = serde_json::json!({
626 "status": "ok"
627 });
628 Ok(JupyterMessage::reply(
629 msg,
630 MessageType::InterruptReply,
631 interrupt_reply,
632 ))
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_kernel_server_creation() {
642 let config = KernelConfig::default();
643 let server = KernelServer::new(config);
644 assert!(server.tasks.is_empty());
645 }
646
647 #[tokio::test]
648 async fn test_message_router_kernel_info() {
649 let engine = Arc::new(tokio::sync::Mutex::new(ExecutionEngine::new()));
650 let (status_tx, _) = broadcast::channel(16);
651
652 let router = MessageRouter::new(engine, "test".to_string(), status_tx);
653
654 let request = JupyterMessage::new(
655 MessageType::KernelInfoRequest,
656 "test",
657 serde_json::json!({}),
658 );
659
660 let reply = router.handle_kernel_info_request(&request).await.unwrap();
661 assert_eq!(reply.header.msg_type, MessageType::KernelInfoReply);
662 assert!(reply.parent_header.is_some());
663 }
664
665 #[tokio::test]
666 async fn test_message_router_execute() {
667 let engine = Arc::new(tokio::sync::Mutex::new(ExecutionEngine::new()));
668 let (status_tx, _) = broadcast::channel(16);
669
670 let router = MessageRouter::new(engine, "test".to_string(), status_tx);
671
672 let execute_req = ExecuteRequest {
673 code: "x = 1 + 2".to_string(),
674 silent: false,
675 store_history: true,
676 user_expressions: HashMap::new(),
677 allow_stdin: false,
678 stop_on_error: false,
679 };
680
681 let content = serde_json::to_value(&execute_req).unwrap();
682 let request = JupyterMessage::new(MessageType::ExecuteRequest, "test", content);
683
684 let reply = router.handle_execute_request(&request).await.unwrap();
685 assert_eq!(reply.header.msg_type, MessageType::ExecuteReply);
686
687 let reply_content: ExecuteReply = serde_json::from_value(reply.content).unwrap();
688 assert_eq!(reply_content.execution_count, 1);
689 }
690
691 #[test]
692 fn test_kernel_info_default() {
693 let info = KernelInfo::default();
694 assert_eq!(info.implementation, "runmat");
695 assert_eq!(info.language_info.name, "matlab");
696 assert_eq!(info.protocol_version, "5.3");
697 }
698}