extcap 0.4.0

Crate for writing extcap Wireshark plugins.
Documentation
use std::fs::File;
use std::future::Future;
use std::io::{self, Cursor};

use bytes::buf::BufMut;
use bytes::{Buf, BytesMut};
use futures::channel::mpsc::{self, Receiver, Sender};
use futures::channel::oneshot;
use futures::future::{self, lazy, BoxFuture, FutureExt};
use futures::sink::SinkExt;
use futures::stream::{StreamExt, TryStreamExt};
use log::{debug, error};
use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};

use crate::control_pipe::{ControlMsg, CtrlPipes};

const PIPE_LEN: usize = 128;

#[derive(Debug)]
enum State {
    New {
        pipe_in: File,
        pipe_out: File,
    },
    Started {
        stop_in: oneshot::Sender<()>,
        stop_out: oneshot::Sender<()>,
    },
}

pub(crate) struct ControlPipeRuntime {
    state: Option<State>,
    tsk: Option<BoxFuture<'static, ()>>,
}

impl ControlPipeRuntime {
    pub(crate) fn new(pipe_in: File, pipe_out: File) -> Self {
        Self {
            state: Some(State::New { pipe_in, pipe_out }),
            tsk: None,
        }
    }

    pub(crate) fn start(&mut self) -> CtrlPipes {
        debug!("start() state={:?}", self.state);

        let (pipe_in, pipe_out) = if let Some(State::New { pipe_in, pipe_out }) = self.state.take()
        {
            (pipe_in, pipe_out)
        } else {
            error!("start() called in wrong state");
            panic!("start() called in wrong state");
        };

        let (snd, rcv_in) = mpsc::channel(PIPE_LEN);
        let (stop_in, stop_in_rx) = oneshot::channel::<()>();

        let (snd_out, rcv) = mpsc::channel(PIPE_LEN);
        let (stop_out, stop_out_rx) = oneshot::channel::<()>();

        self.state = Some(State::Started { stop_in, stop_out });

        let tsk = futures::future::join(
            thread_in(stop_in_rx, pipe_in, snd),
            thread_out(stop_out_rx, pipe_out, rcv),
        )
        .map(|_| ());

        self.tsk = Some(tsk.boxed::<'static>());

        debug!("start() done state={:?}", self.state);
        (rcv_in, snd_out)
    }

    pub(crate) fn run_task(&mut self) -> impl Future<Output = ()> {
        self.tsk.take().unwrap()
    }

    pub(crate) fn stop(mut self) {
        debug!("stop() state={:?}", self.state);
        if let Some(State::Started { stop_in, stop_out }) = self.state.take() {
            stop_in.send(()).unwrap();
            stop_out.send(()).unwrap();
        } else {
            error!("stop() called in wrong state");
            return;
        }
        debug!("stop() done");
    }
}

async fn thread_in(
    stop: oneshot::Receiver<()>,
    pipe: File,
    sender: Sender<ControlMsg>,
) -> Result<(), ()> {
    debug!("thread_in starting ...");
    lazy::<_, Result<(), ()>>(|_| {
        debug!("thread_in started");
        Ok(())
    })
    .await?;
    let tpipe = tokio::fs::File::from_std(pipe);
    let strm = FramedRead::new(tpipe, ControlMsgCodec);
    let task = strm
        .inspect(|msg| debug!("thread_in received {:?}", msg))
        .map_err(|e| error!("thread_in stream_err {:?}", e))
        .forward(sender.sink_map_err(|e| error!("thread_in sink_err {:?}", e)));
    future::select(stop, task).await;
    debug!("thread_in stopped");
    Ok(())
}

async fn thread_out(
    stop: oneshot::Receiver<()>,
    pipe: File,
    mut receiver: Receiver<ControlMsg>,
) -> Result<(), ()> {
    debug!("thread_out starting ...");
    lazy::<_, Result<(), ()>>(|_| {
        debug!("thread_out started");
        Ok(())
    })
    .await?;
    let tpipe = tokio::fs::File::from_std(pipe);
    let mut strm = FramedWrite::new(tpipe, ControlMsgCodec);
    let task = async {
        while let Some(msg) = receiver.next().await {
            debug!("thread_out received {:?}", msg);
            if let Err(e) = strm.send(msg).await {
                error!("thread_out strm_err {:?}", e);
            }
        }
    };
    future::select(stop, task.boxed()).await;
    debug!("thread_out stopped");
    Ok(())
}

struct ControlMsgCodec;

impl Decoder for ControlMsgCodec {
    type Item = ControlMsg;
    type Error = io::Error;

    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        if buf.len() < 4 {
            return Ok(None);
        }
        let mut hdr = Cursor::new(&buf[0..4]);
        if hdr.get_u8() != b'T' {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "Sync Pipe Indication != 'T'",
            ));
        }
        let msg_len = hdr.get_uint(3) as usize;
        if msg_len < 2 {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "Message Length < 2",
            ));
        }
        if buf.len() < (4 + msg_len) {
            return Ok(None);
        }
        let mut pdu = buf.split_to(4 + msg_len);
        pdu.advance(4);
        let cnum = pdu.get_u8();
        let cmd = pdu.get_u8();
        Ok(Some(ControlMsg::new(cnum, cmd.into(), pdu.chunk())))
    }
}

impl Encoder<ControlMsg> for ControlMsgCodec {
    type Error = io::Error;

    fn encode(&mut self, msg: ControlMsg, buf: &mut BytesMut) -> Result<(), Self::Error> {
        buf.reserve(6 + msg.get_data().len());
        buf.put_u8(b'T');
        buf.put_uint(2 + msg.get_data().len() as u64, 3);
        buf.put_u8(msg.get_ctrl_num());
        buf.put_u8(u8::from(msg.get_command()));
        buf.put(msg.get_data());
        debug!("encode() {:?}", buf);
        Ok(())
    }
}