kernel_sidecar/
actions.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll, Waker};
6
7use tokio::sync::{mpsc, Mutex};
8
9use crate::handlers::Handler;
10use crate::jupyter::iopub_content::status::KernelStatus;
11use crate::jupyter::request::Request;
12use crate::jupyter::response::Response;
13
14#[derive(Debug, PartialEq)]
15pub enum ExpectedReplyType {
16    KernelInfo,
17    ExecuteReply,
18    None,
19}
20
21impl From<&Request> for ExpectedReplyType {
22    fn from(request: &Request) -> Self {
23        match request {
24            Request::KernelInfo(_) => ExpectedReplyType::KernelInfo,
25            Request::Execute(_) => ExpectedReplyType::ExecuteReply,
26        }
27    }
28}
29
30impl From<&Response> for ExpectedReplyType {
31    fn from(response: &Response) -> Self {
32        match response {
33            Response::KernelInfo(_) => ExpectedReplyType::KernelInfo,
34            Response::Execute(_) => ExpectedReplyType::ExecuteReply,
35            _ => ExpectedReplyType::None,
36        }
37    }
38}
39
40#[derive(Debug)]
41struct ActionState {
42    completed: bool,
43    waker: Option<Waker>,
44}
45
46#[derive(Debug)]
47pub struct Action {
48    pub request: Request,
49    state: Arc<Mutex<ActionState>>,
50}
51
52impl Action {
53    pub fn new(
54        request: Request,
55        handlers: Vec<Arc<Mutex<dyn Handler>>>,
56        msg_rx: mpsc::Receiver<Response>,
57    ) -> Self {
58        let action_state = Arc::new(Mutex::new(ActionState {
59            completed: false,
60            waker: None,
61        }));
62        let expected_reply = ExpectedReplyType::from(&request);
63        // spawn background task for listening
64        tokio::spawn(Action::listen(
65            msg_rx,
66            expected_reply,
67            handlers,
68            action_state.clone(),
69        ));
70        Action {
71            request,
72            state: action_state,
73        }
74    }
75
76    async fn listen(
77        mut msg_rx: mpsc::Receiver<Response>,
78        expected_reply: ExpectedReplyType,
79        handlers: Vec<Arc<Mutex<dyn Handler>>>,
80        action_state: Arc<Mutex<ActionState>>,
81    ) {
82        // We "finish" this background task when kernel idle and expected reply (if relevant) seen
83        let mut kernel_idle = false;
84        let mut expected_reply_seen = match expected_reply {
85            ExpectedReplyType::KernelInfo => false,
86            ExpectedReplyType::ExecuteReply => false,
87            ExpectedReplyType::None => true,
88        };
89        while let Some(response) = msg_rx.recv().await {
90            for handler_arc in &handlers {
91                let mut handler = handler_arc.lock().await;
92                handler.handle(&response).await;
93            }
94            match response {
95                Response::Status(status) => {
96                    if status.content.execution_state == KernelStatus::Idle {
97                        kernel_idle = true;
98                    }
99                }
100                _ => {
101                    if expected_reply == ExpectedReplyType::from(&response) {
102                        expected_reply_seen = true;
103                    }
104                }
105            }
106            if kernel_idle && expected_reply_seen {
107                let mut state = action_state.lock().await;
108                state.completed = true;
109                if let Some(waker) = state.waker.take() {
110                    waker.wake();
111                }
112                break;
113            }
114        }
115    }
116}
117
118impl Future for Action {
119    type Output = ();
120
121    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122        let mut state = match self.state.try_lock() {
123            Ok(state) => state,
124            Err(_) => {
125                // If we can't get the lock, it means the background task is still running
126                // and we need to wait for it to complete
127                return Poll::Pending;
128            }
129        };
130        if state.completed {
131            Poll::Ready(())
132        } else {
133            state.waker = Some(cx.waker().clone());
134            Poll::Pending
135        }
136    }
137}