use std::collections::HashMap;
use std::error::Error as StdError;
use std::sync::{Arc, Mutex};
use rlmesh_proto::model::v1::JoinResponse;
use tokio::sync::oneshot;
use tonic::Status;
pub(super) type PendingResponses =
Arc<Mutex<HashMap<String, oneshot::Sender<Result<JoinResponse, Status>>>>>;
pub(super) fn spawn_response_pump(
mut response_stream: tonic::Streaming<JoinResponse>,
pending: PendingResponses,
) {
tokio::spawn(async move {
loop {
match response_stream.message().await {
Ok(Some(message)) => {
let request_id = message.request_id.clone();
let sender = pending
.lock()
.expect("pending map poisoned")
.remove(&request_id);
match sender {
Some(sender) => {
let _ = sender.send(Ok(message));
}
None => {
tracing::warn!(
stale_request_id = %request_id,
response_kind = ?message.kind,
"discarding model response with no pending request id"
);
}
}
}
Ok(None) => {
tracing::warn!("model join stream ended");
fail_all_pending(&pending, || Status::unavailable("model join stream ended"));
break;
}
Err(error) => {
if pending.lock().expect("pending map poisoned").is_empty() {
tracing::debug!(
code = ?error.code(),
source = ?error.source(),
"model join stream closed on teardown"
);
} else {
tracing::error!(
code = ?error.code(),
message = %error.message(),
source = ?error.source(),
"model join stream error from server"
);
}
let code = error.code();
let message = error.message().to_string();
fail_all_pending(&pending, || Status::new(code, message.clone()));
break;
}
}
}
});
}
fn fail_all_pending(pending: &PendingResponses, status: impl Fn() -> Status) {
let drained: Vec<_> = pending
.lock()
.expect("pending map poisoned")
.drain()
.collect();
for (_request_id, sender) in drained {
let _ = sender.send(Err(status()));
}
}