use rlmesh_proto::{
CURRENT_WORKFLOW_EDITION_SPEC_SHA256, CURRENT_WORKFLOW_EDITION_STATUS, PROTOCOL_GENERATION,
SUPPORTED_PROTOCOL_GENERATIONS, capabilities, capability_map, check_provisional_edition_pin,
core::v1::OperationTelemetry,
is_protocol_generation_supported,
model::v1::{
CloseRequest, CloseRouteRequest, ConfigureRouteRequest, HandshakeRequest, JoinRequest,
JoinResponse, PredictRequest, PredictResponse, ShutdownRequest, ShutdownResponse,
join_request, join_response, model_service_client::ModelServiceClient,
},
supported_workflow_editions,
};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use crate::error::{Error as GrpcError, ProtocolError, TransportError};
use crate::helpers::normalize_tcp_session_address;
use crate::states::ClientState;
use super::stream::{PendingResponses, spawn_response_pump};
use super::validation::{decode_error, route_request_id, validate_predict_route, validate_route};
use super::wire::{join_request_kind_name, model_error_to_grpc_error};
pub struct ModelClient {
address: String,
client: ModelServiceClient<tonic::transport::Channel>,
token: String,
state: ClientState,
request_tx: Option<mpsc::Sender<JoinRequest>>,
pending: PendingResponses,
request_counter: Arc<AtomicU64>,
last_telemetry: Option<OperationTelemetry>,
server_capabilities: HashMap<String, String>,
}
impl ModelClient {
pub async fn connect(address: &str, token: &str) -> Result<Self, GrpcError> {
let address = normalize_tcp_session_address(address)?;
let endpoint = crate::configure_endpoint(
tonic::transport::Endpoint::from_shared(address.replacen("tcp://", "http://", 1))
.map_err(|err| TransportError::InvalidAddress(err.to_string()))?,
);
let channel = endpoint
.connect()
.await
.map_err(|err| TransportError::ConnectFailed(err.to_string()))?;
Ok(Self {
address,
client: ModelServiceClient::new(channel)
.max_decoding_message_size(crate::MAX_MESSAGE_SIZE)
.max_encoding_message_size(crate::MAX_MESSAGE_SIZE),
token: token.to_string(),
state: ClientState::Connected,
request_tx: None,
pending: Default::default(),
request_counter: Arc::new(AtomicU64::new(0)),
last_telemetry: None,
server_capabilities: HashMap::new(),
})
}
pub async fn connect_with_retry(
address: &str,
token: &str,
options: &crate::connect::ConnectOptions,
) -> Result<Self, GrpcError> {
crate::connect::retry_connect(options, || Self::connect(address, token)).await
}
pub fn address(&self) -> &str {
&self.address
}
pub fn take_last_telemetry(&mut self) -> Option<OperationTelemetry> {
self.last_telemetry.take()
}
pub fn server_pipelines_predict(&self) -> bool {
rlmesh_proto::has_capability(
&self.server_capabilities,
capabilities::MODEL_CONCURRENT_PREDICT_V1,
)
}
pub async fn handshake(&mut self) -> Result<(), GrpcError> {
if self.state != ClientState::Connected {
return Err(crate::error::ClientError::NotConnected.into());
}
let request = self.authorized_request(HandshakeRequest {
protocol_generation: PROTOCOL_GENERATION.to_string(),
client_name: "rlmesh-rust-model-grpc".to_string(),
client_version: env!("CARGO_PKG_VERSION").to_string(),
capabilities: capability_map(&[
capabilities::MODEL_SERVICE_V1,
capabilities::SPACES_CORE_V1,
]),
supported_workflow_editions: supported_workflow_editions(),
offered_edition_spec_sha256: CURRENT_WORKFLOW_EDITION_SPEC_SHA256.to_string(),
offered_edition_status: CURRENT_WORKFLOW_EDITION_STATUS.to_string(),
})?;
let response = self
.client
.handshake(request)
.await
.map_err(crate::error::status_to_grpc_error)?
.into_inner();
if !response.compatible {
return Err(ProtocolError::HandshakeFailed(response.error_message).into());
}
check_provisional_edition_pin(
&response.selected_workflow_edition,
&response.selected_edition_status,
&response.selected_edition_spec_sha256,
&response.server_version,
)
.map_err(ProtocolError::HandshakeFailed)?;
if !is_protocol_generation_supported(&response.server_protocol_generation) {
return Err(ProtocolError::HandshakeFailed(format!(
"server protocol generation {} is unsupported by this client (supports {SUPPORTED_PROTOCOL_GENERATIONS:?})",
response.server_protocol_generation
))
.into());
}
self.server_capabilities = response.capabilities;
self.setup_join_stream().await?;
self.state = ClientState::Ready;
Ok(())
}
pub async fn configure_route(
&mut self,
request: ConfigureRouteRequest,
) -> Result<(), GrpcError> {
self.ensure_ready()?;
validate_route(
request
.context
.as_ref()
.ok_or_else(|| decode_error("configure_route missing route context"))?,
)?;
let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
let response = self
.send_on_stream(JoinRequest {
kind: Some(join_request::Kind::ConfigureRoute(request)),
request_id,
})
.await?;
self.last_telemetry = response.telemetry.clone();
match response.kind {
Some(join_response::Kind::ConfigureRoute(_)) => Ok(()),
Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
_ => Err(ProtocolError::UnexpectedMessage {
expected: "ConfigureRouteResponse".to_string(),
actual: format!("{:?}", response.kind),
}
.into()),
}
}
pub async fn predict(&mut self, request: PredictRequest) -> Result<PredictResponse, GrpcError> {
self.ensure_ready()?;
validate_predict_route(
request
.context
.as_ref()
.ok_or_else(|| decode_error("predict missing route context"))?,
)?;
let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
let response = self
.send_on_stream(JoinRequest {
kind: Some(join_request::Kind::Predict(request)),
request_id,
})
.await?;
self.last_telemetry = response.telemetry.clone();
match response.kind {
Some(join_response::Kind::Predict(predict)) => Ok(predict),
Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
_ => Err(ProtocolError::UnexpectedMessage {
expected: "PredictResponse".to_string(),
actual: format!("{:?}", response.kind),
}
.into()),
}
}
pub async fn close_route(&mut self, request: CloseRouteRequest) -> Result<(), GrpcError> {
self.ensure_ready()?;
validate_route(
request
.context
.as_ref()
.ok_or_else(|| decode_error("close_route missing route context"))?,
)?;
let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
let response = self
.send_on_stream(JoinRequest {
kind: Some(join_request::Kind::CloseRoute(request)),
request_id,
})
.await?;
self.last_telemetry = response.telemetry.clone();
match response.kind {
Some(join_response::Kind::CloseRoute(_)) => Ok(()),
Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
_ => Err(ProtocolError::UnexpectedMessage {
expected: "CloseRouteResponse".to_string(),
actual: format!("{:?}", response.kind),
}
.into()),
}
}
pub async fn close(&mut self, reason: impl Into<String>) -> Result<(), GrpcError> {
self.close_with_timeout(reason, Duration::from_secs(5))
.await
}
pub async fn close_with_timeout(
&mut self,
reason: impl Into<String>,
timeout: Duration,
) -> Result<(), GrpcError> {
if self.state == ClientState::Closed {
return Err(crate::error::ClientError::NotConnected.into());
}
self.ensure_ready()?;
let request = JoinRequest {
kind: Some(join_request::Kind::Close(CloseRequest {
reason: reason.into(),
})),
request_id: self.next_request_id(),
};
let response = tokio::time::timeout(timeout, self.send_on_stream(request))
.await
.map_err(|_| GrpcError::Timeout(timeout))??;
self.last_telemetry = response.telemetry.clone();
self.state = ClientState::Closed;
match response.kind {
Some(join_response::Kind::Close(_)) => Ok(()),
Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
_ => Err(ProtocolError::UnexpectedMessage {
expected: "CloseResponse".to_string(),
actual: format!("{:?}", response.kind),
}
.into()),
}
}
pub async fn shutdown(
&mut self,
reason: impl Into<String>,
) -> Result<ShutdownResponse, GrpcError> {
if self.state == ClientState::Closed {
return Err(crate::error::ClientError::NotConnected.into());
}
let request = self.authorized_request(ShutdownRequest {
reason: reason.into(),
})?;
let response = self
.client
.shutdown(request)
.await
.map_err(crate::error::status_to_grpc_error)?
.into_inner();
if response.accepted {
self.state = ClientState::Closed;
self.request_tx.take();
self.pending.lock().expect("pending map poisoned").clear();
}
Ok(response)
}
pub async fn predict_concurrent(
&self,
request: PredictRequest,
) -> Result<PredictResponse, GrpcError> {
self.ensure_ready()?;
validate_predict_route(
request
.context
.as_ref()
.ok_or_else(|| decode_error("predict missing route context"))?,
)?;
let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
let response = self
.send_on_stream(JoinRequest {
kind: Some(join_request::Kind::Predict(request)),
request_id,
})
.await?;
match response.kind {
Some(join_response::Kind::Predict(predict)) => Ok(predict),
Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
_ => Err(ProtocolError::UnexpectedMessage {
expected: "PredictResponse".to_string(),
actual: format!("{:?}", response.kind),
}
.into()),
}
}
async fn setup_join_stream(&mut self) -> Result<(), GrpcError> {
let (tx, rx) = mpsc::channel::<JoinRequest>(32);
let request_stream = ReceiverStream::new(rx);
let request = self.authorized_request(request_stream)?;
let response = self
.client
.join(request)
.await
.map_err(crate::error::status_to_grpc_error)?;
self.request_tx = Some(tx);
spawn_response_pump(response.into_inner(), Arc::clone(&self.pending));
Ok(())
}
async fn send_on_stream(&self, request: JoinRequest) -> Result<JoinResponse, GrpcError> {
let request_id = request.request_id.clone();
let request_kind = join_request_kind_name(request.kind.as_ref());
let tx = self
.request_tx
.clone()
.ok_or(crate::error::ClientError::NotHandshaked)?;
let (response_tx, response_rx) = oneshot::channel();
{
let mut pending = self.pending.lock().expect("pending map poisoned");
if pending.contains_key(&request_id) {
return Err(crate::error::ProtocolError::DecodeError(format!(
"request_id {request_id:?} is already in flight on this stream"
))
.into());
}
pending.insert(request_id.clone(), response_tx);
}
if tx.send(request).await.is_err() {
self.pending
.lock()
.expect("pending map poisoned")
.remove(&request_id);
return Err(TransportError::ConnectionClosed.into());
}
match response_rx.await {
Ok(Ok(response)) => Ok(response),
Ok(Err(status)) => {
tracing::error!(
request_id = %request_id,
request_kind,
code = ?status.code(),
message = %status.message(),
"model join stream returned an error status"
);
Err(crate::error::status_to_grpc_error(status))
}
Err(_) => {
tracing::error!(
request_id = %request_id,
request_kind,
"model join stream closed while waiting for response"
);
Err(TransportError::ConnectionClosed.into())
}
}
}
fn ensure_ready(&self) -> Result<(), GrpcError> {
match self.state {
ClientState::Ready => Ok(()),
ClientState::Connected => Err(crate::error::ClientError::NotHandshaked.into()),
ClientState::Closed => Err(crate::error::ClientError::NotConnected.into()),
}
}
fn next_request_id(&self) -> String {
let id = self.request_counter.fetch_add(1, Ordering::Relaxed) + 1;
format!("model-grpc-req-{id}")
}
fn authorized_request<T>(&self, message: T) -> Result<tonic::Request<T>, GrpcError> {
let mut request = tonic::Request::new(message);
if !self.token.is_empty() {
request.metadata_mut().insert(
"authorization",
self.token
.parse()
.map_err(|_| TransportError::InvalidAddress("invalid token".to_string()))?,
);
}
Ok(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlmesh_proto::model::v1::PredictResponse;
use tonic::transport::Endpoint;
fn ready_client() -> (ModelClient, mpsc::Receiver<JoinRequest>, PendingResponses) {
let (request_tx, request_rx) = mpsc::channel(8);
let channel = Endpoint::from_static("http://127.0.0.1:1").connect_lazy();
let pending: PendingResponses = Default::default();
let client = ModelClient {
address: "tcp://127.0.0.1:1".to_string(),
client: ModelServiceClient::new(channel),
token: String::new(),
state: ClientState::Ready,
request_tx: Some(request_tx),
pending: Arc::clone(&pending),
request_counter: Arc::new(AtomicU64::new(0)),
last_telemetry: None,
server_capabilities: HashMap::new(),
};
(client, request_rx, pending)
}
fn deliver(pending: &PendingResponses, request_id: &str, response: JoinResponse) {
let sender = pending
.lock()
.unwrap()
.remove(request_id)
.expect("expected a pending waiter for the request id");
sender.send(Ok(response)).expect("waiter still alive");
}
fn predict_response_for(request_id: &str) -> JoinResponse {
JoinResponse {
request_id: request_id.to_string(),
kind: Some(join_response::Kind::Predict(PredictResponse::default())),
telemetry: None,
}
}
#[tokio::test]
async fn send_on_stream_resolves_by_request_id() {
let (client, mut request_rx, pending) = ready_client();
let send = tokio::spawn(async move {
client
.send_on_stream(JoinRequest {
request_id: "target".to_string(),
kind: Some(join_request::Kind::Predict(PredictRequest::default())),
})
.await
});
let sent = request_rx.recv().await.unwrap();
assert_eq!(sent.request_id, "target");
deliver(&pending, "target", predict_response_for("target"));
let response = send.await.unwrap().unwrap();
assert_eq!(response.request_id, "target");
}
#[tokio::test]
async fn two_overlapping_requests_demux_out_of_order() {
let (client, mut request_rx, pending) = ready_client();
let client = Arc::new(client);
let c1 = Arc::clone(&client);
let first = tokio::spawn(async move {
c1.send_on_stream(JoinRequest {
request_id: "req-1".to_string(),
kind: Some(join_request::Kind::Predict(PredictRequest::default())),
})
.await
});
let c2 = Arc::clone(&client);
let second = tokio::spawn(async move {
c2.send_on_stream(JoinRequest {
request_id: "req-2".to_string(),
kind: Some(join_request::Kind::Predict(PredictRequest::default())),
})
.await
});
let mut sent_ids = vec![
request_rx.recv().await.unwrap().request_id,
request_rx.recv().await.unwrap().request_id,
];
sent_ids.sort();
assert_eq!(sent_ids, vec!["req-1".to_string(), "req-2".to_string()]);
deliver(&pending, "req-2", predict_response_for("req-2"));
deliver(&pending, "req-1", predict_response_for("req-1"));
assert_eq!(first.await.unwrap().unwrap().request_id, "req-1");
assert_eq!(second.await.unwrap().unwrap().request_id, "req-2");
}
#[tokio::test]
async fn send_on_stream_errors_when_waiter_dropped_by_stream_close() {
let (client, _request_rx, pending) = ready_client();
let send = tokio::spawn(async move {
client
.send_on_stream(JoinRequest {
request_id: "orphan".to_string(),
kind: Some(join_request::Kind::Predict(PredictRequest::default())),
})
.await
});
tokio::time::sleep(Duration::from_millis(20)).await;
pending.lock().unwrap().clear();
let result = send.await.unwrap();
assert!(result.is_err(), "a closed stream must fail the waiter");
}
}