use super::{
RpcError, RpcErrorCode, RpcInvocationData, RpcTransport, ATTR_METHOD, ATTR_REQUEST_ID,
ATTR_RESPONSE_TIMEOUT_MS, ATTR_VERSION, MAX_V1_PAYLOAD_BYTES, RPC_RESPONSE_TOPIC,
RPC_VERSION_V1, RPC_VERSION_V2,
};
use crate::data_stream::{StreamReader, StreamTextOptions, TextStreamReader};
use crate::room::id::ParticipantIdentity;
use livekit_protocol as proto;
use parking_lot::Mutex;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, time::Duration};
pub(crate) type RpcHandlerFn = Arc<
dyn Fn(RpcInvocationData) -> Pin<Box<dyn Future<Output = Result<String, RpcError>> + Send>>
+ Send
+ Sync,
>;
pub struct HandleRequestOptions {
pub caller_identity: ParticipantIdentity,
pub request_id: String,
pub method: String,
pub payload: String,
pub response_timeout: Duration,
pub version: u32,
}
pub struct RpcServerManager {
handlers: Mutex<HashMap<String, RpcHandlerFn>>,
}
impl RpcServerManager {
pub fn new() -> Self {
Self { handlers: Mutex::new(HashMap::new()) }
}
pub fn register_method(
&self,
method: String,
handler: impl Fn(RpcInvocationData) -> Pin<Box<dyn Future<Output = Result<String, RpcError>> + Send>>
+ Send
+ Sync
+ 'static,
) {
self.handlers.lock().insert(method, Arc::new(handler));
}
pub fn unregister_method(&self, method: &str) {
self.handlers.lock().remove(method);
}
pub(crate) fn get_handler(&self, method: &str) -> Option<RpcHandlerFn> {
self.handlers.lock().get(method).cloned()
}
pub(crate) async fn handle_v1_request(
&self,
options: HandleRequestOptions,
transport: &(impl RpcTransport + 'static),
) {
let HandleRequestOptions {
caller_identity,
request_id,
method,
payload,
response_timeout,
version,
} = options;
if let Err(e) = self.publish_rpc_ack(transport, &caller_identity.0, &request_id).await {
log::error!("Failed to publish RPC ACK: {:?}", e);
}
let response = if version != RPC_VERSION_V1 {
Err(RpcError::built_in(RpcErrorCode::UnsupportedVersion, None))
} else {
self.invoke_handler(&caller_identity, &request_id, &method, &payload, response_timeout)
.await
};
let (resp_payload, error) = match response {
Ok(response_payload) if response_payload.len() <= MAX_V1_PAYLOAD_BYTES => {
(Some(response_payload), None)
}
Ok(_) => (
None,
Some(RpcError::built_in(RpcErrorCode::ResponsePayloadTooLarge, None).to_proto()),
),
Err(e) => (None, Some(e.to_proto())),
};
if let Err(e) = self
.publish_rpc_response_packet(
transport,
&caller_identity.0,
&request_id,
resp_payload,
error,
)
.await
{
log::error!("Failed to publish RPC response: {:?}", e);
}
}
pub(crate) async fn handle_v2_request_stream(
&self,
reader: TextStreamReader,
caller_identity: ParticipantIdentity,
transport: &(impl RpcTransport + 'static),
) {
let attrs = &reader.info().attributes;
let request_id = attrs.get(ATTR_REQUEST_ID).cloned().unwrap_or_default();
let method = attrs.get(ATTR_METHOD).cloned().unwrap_or_default();
let response_timeout_ms: u64 =
attrs.get(ATTR_RESPONSE_TIMEOUT_MS).and_then(|v| v.parse().ok()).unwrap_or(15000);
let version: u32 = attrs.get(ATTR_VERSION).and_then(|v| v.parse().ok()).unwrap_or(0);
let response_timeout = Duration::from_millis(response_timeout_ms);
if let Err(e) = self.publish_rpc_ack(transport, &caller_identity.0, &request_id).await {
log::error!("Failed to publish RPC ACK: {:?}", e);
}
if version != RPC_VERSION_V2 {
let error = RpcError::built_in(RpcErrorCode::UnsupportedVersion, None);
let _ = self
.publish_rpc_response_packet(
transport,
&caller_identity.0,
&request_id,
None,
Some(error.to_proto()),
)
.await;
return;
}
let payload = match reader.read_all().await {
Ok(payload) => payload,
Err(e) => {
log::error!("Failed to read RPC v2 request stream: {:?}", e);
let error = RpcError::built_in(
RpcErrorCode::ApplicationError,
Some(format!("Failed to read request stream: {}", e)),
);
let _ = self
.publish_rpc_response_packet(
transport,
&caller_identity.0,
&request_id,
None,
Some(error.to_proto()),
)
.await;
return;
}
};
let response = self
.invoke_handler(&caller_identity, &request_id, &method, &payload, response_timeout)
.await;
match response {
Ok(response_payload) => {
let mut attributes = HashMap::new();
attributes.insert(ATTR_REQUEST_ID.to_string(), request_id.clone());
let options = StreamTextOptions {
topic: RPC_RESPONSE_TOPIC.to_string(),
attributes,
destination_identities: vec![caller_identity.clone()],
..Default::default()
};
if let Err(e) = transport.send_text(&response_payload, options).await {
log::error!("Failed to send RPC v2 response stream: {:?}", e);
let error = RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string()));
let _ = self
.publish_rpc_response_packet(
transport,
&caller_identity.0,
&request_id,
None,
Some(error.to_proto()),
)
.await;
}
}
Err(e) => {
if let Err(send_err) = self
.publish_rpc_response_packet(
transport,
&caller_identity.0,
&request_id,
None,
Some(e.to_proto()),
)
.await
{
log::error!("Failed to publish RPC error response: {:?}", send_err);
}
}
}
}
async fn invoke_handler(
&self,
caller_identity: &ParticipantIdentity,
request_id: &str,
method: &str,
payload: &str,
response_timeout: Duration,
) -> Result<String, RpcError> {
let handler = self.get_handler(method);
match handler {
Some(handler) => {
let caller_id = caller_identity.clone();
let req_id = request_id.to_string();
let req_payload = payload.to_string();
match tokio::task::spawn(async move {
handler(RpcInvocationData {
request_id: req_id,
caller_identity: caller_id,
payload: req_payload,
response_timeout,
})
.await
})
.await
{
Ok(result) => result,
Err(e) => {
log::error!("RPC method handler returned an error: {:?}", e);
Err(RpcError::built_in(RpcErrorCode::ApplicationError, None))
}
}
}
None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)),
}
}
async fn publish_rpc_response_packet(
&self,
transport: &impl RpcTransport,
destination_identity: &str,
request_id: &str,
payload: Option<String>,
error: Option<proto::RpcError>,
) -> Result<(), crate::room::RoomError> {
let rpc_response_message = proto::RpcResponse {
request_id: request_id.to_string(),
value: Some(match error {
Some(error) => proto::rpc_response::Value::Error(error),
None => proto::rpc_response::Value::Payload(payload.unwrap()),
}),
..Default::default()
};
let data = proto::DataPacket {
value: Some(proto::data_packet::Value::RpcResponse(rpc_response_message)),
destination_identities: vec![destination_identity.to_string()],
..Default::default()
};
transport.publish_data(data).await
}
async fn publish_rpc_ack(
&self,
transport: &impl RpcTransport,
destination_identity: &str,
request_id: &str,
) -> Result<(), crate::room::RoomError> {
let rpc_ack_message =
proto::RpcAck { request_id: request_id.to_string(), ..Default::default() };
let data = proto::DataPacket {
value: Some(proto::data_packet::Value::RpcAck(rpc_ack_message)),
destination_identities: vec![destination_identity.to_string()],
..Default::default()
};
transport.publish_data(data).await
}
}