1use std::{collections::HashMap, io, num::NonZeroUsize};
2
3use crate::{
4 parser::{self, parse_message},
5 raw::{self, GeneralMessage, Response},
6 status::Status,
7 Token,
8};
9
10use tokio::{
11 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12 process, select,
13 sync::mpsc,
14};
15use tracing::{debug, error, info, warn};
16
17type MsgOut = mpsc::Sender<Result<Response, crate::Error>>;
18type StatusOut = mpsc::Sender<Status>;
19type StatusAwaiterPred = Box<dyn Fn(&Status) -> bool + Send + Sync>;
20
21#[derive(derivative::Derivative)]
22#[derivative(Debug)]
23pub(super) enum Msg {
24 Cmd {
25 token: Token,
26 msg: String,
27 out: MsgOut,
28 },
29 ConsoleCmd {
30 token: Token,
31 msg: String,
32 out: mpsc::Sender<Result<(Response, Vec<String>), crate::Error>>,
33 capture_lines: NonZeroUsize,
34 },
35 PopGeneral(mpsc::Sender<Vec<GeneralMessage>>),
36 Status(mpsc::Sender<Status>),
37 NextStatus {
38 current: Status,
39 out: StatusOut,
40 },
41 AwaitStatus {
42 #[derivative(Debug = "ignore")]
43 pred: StatusAwaiterPred,
44 out: StatusOut,
45 },
46}
47
48pub(super) fn spawn(cmd: process::Child) -> mpsc::UnboundedSender<Msg> {
49 let (tx, rx) = mpsc::unbounded_channel::<Msg>();
50 tokio::spawn(async move { mainloop(cmd, rx).await });
51 tx
52}
53
54#[derive(Debug)]
55struct PendingConsole {
56 token: Token,
57 response: Option<Response>,
58 lines: Vec<String>,
59 target: NonZeroUsize,
60 out: mpsc::Sender<Result<(Response, Vec<String>), crate::Error>>,
61}
62
63#[derive(derivative::Derivative)]
64#[derivative(Debug)]
65struct State {
66 stdin: process::ChildStdin,
67 stdout: BufReader<process::ChildStdout>,
68 stderr: BufReader<process::ChildStderr>,
69 stdout_buf: String,
70 stderr_buf: String,
71 status: Status,
72 #[derivative(Debug = "ignore")]
73 notify_next_status: Vec<StatusOut>,
74 #[derivative(Debug = "ignore")]
75 status_awaiters: Vec<(StatusAwaiterPred, StatusOut)>,
76 pending: HashMap<Token, MsgOut>,
77 pending_general: Vec<GeneralMessage>,
78 pending_console: Option<PendingConsole>,
79}
80
81async fn mainloop(mut cmd: process::Child, mut rx: mpsc::UnboundedReceiver<Msg>) {
82 let stdin = cmd
83 .stdin
84 .take()
85 .expect("Stdin not captured. See docs of Gdb::new");
86 let stdout = BufReader::new(
87 cmd.stdout
88 .take()
89 .expect("Stdout not captured. See docs of Gdb::new"),
90 );
91 let stderr = BufReader::new(
92 cmd.stderr
93 .take()
94 .expect("Stderr not captured. See docs of Gdb::new"),
95 );
96
97 let mut state = State {
98 stdin,
99 stdout,
100 stderr,
101 stdout_buf: String::new(),
102 stderr_buf: String::new(),
103 status: Status::Unstarted,
104 notify_next_status: Vec::new(),
105 status_awaiters: Vec::new(),
106 pending: HashMap::new(),
107 pending_general: Vec::new(),
108 pending_console: None,
109 };
110
111 loop {
112 select! {
113 msg = rx.recv(), if &state.pending_console.is_none() => {
115 if let Err(err) = process_msg(msg, &mut state).await {
116 if log_and_check_fatal(&state, err) {
117 break
118 }
119 }
120 }
121
122 result = state.stdout.read_line(&mut state.stdout_buf) => {
123 if let Err(err) = process_stdout(result, &mut state).await {
124 if log_and_check_fatal(&state, err) {
125 break
126 }
127 }
128 }
129
130 result = state.stderr.read_line(&mut state.stderr_buf) => {
131 if let Err(err) = process_stderr(result, &mut state).await {
132 if log_and_check_fatal(&state, err) {
133 break
134 }
135 }
136 }
137 }
138 }
139}
140
141fn log_and_check_fatal(state: &State, error: Error) -> bool {
142 debug!(?state, "State after error");
143 match error {
144 Error::Transient(err) => {
145 error!("Transient error in worker: {}", err);
146 false
147 }
148 Error::Fatal(err) => {
149 error!("Fatal error in worker: {}", err);
150 true
151 }
152 }
153}
154
155#[derive(Debug, thiserror::Error, displaydoc::Display)]
156enum Error {
157 Fatal(#[from] FatalError),
159 Transient(#[from] TransientError),
161}
162
163#[derive(Debug, thiserror::Error, displaydoc::Display)]
164enum FatalError {
165 Stdin(#[source] io::Error),
167 RequestChanClosed,
169 Stdout(#[source] io::Error),
171 Send,
173 Parse(#[from] parser::Error),
175 Stderr(#[source] io::Error),
177}
178
179#[derive(Debug, thiserror::Error, displaydoc::Display)]
180enum TransientError {
181 Send,
183 Parse(#[from] parser::Error),
185}
186
187impl<T> From<mpsc::error::SendError<T>> for FatalError {
188 fn from(_: mpsc::error::SendError<T>) -> Self {
189 Self::Send
190 }
191}
192
193async fn process_msg(msg: Option<Msg>, state: &mut State) -> Result<(), Error> {
194 let msg = msg.ok_or(FatalError::RequestChanClosed)?;
195
196 match msg {
197 Msg::Cmd { token, msg, out } => {
198 write_stdin(&mut state.stdin, token, &msg).await?;
199 state.pending.insert(token, out);
200 }
201
202 Msg::ConsoleCmd {
203 token,
204 msg,
205 out,
206 capture_lines,
207 } => {
208 state.pending_console = Some(PendingConsole {
209 token,
210 response: None,
211 lines: Vec::with_capacity(capture_lines.get()),
212 target: capture_lines,
213 out,
214 });
215 write_stdin(&mut state.stdin, token, &msg).await?;
216 }
217
218 Msg::PopGeneral(out) => {
219 send(&out, state.pending_general.clone()).await?;
220 state.pending_general.clear();
221 }
222
223 Msg::Status(out) => {
224 send(&out, state.status.clone()).await?;
225 }
226
227 Msg::NextStatus {
228 current: current_belief,
229 out,
230 } => {
231 if current_belief == state.status {
232 state.notify_next_status.push(out);
233 } else {
234 debug!(?current_belief, actual = ?state.status, "Caller's current_belief incorrect, sending current status");
235 send(&out, state.status.clone()).await?;
236 }
237 }
238
239 Msg::AwaitStatus { pred, out } => {
240 state.status_awaiters.push((pred, out));
241 }
242 }
243
244 Ok(())
245}
246
247async fn write_stdin(
248 stdin: &mut process::ChildStdin,
249 token: Token,
250 msg: &str,
251) -> Result<(), FatalError> {
252 let mut input = token.serialize();
253 input.push_str(&msg);
254 input.push('\n');
255
256 info!("Sending to gdb {}", input);
257 stdin
258 .write_all(&input.as_bytes())
259 .await
260 .map_err(FatalError::Stdin)?;
261
262 Ok(())
263}
264
265async fn process_stdout(result: io::Result<usize>, state: &mut State) -> Result<(), Error> {
266 result.map_err(FatalError::Stdout)?;
267
268 let line = &state.stdout_buf[..state.stdout_buf.len() - 1]; debug!("Got stdout: {}", line);
270 let response = parse_message(&line).map_err(TransientError::from)?;
271 state.stdout_buf.clear();
272
273 match response {
274 parser::Message::Response(response) => process_parsed_response(state, response).await?,
275 parser::Message::General(general) => process_parsed_general(state, general).await?,
276 }
277 Ok(())
278}
279
280async fn process_parsed_response(
281 state: &mut State,
282 response: parser::Response,
283) -> Result<(), Error> {
284 let token = if let Some(token) = response.token() {
285 token
286 } else {
287 match response {
288 parser::Response::Notify {
289 message, payload, ..
290 } => {
291 process_response_notify(state, message, payload).await?;
292 }
293 result @ parser::Response::Result { .. } => {
294 warn!("Ignoring result without token: {:?}", result);
295 }
296 }
297 return Ok(());
298 };
299
300 if let Some(pending_token) = state.pending_console.as_ref().map(|p| p.token) {
301 if token == pending_token {
302 match Response::from_parsed(response) {
303 Ok(response) => {
304 let mut pending = state.pending_console.as_mut().unwrap();
305 pending.response = Some(response);
306
307 if pending.lines.len() != pending.target.get() {
308 return Ok(());
309 }
310
311 send(
312 &pending.out,
313 Ok((pending.response.clone().unwrap(), pending.lines.clone())),
314 )
315 .await?;
316
317 state.pending_console = None;
318 }
319 Err(err) => {
320 send(&state.pending_console.as_ref().unwrap().out, Err(err)).await?;
321 }
322 }
323 return Ok(());
324 }
325 }
326
327 let out = if let Some(out) = state.pending.remove(&token) {
328 out
329 } else {
330 warn!(
331 "Got unexpected token {:?}, so ignoring: {:?}",
332 token, response
333 );
334 return Ok(());
335 };
336
337 let response = Response::from_parsed(response);
338 info!("Sending response: {:?}", response);
339 send(&out, response).await?;
340
341 Ok(())
342}
343
344async fn process_response_notify(
345 state: &mut State,
346 message: String,
347 payload: raw::Dict,
348) -> Result<(), Error> {
349 if let Some(new_status) = Status::from_notification(&message, payload) {
350 state.status = new_status;
351
352 info!("New status {:?}", state.status);
353
354 let to_notify = &mut state.notify_next_status;
355 debug!("Notifying {} watchers of status", to_notify.len());
356 for out in to_notify.drain(..) {
357 send(&out, state.status.clone()).await?;
358 }
359
360 let mut to_remove = Vec::new();
361 for (idx, (pred, out)) in state.status_awaiters.iter().enumerate() {
362 if pred(&state.status) {
363 send(out, state.status.clone()).await?;
364 to_remove.push(idx);
365 }
366 }
367 debug!(
368 "{} were awaiting this status, {} remain",
369 to_remove.len(),
370 state.status_awaiters.len() - to_remove.len()
371 );
372 for idx in to_remove {
373 drop(state.status_awaiters.remove(idx));
374 }
375 }
376
377 Ok(())
378}
379
380async fn process_parsed_general(
381 state: &mut State,
382 general: raw::GeneralMessage,
383) -> Result<(), Error> {
384 if let Some(pending) = state.pending_console.as_mut() {
385 if let GeneralMessage::Console(line) = general {
386 debug!(?pending, "Received console line for command: {}", line);
387
388 if pending.lines.len() < pending.target.get() {
389 pending.lines.push(line);
390 }
391
392 if pending.lines.len() != pending.target.get() || pending.response.is_none() {
393 return Ok(());
394 }
395
396 send(
397 &pending.out,
398 Ok((pending.response.clone().unwrap(), pending.lines.clone())),
399 )
400 .await?;
401
402 state.pending_console = None;
403
404 return Ok(());
405 }
406 }
407
408 if general == GeneralMessage::Done {
409 debug!("Ignoring done");
411 return Ok(());
412 }
413
414 info!("Got general message: {:?}", general);
415 state.pending_general.push(general);
416
417 Ok(())
418}
419
420async fn process_stderr(result: io::Result<usize>, state: &mut State) -> Result<(), Error> {
421 result.map_err(FatalError::Stderr)?;
422
423 let line = &state.stderr_buf[..state.stderr_buf.len() - 1]; debug!("Got stderr: {}", line);
425 let message = GeneralMessage::InferiorStderr(line.into());
426 state.pending_general.push(message);
427 state.stderr_buf.clear();
428
429 Ok(())
430}
431
432async fn send<T>(chan: &mpsc::Sender<T>, val: T) -> Result<(), Error> {
433 chan.send(val)
434 .await
435 .map_err(|_| Error::Transient(TransientError::Send))
436}