kernel_sidecar/
client.rs

1/*
2The Kernel Sidecar Client is the main entrypoint for connecting to a Kernel over ZMQ and issuing
3Actions (requests) to the Kernel then handling all responses with parent_header_msg id's matching
4the original request.
5
6Each Action is "complete" (awaitable) when the Kernel status has gone back to Idle and the expected
7reply type has been seen (e.g. kernel_info_reply for kernel_info_request).
8
9Message passing between background tasks is done with mpsc channels.
10 - background tasks listening to iopub and shell channels push messages to a central process_message
11   worker over mpsc.
12 - process_message background task deserializes messages and looks up the appropriate Action based
13   on parent header msg id then pushes to the Action handlers over mpsc.
14
15Example usage, run until a kernel info request/reply has been completed and print out all ZMQ
16messages coming back over iopub and shell channels:
17
18let connection_info = ConnectionInfo::from_file("/tmp/kernel.json")
19    .expect("Make sure to run python -m ipykernel_launcher -f /tmp/kernel.json");
20let client = Client::new(connection_info).await;
21
22#[derive(Debug)]
23struct DebugHandler;
24
25#[async_trait::async_trait]
26impl Handler for DebugHandler {
27    async fn handle(&self, msg: &Response) {
28        dbg!(msg);
29    }
30}
31
32let handler = DebugHandler {};
33let handlers = vec![Arc::new(handler) as Arc<dyn Handler>];
34let action = client.kernel_info_request(handlers).await;
35action.await;
36*/
37
38use std::collections::HashMap;
39
40use std::sync::Arc;
41use std::time::Duration;
42use tokio::sync::{mpsc, Mutex, Notify, RwLock};
43use tokio::time::sleep;
44use zeromq::{DealerSocket, ReqSocket, Socket, SocketRecv, SocketSend, SubSocket, ZmqMessage};
45
46use crate::actions::Action;
47use crate::handlers::Handler;
48use crate::jupyter::connection_file::ConnectionInfo;
49use crate::jupyter::request::Request;
50use crate::jupyter::response::Response;
51use crate::jupyter::shell_content::execute::ExecuteRequest;
52use crate::jupyter::shell_content::kernel_info::KernelInfoRequest;
53use crate::jupyter::wire_protocol::WireProtocol;
54
55#[derive(Debug, Clone)]
56pub struct Client {
57    actions: Arc<RwLock<HashMap<String, mpsc::Sender<Response>>>>,
58    connection_info: ConnectionInfo,
59    shell_tx: mpsc::Sender<ZmqMessage>,
60    shutdown_signal: Arc<Notify>,
61}
62
63impl Client {
64    pub async fn new(connection_info: ConnectionInfo) -> Self {
65        let actions = Arc::new(RwLock::new(HashMap::new()));
66        // message passing for methods to send requests out over shell channel via shell_worker
67        let (shell_tx, shell_rx) = mpsc::channel(100);
68
69        // message passing for iopub and shell listeners into process_message_worker
70        let (process_msg_tx, process_msg_rx) = mpsc::channel(100);
71
72        // For shutting down ZMQ listeners when Client is dropped
73        let shutdown_signal = Arc::new(Notify::new());
74
75        // spawn iopub and shell listeners
76        let iopub_address = connection_info.iopub_address();
77        let shell_address = connection_info.shell_address();
78
79        tokio::spawn(iopub_worker(
80            iopub_address,
81            process_msg_tx.clone(),
82            shutdown_signal.clone(),
83        ));
84        tokio::spawn(shell_worker(
85            shell_address,
86            shell_rx,
87            process_msg_tx.clone(),
88            shutdown_signal.clone(),
89        ));
90
91        // spawn process_message_worker
92        tokio::spawn(process_message_worker(
93            process_msg_rx,
94            actions.clone(),
95            shutdown_signal.clone(),
96        ));
97
98        Client {
99            actions,
100            connection_info,
101            shell_tx,
102            shutdown_signal,
103        }
104    }
105
106    // Try to connect to the heartbeat channel and send a ping message
107    // You can use this as a way to wait for a new Kernel to come up or check if it's connected
108    pub async fn heartbeat(&self) {
109        loop {
110            let mut socket = ReqSocket::new();
111
112            // Try to connect to the heartbeat channel
113            if let Err(_e) = socket
114                .connect(self.connection_info.heartbeat_address().as_str())
115                .await
116            {
117                sleep(Duration::from_millis(50)).await;
118                continue; // If connection fails, retry in the next iteration of the loop
119            }
120
121            // Send a ping message
122            let ping_msg = ZmqMessage::from("ping");
123            if let Err(_e) = socket.send(ping_msg).await {
124                sleep(Duration::from_millis(50)).await;
125                continue; // If sending fails, retry in the next iteration of the loop
126            }
127
128            // Wait for a pong message
129            match socket.recv().await {
130                Ok(_) => {
131                    break; // Successful pong message received, break the loop
132                }
133                Err(_) => {
134                    sleep(Duration::from_millis(50)).await;
135                    continue; // If receiving fails, retry in the next iteration of the loop
136                }
137            }
138        }
139    }
140
141    // Creates an Action from a request + handlers, serializes the request to be sent over ZMQ,
142    // sends over shell channel, and registers the request header msg_id in the Actions hashmap
143    // so that all response messages can get routed to the appropriate Action handlers
144    async fn send_request(
145        &self,
146        request: Request,
147        handlers: Vec<Arc<Mutex<dyn Handler>>>,
148    ) -> Action {
149        let (msg_tx, msg_rx) = mpsc::channel(100);
150        let action = Action::new(request, handlers, msg_rx);
151        let msg_id = action.request.msg_id();
152        self.actions.write().await.insert(msg_id.clone(), msg_tx);
153        let wp: WireProtocol = action.request.into_wire_protocol(&self.connection_info.key);
154        let zmq_msg: ZmqMessage = wp.into();
155        self.shell_tx.send(zmq_msg).await.unwrap();
156        action
157    }
158
159    pub async fn kernel_info_request(&self, handlers: Vec<Arc<Mutex<dyn Handler>>>) -> Action {
160        let request = KernelInfoRequest::new();
161        self.send_request(request.into(), handlers).await
162    }
163
164    pub async fn execute_request(
165        &self,
166        code: String,
167        handlers: Vec<Arc<Mutex<dyn Handler>>>,
168    ) -> Action {
169        let request = ExecuteRequest::new(code);
170        self.send_request(request.into(), handlers).await
171    }
172}
173
174impl Drop for Client {
175    fn drop(&mut self) {
176        self.shutdown_signal.notify_waiters();
177    }
178}
179
180/// The tasks listening on iopub and shell channels will push any messages they receive into this
181/// processing function. Its job is to deserialize ZmqMessage into the appropriate Jupyter message
182/// and then delegate it to the appropriate Action to be handled based on parent msg_id.
183async fn process_message_worker(
184    mut msg_rx: mpsc::Receiver<ZmqMessage>,
185    actions: Arc<RwLock<HashMap<String, mpsc::Sender<Response>>>>,
186    shutdown_signal: Arc<Notify>, // hook to shutdown background task if Client is dropped
187) {
188    loop {
189        tokio::select! {
190            Some(zmq_msg) = msg_rx.recv() => {
191                let response: Response = zmq_msg.into();
192                let msg_id = response.parent_msg_id();
193                if msg_id.is_none() {
194                    dbg!("No parent msg id, skipping msg_type {}", response.msg_type());
195                    continue;
196                }
197                let msg_id = msg_id.unwrap();
198                if let Some(action) = actions.read().await.get(&msg_id) {
199                   let sent = action.send(response).await;
200                   // If we're seeing SendError here, it means we're still seeing ZMQ messages with
201                   // parent header msg id matching a request / Action that is "completed" and has
202                   // shut down its mpsc Receiver channel. That's probably happening because the
203                   // Action is not configured to expect some Reply type and is "finishing" when
204                   // Kernel status goes Idle but then we send along another Reply messages to a
205                   // shutdown mpsc Receiver channel.
206                   match sent {
207                          Ok(_) => {},
208                          Err(e) => {
209                            dbg!(e);
210                          }
211                   }
212                }
213            },
214            _ = shutdown_signal.notified() => {
215                break;
216            }
217        }
218    }
219}
220
221/// iopub channel background task is only responsible for listening to the iopub channel and pushing
222/// messages to the process_message_worker. We never send anything out on the iopub channel.
223async fn iopub_worker(
224    iopub_address: String,
225    msg_tx: mpsc::Sender<ZmqMessage>,
226    shutdown_signal: Arc<Notify>,
227) {
228    let mut socket = SubSocket::new();
229    socket.connect(iopub_address.as_str()).await.unwrap();
230    socket.subscribe("").await.unwrap();
231
232    loop {
233        tokio::select! {
234            Ok(msg) = socket.recv() => {
235                msg_tx.send(msg).await.unwrap();
236            },
237            _ = shutdown_signal.notified() => {
238                break;
239            }
240        }
241    }
242}
243
244/// shell channel background task needs to have a way for the Client to send stuff out over shell
245/// in addition to listening for replies coming back on the channel, then pushing those to the
246/// process_message_worker.
247async fn shell_worker(
248    shell_address: String,
249    mut msg_rx: mpsc::Receiver<ZmqMessage>, // Client wants to send Jupyter message over ZMQ
250    msg_tx: mpsc::Sender<ZmqMessage>,       // Kernel sent a reply over ZMQ, needs to get processed
251    shutdown_signal: Arc<Notify>,
252) {
253    let mut socket = DealerSocket::new();
254    socket.connect(shell_address.as_str()).await.unwrap();
255
256    loop {
257        tokio::select! {
258            Some(client_to_kernel_msg) = msg_rx.recv() => {
259                socket.send(client_to_kernel_msg).await.unwrap();
260            }
261            kernel_to_client_msg = socket.recv() => {
262                match kernel_to_client_msg {
263                    Ok(msg) => {
264                        msg_tx.send(msg).await.unwrap();
265                    }
266                    Err(e) => {
267                        dbg!(e);
268                    }
269                }
270            },
271            _ = shutdown_signal.notified() => {
272                break;
273            }
274        }
275    }
276}