forevervm_sdk/client/
repl.rs

1use super::{
2    typed_socket::{websocket_connect, WebSocketRecv, WebSocketSend},
3    util::authorized_request,
4    ClientError,
5};
6use crate::api::{
7    api_types::{ExecResult, Instruction},
8    id_types::{InstructionSeq, MachineName, RequestSeq},
9    protocol::{MessageFromServer, MessageToServer, StandardOutput},
10    token::ApiToken,
11};
12use std::{
13    ops::{Deref, DerefMut},
14    sync::{atomic::AtomicU32, Arc, Mutex},
15};
16use tokio::{
17    sync::{broadcast, oneshot},
18    task::JoinHandle,
19};
20
21pub const DEFAULT_INSTRUCTION_TIMEOUT_SECONDS: i32 = 15;
22
23#[derive(Default)]
24pub struct RequestSeqGenerator {
25    next: AtomicU32,
26}
27
28impl RequestSeqGenerator {
29    pub fn next(&self) -> RequestSeq {
30        let r = self.next.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
31        r.into()
32    }
33}
34
35#[derive(Debug)]
36pub enum ReplConnectionState {
37    Idle,
38    WaitingForInstructionSeq {
39        request_id: RequestSeq,
40        send_result_handle: oneshot::Sender<ExecResultHandle>,
41    },
42    WaitingForResult {
43        instruction_id: InstructionSeq,
44        output_sender: broadcast::Sender<StandardOutput>,
45        result_sender: oneshot::Sender<ExecResult>,
46    },
47}
48
49impl Default for ReplConnectionState {
50    fn default() -> Self {
51        Self::Idle
52    }
53}
54
55pub struct ReplConnection {
56    pub machine_name: MachineName,
57    request_seq_generator: RequestSeqGenerator,
58    sender: WebSocketSend<MessageToServer>,
59
60    receiver_handle: Option<JoinHandle<()>>,
61    state: Arc<Mutex<ReplConnectionState>>,
62}
63
64fn handle_message(
65    message: MessageFromServer,
66    state: Arc<Mutex<ReplConnectionState>>,
67) -> Result<(), ClientError> {
68    let msg = message;
69    match msg {
70        MessageFromServer::ExecReceived { seq, request_id } => {
71            let mut state = state.lock().expect("State lock poisoned");
72            let old_state = std::mem::take(state.deref_mut());
73
74            match old_state {
75                ReplConnectionState::WaitingForInstructionSeq {
76                    request_id: expected_request_seq,
77                    send_result_handle: receiver_sender,
78                } => {
79                    if request_id != expected_request_seq {
80                        tracing::warn!(
81                            ?request_id,
82                            ?expected_request_seq,
83                            "Unexpected request seq"
84                        );
85                        return Ok(());
86                    }
87
88                    let (output_sender, output_receiver) = broadcast::channel::<StandardOutput>(50);
89                    let (result_sender, result_receiver) = oneshot::channel();
90
91                    *state = ReplConnectionState::WaitingForResult {
92                        instruction_id: seq,
93                        output_sender,
94                        result_sender,
95                    };
96
97                    let _ = receiver_sender.send(ExecResultHandle {
98                        result: result_receiver,
99                        receiver: output_receiver,
100                    });
101                }
102                state => {
103                    tracing::error!(?state, "Unexpected ExecReceived while in state {state:?}");
104                }
105            }
106        }
107        MessageFromServer::Result(result) => {
108            let mut state = state.lock().expect("State lock poisoned");
109            let old_state = std::mem::take(state.deref_mut());
110
111            match old_state {
112                ReplConnectionState::WaitingForResult {
113                    instruction_id: instruction_seq,
114                    result_sender,
115                    ..
116                } => {
117                    if result.instruction_id != instruction_seq {
118                        tracing::warn!(
119                            ?instruction_seq,
120                            ?result.instruction_id,
121                            "Unexpected instruction seq"
122                        );
123                        return Ok(());
124                    }
125
126                    let _ = result_sender.send(result.result);
127                }
128                state => {
129                    tracing::error!(?state, "Unexpected Result while in state {state:?}");
130                }
131            }
132        }
133        MessageFromServer::Output {
134            chunk,
135            instruction_id: instruction,
136        } => {
137            let state = state.lock().expect("State lock poisoned");
138
139            match state.deref() {
140                ReplConnectionState::WaitingForResult {
141                    instruction_id: instruction_seq,
142                    output_sender,
143                    ..
144                } => {
145                    if *instruction_seq != instruction {
146                        tracing::warn!(
147                            ?instruction_seq,
148                            ?instruction,
149                            "Unexpected instruction seq"
150                        );
151                        return Ok(());
152                    }
153
154                    let _ = output_sender.send(chunk);
155                }
156                state => {
157                    tracing::error!(?state, "Unexpected Output while in state {state:?}");
158                }
159            }
160        }
161        MessageFromServer::Error(err) => {
162            return Err(ClientError::ApiError(err));
163        }
164        MessageFromServer::Connected { machine_name: _ } => {}
165        msg => tracing::warn!("message type not implmented: {msg:?}"),
166    }
167
168    Ok(())
169}
170
171async fn receive_loop(
172    mut receiver: WebSocketRecv<MessageFromServer>,
173    state: Arc<Mutex<ReplConnectionState>>,
174) {
175    while let Ok(Some(msg)) = receiver.recv().await {
176        if let Err(err) = handle_message(msg, state.clone()) {
177            tracing::error!(?err, "Failed to handle message");
178        }
179    }
180}
181
182impl ReplConnection {
183    pub async fn new(url: reqwest::Url, token: ApiToken) -> Result<Self, ClientError> {
184        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
185
186        let req = authorized_request(url, token)?;
187        let (sender, mut receiver) =
188            websocket_connect::<MessageToServer, MessageFromServer>(req).await?;
189
190        let state: Arc<Mutex<ReplConnectionState>> = Arc::default();
191
192        let machine_name = match receiver.recv().await? {
193            Some(MessageFromServer::Connected { machine_name }) => machine_name,
194            _ => {
195                return Err(ClientError::Other(String::from(
196                    "Expected `connected` message from REPL.",
197                )))
198            }
199        };
200
201        let receiver_handle = tokio::spawn(receive_loop(receiver, state.clone()));
202
203        Ok(Self {
204            machine_name,
205            request_seq_generator: Default::default(),
206            sender,
207            receiver_handle: Some(receiver_handle),
208            state,
209        })
210    }
211
212    pub async fn exec(&mut self, code: &str) -> Result<ExecResultHandle, ClientError> {
213        let instruction = Instruction {
214            code: code.to_string(),
215            timeout_seconds: DEFAULT_INSTRUCTION_TIMEOUT_SECONDS,
216        };
217        self.exec_instruction(instruction).await
218    }
219
220    pub async fn exec_instruction(
221        &mut self,
222        instruction: Instruction,
223    ) -> Result<ExecResultHandle, ClientError> {
224        let request_id = self.request_seq_generator.next();
225        let message = MessageToServer::Exec {
226            instruction,
227            request_id,
228        };
229        self.sender.send(&message).await?;
230
231        let (send_result_handle, receive_result_handle) = oneshot::channel::<ExecResultHandle>();
232        {
233            let mut state = self.state.lock().expect("State lock poisoned");
234
235            *state.deref_mut() = ReplConnectionState::WaitingForInstructionSeq {
236                request_id,
237                send_result_handle,
238            };
239        }
240
241        receive_result_handle
242            .await
243            .map_err(|_| ClientError::InstructionInterrupted)
244    }
245}
246
247impl Drop for ReplConnection {
248    fn drop(&mut self) {
249        if let Some(handle) = self.receiver_handle.take() {
250            handle.abort();
251        }
252    }
253}
254
255#[derive(Debug)]
256pub struct ExecResultHandle {
257    result: oneshot::Receiver<ExecResult>,
258    receiver: broadcast::Receiver<StandardOutput>,
259}
260
261impl ExecResultHandle {
262    pub async fn next(&mut self) -> Option<StandardOutput> {
263        self.receiver.recv().await.ok()
264    }
265
266    pub async fn result(self) -> Result<ExecResult, ClientError> {
267        self.result
268            .await
269            .map_err(|_| ClientError::InstructionInterrupted)
270    }
271}