shvclient 5.0.2

A Rust framework for Silicon Heaven RPC devices
use std::{error::Error, sync::{Arc, RwLock}, time::Duration};
use event_listener::Event;
use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, future::{Either, select}};
use futures_timer::Delay;
use log::debug;
use shvproto::RpcValue;
use shvrpc::{RpcMessage, RpcMessageMetaTags, rpc::ShvRI, rpcmessage::{RpcError, RpcErrorCode}};
use crate::{ConnectionCommand, ConnectionEvent, runtime::TaskHandle};

const TIMEOUT_DURATION: Duration = Duration::from_secs(5);

#[must_use]
pub struct PendingResponse<'a>(i64, &'a TestApp);

impl PendingResponse<'_> {
    pub async fn await_response(self, expected_result: Result<&str, &str>) {
        let rpc_message = self.1.await_and_remove(|msg| msg.is_response() && msg.request_id() == Some(self.0)).await;
        let response = rpc_message.response().map(|resp| resp.success().expect("Expected a success response"));
        match expected_result {
            Ok(expected_result) => {
                let result = response.expect("Expected a success response");
                assert_eq!(*result, RpcValue::from_cpon(expected_result).unwrap_or_else(|err| panic!("Invalid CPON '{expected_result}': {err}")), "Unexpected value of the result");
            },
            Err(err) => {
                let result = response.expect_err("Expected an Err response");
                assert_eq!(result.to_string(), err, "Unexpected value of the result");
            },
        }
    }
}

#[must_use]
pub struct PendingRequest(RpcMessage, UnboundedSender<ConnectionEvent>);
impl PendingRequest {
    pub fn respond(mut self, result: Result<&str, &str>) {
        let rq_id = self.0.request_id().expect("RpcMessage must be request");
        debug!(target: "test-driver", "==> rq_id:{rq_id} {result:?}");
        match result {
            Ok(result) => {
                let result = RpcValue::from_cpon(result).unwrap_or_else(|err| panic!("Invalid CPON '{result}': {err}"));
                self.0.set_result(result);
            },
            Err(err) => {
                self.0.set_error(RpcError::new(RpcErrorCode::MethodCallException, err));
            },
        }
        self.1.unbounded_send(ConnectionEvent::RpcFrameReceived(self.0.to_frame().expect("to_frame() must work"))).expect("sending ConnectionEvent must work");
    }
}

pub trait ParamMatcher {
    fn matches(&self, actual: &RpcValue) -> bool;
}

impl ParamMatcher for &str {
    fn matches(&self, actual: &RpcValue) -> bool {
        let expected = RpcValue::from_cpon(self)
            .unwrap_or_else(|err| panic!("Invalid CPON '{self}': {err}"));
        actual == &expected
    }
}

impl<F> ParamMatcher for F
where
    F: Fn(&RpcValue) -> bool
{
    fn matches(&self, actual: &RpcValue) -> bool {
        self(actual)
    }
}

pub struct TestApp {
    app: TaskHandle<Result<(), Box<dyn Error + Send + Sync>>>,
    msg_task: TaskHandle<()>,
    conn_evt_tx: UnboundedSender<ConnectionEvent>,
    pending_msgs: Arc<RwLock<Vec<RpcMessage>>>,
    notifier: Arc<Event>,
}

impl TestApp {
    pub fn request(&self, ri: impl TryInto<ShvRI, Error = impl std::fmt::Display>, param: &str) -> PendingResponse<'_> {
        let ri = ri.try_into().unwrap_or_else(|err| panic!("Invalid RI: {err}"));
        let path = ri.path();
        let method = ri.method();
        let mut rpc_message = RpcMessage::new_request(path, method).with_param(RpcValue::from_cpon(param).unwrap_or_else(|err| panic!("Invalid CPON '{param}': {err}")));
        rpc_message.set_user_id("test-driver");
        rpc_message.set_access_level(shvrpc::metamethod::AccessLevel::Superuser);
        let rq_id = rpc_message.request_id().expect("This must be request");
        debug!(target: "test-driver", "==> rq_id:{rq_id} {path}:{method}, param: {param}");
        self.conn_evt_tx.unbounded_send(ConnectionEvent::RpcFrameReceived(rpc_message.to_frame().expect("to_frame must work"))).expect("events must work");
        PendingResponse(rq_id, self)
    }

    async fn await_and_remove<F>(&self, matcher: F) -> RpcMessage
    where
        F: Fn(&RpcMessage) -> bool
    {
        loop {
            let event_listener = self.notifier.listen();

            {
                let mut pending = self.pending_msgs.write().expect("pending_msgs write should succeed");
                let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos));

                if let Some(matched_msg) = matched_msg {
                    return matched_msg;
                }
            }

            timeout(TIMEOUT_DURATION, event_listener).await.unwrap_or_else(|err| panic!("{err}"));
        }
    }

    pub fn signal(&self, ri: impl TryInto<ShvRI, Error = impl std::fmt::Display>, param: &str) {
        let ri = ri.try_into().unwrap_or_else(|err| panic!("Invalid RI: {err}"));
        let path = ri.path();
        let method = ri.method();
        let signal = ri.signal().unwrap_or_else(|| panic!("Signal RI must have a signal: {ri}"));
        debug!(target: "test-driver", "==> {path}:{method}:{signal}, param: {param}");
        let param = RpcValue::from_cpon(param).unwrap_or_else(|err| panic!("Invalid CPON '{param}': {err}"));
        let rpc_message = RpcMessage::new_signal_with_source(path, signal, method).with_param(param);
        self.conn_evt_tx.unbounded_send(ConnectionEvent::RpcFrameReceived(rpc_message.to_frame().expect("to_frame must work"))).expect("events must work");
    }

    pub async fn await_signal(
        &self,
        expected_ri: impl TryInto<ShvRI, Error = impl std::fmt::Display>,
        expected_param: impl ParamMatcher,
    ) {
        let expected_ri = expected_ri.try_into().unwrap_or_else(|err| panic!("Invalid RI: {err}"));
        let expected_signal = expected_ri.signal().unwrap_or_else(|| panic!("Signal RI must have a signal: {expected_ri}"));

        self.await_and_remove(|rpc_message| {
            if !rpc_message.is_signal() {
                return false;
            }
            let shv_path = rpc_message.shv_path().expect("msg must have a path");
            let method = rpc_message.method().expect("msg must have a method");
            let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);

            shv_path == expected_ri.path() &&
                method == expected_signal &&
                expected_param.matches(&param)
        }).await;
    }

    pub async fn await_request(
        &self,
        expected_ri: impl TryInto<ShvRI, Error = impl std::fmt::Display>,
        expected_param: impl ParamMatcher,
    ) -> PendingRequest {
        let expected_ri = expected_ri.try_into().unwrap_or_else(|err| panic!("Invalid RI: {err}"));

        let rpc_message = self.await_and_remove(|rpc_message| {
            if !rpc_message.is_request() {
                return false;
            }
            let shv_path = rpc_message.shv_path().expect("msg must have a path");
            let method = rpc_message.method().expect("msg must have a method");
            let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);

            shv_path == expected_ri.path() &&
                method == expected_ri.method() &&
                expected_param.matches(&param)
        }).await;

        PendingRequest(
            rpc_message.prepare_response().expect("prepare_response must work"),
            self.conn_evt_tx.clone()
        )
    }

    pub async fn await_subscription(&self, expected_ri: impl TryInto<ShvRI, Error = impl std::fmt::Display>) {
        let expected_ri = expected_ri.try_into().unwrap_or_else(|err| panic!("Invalid RI: {err}"));
        let expected_ri = expected_ri.as_str();
        self.await_request(".broker/currentClient:subscribe", format!(r#"["{expected_ri}",null]""#).as_str()).await
            .respond(Ok("true"));
    }

    pub fn new(app_maker: impl FnOnce(UnboundedReceiver<ConnectionEvent>) -> TaskHandle<Result<(), Box<dyn Error + Send + Sync>>>) -> Self {
        let (conn_evt_tx, conn_evt_rx) = futures::channel::mpsc::unbounded::<ConnectionEvent>();
        let app = app_maker(conn_evt_rx);

        let (conn_cmd_sender_in, mut conn_cmd_receiver_in) = unbounded();
        let pending_msgs = Arc::new(RwLock::new(Vec::new()));
        let notifier = Arc::new(Event::new());

        let msg_task = {
            let pending_msgs = pending_msgs.clone();
            let notifier = notifier.clone();
            crate::runtime::spawn_task(async move {
                while let Some(ConnectionCommand::SendMessage(rpc_message)) = conn_cmd_receiver_in.next().await {
                    let shv_path = rpc_message.shv_path().unwrap_or_default();
                    let method = rpc_message.method().unwrap_or_default();
                    let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);
                    let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} "));
                    let msg_suffix = rpc_message.response().map(|resp| resp.success().cloned().unwrap_or_else(RpcValue::null));
                    if method.is_empty() && shv_path.is_empty() {
                        debug!(target: "test-driver", "<== {msg_prefix}-> {msg_suffix:?}");
                    } else {
                        debug!(target: "test-driver", "<== {msg_prefix}{shv_path}:{method}, param: {param}");
                    }

                    {
                        pending_msgs.write().expect("pending_msgs write should succeed").push(rpc_message);
                    }
                    notifier.notify(usize::MAX);
                }
            })
        };
        conn_evt_tx.unbounded_send(ConnectionEvent::Connected(conn_cmd_sender_in)).expect("Events must work");
        Self {
            app,
            msg_task,
            conn_evt_tx,
            pending_msgs,
            notifier,
        }
    }

    pub async fn wait_until_finished(self) -> shvrpc::Result<()> {
        // Wait for a bit to ensure silence from the app.
        futures_timer::Delay::new(Duration::from_millis(500)).await;
        self.conn_evt_tx.close_channel();

        {
            let mut pending_msgs = self.pending_msgs.write().expect("pending_msgs write should succeed");
            // We'll let unsubscribe calls slide.
            pending_msgs.retain(|msg|
                !msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe")
            );
            for rpc_message in pending_msgs.iter() {
                let shv_path = rpc_message.shv_path().unwrap_or_default();
                let method = rpc_message.method().unwrap_or_default();
                let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);
                let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} "));
                if rpc_message.is_response() {
                    let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported"));
                    debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}");
                } else if method.is_empty() && shv_path.is_empty() {
                    debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}");
                } else {
                    debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}");
                }
            };

            if !pending_msgs.is_empty() {
                return Err("There were unexpected messages received from the app.".into());
            }
        }

        let end = async move {
            self.msg_task.await?;
            self.app.await??;

            Ok(())
        };

        timeout(TIMEOUT_DURATION, end).await?
    }
}

pub async fn timeout<F, T>(dur: Duration, fut: F) -> shvrpc::Result<T>
where
    F: Future<Output = T>,
{
    futures::pin_mut!(fut);
    let timeout = Delay::new(dur);
    futures::pin_mut!(timeout);

    match select(fut, timeout).await {
        Either::Left((val, _)) => Ok(val),
        Either::Right(_) => Err("Timed out while waiting for a future".into()),
    }
}