use super::{
PerformRpcData, RpcError, RpcErrorCode, RpcTransport, ATTR_METHOD, ATTR_REQUEST_ID,
ATTR_RESPONSE_TIMEOUT_MS, ATTR_VERSION, MAX_V1_PAYLOAD_BYTES, RPC_REQUEST_TOPIC,
RPC_VERSION_V1, RPC_VERSION_V2,
};
use crate::data_stream::{StreamReader, StreamTextOptions, TextStreamReader};
use crate::room::id::ParticipantIdentity;
use libwebrtc::native::create_random_uuid;
use livekit_api::signal_client::CLIENT_PROTOCOL_DATA_STREAM_RPC;
use livekit_protocol as proto;
use parking_lot::Mutex;
use semver::Version;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::oneshot;
pub struct RpcClientManager {
pending_acks: Mutex<HashMap<String, oneshot::Sender<()>>>,
pending_responses: Mutex<HashMap<String, oneshot::Sender<Result<String, RpcError>>>>,
}
impl RpcClientManager {
pub fn new() -> Self {
Self {
pending_acks: Mutex::new(HashMap::new()),
pending_responses: Mutex::new(HashMap::new()),
}
}
pub(crate) async fn perform_rpc(
&self,
data: PerformRpcData,
transport: &(impl RpcTransport + 'static),
) -> Result<String, RpcError> {
let max_round_trip_latency = Duration::from_millis(7000);
let min_effective_timeout = Duration::from_millis(1000);
if let Some(version_str) = transport.server_version() {
let server_version = Version::parse(&version_str).unwrap();
let min_required_version = Version::parse("1.8.0").unwrap();
if server_version < min_required_version {
return Err(RpcError::built_in(RpcErrorCode::UnsupportedServer, None));
}
}
let remote_protocol = transport
.remote_client_protocol(&ParticipantIdentity(data.destination_identity.clone()));
let use_v2 = remote_protocol >= CLIENT_PROTOCOL_DATA_STREAM_RPC;
if !use_v2 && data.payload.len() > MAX_V1_PAYLOAD_BYTES {
return Err(RpcError::built_in(RpcErrorCode::RequestPayloadTooLarge, None));
}
let id = create_random_uuid();
let (ack_tx, ack_rx) = oneshot::channel();
let (response_tx, response_rx) = oneshot::channel();
let effective_timeout = std::cmp::max(
data.response_timeout.saturating_sub(max_round_trip_latency),
min_effective_timeout,
);
{
let mut pending_acks = self.pending_acks.lock();
let mut pending_responses = self.pending_responses.lock();
pending_acks.insert(id.clone(), ack_tx);
pending_responses.insert(id.clone(), response_tx);
}
let send_result = if use_v2 {
self.send_v2_request(
transport,
&data.destination_identity,
&id,
&data.method,
&data.payload,
effective_timeout,
)
.await
} else {
self.send_v1_request(
transport,
&data.destination_identity,
&id,
&data.method,
&data.payload,
effective_timeout,
)
.await
.map_err(|e| RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string())))
};
if let Err(e) = send_result {
let mut pending_acks = self.pending_acks.lock();
let mut pending_responses = self.pending_responses.lock();
pending_acks.remove(&id);
pending_responses.remove(&id);
log::error!("Failed to publish RPC request: {}", e);
return Err(e);
}
match tokio::time::timeout(max_round_trip_latency, ack_rx).await {
Err(_) => {
let mut pending_acks = self.pending_acks.lock();
let mut pending_responses = self.pending_responses.lock();
pending_acks.remove(&id);
pending_responses.remove(&id);
return Err(RpcError::built_in(RpcErrorCode::ConnectionTimeout, None));
}
Ok(_) => {
}
}
let response = match tokio::time::timeout(data.response_timeout, response_rx).await {
Err(_) => {
self.pending_responses.lock().remove(&id);
return Err(RpcError::built_in(RpcErrorCode::ResponseTimeout, None));
}
Ok(result) => result,
};
match response {
Err(_) => {
Err(RpcError::built_in(RpcErrorCode::RecipientDisconnected, None))
}
Ok(Err(e)) => {
Err(e)
}
Ok(Ok(payload)) => {
Ok(payload)
}
}
}
pub(crate) async fn send_v1_request(
&self,
transport: &impl RpcTransport,
destination_identity: &str,
id: &str,
method: &str,
payload: &str,
response_timeout: Duration,
) -> Result<(), crate::room::RoomError> {
let rpc_request_message = proto::RpcRequest {
id: id.to_string(),
method: method.to_string(),
payload: payload.to_string(),
response_timeout_ms: response_timeout.as_millis() as u32,
version: RPC_VERSION_V1,
..Default::default()
};
let data = proto::DataPacket {
value: Some(proto::data_packet::Value::RpcRequest(rpc_request_message)),
destination_identities: vec![destination_identity.to_string()],
..Default::default()
};
transport.publish_data(data).await
}
async fn send_v2_request(
&self,
transport: &impl RpcTransport,
destination_identity: &str,
id: &str,
method: &str,
payload: &str,
response_timeout: Duration,
) -> Result<(), RpcError> {
let mut attributes = HashMap::new();
attributes.insert(ATTR_REQUEST_ID.to_string(), id.to_string());
attributes.insert(ATTR_METHOD.to_string(), method.to_string());
attributes
.insert(ATTR_RESPONSE_TIMEOUT_MS.to_string(), response_timeout.as_millis().to_string());
attributes.insert(ATTR_VERSION.to_string(), RPC_VERSION_V2.to_string());
let options = StreamTextOptions {
topic: RPC_REQUEST_TOPIC.to_string(),
attributes,
destination_identities: vec![ParticipantIdentity(destination_identity.to_string())],
..Default::default()
};
transport
.send_text(payload, options)
.await
.map(|_| ())
.map_err(|e| RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string())))
}
#[cfg(test)]
pub(crate) fn drop_pending_response(&self, request_id: &str) {
self.pending_responses.lock().remove(request_id);
}
#[cfg(test)]
pub(crate) fn insert_pending_response(
&self,
request_id: String,
tx: tokio::sync::oneshot::Sender<Result<String, RpcError>>,
) {
self.pending_responses.lock().insert(request_id, tx);
}
pub(crate) fn handle_incoming_rpc_ack(&self, request_id: String) {
let mut pending = self.pending_acks.lock();
if let Some(tx) = pending.remove(&request_id) {
let _ = tx.send(());
} else {
log::error!("Ack received for unexpected RPC request: {}", request_id);
}
}
pub(crate) fn handle_v1_response_packet(
&self,
request_id: String,
payload: Option<String>,
error: Option<proto::RpcError>,
) {
let mut pending = self.pending_responses.lock();
if let Some(tx) = pending.remove(&request_id) {
let _ = tx.send(match error {
Some(e) => Err(RpcError::from_proto(e)),
None => Ok(payload.unwrap_or_default()),
});
} else {
log::error!("Response received for unexpected RPC request: {}", request_id);
}
}
pub(crate) async fn handle_v2_response_stream(&self, reader: TextStreamReader) {
let request_id = reader.info().attributes.get(ATTR_REQUEST_ID).cloned().unwrap_or_default();
if request_id.is_empty() {
log::error!("RPC v2 response stream missing request_id attribute");
return;
}
let payload = match reader.read_all().await {
Ok(payload) => payload,
Err(e) => {
log::error!("Failed to read RPC v2 response stream: {:?}", e);
let mut pending = self.pending_responses.lock();
if let Some(tx) = pending.remove(&request_id) {
let _ = tx.send(Err(RpcError::built_in(
RpcErrorCode::ApplicationError,
Some(format!("Failed to read response stream: {}", e)),
)));
}
return;
}
};
let mut pending = self.pending_responses.lock();
if let Some(tx) = pending.remove(&request_id) {
let _ = tx.send(Ok(payload));
} else {
log::error!("Response stream received for unexpected RPC request: {}", request_id);
}
}
}