use crate::{
connection::Connection,
errors::JupyterResult,
jupyter_message::{JupyterMessage, JupyterMessageType},
ExecutionRequest, JupyterConnection, JupyterKernelProtocol, JupyterKernelSockets,
};
use crate::{
commands::start::KernelControl,
jupyter_message::{CommonInfoRequest, KernelInfoReply},
};
use serde_json::Value;
use std::{sync::Arc, time::SystemTime};
use tokio::{sync::Mutex, task::JoinHandle};
use zeromq::{PubSocket, RepSocket, RouterSocket, Socket, SocketRecv, SocketSend, ZmqMessage};
#[derive(Clone)]
#[allow(unused)]
pub(crate) struct SealedServer {
heartbeat: Arc<Mutex<Connection<RepSocket>>>,
iopub: Arc<Mutex<Connection<PubSocket>>>,
stdin: Arc<Mutex<Connection<RouterSocket>>>,
control: Arc<Mutex<Connection<RouterSocket>>>,
shell_socket: Arc<Mutex<Connection<RouterSocket>>>,
latest_execution_request: Arc<Mutex<Option<JupyterMessage>>>,
shutdown_sender: Arc<Mutex<Option<crossbeam_channel::Sender<()>>>>,
tokio_handle: tokio::runtime::Handle,
}
pub struct ExecuteProvider<T> {
pub(crate) context: Arc<Mutex<T>>,
pub(crate) sockets: JupyterKernelSockets,
}
impl<T> Clone for ExecuteProvider<T> {
fn clone(&self) -> Self {
Self { context: self.context.clone(), sockets: self.sockets.clone() }
}
}
impl<T> ExecuteProvider<T> {
pub fn new(context: T, sockets: JupyterKernelSockets) -> Self
where
T: JupyterKernelProtocol + 'static,
{
Self { context: Arc::new(Mutex::new(context)), sockets }
}
}
struct ShutdownReceiver {
recv: crossbeam_channel::Receiver<()>,
}
impl SealedServer {
pub(crate) fn run<T>(config: &KernelControl, server: T) -> JupyterResult<()>
where
T: JupyterKernelProtocol + 'static,
{
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()
.unwrap();
let handle = runtime.handle().clone();
runtime.block_on(async {
let shutdown_receiver = Self::start(config, handle, server).await?;
shutdown_receiver.wait_for_shutdown().await;
let result: JupyterResult<()> = Ok(());
result
})?;
Ok(())
}
async fn start<T>(
config: &KernelControl,
tokio_handle: tokio::runtime::Handle,
mut server: T,
) -> JupyterResult<ShutdownReceiver>
where
T: JupyterKernelProtocol + 'static,
{
let heartbeat = bind_socket::<RepSocket>(config, config.hb_port).await?;
let shell_socket = bind_socket::<RouterSocket>(config, config.shell_port).await?;
let control_socket = bind_socket::<RouterSocket>(config, config.control_port).await?;
let stdin_socket = bind_socket::<RouterSocket>(config, config.stdin_port).await?;
let io_pub_socket = bind_socket::<PubSocket>(config, config.iopub_port).await?;
let io_pub = Arc::new(Mutex::new(io_pub_socket));
let (shutdown_sender, shutdown_receiver) = crossbeam_channel::unbounded();
let latest_execution_request = Arc::new(Mutex::new(None));
let sockets = JupyterKernelSockets {
execute_count: Arc::new(Mutex::new(1)),
io_channel: Some(io_pub.clone()),
debugging: Arc::new(Mutex::new(false)),
};
let setup = JupyterConnection { boot_path: Default::default(), sockets: sockets.clone() };
server.connected(setup);
let here = SealedServer {
iopub: io_pub,
heartbeat: Arc::new(Mutex::new(heartbeat)),
latest_execution_request,
stdin: Arc::new(Mutex::new(stdin_socket)),
control: Arc::new(Mutex::new(control_socket)),
shutdown_sender: Arc::new(Mutex::new(Some(shutdown_sender))),
tokio_handle,
shell_socket: Arc::new(Mutex::new(shell_socket)),
};
let context = ExecuteProvider::new(server, sockets);
here.clone().spawn_heart_beat();
here.clone().spawn_shell_execution(context.clone());
here.clone().spawn_control(context.clone());
here.clone().spawn_std_in(context.clone());
Ok(ShutdownReceiver { recv: shutdown_receiver })
}
async fn signal_shutdown(&self) {
self.shutdown_sender.lock().await.take();
}
fn spawn_heart_beat(self) -> JoinHandle<()> {
tokio::spawn(async move {
loop {
if let Err(e) = self.clone().handle_heart_beat().await {
tracing::warn!("Error sending heartbeat: {:?}", e);
}
}
})
}
async fn handle_heart_beat(self) -> JupyterResult<()> {
let mut connection = match self.heartbeat.try_lock() {
Ok(o) => o,
Err(_) => return Ok(()),
};
let _ = connection.socket.recv().await?;
connection.socket.send(ZmqMessage::from(b"ping".to_vec())).await?;
Ok(())
}
fn spawn_shell_execution<T>(self, executor: ExecuteProvider<T>) -> JoinHandle<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
tokio::spawn(async move {
tracing::info!("Shell Executor Spawned");
loop {
if let Err(e) = self.clone().handle_shell(executor.clone()).await {
tracing::error!("Error sending shell execution: {:?}", e);
}
}
})
}
async fn handle_shell<'a, T>(self, executor: ExecuteProvider<T>) -> JupyterResult<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
let request = JupyterMessage::read(&mut &mut self.shell_socket.lock().await).await?;
request.send_state(self.iopub.clone(), true).await?;
match request.kind() {
JupyterMessageType::KernelInfoRequest => {
let info = executor.context.lock().await.language_info();
let cont = KernelInfoReply::build(info);
request.as_reply().with_content(cont)?.send_by(&mut &mut self.shell_socket.lock().await).await?
}
JupyterMessageType::ExecuteRequest => {
let time = SystemTime::now();
let mut task = request.recast::<ExecutionRequest>()?;
task.header = request.clone();
let mut runner = executor.context.lock().await;
let count = executor.sockets.get_counter();
let reply = runner.running(task.clone()).await.with_count(count);
match time.elapsed() {
Ok(o) => {
let escape = runner.running_time(o.as_secs_f64());
if !escape.is_empty() {
let time = task.as_result("text/html".to_string(), Value::String(escape));
request
.as_reply()
.with_message_type(JupyterMessageType::ExecuteResult)
.with_content(time)?
.send_by(&mut &mut self.iopub.lock().await)
.await?;
}
}
Err(_) => {}
}
request.as_reply().with_content(reply)?.send_by(&mut &mut self.shell_socket.lock().await).await?;
}
JupyterMessageType::CommonInfoRequest => {
let task = request.recast::<CommonInfoRequest>()?;
request.as_reply().with_content(task.as_reply())?.send_by(&mut &mut self.shell_socket.lock().await).await?;
}
JupyterMessageType::Custom(v) => {
tracing::error!("Got unknown shell message: {:?}", v);
}
_ => {
tracing::warn!("Got custom shell message: {:?}", request);
}
}
request.send_state(self.iopub, false).await?;
Ok(())
}
#[allow(dead_code)]
fn spawn_execution_queue<T>(self, executor: ExecuteProvider<T>) -> JoinHandle<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
let mut running_count = 0;
tokio::spawn(async move {
tracing::trace!("Queue Executor Spawned");
loop {
if let Err(e) = self.clone().handle_execution_queue(executor.clone(), running_count).await {
eprintln!("Error sending execution queue: {:?}", e);
}
running_count += 1;
}
})
}
#[allow(dead_code)]
async fn handle_execution_queue<T>(self, _executor: ExecuteProvider<T>, _count: i32) -> JupyterResult<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
todo!()
}
fn spawn_control<T>(self, executor: ExecuteProvider<T>) -> JoinHandle<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
tokio::spawn(async move {
tracing::info!("Control Executor Spawned");
loop {
if let Err(e) = self.clone().handle_control(executor.clone()).await {
tracing::error!("Error sending control execution: {:?}", e);
}
}
})
}
async fn handle_control<'a, T>(self, executor: ExecuteProvider<T>) -> JupyterResult<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
let control = &mut self.control.lock().await;
let request = JupyterMessage::read(control).await?;
request.send_state(self.iopub.clone(), true).await?;
match request.kind() {
JupyterMessageType::KernelInfoRequest => {
let info = executor.context.lock().await;
let cont = KernelInfoReply::build(info.language_info());
request.as_reply().with_content(cont)?.send_by(control).await?
}
JupyterMessageType::DebugRequest => {
let result = request.debug_response(executor).await?;
request.as_reply().with_content(result)?.send_by(control).await?;
}
JupyterMessageType::InterruptRequest => {
let runner = executor.context.lock().await;
match runner.interrupt_kernel() {
Some(_) => {
request
.as_reply()
.create_message(JupyterMessageType::StatusReply)
.with_content("ok")?
.send_by(control)
.await?
}
None => {
request
.as_reply()
.create_message(JupyterMessageType::StatusReply)
.with_content("ok")?
.send_by(control)
.await?
}
}
}
JupyterMessageType::ShutdownRequest => self.signal_shutdown().await,
JupyterMessageType::Custom(v) => {
tracing::error!("Got unknown control message: {:#?}", v);
}
_ => {
tracing::warn!("Got custom control message: {:#?}", request);
}
}
request.send_state(self.iopub.clone(), false).await?;
Ok(())
}
fn spawn_std_in<T>(self, executor: ExecuteProvider<T>) -> JoinHandle<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
tokio::spawn(async move {
tracing::info!("IO Executor Spawned");
loop {
if let Err(e) = self.clone().handle_std_in(executor.clone()).await {
tracing::error!("Error sending io execution: {:?}", e);
}
}
})
}
async fn handle_std_in<'a, T>(self, _: ExecuteProvider<T>) -> JupyterResult<()>
where
T: JupyterKernelProtocol + Send + 'static,
{
let io = &mut self.stdin.lock().await;
let request = JupyterMessage::read(io).await?;
match request.kind() {
JupyterMessageType::Custom(v) => {
tracing::error!("Got unknown io message: {:?}", v);
}
_ => {
tracing::warn!("Got custom io message: {:?}", request);
}
}
Ok(())
}
}
impl ShutdownReceiver {
async fn wait_for_shutdown(self) {
let _ = tokio::task::spawn_blocking(move || self.recv.recv()).await;
}
}
async fn bind_socket<S: Socket>(config: &KernelControl, port: u16) -> JupyterResult<Connection<S>> {
let endpoint = format!("{}://{}:{}", config.transport, config.ip, port);
let mut socket = S::new();
socket.bind(&endpoint).await?;
Connection::new(socket, &config.key)
}