use std::error::Error as StdError;
use std::sync::Arc;
use anyhow::{Error, Result};
use futures::{stream, stream::StreamExt};
use crate::{
http::service::metrics::Metrics, model_card::ModelDeploymentCard, preprocessor::BackendOutput,
protocols::common::llm_backend::PreprocessedRequest,
};
use dynamo_runtime::error::{self, BackendError, DynamoError, ErrorType};
use dynamo_runtime::pipeline::{
AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, async_trait,
};
use dynamo_runtime::protocols::{annotated::Annotated, maybe_error::MaybeError};
fn is_migratable(err: &(dyn StdError + 'static)) -> bool {
const MIGRATABLE: &[ErrorType] = &[
ErrorType::CannotConnect,
ErrorType::Disconnected,
ErrorType::ConnectionTimeout,
ErrorType::Backend(BackendError::EngineShutdown),
];
const NON_MIGRATABLE: &[ErrorType] = &[
];
error::match_error_chain(err, MIGRATABLE, NON_MIGRATABLE)
}
pub struct Migration {
migration_limit: u32,
model_name: Arc<String>,
metrics: Arc<Metrics>,
}
impl Migration {
pub fn new(migration_limit: u32, model_name: String, metrics: Arc<Metrics>) -> Arc<Self> {
tracing::debug!("model {} migration limit {}", model_name, migration_limit);
Arc::new(Self {
migration_limit,
model_name: Arc::new(model_name),
metrics,
})
}
pub fn from_mdc(
mdc: &ModelDeploymentCard,
migration_limit: u32,
metrics: Arc<Metrics>,
) -> Arc<Self> {
Self::new(migration_limit, mdc.display_name.clone(), metrics)
}
}
#[async_trait]
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
> for Migration
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let (preprocessed_request, context) = request.transfer(());
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
let retry_manager = RetryManager::build(
engine_ctx,
preprocessed_request,
next,
self.migration_limit,
self.model_name.clone(),
self.metrics.clone(),
)
.await?;
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
})
.fuse();
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
}
}
struct RetryManager {
context: Arc<dyn AsyncEngineContext>,
request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
next_stream: Option<ManyOut<Annotated<BackendOutput>>>,
retries_left: u32,
model_name: Arc<String>,
metrics: Arc<Metrics>,
}
impl RetryManager {
pub async fn build(
context: Arc<dyn AsyncEngineContext>,
preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
retries_left: u32,
model_name: Arc<String>,
metrics: Arc<Metrics>,
) -> Result<Self> {
let mut slf = Self {
context,
request: preprocessed_request,
next_generate: next,
next_stream: None,
retries_left: retries_left + 1, model_name,
metrics,
};
slf.new_stream().await?;
Ok(slf)
}
pub async fn next(&mut self) -> Option<Annotated<BackendOutput>> {
loop {
let response_stream = match self.next_stream.as_mut() {
Some(stream) => stream,
None => {
tracing::error!("next() called with next_stream is None - should not happen");
return Some(Annotated::from_err(DynamoError::msg("next_stream is None")));
}
};
if let Some(response) = response_stream.next().await {
if let Some(err) = response.err()
&& is_migratable(&err)
{
tracing::warn!("Stream disconnected... recreating stream... {}", err);
self.metrics.inc_migration_ongoing_request(&self.model_name);
if let Err(err) = self.new_stream().await {
tracing::warn!("Cannot recreate stream: {:#}", err);
} else {
continue;
}
}
self.track_response(&response);
return Some(response);
}
return None;
}
}
async fn new_stream(&mut self) -> Result<()> {
let mut response_stream: Option<Result<ManyOut<Annotated<BackendOutput>>>> = None;
while self.retries_left > 0 {
self.retries_left -= 1;
let request = Context::with_id(self.request.clone(), self.context.id().to_string());
self.context.link_child(request.context());
if self.context.is_stopped() || self.context.is_killed() {
tracing::debug!("Abort creating new stream after context is stopped or killed");
return Err(Error::msg(format!(
"Context id {} is stopped or killed",
self.context.id()
)));
}
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& is_migratable(err.as_ref())
{
tracing::warn!("Creating new stream... retrying... {}", err);
self.metrics.inc_migration_new_request(&self.model_name);
continue;
}
break;
}
match response_stream {
Some(Ok(next_stream)) => {
self.next_stream = Some(next_stream);
Ok(())
}
Some(Err(err)) => Err(err), None => Err(Error::msg(
"Migration limit exhausted", )),
}
}
fn track_response(&mut self, response: &Annotated<BackendOutput>) {
if self.retries_left == 0 {
return;
}
let llm_engine_output = match response.data.as_ref() {
Some(output) => output,
None => return,
};
if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
self.request.stop_conditions.max_tokens =
Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
}
for token_id in llm_engine_output.token_ids.iter() {
self.request.token_ids.push(*token_id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::service::metrics::Metrics;
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::error::{DynamoError, ErrorType};
use dynamo_runtime::pipeline::AsyncEngine;
use dynamo_runtime::pipeline::context::Controller;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::mpsc;
const TEST_MODEL: &str = "test-model";
fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
PreprocessedRequest::builder()
.model("mock".to_string())
.token_ids(vec![1, 2, 3])
.stop_conditions(StopConditions {
max_tokens: Some(max_tokens),
..Default::default()
})
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.eos_token_ids(vec![])
.annotations(vec![])
.build()
.unwrap()
}
fn create_mock_output(token_id: u32) -> Annotated<BackendOutput> {
Annotated::from_data(BackendOutput {
token_ids: vec![token_id],
tokens: vec![],
text: Some(format!("token_{token_id}")),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
disaggregated_params: None,
completion_usage: None,
})
}
#[derive(Debug, Clone)]
enum MockBehavior {
Success,
FailThenSuccess,
MidStreamFail { fail_after: usize },
MidStreamFailAlways { fail_after: usize },
MidStreamFailAlwaysStreamError { fail_after: usize },
AlwaysFail,
}
struct MockEngine {
behavior: MockBehavior,
num_responses: usize,
token_offset: u32,
call_count: Arc<AtomicU32>,
context_id: String,
}
impl MockEngine {
fn new(
behavior: MockBehavior,
num_responses: usize,
token_offset: u32,
context_id: String,
) -> Self {
Self {
behavior,
num_responses,
token_offset,
call_count: Arc::new(AtomicU32::new(0)),
context_id,
}
}
}
#[async_trait]
impl
AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, anyhow::Error>
for MockEngine
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
let (preprocessed_request, context) = request.transfer(());
assert_eq!(
context.id().to_string(),
self.context_id,
"Context ID mismatch"
);
let initial_tokens = 3; let responses_already_generated = preprocessed_request
.token_ids
.len()
.saturating_sub(initial_tokens);
let expected_max_tokens =
self.num_responses
.saturating_sub(responses_already_generated) as u32;
assert_eq!(
preprocessed_request.stop_conditions.max_tokens,
Some(expected_max_tokens),
"max_tokens should be {} but got {:?}",
expected_max_tokens,
preprocessed_request.stop_conditions.max_tokens
);
match &self.behavior {
MockBehavior::Success => {
self.send_responses(responses_already_generated, self.num_responses)
.await
}
MockBehavior::FailThenSuccess => {
if call_num == 0 {
return Err(anyhow::anyhow!(
DynamoError::builder()
.error_type(ErrorType::CannotConnect)
.message("no responders")
.build()
));
} else {
self.send_responses(responses_already_generated, self.num_responses)
.await
}
}
MockBehavior::MidStreamFail { fail_after } => {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
let fail_after = *fail_after;
let num_responses = self.num_responses;
if call_num == 0 {
tokio::spawn(async move {
for i in responses_already_generated..fail_after.min(num_responses) {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
let error_response = Annotated::from_err(
DynamoError::builder()
.error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build(),
);
let _ = tx.send(error_response).await;
});
} else {
tokio::spawn(async move {
for i in responses_already_generated..num_responses {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
});
}
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
}
MockBehavior::MidStreamFailAlways { fail_after } => {
if call_num == 0 {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
let fail_after = *fail_after;
let num_responses = self.num_responses;
tokio::spawn(async move {
for i in responses_already_generated..fail_after.min(num_responses) {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
let error_response = Annotated::from_err(
DynamoError::builder()
.error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build(),
);
let _ = tx.send(error_response).await;
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
} else {
Err(anyhow::anyhow!(
DynamoError::builder()
.error_type(ErrorType::CannotConnect)
.message("no responders")
.build()
))
}
}
MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
let fail_after = *fail_after;
let num_responses = self.num_responses;
if call_num == 0 {
tokio::spawn(async move {
for i in responses_already_generated..fail_after.min(num_responses) {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
let error_response = Annotated::from_err(
DynamoError::builder()
.error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build(),
);
let _ = tx.send(error_response).await;
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
} else {
tokio::spawn(async move {
let error_response = Annotated::from_err(
DynamoError::builder()
.error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build(),
);
let _ = tx.send(error_response).await;
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
}
}
MockBehavior::AlwaysFail => {
Err(anyhow::anyhow!(
DynamoError::builder()
.error_type(ErrorType::CannotConnect)
.message("no responders")
.build()
))
}
}
}
}
impl MockEngine {
async fn send_responses(
&self,
start: usize,
end: usize,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
tokio::spawn(async move {
for i in start..end {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
}
}
#[tokio::test]
async fn test_retry_manager_no_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::Success,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
0,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 10);
for (i, response) in responses.iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); }
}
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
}
#[tokio::test]
async fn test_retry_manager_new_request_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::FailThenSuccess,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 10);
for (i, response) in responses.iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); }
}
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 1);
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
}
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 5 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 10);
for (i, response) in responses.iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); }
}
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1);
}
#[tokio::test]
async fn test_retry_manager_new_request_migration_indefinite_failure() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(0);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::AlwaysFail,
0,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let retry_manager_result = RetryManager::build(
ctx,
request,
next_generate,
3,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await;
assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
assert!(error.to_string().contains("no responders"));
}
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 4); assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
}
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlways { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
) .await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 4);
for (i, response) in responses[0..3].iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); }
}
let error_response = &responses[3];
let err = error_response.err().expect("expected error response");
assert_eq!(err.error_type(), ErrorType::Disconnected);
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 3); assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1); }
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
) .await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 4);
for (i, response) in responses[0..3].iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); }
}
let error_response = &responses[3];
let err = error_response.err().expect("expected error response");
assert_eq!(err.error_type(), ErrorType::Disconnected);
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 4); }
#[tokio::test]
async fn test_retry_manager_context_stopped_before_stream() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::Success,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
ctx.stop_generating();
let metrics = Arc::new(Metrics::new());
let retry_manager_result = RetryManager::build(
ctx,
request,
next_generate,
3,
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await;
assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
assert!(
error
.to_string()
.contains(&format!("Context id {} is stopped or killed", context_id))
);
}
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
}
}