Skip to main content

rlmesh_grpc/model/
client.rs

1use rlmesh_proto::{
2    CURRENT_WORKFLOW_EDITION_SPEC_SHA256, CURRENT_WORKFLOW_EDITION_STATUS, PROTOCOL_GENERATION,
3    SUPPORTED_PROTOCOL_GENERATIONS, capabilities, capability_map, check_provisional_edition_pin,
4    core::v1::OperationTelemetry,
5    is_protocol_generation_supported,
6    model::v1::{
7        CloseRequest, CloseRouteRequest, ConfigureRouteRequest, HandshakeRequest, JoinRequest,
8        JoinResponse, PredictRequest, PredictResponse, ShutdownRequest, ShutdownResponse,
9        join_request, join_response, model_service_client::ModelServiceClient,
10    },
11    supported_workflow_editions,
12};
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::time::Duration;
17use tokio::sync::{mpsc, oneshot};
18use tokio_stream::wrappers::ReceiverStream;
19
20use crate::error::{Error as GrpcError, ProtocolError, TransportError};
21use crate::helpers::normalize_tcp_session_address;
22use crate::states::ClientState;
23
24use super::stream::{PendingResponses, spawn_response_pump};
25use super::validation::{decode_error, route_request_id, validate_predict_route, validate_route};
26use super::wire::{join_request_kind_name, model_error_to_grpc_error};
27
28/// Client for a ModelService server's Join bidi stream.
29///
30/// # Concurrency: demux by `request_id`
31///
32/// Responses are demultiplexed by `request_id` through a shared pending map, so
33/// **multiple requests can be in flight on one connection at once**. A response
34/// pump routes each response to the matching waiter; a response with no pending
35/// waiter (a late one from an abandoned request, or an unknown id) is logged and
36/// dropped.
37///
38/// The public per-request methods ([`predict`](Self::predict),
39/// [`configure_route`](Self::configure_route), [`close_route`](Self::close_route),
40/// [`close`](Self::close)) take `&mut self` and await their own response, so used
41/// alone they behave exactly as before (one request at a time, response matched
42/// by id). To actually overlap predicts on one connection, use
43/// [`predict_concurrent`](Self::predict_concurrent), which takes `&self` and may
44/// be called from multiple tasks concurrently. The matching server advertises
45/// the `rlmesh.model.concurrent_predict.v1` capability when it pipelines.
46pub struct ModelClient {
47    address: String,
48    client: ModelServiceClient<tonic::transport::Channel>,
49    token: String,
50    state: ClientState,
51    request_tx: Option<mpsc::Sender<JoinRequest>>,
52    pending: PendingResponses,
53    request_counter: Arc<AtomicU64>,
54    last_telemetry: Option<OperationTelemetry>,
55    server_capabilities: HashMap<String, String>,
56}
57
58impl ModelClient {
59    pub async fn connect(address: &str, token: &str) -> Result<Self, GrpcError> {
60        let address = normalize_tcp_session_address(address)?;
61        let endpoint = crate::configure_endpoint(
62            tonic::transport::Endpoint::from_shared(address.replacen("tcp://", "http://", 1))
63                .map_err(|err| TransportError::InvalidAddress(err.to_string()))?,
64        );
65        let channel = endpoint
66            .connect()
67            .await
68            .map_err(|err| TransportError::ConnectFailed(err.to_string()))?;
69
70        Ok(Self {
71            address,
72            client: ModelServiceClient::new(channel)
73                .max_decoding_message_size(crate::MAX_MESSAGE_SIZE)
74                .max_encoding_message_size(crate::MAX_MESSAGE_SIZE),
75            token: token.to_string(),
76            state: ClientState::Connected,
77            request_tx: None,
78            pending: Default::default(),
79            request_counter: Arc::new(AtomicU64::new(0)),
80            last_telemetry: None,
81            server_capabilities: HashMap::new(),
82        })
83    }
84
85    /// Connect to a ModelService server, retrying until the server accepts the
86    /// connection (or the deadline/cancellation in `options` fires).
87    ///
88    /// Only the transport connect is retried; perform the handshake explicitly
89    /// on the returned client.
90    pub async fn connect_with_retry(
91        address: &str,
92        token: &str,
93        options: &crate::connect::ConnectOptions,
94    ) -> Result<Self, GrpcError> {
95        crate::connect::retry_connect(options, || Self::connect(address, token)).await
96    }
97
98    pub fn address(&self) -> &str {
99        &self.address
100    }
101
102    pub fn take_last_telemetry(&mut self) -> Option<OperationTelemetry> {
103        self.last_telemetry.take()
104    }
105
106    /// Whether the server advertised that it pipelines Join-stream predicts
107    /// (`rlmesh.model.concurrent_predict.v1`). Advisory: overlapping predicts via
108    /// [`predict_concurrent`](Self::predict_concurrent) work either way, but
109    /// serialize behind the handler when this is false.
110    pub fn server_pipelines_predict(&self) -> bool {
111        rlmesh_proto::has_capability(
112            &self.server_capabilities,
113            capabilities::MODEL_CONCURRENT_PREDICT_V1,
114        )
115    }
116
117    pub async fn handshake(&mut self) -> Result<(), GrpcError> {
118        if self.state != ClientState::Connected {
119            return Err(crate::error::ClientError::NotConnected.into());
120        }
121
122        let request = self.authorized_request(HandshakeRequest {
123            protocol_generation: PROTOCOL_GENERATION.to_string(),
124            client_name: "rlmesh-rust-model-grpc".to_string(),
125            client_version: env!("CARGO_PKG_VERSION").to_string(),
126            capabilities: capability_map(&[
127                capabilities::MODEL_SERVICE_V1,
128                capabilities::SPACES_CORE_V1,
129            ]),
130            supported_workflow_editions: supported_workflow_editions(),
131            offered_edition_spec_sha256: CURRENT_WORKFLOW_EDITION_SPEC_SHA256.to_string(),
132            offered_edition_status: CURRENT_WORKFLOW_EDITION_STATUS.to_string(),
133        })?;
134
135        let response = self
136            .client
137            .handshake(request)
138            .await
139            .map_err(crate::error::status_to_grpc_error)?
140            .into_inner();
141
142        if !response.compatible {
143            return Err(ProtocolError::HandshakeFailed(response.error_message).into());
144        }
145        check_provisional_edition_pin(
146            &response.selected_workflow_edition,
147            &response.selected_edition_status,
148            &response.selected_edition_spec_sha256,
149            &response.server_version,
150        )
151        .map_err(ProtocolError::HandshakeFailed)?;
152        if !is_protocol_generation_supported(&response.server_protocol_generation) {
153            return Err(ProtocolError::HandshakeFailed(format!(
154                "server protocol generation {} is unsupported by this client (supports {SUPPORTED_PROTOCOL_GENERATIONS:?})",
155                response.server_protocol_generation
156            ))
157            .into());
158        }
159        self.server_capabilities = response.capabilities;
160
161        self.setup_join_stream().await?;
162        self.state = ClientState::Ready;
163        Ok(())
164    }
165
166    pub async fn configure_route(
167        &mut self,
168        request: ConfigureRouteRequest,
169    ) -> Result<(), GrpcError> {
170        self.ensure_ready()?;
171        validate_route(
172            request
173                .context
174                .as_ref()
175                .ok_or_else(|| decode_error("configure_route missing route context"))?,
176        )?;
177        let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
178        let response = self
179            .send_on_stream(JoinRequest {
180                kind: Some(join_request::Kind::ConfigureRoute(request)),
181                request_id,
182            })
183            .await?;
184        self.last_telemetry = response.telemetry.clone();
185        match response.kind {
186            Some(join_response::Kind::ConfigureRoute(_)) => Ok(()),
187            Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
188            _ => Err(ProtocolError::UnexpectedMessage {
189                expected: "ConfigureRouteResponse".to_string(),
190                actual: format!("{:?}", response.kind),
191            }
192            .into()),
193        }
194    }
195
196    pub async fn predict(&mut self, request: PredictRequest) -> Result<PredictResponse, GrpcError> {
197        self.ensure_ready()?;
198        validate_predict_route(
199            request
200                .context
201                .as_ref()
202                .ok_or_else(|| decode_error("predict missing route context"))?,
203        )?;
204        let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
205        let response = self
206            .send_on_stream(JoinRequest {
207                kind: Some(join_request::Kind::Predict(request)),
208                request_id,
209            })
210            .await?;
211        self.last_telemetry = response.telemetry.clone();
212
213        match response.kind {
214            Some(join_response::Kind::Predict(predict)) => Ok(predict),
215            Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
216            _ => Err(ProtocolError::UnexpectedMessage {
217                expected: "PredictResponse".to_string(),
218                actual: format!("{:?}", response.kind),
219            }
220            .into()),
221        }
222    }
223
224    pub async fn close_route(&mut self, request: CloseRouteRequest) -> Result<(), GrpcError> {
225        self.ensure_ready()?;
226        validate_route(
227            request
228                .context
229                .as_ref()
230                .ok_or_else(|| decode_error("close_route missing route context"))?,
231        )?;
232        let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
233        let response = self
234            .send_on_stream(JoinRequest {
235                kind: Some(join_request::Kind::CloseRoute(request)),
236                request_id,
237            })
238            .await?;
239        self.last_telemetry = response.telemetry.clone();
240        match response.kind {
241            Some(join_response::Kind::CloseRoute(_)) => Ok(()),
242            Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
243            _ => Err(ProtocolError::UnexpectedMessage {
244                expected: "CloseRouteResponse".to_string(),
245                actual: format!("{:?}", response.kind),
246            }
247            .into()),
248        }
249    }
250
251    pub async fn close(&mut self, reason: impl Into<String>) -> Result<(), GrpcError> {
252        self.close_with_timeout(reason, Duration::from_secs(5))
253            .await
254    }
255
256    pub async fn close_with_timeout(
257        &mut self,
258        reason: impl Into<String>,
259        timeout: Duration,
260    ) -> Result<(), GrpcError> {
261        if self.state == ClientState::Closed {
262            return Err(crate::error::ClientError::NotConnected.into());
263        }
264        self.ensure_ready()?;
265
266        let request = JoinRequest {
267            kind: Some(join_request::Kind::Close(CloseRequest {
268                reason: reason.into(),
269            })),
270            request_id: self.next_request_id(),
271        };
272
273        let response = tokio::time::timeout(timeout, self.send_on_stream(request))
274            .await
275            .map_err(|_| GrpcError::Timeout(timeout))??;
276        self.last_telemetry = response.telemetry.clone();
277        self.state = ClientState::Closed;
278
279        match response.kind {
280            Some(join_response::Kind::Close(_)) => Ok(()),
281            Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
282            _ => Err(ProtocolError::UnexpectedMessage {
283                expected: "CloseResponse".to_string(),
284                actual: format!("{:?}", response.kind),
285            }
286            .into()),
287        }
288    }
289
290    pub async fn shutdown(
291        &mut self,
292        reason: impl Into<String>,
293    ) -> Result<ShutdownResponse, GrpcError> {
294        if self.state == ClientState::Closed {
295            return Err(crate::error::ClientError::NotConnected.into());
296        }
297
298        let request = self.authorized_request(ShutdownRequest {
299            reason: reason.into(),
300        })?;
301        let response = self
302            .client
303            .shutdown(request)
304            .await
305            .map_err(crate::error::status_to_grpc_error)?
306            .into_inner();
307
308        if response.accepted {
309            self.state = ClientState::Closed;
310            self.request_tx.take();
311            // Drop the request stream sender; the pump will then see the stream
312            // end and fail any still-pending waiters.
313            self.pending.lock().expect("pending map poisoned").clear();
314        }
315
316        Ok(response)
317    }
318
319    /// Issue a predict that may overlap other in-flight requests on the same
320    /// connection. Takes `&self`, so it can be called from multiple tasks
321    /// concurrently; responses are demuxed by `request_id`.
322    ///
323    /// Unlike [`predict`](Self::predict), this does not record
324    /// `last_telemetry` (that field is single-threaded `&mut self` state); read
325    /// per-call telemetry from the returned response if needed. The server only
326    /// pipelines these when it advertises `rlmesh.model.concurrent_predict.v1`;
327    /// against a serial server they still complete correctly, just serialized.
328    pub async fn predict_concurrent(
329        &self,
330        request: PredictRequest,
331    ) -> Result<PredictResponse, GrpcError> {
332        self.ensure_ready()?;
333        validate_predict_route(
334            request
335                .context
336                .as_ref()
337                .ok_or_else(|| decode_error("predict missing route context"))?,
338        )?;
339        let request_id = route_request_id(request.context.as_ref(), || self.next_request_id());
340        let response = self
341            .send_on_stream(JoinRequest {
342                kind: Some(join_request::Kind::Predict(request)),
343                request_id,
344            })
345            .await?;
346        match response.kind {
347            Some(join_response::Kind::Predict(predict)) => Ok(predict),
348            Some(join_response::Kind::Error(error)) => Err(model_error_to_grpc_error(error)),
349            _ => Err(ProtocolError::UnexpectedMessage {
350                expected: "PredictResponse".to_string(),
351                actual: format!("{:?}", response.kind),
352            }
353            .into()),
354        }
355    }
356
357    async fn setup_join_stream(&mut self) -> Result<(), GrpcError> {
358        let (tx, rx) = mpsc::channel::<JoinRequest>(32);
359        let request_stream = ReceiverStream::new(rx);
360        let request = self.authorized_request(request_stream)?;
361
362        let response = self
363            .client
364            .join(request)
365            .await
366            .map_err(crate::error::status_to_grpc_error)?;
367
368        self.request_tx = Some(tx);
369        spawn_response_pump(response.into_inner(), Arc::clone(&self.pending));
370        Ok(())
371    }
372
373    /// Send one request and await its response, matched by `request_id` through
374    /// the shared pending map. Takes `&self` so both the `&mut self` public
375    /// methods and the concurrent `predict_concurrent` path can use it.
376    async fn send_on_stream(&self, request: JoinRequest) -> Result<JoinResponse, GrpcError> {
377        let request_id = request.request_id.clone();
378        let request_kind = join_request_kind_name(request.kind.as_ref());
379        let tx = self
380            .request_tx
381            .clone()
382            .ok_or(crate::error::ClientError::NotHandshaked)?;
383
384        // Register the waiter *before* sending so a fast response cannot race
385        // ahead of the pending insert.
386        let (response_tx, response_rx) = oneshot::channel();
387        {
388            let mut pending = self.pending.lock().expect("pending map poisoned");
389            // The demux is keyed by request_id; silently overwriting a live
390            // entry would strand the first caller until stream end. Reject a
391            // duplicate caller-supplied id instead.
392            if pending.contains_key(&request_id) {
393                return Err(crate::error::ProtocolError::DecodeError(format!(
394                    "request_id {request_id:?} is already in flight on this stream"
395                ))
396                .into());
397            }
398            pending.insert(request_id.clone(), response_tx);
399        }
400
401        if tx.send(request).await.is_err() {
402            // The stream is gone; clean up our pending entry.
403            self.pending
404                .lock()
405                .expect("pending map poisoned")
406                .remove(&request_id);
407            return Err(TransportError::ConnectionClosed.into());
408        }
409
410        match response_rx.await {
411            Ok(Ok(response)) => Ok(response),
412            Ok(Err(status)) => {
413                tracing::error!(
414                    request_id = %request_id,
415                    request_kind,
416                    code = ?status.code(),
417                    message = %status.message(),
418                    "model join stream returned an error status"
419                );
420                Err(crate::error::status_to_grpc_error(status))
421            }
422            Err(_) => {
423                // The pump dropped our sender without sending; the stream closed.
424                tracing::error!(
425                    request_id = %request_id,
426                    request_kind,
427                    "model join stream closed while waiting for response"
428                );
429                Err(TransportError::ConnectionClosed.into())
430            }
431        }
432    }
433
434    fn ensure_ready(&self) -> Result<(), GrpcError> {
435        match self.state {
436            ClientState::Ready => Ok(()),
437            ClientState::Connected => Err(crate::error::ClientError::NotHandshaked.into()),
438            ClientState::Closed => Err(crate::error::ClientError::NotConnected.into()),
439        }
440    }
441
442    fn next_request_id(&self) -> String {
443        let id = self.request_counter.fetch_add(1, Ordering::Relaxed) + 1;
444        format!("model-grpc-req-{id}")
445    }
446
447    fn authorized_request<T>(&self, message: T) -> Result<tonic::Request<T>, GrpcError> {
448        let mut request = tonic::Request::new(message);
449        if !self.token.is_empty() {
450            request.metadata_mut().insert(
451                "authorization",
452                self.token
453                    .parse()
454                    .map_err(|_| TransportError::InvalidAddress("invalid token".to_string()))?,
455            );
456        }
457        Ok(request)
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use rlmesh_proto::model::v1::PredictResponse;
465    use tonic::transport::Endpoint;
466
467    /// Build a `Ready` client wired to an in-memory request channel, plus a
468    /// "fake pump" handle: a closure-friendly clone of the pending map and the
469    /// receiver of outgoing requests. This drives the real `send_on_stream`
470    /// demux without a transport.
471    fn ready_client() -> (ModelClient, mpsc::Receiver<JoinRequest>, PendingResponses) {
472        let (request_tx, request_rx) = mpsc::channel(8);
473        let channel = Endpoint::from_static("http://127.0.0.1:1").connect_lazy();
474        let pending: PendingResponses = Default::default();
475        let client = ModelClient {
476            address: "tcp://127.0.0.1:1".to_string(),
477            client: ModelServiceClient::new(channel),
478            token: String::new(),
479            state: ClientState::Ready,
480            request_tx: Some(request_tx),
481            pending: Arc::clone(&pending),
482            request_counter: Arc::new(AtomicU64::new(0)),
483            last_telemetry: None,
484            server_capabilities: HashMap::new(),
485        };
486        (client, request_rx, pending)
487    }
488
489    /// Route a response into the pending map exactly as the real pump would.
490    fn deliver(pending: &PendingResponses, request_id: &str, response: JoinResponse) {
491        let sender = pending
492            .lock()
493            .unwrap()
494            .remove(request_id)
495            .expect("expected a pending waiter for the request id");
496        sender.send(Ok(response)).expect("waiter still alive");
497    }
498
499    fn predict_response_for(request_id: &str) -> JoinResponse {
500        JoinResponse {
501            request_id: request_id.to_string(),
502            kind: Some(join_response::Kind::Predict(PredictResponse::default())),
503            telemetry: None,
504        }
505    }
506
507    #[tokio::test]
508    async fn send_on_stream_resolves_by_request_id() {
509        let (client, mut request_rx, pending) = ready_client();
510
511        let send = tokio::spawn(async move {
512            client
513                .send_on_stream(JoinRequest {
514                    request_id: "target".to_string(),
515                    kind: Some(join_request::Kind::Predict(PredictRequest::default())),
516                })
517                .await
518        });
519
520        // The request reaches the stream and a waiter is registered.
521        let sent = request_rx.recv().await.unwrap();
522        assert_eq!(sent.request_id, "target");
523        deliver(&pending, "target", predict_response_for("target"));
524
525        let response = send.await.unwrap().unwrap();
526        assert_eq!(response.request_id, "target");
527    }
528
529    #[tokio::test]
530    async fn two_overlapping_requests_demux_out_of_order() {
531        let (client, mut request_rx, pending) = ready_client();
532        let client = Arc::new(client);
533
534        // Two predicts in flight at once on the same connection.
535        let c1 = Arc::clone(&client);
536        let first = tokio::spawn(async move {
537            c1.send_on_stream(JoinRequest {
538                request_id: "req-1".to_string(),
539                kind: Some(join_request::Kind::Predict(PredictRequest::default())),
540            })
541            .await
542        });
543        let c2 = Arc::clone(&client);
544        let second = tokio::spawn(async move {
545            c2.send_on_stream(JoinRequest {
546                request_id: "req-2".to_string(),
547                kind: Some(join_request::Kind::Predict(PredictRequest::default())),
548            })
549            .await
550        });
551
552        // Both requests are sent before either response arrives.
553        let mut sent_ids = vec![
554            request_rx.recv().await.unwrap().request_id,
555            request_rx.recv().await.unwrap().request_id,
556        ];
557        sent_ids.sort();
558        assert_eq!(sent_ids, vec!["req-1".to_string(), "req-2".to_string()]);
559
560        // Deliver responses out of order: req-2 first, then req-1.
561        deliver(&pending, "req-2", predict_response_for("req-2"));
562        deliver(&pending, "req-1", predict_response_for("req-1"));
563
564        // Each waiter gets exactly its own response, regardless of order.
565        assert_eq!(first.await.unwrap().unwrap().request_id, "req-1");
566        assert_eq!(second.await.unwrap().unwrap().request_id, "req-2");
567    }
568
569    #[tokio::test]
570    async fn send_on_stream_errors_when_waiter_dropped_by_stream_close() {
571        let (client, _request_rx, pending) = ready_client();
572
573        let send = tokio::spawn(async move {
574            client
575                .send_on_stream(JoinRequest {
576                    request_id: "orphan".to_string(),
577                    kind: Some(join_request::Kind::Predict(PredictRequest::default())),
578                })
579                .await
580        });
581
582        // Simulate the pump dropping every pending sender on stream close.
583        tokio::time::sleep(Duration::from_millis(20)).await;
584        pending.lock().unwrap().clear();
585
586        let result = send.await.unwrap();
587        assert!(result.is_err(), "a closed stream must fail the waiter");
588    }
589}