1use super::{
2 typed_socket::{websocket_connect, WebSocketRecv, WebSocketSend},
3 util::authorized_request,
4 ClientError,
5};
6use crate::api::{
7 api_types::{ExecResult, Instruction},
8 id_types::{InstructionSeq, MachineName, RequestSeq},
9 protocol::{MessageFromServer, MessageToServer, StandardOutput},
10 token::ApiToken,
11};
12use std::{
13 ops::{Deref, DerefMut},
14 sync::{atomic::AtomicU32, Arc, Mutex},
15};
16use tokio::{
17 sync::{broadcast, oneshot},
18 task::JoinHandle,
19};
20
21pub const DEFAULT_INSTRUCTION_TIMEOUT_SECONDS: i32 = 15;
22
23#[derive(Default)]
24pub struct RequestSeqGenerator {
25 next: AtomicU32,
26}
27
28impl RequestSeqGenerator {
29 pub fn next(&self) -> RequestSeq {
30 let r = self.next.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
31 r.into()
32 }
33}
34
35#[derive(Debug)]
36pub enum ReplConnectionState {
37 Idle,
38 WaitingForInstructionSeq {
39 request_id: RequestSeq,
40 send_result_handle: oneshot::Sender<ExecResultHandle>,
41 },
42 WaitingForResult {
43 instruction_id: InstructionSeq,
44 output_sender: broadcast::Sender<StandardOutput>,
45 result_sender: oneshot::Sender<ExecResult>,
46 },
47}
48
49impl Default for ReplConnectionState {
50 fn default() -> Self {
51 Self::Idle
52 }
53}
54
55pub struct ReplConnection {
56 pub machine_name: MachineName,
57 request_seq_generator: RequestSeqGenerator,
58 sender: WebSocketSend<MessageToServer>,
59
60 receiver_handle: Option<JoinHandle<()>>,
61 state: Arc<Mutex<ReplConnectionState>>,
62}
63
64fn handle_message(
65 message: MessageFromServer,
66 state: Arc<Mutex<ReplConnectionState>>,
67) -> Result<(), ClientError> {
68 let msg = message;
69 match msg {
70 MessageFromServer::ExecReceived { seq, request_id } => {
71 let mut state = state.lock().expect("State lock poisoned");
72 let old_state = std::mem::take(state.deref_mut());
73
74 match old_state {
75 ReplConnectionState::WaitingForInstructionSeq {
76 request_id: expected_request_seq,
77 send_result_handle: receiver_sender,
78 } => {
79 if request_id != expected_request_seq {
80 tracing::warn!(
81 ?request_id,
82 ?expected_request_seq,
83 "Unexpected request seq"
84 );
85 return Ok(());
86 }
87
88 let (output_sender, output_receiver) = broadcast::channel::<StandardOutput>(50);
89 let (result_sender, result_receiver) = oneshot::channel();
90
91 *state = ReplConnectionState::WaitingForResult {
92 instruction_id: seq,
93 output_sender,
94 result_sender,
95 };
96
97 let _ = receiver_sender.send(ExecResultHandle {
98 result: result_receiver,
99 receiver: output_receiver,
100 });
101 }
102 state => {
103 tracing::error!(?state, "Unexpected ExecReceived while in state {state:?}");
104 }
105 }
106 }
107 MessageFromServer::Result(result) => {
108 let mut state = state.lock().expect("State lock poisoned");
109 let old_state = std::mem::take(state.deref_mut());
110
111 match old_state {
112 ReplConnectionState::WaitingForResult {
113 instruction_id: instruction_seq,
114 result_sender,
115 ..
116 } => {
117 if result.instruction_id != instruction_seq {
118 tracing::warn!(
119 ?instruction_seq,
120 ?result.instruction_id,
121 "Unexpected instruction seq"
122 );
123 return Ok(());
124 }
125
126 let _ = result_sender.send(result.result);
127 }
128 state => {
129 tracing::error!(?state, "Unexpected Result while in state {state:?}");
130 }
131 }
132 }
133 MessageFromServer::Output {
134 chunk,
135 instruction_id: instruction,
136 } => {
137 let state = state.lock().expect("State lock poisoned");
138
139 match state.deref() {
140 ReplConnectionState::WaitingForResult {
141 instruction_id: instruction_seq,
142 output_sender,
143 ..
144 } => {
145 if *instruction_seq != instruction {
146 tracing::warn!(
147 ?instruction_seq,
148 ?instruction,
149 "Unexpected instruction seq"
150 );
151 return Ok(());
152 }
153
154 let _ = output_sender.send(chunk);
155 }
156 state => {
157 tracing::error!(?state, "Unexpected Output while in state {state:?}");
158 }
159 }
160 }
161 MessageFromServer::Error(err) => {
162 return Err(ClientError::ApiError(err));
163 }
164 MessageFromServer::Connected { machine_name: _ } => {}
165 msg => tracing::warn!("message type not implmented: {msg:?}"),
166 }
167
168 Ok(())
169}
170
171async fn receive_loop(
172 mut receiver: WebSocketRecv<MessageFromServer>,
173 state: Arc<Mutex<ReplConnectionState>>,
174) {
175 while let Ok(Some(msg)) = receiver.recv().await {
176 if let Err(err) = handle_message(msg, state.clone()) {
177 tracing::error!(?err, "Failed to handle message");
178 }
179 }
180}
181
182impl ReplConnection {
183 pub async fn new(url: reqwest::Url, token: ApiToken) -> Result<Self, ClientError> {
184 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
185
186 let req = authorized_request(url, token)?;
187 let (sender, mut receiver) =
188 websocket_connect::<MessageToServer, MessageFromServer>(req).await?;
189
190 let state: Arc<Mutex<ReplConnectionState>> = Arc::default();
191
192 let machine_name = match receiver.recv().await? {
193 Some(MessageFromServer::Connected { machine_name }) => machine_name,
194 _ => {
195 return Err(ClientError::Other(String::from(
196 "Expected `connected` message from REPL.",
197 )))
198 }
199 };
200
201 let receiver_handle = tokio::spawn(receive_loop(receiver, state.clone()));
202
203 Ok(Self {
204 machine_name,
205 request_seq_generator: Default::default(),
206 sender,
207 receiver_handle: Some(receiver_handle),
208 state,
209 })
210 }
211
212 pub async fn exec(&mut self, code: &str) -> Result<ExecResultHandle, ClientError> {
213 let instruction = Instruction {
214 code: code.to_string(),
215 timeout_seconds: DEFAULT_INSTRUCTION_TIMEOUT_SECONDS,
216 };
217 self.exec_instruction(instruction).await
218 }
219
220 pub async fn exec_instruction(
221 &mut self,
222 instruction: Instruction,
223 ) -> Result<ExecResultHandle, ClientError> {
224 let request_id = self.request_seq_generator.next();
225 let message = MessageToServer::Exec {
226 instruction,
227 request_id,
228 };
229 self.sender.send(&message).await?;
230
231 let (send_result_handle, receive_result_handle) = oneshot::channel::<ExecResultHandle>();
232 {
233 let mut state = self.state.lock().expect("State lock poisoned");
234
235 *state.deref_mut() = ReplConnectionState::WaitingForInstructionSeq {
236 request_id,
237 send_result_handle,
238 };
239 }
240
241 receive_result_handle
242 .await
243 .map_err(|_| ClientError::InstructionInterrupted)
244 }
245}
246
247impl Drop for ReplConnection {
248 fn drop(&mut self) {
249 if let Some(handle) = self.receiver_handle.take() {
250 handle.abort();
251 }
252 }
253}
254
255#[derive(Debug)]
256pub struct ExecResultHandle {
257 result: oneshot::Receiver<ExecResult>,
258 receiver: broadcast::Receiver<StandardOutput>,
259}
260
261impl ExecResultHandle {
262 pub async fn next(&mut self) -> Option<StandardOutput> {
263 self.receiver.recv().await.ok()
264 }
265
266 pub async fn result(self) -> Result<ExecResult, ClientError> {
267 self.result
268 .await
269 .map_err(|_| ClientError::InstructionInterrupted)
270 }
271}