1use rlmesh_proto::{
2 CURRENT_WORKFLOW_EDITION_SPEC_SHA256, CURRENT_WORKFLOW_EDITION_STATUS, PROTOCOL_GENERATION,
3 SUPPORTED_PROTOCOL_GENERATIONS, capabilities, capability_map, check_provisional_edition_pin,
4 core::v1::OperationTelemetry,
5 is_protocol_generation_supported,
6 model::v1::{
7 CloseRequest, CloseRouteRequest, ConfigureRouteRequest, HandshakeRequest, JoinRequest,
8 JoinResponse, PredictRequest, PredictResponse, ShutdownRequest, ShutdownResponse,
9 join_request, join_response, model_service_client::ModelServiceClient,
10 },
11 supported_workflow_editions,
12};
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::time::Duration;
17use tokio::sync::{mpsc, oneshot};
18use tokio_stream::wrappers::ReceiverStream;
19
20use crate::error::{Error as GrpcError, ProtocolError, TransportError};
21use crate::helpers::normalize_tcp_session_address;
22use crate::states::ClientState;
23
24use super::stream::{PendingResponses, spawn_response_pump};
25use super::validation::{decode_error, route_request_id, validate_predict_route, validate_route};
26use super::wire::{join_request_kind_name, model_error_to_grpc_error};
27
28pub struct ModelClient {
47 address: String,
48 client: ModelServiceClient<tonic::transport::Channel>,
49 token: String,
50 state: ClientState,
51 request_tx: Option<mpsc::Sender<JoinRequest>>,
52 pending: PendingResponses,
53 request_counter: Arc<AtomicU64>,
54 last_telemetry: Option<OperationTelemetry>,
55 server_capabilities: HashMap<String, String>,
56}
57
58impl ModelClient {
59 pub async fn connect(address: &str, token: &str) -> Result<Self, GrpcError> {
60 let address = normalize_tcp_session_address(address)?;
61 let endpoint = crate::configure_endpoint(
62 tonic::transport::Endpoint::from_shared(address.replacen("tcp://", "http://", 1))
63 .map_err(|err| TransportError::InvalidAddress(err.to_string()))?,
64 );
65 let channel = endpoint
66 .connect()
67 .await
68 .map_err(|err| TransportError::ConnectFailed(err.to_string()))?;
69
70 Ok(Self {
71 address,
72 client: ModelServiceClient::new(channel)
73 .max_decoding_message_size(crate::MAX_MESSAGE_SIZE)
74 .max_encoding_message_size(crate::MAX_MESSAGE_SIZE),
75 token: token.to_string(),
76 state: ClientState::Connected,
77 request_tx: None,
78 pending: Default::default(),
79 request_counter: Arc::new(AtomicU64::new(0)),
80 last_telemetry: None,
81 server_capabilities: HashMap::new(),
82 })
83 }
84
85 pub async fn connect_with_retry(
91 address: &str,
92 token: &str,
93 options: &crate::connect::ConnectOptions,
94 ) -> Result<Self, GrpcError> {
95 crate::connect::retry_connect(options, || Self::connect(address, token)).await
96 }
97
98 pub fn address(&self) -> &str {
99 &self.address
100 }
101
102 pub fn take_last_telemetry(&mut self) -> Option<OperationTelemetry> {
103 self.last_telemetry.take()
104 }
105
106 pub fn server_pipelines_predict(&self) -> bool {
111 rlmesh_proto::has_capability(
112 &self.server_capabilities,
113 capabilities::MODEL_CONCURRENT_PREDICT_V1,
114 )
115 }
116
117 pub async fn handshake(&mut self) -> Result<(), GrpcError> {
118 if self.state != ClientState::Connected {
119 return Err(crate::error::ClientError::NotConnected.into());
120 }
121
122 let request = self.authorized_request(HandshakeRequest {
123 protocol_generation: PROTOCOL_GENERATION.to_string(),
124 client_name: "rlmesh-rust-model-grpc".to_string(),
125 client_version: env!("CARGO_PKG_VERSION").to_string(),
126 capabilities: capability_map(&[
127 capabilities::MODEL_SERVICE_V1,
128 capabilities::SPACES_CORE_V1,
129 ]),
130 supported_workflow_editions: supported_workflow_editions(),
131 offered_edition_spec_sha256: CURRENT_WORKFLOW_EDITION_SPEC_SHA256.to_string(),
132 offered_edition_status: CURRENT_WORKFLOW_EDITION_STATUS.to_string(),
133 })?;
134
135 let response = self
136 .client
137 .handshake(request)
138 .await
139 .map_err(crate::error::status_to_grpc_error)?
140 .into_inner();
141
142 if !response.compatible {
143 return Err(ProtocolError::HandshakeFailed(response.error_message).into());
144 }
145 check_provisional_edition_pin(
146 &response.selected_workflow_edition,
147 &response.selected_edition_status,
148 &response.selected_edition_spec_sha256,
149 &response.server_version,
150 )
151 .map_err(ProtocolError::HandshakeFailed)?;
152 if !is_protocol_generation_supported(&response.server_protocol_generation) {
153 return Err(ProtocolError::HandshakeFailed(format!(
154 "server protocol generation {} is unsupported by this client (supports {SUPPORTED_PROTOCOL_GENERATIONS:?})",
155 response.server_protocol_generation
156 ))
157 .into());
158 }
159 self.server_capabilities = response.capabilities;
160
161 self.setup_join_stream().await?;
162 self.state = ClientState::Ready;
163 Ok(())
164 }
165
166 pub async fn configure_route(
167 &mut self,
168 request: ConfigureRouteRequest,
169 ) -> Result<(), GrpcError> {
170 self.ensure_ready()?;
171 validate_route(
172 request
173 .context
174 .as_ref()
175 .ok_or_else(|| decode_error("configure_route missing route context"))?,
176 )?;
177 let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
178 let response = self
179 .send_on_stream(JoinRequest {
180 kind: Some(join_request::Kind::ConfigureRoute(request)),
181 request_id,
182 })
183 .await?;
184 self.last_telemetry = response.telemetry.clone();
185 match response.kind {
186 Some(join_response::Kind::ConfigureRoute(_)) => Ok(()),
187 Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
188 _ => Err(ProtocolError::UnexpectedMessage {
189 expected: "ConfigureRouteResponse".to_string(),
190 actual: format!("{:?}", response.kind),
191 }
192 .into()),
193 }
194 }
195
196 pub async fn predict(&mut self, request: PredictRequest) -> Result<PredictResponse, GrpcError> {
197 self.ensure_ready()?;
198 validate_predict_route(
199 request
200 .context
201 .as_ref()
202 .ok_or_else(|| decode_error("predict missing route context"))?,
203 )?;
204 let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
205 let response = self
206 .send_on_stream(JoinRequest {
207 kind: Some(join_request::Kind::Predict(request)),
208 request_id,
209 })
210 .await?;
211 self.last_telemetry = response.telemetry.clone();
212
213 match response.kind {
214 Some(join_response::Kind::Predict(predict)) => Ok(predict),
215 Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
216 _ => Err(ProtocolError::UnexpectedMessage {
217 expected: "PredictResponse".to_string(),
218 actual: format!("{:?}", response.kind),
219 }
220 .into()),
221 }
222 }
223
224 pub async fn close_route(&mut self, request: CloseRouteRequest) -> Result<(), GrpcError> {
225 self.ensure_ready()?;
226 validate_route(
227 request
228 .context
229 .as_ref()
230 .ok_or_else(|| decode_error("close_route missing route context"))?,
231 )?;
232 let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
233 let response = self
234 .send_on_stream(JoinRequest {
235 kind: Some(join_request::Kind::CloseRoute(request)),
236 request_id,
237 })
238 .await?;
239 self.last_telemetry = response.telemetry.clone();
240 match response.kind {
241 Some(join_response::Kind::CloseRoute(_)) => Ok(()),
242 Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
243 _ => Err(ProtocolError::UnexpectedMessage {
244 expected: "CloseRouteResponse".to_string(),
245 actual: format!("{:?}", response.kind),
246 }
247 .into()),
248 }
249 }
250
251 pub async fn close(&mut self, reason: impl Into<String>) -> Result<(), GrpcError> {
252 self.close_with_timeout(reason, Duration::from_secs(5))
253 .await
254 }
255
256 pub async fn close_with_timeout(
257 &mut self,
258 reason: impl Into<String>,
259 timeout: Duration,
260 ) -> Result<(), GrpcError> {
261 if self.state == ClientState::Closed {
262 return Err(crate::error::ClientError::NotConnected.into());
263 }
264 self.ensure_ready()?;
265
266 let request = JoinRequest {
267 kind: Some(join_request::Kind::Close(CloseRequest {
268 reason: reason.into(),
269 })),
270 request_id: self.next_request_id(),
271 };
272
273 let response = tokio::time::timeout(timeout, self.send_on_stream(request))
274 .await
275 .map_err(|_| GrpcError::Timeout(timeout))??;
276 self.last_telemetry = response.telemetry.clone();
277 self.state = ClientState::Closed;
278
279 match response.kind {
280 Some(join_response::Kind::Close(_)) => Ok(()),
281 Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
282 _ => Err(ProtocolError::UnexpectedMessage {
283 expected: "CloseResponse".to_string(),
284 actual: format!("{:?}", response.kind),
285 }
286 .into()),
287 }
288 }
289
290 pub async fn shutdown(
291 &mut self,
292 reason: impl Into<String>,
293 ) -> Result<ShutdownResponse, GrpcError> {
294 if self.state == ClientState::Closed {
295 return Err(crate::error::ClientError::NotConnected.into());
296 }
297
298 let request = self.authorized_request(ShutdownRequest {
299 reason: reason.into(),
300 })?;
301 let response = self
302 .client
303 .shutdown(request)
304 .await
305 .map_err(crate::error::status_to_grpc_error)?
306 .into_inner();
307
308 if response.accepted {
309 self.state = ClientState::Closed;
310 self.request_tx.take();
311 self.pending.lock().expect("pending map poisoned").clear();
314 }
315
316 Ok(response)
317 }
318
319 pub async fn predict_concurrent(
329 &self,
330 request: PredictRequest,
331 ) -> Result<PredictResponse, GrpcError> {
332 self.ensure_ready()?;
333 validate_predict_route(
334 request
335 .context
336 .as_ref()
337 .ok_or_else(|| decode_error("predict missing route context"))?,
338 )?;
339 let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
340 let response = self
341 .send_on_stream(JoinRequest {
342 kind: Some(join_request::Kind::Predict(request)),
343 request_id,
344 })
345 .await?;
346 match response.kind {
347 Some(join_response::Kind::Predict(predict)) => Ok(predict),
348 Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
349 _ => Err(ProtocolError::UnexpectedMessage {
350 expected: "PredictResponse".to_string(),
351 actual: format!("{:?}", response.kind),
352 }
353 .into()),
354 }
355 }
356
357 async fn setup_join_stream(&mut self) -> Result<(), GrpcError> {
358 let (tx, rx) = mpsc::channel::<JoinRequest>(32);
359 let request_stream = ReceiverStream::new(rx);
360 let request = self.authorized_request(request_stream)?;
361
362 let response = self
363 .client
364 .join(request)
365 .await
366 .map_err(crate::error::status_to_grpc_error)?;
367
368 self.request_tx = Some(tx);
369 spawn_response_pump(response.into_inner(), Arc::clone(&self.pending));
370 Ok(())
371 }
372
373 async fn send_on_stream(&self, request: JoinRequest) -> Result<JoinResponse, GrpcError> {
377 let request_id = request.request_id.clone();
378 let request_kind = join_request_kind_name(request.kind.as_ref());
379 let tx = self
380 .request_tx
381 .clone()
382 .ok_or(crate::error::ClientError::NotHandshaked)?;
383
384 let (response_tx, response_rx) = oneshot::channel();
387 {
388 let mut pending = self.pending.lock().expect("pending map poisoned");
389 if pending.contains_key(&request_id) {
393 return Err(crate::error::ProtocolError::DecodeError(format!(
394 "request_id {request_id:?} is already in flight on this stream"
395 ))
396 .into());
397 }
398 pending.insert(request_id.clone(), response_tx);
399 }
400
401 if tx.send(request).await.is_err() {
402 self.pending
404 .lock()
405 .expect("pending map poisoned")
406 .remove(&request_id);
407 return Err(TransportError::ConnectionClosed.into());
408 }
409
410 match response_rx.await {
411 Ok(Ok(response)) => Ok(response),
412 Ok(Err(status)) => {
413 tracing::error!(
414 request_id = %request_id,
415 request_kind,
416 code = ?status.code(),
417 message = %status.message(),
418 "model join stream returned an error status"
419 );
420 Err(crate::error::status_to_grpc_error(status))
421 }
422 Err(_) => {
423 tracing::error!(
425 request_id = %request_id,
426 request_kind,
427 "model join stream closed while waiting for response"
428 );
429 Err(TransportError::ConnectionClosed.into())
430 }
431 }
432 }
433
434 fn ensure_ready(&self) -> Result<(), GrpcError> {
435 match self.state {
436 ClientState::Ready => Ok(()),
437 ClientState::Connected => Err(crate::error::ClientError::NotHandshaked.into()),
438 ClientState::Closed => Err(crate::error::ClientError::NotConnected.into()),
439 }
440 }
441
442 fn next_request_id(&self) -> String {
443 let id = self.request_counter.fetch_add(1, Ordering::Relaxed) + 1;
444 format!("model-grpc-req-{id}")
445 }
446
447 fn authorized_request<T>(&self, message: T) -> Result<tonic::Request<T>, GrpcError> {
448 let mut request = tonic::Request::new(message);
449 if !self.token.is_empty() {
450 request.metadata_mut().insert(
451 "authorization",
452 self.token
453 .parse()
454 .map_err(|_| TransportError::InvalidAddress("invalid token".to_string()))?,
455 );
456 }
457 Ok(request)
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use rlmesh_proto::model::v1::PredictResponse;
465 use tonic::transport::Endpoint;
466
467 fn ready_client() -> (ModelClient, mpsc::Receiver<JoinRequest>, PendingResponses) {
472 let (request_tx, request_rx) = mpsc::channel(8);
473 let channel = Endpoint::from_static("http://127.0.0.1:1").connect_lazy();
474 let pending: PendingResponses = Default::default();
475 let client = ModelClient {
476 address: "tcp://127.0.0.1:1".to_string(),
477 client: ModelServiceClient::new(channel),
478 token: String::new(),
479 state: ClientState::Ready,
480 request_tx: Some(request_tx),
481 pending: Arc::clone(&pending),
482 request_counter: Arc::new(AtomicU64::new(0)),
483 last_telemetry: None,
484 server_capabilities: HashMap::new(),
485 };
486 (client, request_rx, pending)
487 }
488
489 fn deliver(pending: &PendingResponses, request_id: &str, response: JoinResponse) {
491 let sender = pending
492 .lock()
493 .unwrap()
494 .remove(request_id)
495 .expect("expected a pending waiter for the request id");
496 sender.send(Ok(response)).expect("waiter still alive");
497 }
498
499 fn predict_response_for(request_id: &str) -> JoinResponse {
500 JoinResponse {
501 request_id: request_id.to_string(),
502 kind: Some(join_response::Kind::Predict(PredictResponse::default())),
503 telemetry: None,
504 }
505 }
506
507 #[tokio::test]
508 async fn send_on_stream_resolves_by_request_id() {
509 let (client, mut request_rx, pending) = ready_client();
510
511 let send = tokio::spawn(async move {
512 client
513 .send_on_stream(JoinRequest {
514 request_id: "target".to_string(),
515 kind: Some(join_request::Kind::Predict(PredictRequest::default())),
516 })
517 .await
518 });
519
520 let sent = request_rx.recv().await.unwrap();
522 assert_eq!(sent.request_id, "target");
523 deliver(&pending, "target", predict_response_for("target"));
524
525 let response = send.await.unwrap().unwrap();
526 assert_eq!(response.request_id, "target");
527 }
528
529 #[tokio::test]
530 async fn two_overlapping_requests_demux_out_of_order() {
531 let (client, mut request_rx, pending) = ready_client();
532 let client = Arc::new(client);
533
534 let c1 = Arc::clone(&client);
536 let first = tokio::spawn(async move {
537 c1.send_on_stream(JoinRequest {
538 request_id: "req-1".to_string(),
539 kind: Some(join_request::Kind::Predict(PredictRequest::default())),
540 })
541 .await
542 });
543 let c2 = Arc::clone(&client);
544 let second = tokio::spawn(async move {
545 c2.send_on_stream(JoinRequest {
546 request_id: "req-2".to_string(),
547 kind: Some(join_request::Kind::Predict(PredictRequest::default())),
548 })
549 .await
550 });
551
552 let mut sent_ids = vec![
554 request_rx.recv().await.unwrap().request_id,
555 request_rx.recv().await.unwrap().request_id,
556 ];
557 sent_ids.sort();
558 assert_eq!(sent_ids, vec!["req-1".to_string(), "req-2".to_string()]);
559
560 deliver(&pending, "req-2", predict_response_for("req-2"));
562 deliver(&pending, "req-1", predict_response_for("req-1"));
563
564 assert_eq!(first.await.unwrap().unwrap().request_id, "req-1");
566 assert_eq!(second.await.unwrap().unwrap().request_id, "req-2");
567 }
568
569 #[tokio::test]
570 async fn send_on_stream_errors_when_waiter_dropped_by_stream_close() {
571 let (client, _request_rx, pending) = ready_client();
572
573 let send = tokio::spawn(async move {
574 client
575 .send_on_stream(JoinRequest {
576 request_id: "orphan".to_string(),
577 kind: Some(join_request::Kind::Predict(PredictRequest::default())),
578 })
579 .await
580 });
581
582 tokio::time::sleep(Duration::from_millis(20)).await;
584 pending.lock().unwrap().clear();
585
586 let result = send.await.unwrap();
587 assert!(result.is_err(), "a closed stream must fail the waiter");
588 }
589}