ezrpc 0.1.1

Ergonomic, flexible and Zero-cost RPC framework
Documentation
#![feature(async_closure)]
#![feature(lazy_cell, async_fn_in_trait)]

use ezrpc::{channel, ws::*, Adapter, ArcRpcContext, MappedService, Session};
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;

const RECURSIVE_ADD: u32 = 1;
const ECHO_BIGDATA: u32 = 2;
const SLEEP_MS: u32 = 3;
const ECHO_INT: u32 = 4;
const INCREMENT: u32 = 5;

static LOGGER: LazyLock<()> = LazyLock::new(|| {
    flexi_logger::Logger::try_with_env_or_str("debug")
        .unwrap()
        .start()
        .unwrap();
});

#[derive(Debug, Default)]
pub struct TestService;

mod rpc {
    use super::*;

    ezrpc::ezrpc! {
        fn partTuple(a: i32, b: i32, c: Option<i32>) -> i32;
        fn recadd(*ctx, val: u32) -> u32 [RECURSIVE_ADD];
        fn test_bytes(buf: &serde_bytes::Bytes);
        fn test_str<'u>(buf: &'u str) -> String;
    }
}

rpc::init_service!(rpc::Service, TestService);

#[allow(non_snake_case)]
impl rpc::Service for TestService {
    async fn partTuple(&self, a: i32, b: i32, _c: Option<i32>) -> anyhow::Result<i32> {
        println!("a: {a} b: {b}");
        Ok(a + b)
    }

    async fn recadd(&self, ctx: ArcRpcContext, val: u32) -> anyhow::Result<u32> {
        println!("Server Received {:?}", val);
        // ctx.packet.method();
        let val: u32 = ctx.session.request(RECURSIVE_ADD, &val).await.unwrap();
        println!("Server Requested {:?}", val);
        Ok(val + 1)
    }

    async fn test_bytes(&self, buf: &serde_bytes::Bytes) -> anyhow::Result<()> {
        // println!("{buf:?}");
        assert_eq!(buf, BYTES);
        Ok(())
    }

    async fn test_str<'u>(&'u self, buf: &'u str) -> anyhow::Result<String> {
        Ok(buf.into())
    }
}

const BYTES: &[u8] = &[1, 2, 3];

async fn server_thread(adapter: Arc<dyn Adapter>) {
    let sv = MappedService::<TestService>::default();

    sv.add("stream", async move |ctx: ArcRpcContext| {
        let st = ctx.into_stream();
        for i in 0..10 {
            st.send(i).await.expect("response_part");
        }
        // st.end().await.unwrap();
    });

    sv.add(ECHO_BIGDATA, |data: Vec<u8>| {
        println!("BigData Received {}", data.len());
        data
    });
    // sv.add(ECHO_BIGDATA, |data: Vec<u8>| {
    //     println!("BigData Received {}", data.len());
    //     data
    // });
    // #[cfg(target_arch = "wasm32")]
    // sv.add(SLEEP_MS, async move |ms: u32| {
    //     tokio::time::sleep(Duration::from_millis(ms as _)).await
    // });
    sv.add(SLEEP_MS, async move |ms: u32| {
        #[cfg(not(target_arch = "wasm32"))]
        tokio::time::sleep(Duration::from_millis(ms as _)).await
    });
    sv.add(ECHO_INT, core::convert::identity::<u32>);
    let num = spin::RwLock::new(0u32);
    sv.add(INCREMENT, move |val: u32| {
        assert_eq!(val, *num.read());
        println!("val: {}", val);
        *num.write() += 1;
    });

    TestService::init_service(&sv);

    let session = Arc::new(Session::new(adapter, Arc::new(sv)));
    session.loop_dispatch().await.unwrap();
}

// trait Test: Send + Sync + Sized + 'static {
//     fn add<'a>(&'a self, a: i32, b: i32) -> impl Future<Output = anyhow::Result<i32>> + 'a;
// }

// fn addservice<T: Test>(map: MapService<T>) {
//     map.add(0, T::add);
// }

async fn session_test(session: &Arc<Session>) {
    use rpc::SessionMethod;

    let ss = session.clone();
    let monitor = tokio::spawn(async move {
        loop {
            tokio::time::sleep(Duration::from_secs(5)).await;
            #[cfg(debug_assertions)]
            println!("Pending request: {:?}", ss.pending_method.read());
        }
    });

    let val: u32 = session.recadd(0u32).await.unwrap();
    assert_eq!(val, 2);

    let mut r = session.request_recver("stream", ()).await.unwrap();
    for i in 0..10 {
        assert_eq!(r.recv_decode::<u32>().await.unwrap(), Some(i));
    }

    const LEN: usize = 1024 * 1024;
    let tail = &[1, 2, 3];
    let mut data: Vec<u8> = Vec::with_capacity(LEN);
    data.resize(LEN, 0);
    data.extend_from_slice(tail);

    session
        .test_bytes(serde_bytes::Bytes::new(BYTES))
        .await
        .unwrap();

    let res: i32 = session.partTuple(1, 2, None).await.unwrap();
    assert_eq!(res, 3);

    tokio::task::spawn_blocking({
        let session = session.clone();
        move || {
            let res: i32 = session.request_sync("partTuple", (1, 2, 3)).unwrap();
            assert_eq!(res, 3);
        }
    });

    let res: i32 = session.request("partTuple", (1, 2, 3)).await.unwrap();
    assert_eq!(res, 3);

    // test
    for i in 0u32..100 {
        session.notify(INCREMENT, i).await.unwrap();
    }

    let data: Vec<u8> = session.request(ECHO_BIGDATA, &data).await.unwrap();
    assert_eq!(data.len(), LEN + tail.len());
    assert_eq!(&data[data.len() - tail.len()..], tail);

    monitor.abort();
}

async fn test_client<A: Adapter>(adapter: A) -> Arc<Session> {
    let srv = MappedService::<()>::default();
    srv.add(RECURSIVE_ADD, async move |val: u32| {
        println!("Client Received {:?}", val);
        val + 1
    });
    let session = Arc::new(Session::from(adapter, srv));
    session_test(&session).await;
    session
}

#[cfg(target_arch = "wasm32")]
#[wasm_bindgen_test::wasm_bindgen_test]
async fn test_ws() {
    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);

    test_client().await;
}

#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_ws() {
    let _ = LOGGER.clone();
    tokio::task::LocalSet::new()
        .run_until(async {
            tokio::task::spawn_local(async {
                let listener = tokio::net::TcpListener::bind("127.0.0.1:3333")
                    .await
                    .unwrap();
                server_thread(Arc::new(WsAdapter::accept(&listener).await.unwrap())).await;
            });

            tokio::time::sleep(Duration::from_millis(100)).await;
            test_client(WsAdapter::connect("ws://127.0.0.1:3333").await.unwrap()).await;
        })
        .await;
}

#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_channel() {
    let _ = LOGGER.clone();
    tokio::task::LocalSet::new()
        .run_until(async {
            let (s1, s2) = channel::ChannelAdapter::new();
            tokio::task::spawn_local(async {
                server_thread(Arc::new(s1)).await;
            });

            tokio::time::sleep(Duration::from_millis(100)).await;
            test_client(s2).await;
        })
        .await;
}

#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_ipc() {
    use ezrpc::ipc::{NamedAdapter, NamedAdapterServer};

    let _ = LOGGER.clone();
    tokio::task::LocalSet::new()
        .run_until(async {
            tokio::task::spawn_local(async {
                server_thread(Arc::new(
                    NamedAdapterServer::bind_accept("test_ipc").await.unwrap(),
                ))
                .await;
            });

            tokio::time::sleep(Duration::from_millis(100)).await;
            test_client(NamedAdapter::connect("test_ipc").await.unwrap()).await;
        })
        .await;
}

#[cfg(not(target_arch = "wasm32"))]
#[ignore]
#[tokio::test]
async fn only_server() {
    let listener = tokio::net::TcpListener::bind("127.0.0.1:3333")
        .await
        .unwrap();
    server_thread(Arc::new(WsAdapter::accept(&listener).await.unwrap())).await;
}