kernel_sidecar/client.rs
1/*
2The Kernel Sidecar Client is the main entrypoint for connecting to a Kernel over ZMQ and issuing
3Actions (requests) to the Kernel then handling all responses with parent_header_msg id's matching
4the original request.
5
6Each Action is "complete" (awaitable) when the Kernel status has gone back to Idle and the expected
7reply type has been seen (e.g. kernel_info_reply for kernel_info_request).
8
9Message passing between background tasks is done with mpsc channels.
10 - background tasks listening to iopub and shell channels push messages to a central process_message
11 worker over mpsc.
12 - process_message background task deserializes messages and looks up the appropriate Action based
13 on parent header msg id then pushes to the Action handlers over mpsc.
14
15Example usage, run until a kernel info request/reply has been completed and print out all ZMQ
16messages coming back over iopub and shell channels:
17
18let connection_info = ConnectionInfo::from_file("/tmp/kernel.json")
19 .expect("Make sure to run python -m ipykernel_launcher -f /tmp/kernel.json");
20let client = Client::new(connection_info).await;
21
22#[derive(Debug)]
23struct DebugHandler;
24
25#[async_trait::async_trait]
26impl Handler for DebugHandler {
27 async fn handle(&self, msg: &Response) {
28 dbg!(msg);
29 }
30}
31
32let handler = DebugHandler {};
33let handlers = vec![Arc::new(handler) as Arc<dyn Handler>];
34let action = client.kernel_info_request(handlers).await;
35action.await;
36*/
37
38use std::collections::HashMap;
39
40use std::sync::Arc;
41use std::time::Duration;
42use tokio::sync::{mpsc, Mutex, Notify, RwLock};
43use tokio::time::sleep;
44use zeromq::{DealerSocket, ReqSocket, Socket, SocketRecv, SocketSend, SubSocket, ZmqMessage};
45
46use crate::actions::Action;
47use crate::handlers::Handler;
48use crate::jupyter::connection_file::ConnectionInfo;
49use crate::jupyter::request::Request;
50use crate::jupyter::response::Response;
51use crate::jupyter::shell_content::execute::ExecuteRequest;
52use crate::jupyter::shell_content::kernel_info::KernelInfoRequest;
53use crate::jupyter::wire_protocol::WireProtocol;
54
55#[derive(Debug, Clone)]
56pub struct Client {
57 actions: Arc<RwLock<HashMap<String, mpsc::Sender<Response>>>>,
58 connection_info: ConnectionInfo,
59 shell_tx: mpsc::Sender<ZmqMessage>,
60 shutdown_signal: Arc<Notify>,
61}
62
63impl Client {
64 pub async fn new(connection_info: ConnectionInfo) -> Self {
65 let actions = Arc::new(RwLock::new(HashMap::new()));
66 // message passing for methods to send requests out over shell channel via shell_worker
67 let (shell_tx, shell_rx) = mpsc::channel(100);
68
69 // message passing for iopub and shell listeners into process_message_worker
70 let (process_msg_tx, process_msg_rx) = mpsc::channel(100);
71
72 // For shutting down ZMQ listeners when Client is dropped
73 let shutdown_signal = Arc::new(Notify::new());
74
75 // spawn iopub and shell listeners
76 let iopub_address = connection_info.iopub_address();
77 let shell_address = connection_info.shell_address();
78
79 tokio::spawn(iopub_worker(
80 iopub_address,
81 process_msg_tx.clone(),
82 shutdown_signal.clone(),
83 ));
84 tokio::spawn(shell_worker(
85 shell_address,
86 shell_rx,
87 process_msg_tx.clone(),
88 shutdown_signal.clone(),
89 ));
90
91 // spawn process_message_worker
92 tokio::spawn(process_message_worker(
93 process_msg_rx,
94 actions.clone(),
95 shutdown_signal.clone(),
96 ));
97
98 Client {
99 actions,
100 connection_info,
101 shell_tx,
102 shutdown_signal,
103 }
104 }
105
106 // Try to connect to the heartbeat channel and send a ping message
107 // You can use this as a way to wait for a new Kernel to come up or check if it's connected
108 pub async fn heartbeat(&self) {
109 loop {
110 let mut socket = ReqSocket::new();
111
112 // Try to connect to the heartbeat channel
113 if let Err(_e) = socket
114 .connect(self.connection_info.heartbeat_address().as_str())
115 .await
116 {
117 sleep(Duration::from_millis(50)).await;
118 continue; // If connection fails, retry in the next iteration of the loop
119 }
120
121 // Send a ping message
122 let ping_msg = ZmqMessage::from("ping");
123 if let Err(_e) = socket.send(ping_msg).await {
124 sleep(Duration::from_millis(50)).await;
125 continue; // If sending fails, retry in the next iteration of the loop
126 }
127
128 // Wait for a pong message
129 match socket.recv().await {
130 Ok(_) => {
131 break; // Successful pong message received, break the loop
132 }
133 Err(_) => {
134 sleep(Duration::from_millis(50)).await;
135 continue; // If receiving fails, retry in the next iteration of the loop
136 }
137 }
138 }
139 }
140
141 // Creates an Action from a request + handlers, serializes the request to be sent over ZMQ,
142 // sends over shell channel, and registers the request header msg_id in the Actions hashmap
143 // so that all response messages can get routed to the appropriate Action handlers
144 async fn send_request(
145 &self,
146 request: Request,
147 handlers: Vec<Arc<Mutex<dyn Handler>>>,
148 ) -> Action {
149 let (msg_tx, msg_rx) = mpsc::channel(100);
150 let action = Action::new(request, handlers, msg_rx);
151 let msg_id = action.request.msg_id();
152 self.actions.write().await.insert(msg_id.clone(), msg_tx);
153 let wp: WireProtocol = action.request.into_wire_protocol(&self.connection_info.key);
154 let zmq_msg: ZmqMessage = wp.into();
155 self.shell_tx.send(zmq_msg).await.unwrap();
156 action
157 }
158
159 pub async fn kernel_info_request(&self, handlers: Vec<Arc<Mutex<dyn Handler>>>) -> Action {
160 let request = KernelInfoRequest::new();
161 self.send_request(request.into(), handlers).await
162 }
163
164 pub async fn execute_request(
165 &self,
166 code: String,
167 handlers: Vec<Arc<Mutex<dyn Handler>>>,
168 ) -> Action {
169 let request = ExecuteRequest::new(code);
170 self.send_request(request.into(), handlers).await
171 }
172}
173
174impl Drop for Client {
175 fn drop(&mut self) {
176 self.shutdown_signal.notify_waiters();
177 }
178}
179
180/// The tasks listening on iopub and shell channels will push any messages they receive into this
181/// processing function. Its job is to deserialize ZmqMessage into the appropriate Jupyter message
182/// and then delegate it to the appropriate Action to be handled based on parent msg_id.
183async fn process_message_worker(
184 mut msg_rx: mpsc::Receiver<ZmqMessage>,
185 actions: Arc<RwLock<HashMap<String, mpsc::Sender<Response>>>>,
186 shutdown_signal: Arc<Notify>, // hook to shutdown background task if Client is dropped
187) {
188 loop {
189 tokio::select! {
190 Some(zmq_msg) = msg_rx.recv() => {
191 let response: Response = zmq_msg.into();
192 let msg_id = response.parent_msg_id();
193 if msg_id.is_none() {
194 dbg!("No parent msg id, skipping msg_type {}", response.msg_type());
195 continue;
196 }
197 let msg_id = msg_id.unwrap();
198 if let Some(action) = actions.read().await.get(&msg_id) {
199 let sent = action.send(response).await;
200 // If we're seeing SendError here, it means we're still seeing ZMQ messages with
201 // parent header msg id matching a request / Action that is "completed" and has
202 // shut down its mpsc Receiver channel. That's probably happening because the
203 // Action is not configured to expect some Reply type and is "finishing" when
204 // Kernel status goes Idle but then we send along another Reply messages to a
205 // shutdown mpsc Receiver channel.
206 match sent {
207 Ok(_) => {},
208 Err(e) => {
209 dbg!(e);
210 }
211 }
212 }
213 },
214 _ = shutdown_signal.notified() => {
215 break;
216 }
217 }
218 }
219}
220
221/// iopub channel background task is only responsible for listening to the iopub channel and pushing
222/// messages to the process_message_worker. We never send anything out on the iopub channel.
223async fn iopub_worker(
224 iopub_address: String,
225 msg_tx: mpsc::Sender<ZmqMessage>,
226 shutdown_signal: Arc<Notify>,
227) {
228 let mut socket = SubSocket::new();
229 socket.connect(iopub_address.as_str()).await.unwrap();
230 socket.subscribe("").await.unwrap();
231
232 loop {
233 tokio::select! {
234 Ok(msg) = socket.recv() => {
235 msg_tx.send(msg).await.unwrap();
236 },
237 _ = shutdown_signal.notified() => {
238 break;
239 }
240 }
241 }
242}
243
244/// shell channel background task needs to have a way for the Client to send stuff out over shell
245/// in addition to listening for replies coming back on the channel, then pushing those to the
246/// process_message_worker.
247async fn shell_worker(
248 shell_address: String,
249 mut msg_rx: mpsc::Receiver<ZmqMessage>, // Client wants to send Jupyter message over ZMQ
250 msg_tx: mpsc::Sender<ZmqMessage>, // Kernel sent a reply over ZMQ, needs to get processed
251 shutdown_signal: Arc<Notify>,
252) {
253 let mut socket = DealerSocket::new();
254 socket.connect(shell_address.as_str()).await.unwrap();
255
256 loop {
257 tokio::select! {
258 Some(client_to_kernel_msg) = msg_rx.recv() => {
259 socket.send(client_to_kernel_msg).await.unwrap();
260 }
261 kernel_to_client_msg = socket.recv() => {
262 match kernel_to_client_msg {
263 Ok(msg) => {
264 msg_tx.send(msg).await.unwrap();
265 }
266 Err(e) => {
267 dbg!(e);
268 }
269 }
270 },
271 _ = shutdown_signal.notified() => {
272 break;
273 }
274 }
275 }
276}