gdbmi/
worker.rs

1use std::{collections::HashMap, io, num::NonZeroUsize};
2
3use crate::{
4    parser::{self, parse_message},
5    raw::{self, GeneralMessage, Response},
6    status::Status,
7    Token,
8};
9
10use tokio::{
11    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12    process, select,
13    sync::mpsc,
14};
15use tracing::{debug, error, info, warn};
16
17type MsgOut = mpsc::Sender<Result<Response, crate::Error>>;
18type StatusOut = mpsc::Sender<Status>;
19type StatusAwaiterPred = Box<dyn Fn(&Status) -> bool + Send + Sync>;
20
21#[derive(derivative::Derivative)]
22#[derivative(Debug)]
23pub(super) enum Msg {
24    Cmd {
25        token: Token,
26        msg: String,
27        out: MsgOut,
28    },
29    ConsoleCmd {
30        token: Token,
31        msg: String,
32        out: mpsc::Sender<Result<(Response, Vec<String>), crate::Error>>,
33        capture_lines: NonZeroUsize,
34    },
35    PopGeneral(mpsc::Sender<Vec<GeneralMessage>>),
36    Status(mpsc::Sender<Status>),
37    NextStatus {
38        current: Status,
39        out: StatusOut,
40    },
41    AwaitStatus {
42        #[derivative(Debug = "ignore")]
43        pred: StatusAwaiterPred,
44        out: StatusOut,
45    },
46}
47
48pub(super) fn spawn(cmd: process::Child) -> mpsc::UnboundedSender<Msg> {
49    let (tx, rx) = mpsc::unbounded_channel::<Msg>();
50    tokio::spawn(async move { mainloop(cmd, rx).await });
51    tx
52}
53
54#[derive(Debug)]
55struct PendingConsole {
56    token: Token,
57    response: Option<Response>,
58    lines: Vec<String>,
59    target: NonZeroUsize,
60    out: mpsc::Sender<Result<(Response, Vec<String>), crate::Error>>,
61}
62
63#[derive(derivative::Derivative)]
64#[derivative(Debug)]
65struct State {
66    stdin: process::ChildStdin,
67    stdout: BufReader<process::ChildStdout>,
68    stderr: BufReader<process::ChildStderr>,
69    stdout_buf: String,
70    stderr_buf: String,
71    status: Status,
72    #[derivative(Debug = "ignore")]
73    notify_next_status: Vec<StatusOut>,
74    #[derivative(Debug = "ignore")]
75    status_awaiters: Vec<(StatusAwaiterPred, StatusOut)>,
76    pending: HashMap<Token, MsgOut>,
77    pending_general: Vec<GeneralMessage>,
78    pending_console: Option<PendingConsole>,
79}
80
81async fn mainloop(mut cmd: process::Child, mut rx: mpsc::UnboundedReceiver<Msg>) {
82    let stdin = cmd
83        .stdin
84        .take()
85        .expect("Stdin not captured. See docs of Gdb::new");
86    let stdout = BufReader::new(
87        cmd.stdout
88            .take()
89            .expect("Stdout not captured. See docs of Gdb::new"),
90    );
91    let stderr = BufReader::new(
92        cmd.stderr
93            .take()
94            .expect("Stderr not captured. See docs of Gdb::new"),
95    );
96
97    let mut state = State {
98        stdin,
99        stdout,
100        stderr,
101        stdout_buf: String::new(),
102        stderr_buf: String::new(),
103        status: Status::Unstarted,
104        notify_next_status: Vec::new(),
105        status_awaiters: Vec::new(),
106        pending: HashMap::new(),
107        pending_general: Vec::new(),
108        pending_console: None,
109    };
110
111    loop {
112        select! {
113            // Don't pull any new command while we're waiting for console output
114            msg = rx.recv(), if &state.pending_console.is_none() => {
115                if let Err(err) = process_msg(msg, &mut state).await {
116                    if log_and_check_fatal(&state, err) {
117                        break
118                    }
119                }
120            }
121
122            result = state.stdout.read_line(&mut state.stdout_buf) => {
123                if let Err(err) = process_stdout(result, &mut state).await {
124                    if log_and_check_fatal(&state, err) {
125                        break
126                    }
127                }
128            }
129
130            result = state.stderr.read_line(&mut state.stderr_buf) => {
131                if let Err(err) = process_stderr(result, &mut state).await {
132                    if log_and_check_fatal(&state, err) {
133                        break
134                    }
135                }
136            }
137        }
138    }
139}
140
141fn log_and_check_fatal(state: &State, error: Error) -> bool {
142    debug!(?state, "State after error");
143    match error {
144        Error::Transient(err) => {
145            error!("Transient error in worker: {}", err);
146            false
147        }
148        Error::Fatal(err) => {
149            error!("Fatal error in worker: {}", err);
150            true
151        }
152    }
153}
154
155#[derive(Debug, thiserror::Error, displaydoc::Display)]
156enum Error {
157    /// Fatal error in worker
158    Fatal(#[from] FatalError),
159    /// Transient error in worker
160    Transient(#[from] TransientError),
161}
162
163#[derive(Debug, thiserror::Error, displaydoc::Display)]
164enum FatalError {
165    /// Failed to write to stdin
166    Stdin(#[source] io::Error),
167    /// Request channel closed
168    RequestChanClosed,
169    /// Failed to read stdout
170    Stdout(#[source] io::Error),
171    /// Failed to send to out chan
172    Send,
173    /// Failed to parse response
174    Parse(#[from] parser::Error),
175    /// Failed to read stderr
176    Stderr(#[source] io::Error),
177}
178
179#[derive(Debug, thiserror::Error, displaydoc::Display)]
180enum TransientError {
181    /// Failed to send to out chan
182    Send,
183    /// Failed to parse response
184    Parse(#[from] parser::Error),
185}
186
187impl<T> From<mpsc::error::SendError<T>> for FatalError {
188    fn from(_: mpsc::error::SendError<T>) -> Self {
189        Self::Send
190    }
191}
192
193async fn process_msg(msg: Option<Msg>, state: &mut State) -> Result<(), Error> {
194    let msg = msg.ok_or(FatalError::RequestChanClosed)?;
195
196    match msg {
197        Msg::Cmd { token, msg, out } => {
198            write_stdin(&mut state.stdin, token, &msg).await?;
199            state.pending.insert(token, out);
200        }
201
202        Msg::ConsoleCmd {
203            token,
204            msg,
205            out,
206            capture_lines,
207        } => {
208            state.pending_console = Some(PendingConsole {
209                token,
210                response: None,
211                lines: Vec::with_capacity(capture_lines.get()),
212                target: capture_lines,
213                out,
214            });
215            write_stdin(&mut state.stdin, token, &msg).await?;
216        }
217
218        Msg::PopGeneral(out) => {
219            send(&out, state.pending_general.clone()).await?;
220            state.pending_general.clear();
221        }
222
223        Msg::Status(out) => {
224            send(&out, state.status.clone()).await?;
225        }
226
227        Msg::NextStatus {
228            current: current_belief,
229            out,
230        } => {
231            if current_belief == state.status {
232                state.notify_next_status.push(out);
233            } else {
234                debug!(?current_belief, actual = ?state.status, "Caller's current_belief incorrect, sending current status");
235                send(&out, state.status.clone()).await?;
236            }
237        }
238
239        Msg::AwaitStatus { pred, out } => {
240            state.status_awaiters.push((pred, out));
241        }
242    }
243
244    Ok(())
245}
246
247async fn write_stdin(
248    stdin: &mut process::ChildStdin,
249    token: Token,
250    msg: &str,
251) -> Result<(), FatalError> {
252    let mut input = token.serialize();
253    input.push_str(&msg);
254    input.push('\n');
255
256    info!("Sending to gdb {}", input);
257    stdin
258        .write_all(&input.as_bytes())
259        .await
260        .map_err(FatalError::Stdin)?;
261
262    Ok(())
263}
264
265async fn process_stdout(result: io::Result<usize>, state: &mut State) -> Result<(), Error> {
266    result.map_err(FatalError::Stdout)?;
267
268    let line = &state.stdout_buf[..state.stdout_buf.len() - 1]; // strip the newline
269    debug!("Got stdout: {}", line);
270    let response = parse_message(&line).map_err(TransientError::from)?;
271    state.stdout_buf.clear();
272
273    match response {
274        parser::Message::Response(response) => process_parsed_response(state, response).await?,
275        parser::Message::General(general) => process_parsed_general(state, general).await?,
276    }
277    Ok(())
278}
279
280async fn process_parsed_response(
281    state: &mut State,
282    response: parser::Response,
283) -> Result<(), Error> {
284    let token = if let Some(token) = response.token() {
285        token
286    } else {
287        match response {
288            parser::Response::Notify {
289                message, payload, ..
290            } => {
291                process_response_notify(state, message, payload).await?;
292            }
293            result @ parser::Response::Result { .. } => {
294                warn!("Ignoring result without token: {:?}", result);
295            }
296        }
297        return Ok(());
298    };
299
300    if let Some(pending_token) = state.pending_console.as_ref().map(|p| p.token) {
301        if token == pending_token {
302            match Response::from_parsed(response) {
303                Ok(response) => {
304                    let mut pending = state.pending_console.as_mut().unwrap();
305                    pending.response = Some(response);
306
307                    if pending.lines.len() != pending.target.get() {
308                        return Ok(());
309                    }
310
311                    send(
312                        &pending.out,
313                        Ok((pending.response.clone().unwrap(), pending.lines.clone())),
314                    )
315                    .await?;
316
317                    state.pending_console = None;
318                }
319                Err(err) => {
320                    send(&state.pending_console.as_ref().unwrap().out, Err(err)).await?;
321                }
322            }
323            return Ok(());
324        }
325    }
326
327    let out = if let Some(out) = state.pending.remove(&token) {
328        out
329    } else {
330        warn!(
331            "Got unexpected token {:?}, so ignoring: {:?}",
332            token, response
333        );
334        return Ok(());
335    };
336
337    let response = Response::from_parsed(response);
338    info!("Sending response: {:?}", response);
339    send(&out, response).await?;
340
341    Ok(())
342}
343
344async fn process_response_notify(
345    state: &mut State,
346    message: String,
347    payload: raw::Dict,
348) -> Result<(), Error> {
349    if let Some(new_status) = Status::from_notification(&message, payload) {
350        state.status = new_status;
351
352        info!("New status {:?}", state.status);
353
354        let to_notify = &mut state.notify_next_status;
355        debug!("Notifying {} watchers of status", to_notify.len());
356        for out in to_notify.drain(..) {
357            send(&out, state.status.clone()).await?;
358        }
359
360        let mut to_remove = Vec::new();
361        for (idx, (pred, out)) in state.status_awaiters.iter().enumerate() {
362            if pred(&state.status) {
363                send(out, state.status.clone()).await?;
364                to_remove.push(idx);
365            }
366        }
367        debug!(
368            "{} were awaiting this status, {} remain",
369            to_remove.len(),
370            state.status_awaiters.len() - to_remove.len()
371        );
372        for idx in to_remove {
373            drop(state.status_awaiters.remove(idx));
374        }
375    }
376
377    Ok(())
378}
379
380async fn process_parsed_general(
381    state: &mut State,
382    general: raw::GeneralMessage,
383) -> Result<(), Error> {
384    if let Some(pending) = state.pending_console.as_mut() {
385        if let GeneralMessage::Console(line) = general {
386            debug!(?pending, "Received console line for command: {}", line);
387
388            if pending.lines.len() < pending.target.get() {
389                pending.lines.push(line);
390            }
391
392            if pending.lines.len() != pending.target.get() || pending.response.is_none() {
393                return Ok(());
394            }
395
396            send(
397                &pending.out,
398                Ok((pending.response.clone().unwrap(), pending.lines.clone())),
399            )
400            .await?;
401
402            state.pending_console = None;
403
404            return Ok(());
405        }
406    }
407
408    if general == GeneralMessage::Done {
409        // Suppress these, as they come after every command
410        debug!("Ignoring done");
411        return Ok(());
412    }
413
414    info!("Got general message: {:?}", general);
415    state.pending_general.push(general);
416
417    Ok(())
418}
419
420async fn process_stderr(result: io::Result<usize>, state: &mut State) -> Result<(), Error> {
421    result.map_err(FatalError::Stderr)?;
422
423    let line = &state.stderr_buf[..state.stderr_buf.len() - 1]; // strip the newline
424    debug!("Got stderr: {}", line);
425    let message = GeneralMessage::InferiorStderr(line.into());
426    state.pending_general.push(message);
427    state.stderr_buf.clear();
428
429    Ok(())
430}
431
432async fn send<T>(chan: &mpsc::Sender<T>, val: T) -> Result<(), Error> {
433    chan.send(val)
434        .await
435        .map_err(|_| Error::Transient(TransientError::Send))
436}