use embedded_rpc::{RequestDroppedError, RpcService};
use embassy_sync::blocking_mutex::raw::ThreadModeRawMutex;
use std::future::Future;
use std::sync::Arc;
use tokio::join;
use tokio::runtime::Builder;
use tokio::sync::oneshot;
use tokio::time::{Duration, timeout};
fn run_on_main_named_thread<F>(f: F)
where
F: Future<Output = ()> + Send + 'static,
{
std::thread::Builder::new()
.name("main".to_string())
.spawn(move || {
let rt = Builder::new_current_thread()
.enable_all()
.build()
.expect("tokio runtime");
rt.block_on(f);
})
.expect("spawn")
.join()
.expect("thread join");
}
struct BufferRequest<'a> {
buffer: &'a mut [u8],
}
#[test]
fn server_writes_through_client_buffer_slice() {
run_on_main_named_thread(async {
let mut buf = [0u8; 8];
let client_result = {
let service = RpcService::<ThreadModeRawMutex, BufferRequest<'_>, ()>::new();
let server = async {
let (req, served) = service.serve().await;
req.buffer[0] = 0xab;
req.buffer[1] = 0xcd;
req.buffer[2..].fill(0x7e);
served.respond(());
};
let client = async { service.request(BufferRequest { buffer: &mut buf }).await };
let (_, r) = join!(server, client);
r
};
assert_eq!(buf, [0xab, 0xcd, 0x7e, 0x7e, 0x7e, 0x7e, 0x7e, 0x7e]);
assert_eq!(client_result, Ok(()));
});
}
#[test]
fn round_trip_response_is_delivered() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let server = {
let service = Arc::clone(&service);
async move {
let (_, served) = service.serve().await;
served.respond(10);
}
};
let client = {
let service = Arc::clone(&service);
async move { service.request(5).await }
};
let (_, client_result) = join!(server, client);
assert_eq!(client_result, Ok(10));
});
}
#[test]
fn dropped_request_returns_error() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let server = {
let service = Arc::clone(&service);
async move {
let (_, _served) = service.serve().await;
}
};
let client = {
let service = Arc::clone(&service);
async move { service.request(42).await }
};
let (_, client_result) = join!(server, client);
assert_eq!(client_result, Err(RequestDroppedError));
});
}
#[test]
fn server_can_wait_before_client_requests() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let server = {
let service = Arc::clone(&service);
async move {
let (_, served) = service.serve().await;
served.respond(123);
}
};
let client = {
let service = Arc::clone(&service);
async move {
tokio::time::sleep(Duration::from_millis(10)).await;
service.request(1).await
}
};
let result = timeout(Duration::from_secs(1), async {
let (_, client_result) = join!(server, client);
client_result
})
.await
.expect("operation timed out");
assert_eq!(result, Ok(123));
});
}
#[test]
fn concurrent_clients_are_serialized_and_complete() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let server = {
let service = Arc::clone(&service);
async move {
for _ in 0..2 {
let (req, served) = service.serve().await;
served.respond(req + 1);
}
}
};
let c1 = {
let service = Arc::clone(&service);
async move { service.request(10).await }
};
let c2 = {
let service = Arc::clone(&service);
async move { service.request(20).await }
};
let (r1, r2) = timeout(Duration::from_secs(1), async {
let (_, a, b) = join!(server, c1, c2);
(a, b)
})
.await
.expect("operation timed out");
assert_eq!(r1, Ok(11));
assert_eq!(r2, Ok(21));
});
}
#[test]
fn cancelled_client_before_server_takes_request_releases_slot() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let server = {
let service = Arc::clone(&service);
async move {
let (n, served) = service.serve().await;
served.respond(n.saturating_add(10));
}
};
let client_cancelled = {
let service = Arc::clone(&service);
tokio::spawn(async move { service.request(1).await })
};
tokio::task::yield_now().await;
tokio::task::yield_now().await;
client_cancelled.abort();
let _ = client_cancelled.await;
let client_ok = {
let service = Arc::clone(&service);
async move { service.request(2).await }
};
let (_, r) = timeout(Duration::from_secs(1), async { join!(server, client_ok) })
.await
.expect("operation timed out");
assert_eq!(r, Ok(12));
});
}
#[test]
fn cancelled_client_after_server_takes_request_releases_slot_on_respond() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let (tx, rx) = oneshot::channel();
let svc_server = Arc::clone(&service);
let server = tokio::spawn(async move {
let (_, served) = svc_server.serve().await;
tx.send(()).expect("signal client");
tokio::time::sleep(Duration::from_millis(20)).await;
served.respond(99);
let (_, served2) = svc_server.serve().await;
served2.respond(7);
});
let client_cancelled = {
let service = Arc::clone(&service);
tokio::spawn(async move { service.request(1).await })
};
timeout(Duration::from_secs(1), rx)
.await
.expect("timed out waiting for server")
.expect("channel closed");
client_cancelled.abort();
let _ = client_cancelled.await;
let client_ok = {
let service = Arc::clone(&service);
async move { service.request(2).await }
};
let (_, r) = timeout(Duration::from_secs(1), async {
join!(
async move { server.await.expect("server task panicked") },
client_ok
)
})
.await
.expect("operation timed out");
assert_eq!(r, Ok(7));
});
}
#[test]
fn cancelled_client_after_server_takes_request_releases_slot_on_server_drop() {
run_on_main_named_thread(async {
let service = Arc::new(RpcService::<ThreadModeRawMutex, u32, u32>::new());
let (tx, rx) = oneshot::channel();
let server_drop = {
let service = Arc::clone(&service);
async move {
let (_, _served) = service.serve().await;
tx.send(()).expect("signal client");
tokio::time::sleep(Duration::from_millis(20)).await;
}
};
let client_cancelled = {
let service = Arc::clone(&service);
tokio::spawn(async move { service.request(1).await })
};
let server_handle = tokio::spawn(server_drop);
timeout(Duration::from_secs(1), rx)
.await
.expect("timed out waiting for server")
.expect("channel closed");
client_cancelled.abort();
let _ = client_cancelled.await;
timeout(Duration::from_secs(1), server_handle)
.await
.expect("server_drop timed out")
.expect("server task panicked");
let svc = Arc::clone(&service);
let (_, r) = timeout(Duration::from_secs(1), async {
join!(
async move {
let (_, served) = svc.serve().await;
served.respond(42);
},
async move { service.request(2).await }
)
})
.await
.expect("second RPC timed out");
assert_eq!(r, Ok(42));
});
}