use std::error::Error as StdError;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use rapace::transport::shm::{ShmSession, ShmSessionConfig, ShmTransport};
use rapace::{Frame, RpcError, RpcSession, TransportError};
pub const DEFAULT_SHM_CONFIG: ShmSessionConfig = ShmSessionConfig {
ring_capacity: 256, slot_size: 65536, slot_count: 128, };
const PLUGIN_CHANNEL_START: u32 = 2;
#[derive(Debug)]
pub enum PluginError {
Args(String),
ShmTimeout(PathBuf),
ShmOpen(String),
Rpc(RpcError),
Transport(TransportError),
}
impl std::fmt::Display for PluginError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Args(msg) => write!(f, "Argument error: {}", msg),
Self::ShmTimeout(path) => write!(f, "SHM file not created by host: {}", path.display()),
Self::ShmOpen(msg) => write!(f, "Failed to open SHM: {}", msg),
Self::Rpc(e) => write!(f, "RPC error: {:?}", e),
Self::Transport(e) => write!(f, "Transport error: {:?}", e),
}
}
}
impl StdError for PluginError {}
impl From<RpcError> for PluginError {
fn from(e: RpcError) -> Self {
Self::Rpc(e)
}
}
impl From<TransportError> for PluginError {
fn from(e: TransportError) -> Self {
Self::Transport(e)
}
}
pub trait ServiceDispatch: Send + Sync + 'static {
fn dispatch(
&self,
method_id: u32,
payload: &[u8],
) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send + 'static>>;
}
pub struct DispatcherBuilder {
services: Vec<Box<dyn ServiceDispatch>>,
}
impl DispatcherBuilder {
pub fn new() -> Self {
Self {
services: Vec::new(),
}
}
pub fn add_service<S>(mut self, service: S) -> Self
where
S: ServiceDispatch,
{
self.services.push(Box::new(service));
self
}
#[allow(clippy::type_complexity)]
pub fn build(
self,
) -> impl Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
+ Send
+ Sync
+ 'static {
let services = Arc::new(self.services);
move |_channel_id, method_id, payload| {
let services = services.clone();
Box::pin(async move {
for service in services.iter() {
let result = service.dispatch(method_id, &payload).await;
if !matches!(
&result,
Err(RpcError::Status {
code: rapace::ErrorCode::Unimplemented,
..
})
) {
return result;
}
}
Err(RpcError::Status {
code: rapace::ErrorCode::Unimplemented,
message: format!("Unknown method_id: {}", method_id),
})
})
}
}
}
impl Default for DispatcherBuilder {
fn default() -> Self {
Self::new()
}
}
fn parse_args() -> Result<PathBuf, PluginError> {
for arg in std::env::args().skip(1) {
if let Some(value) = arg.strip_prefix("--shm-path=") {
return Ok(PathBuf::from(value));
} else if !arg.starts_with("--") {
return Ok(PathBuf::from(arg));
}
}
Err(PluginError::Args(
"Missing SHM path (use --shm-path=PATH or provide as first argument)".to_string(),
))
}
async fn wait_for_shm(path: &std::path::Path, timeout_ms: u64) -> Result<(), PluginError> {
let attempts = timeout_ms / 100;
for i in 0..attempts {
if path.exists() {
return Ok(());
}
if i < attempts - 1 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
Err(PluginError::ShmTimeout(path.to_path_buf()))
}
async fn setup_plugin(
config: ShmSessionConfig,
) -> Result<(Arc<RpcSession<ShmTransport>>, PathBuf), PluginError> {
let shm_path = parse_args()?;
wait_for_shm(&shm_path, 5000).await?;
let shm_session = ShmSession::open_file(&shm_path, config)
.map_err(|e| PluginError::ShmOpen(format!("{:?}", e)))?;
let transport = Arc::new(ShmTransport::new(shm_session));
let session = Arc::new(RpcSession::with_channel_start(
transport,
PLUGIN_CHANNEL_START,
));
Ok((session, shm_path))
}
pub async fn run<S>(service: S) -> Result<(), PluginError>
where
S: ServiceDispatch,
{
run_with_config(service, DEFAULT_SHM_CONFIG).await
}
pub async fn run_with_config<S>(service: S, config: ShmSessionConfig) -> Result<(), PluginError>
where
S: ServiceDispatch,
{
let (session, shm_path) = setup_plugin(config).await?;
tracing::info!("Connected to host via SHM: {}", shm_path.display());
let dispatcher = {
let service = Arc::new(service);
move |_channel_id: u32, method_id: u32, payload: Vec<u8>| {
let service = service.clone();
Box::pin(async move { service.dispatch(method_id, &payload).await })
as Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
}
};
session.set_dispatcher(dispatcher);
session.run().await?;
Ok(())
}
pub async fn run_multi<F>(builder_fn: F) -> Result<(), PluginError>
where
F: FnOnce(DispatcherBuilder) -> DispatcherBuilder,
{
run_multi_with_config(builder_fn, DEFAULT_SHM_CONFIG).await
}
pub async fn run_multi_with_config<F>(
builder_fn: F,
config: ShmSessionConfig,
) -> Result<(), PluginError>
where
F: FnOnce(DispatcherBuilder) -> DispatcherBuilder,
{
let (session, shm_path) = setup_plugin(config).await?;
tracing::info!("Connected to host via SHM: {}", shm_path.display());
let builder = DispatcherBuilder::new();
let builder = builder_fn(builder);
let dispatcher = builder.build();
session.set_dispatcher(dispatcher);
session.run().await?;
Ok(())
}
pub trait RpcSessionExt<T> {
fn set_service<S>(&self, service: S)
where
S: ServiceDispatch;
}
impl<T> RpcSessionExt<T> for RpcSession<T>
where
T: rapace::Transport + 'static,
{
fn set_service<S>(&self, service: S)
where
S: ServiceDispatch,
{
let service = Arc::new(service);
let dispatcher = move |_channel_id: u32, method_id: u32, payload: Vec<u8>| {
let service = service.clone();
Box::pin(async move { service.dispatch(method_id, &payload).await })
as Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
};
self.set_dispatcher(dispatcher);
}
}