1use std::sync::Arc;
5
6use anyhow::{Error, Result};
7use futures::{stream, stream::StreamExt};
8
9use async_nats::client::{
10 RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
11};
12
13use crate::{
14 model_card::ModelDeploymentCard,
15 protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
16};
17
18use dynamo_runtime::{
19 pipeline::{
20 AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
21 ServerStreamingEngine, SingleIn, async_trait, network::STREAM_ERR_MSG,
22 },
23 protocols::{annotated::Annotated, maybe_error::MaybeError},
24};
25
26pub struct Migration {
27 migration_limit: u32,
28}
29
30impl Migration {
31 pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
32 tracing::debug!(
33 "model {} migration limit {}",
34 mdc.display_name,
35 mdc.migration_limit
36 );
37 Arc::new(Self {
38 migration_limit: mdc.migration_limit,
39 })
40 }
41}
42
43#[async_trait]
44impl
45 Operator<
46 SingleIn<PreprocessedRequest>,
47 ManyOut<Annotated<LLMEngineOutput>>,
48 SingleIn<PreprocessedRequest>,
49 ManyOut<Annotated<LLMEngineOutput>>,
50 > for Migration
51{
52 async fn generate(
53 &self,
54 request: SingleIn<PreprocessedRequest>,
55 next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
56 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
57 let (preprocessed_request, context) = request.transfer(());
58 let engine_ctx = context.context();
59 let engine_ctx_ = engine_ctx.clone();
60 let retry_manager =
61 RetryManager::build(engine_ctx, preprocessed_request, next, self.migration_limit)
62 .await?;
63 let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
64 retry_manager
65 .next()
66 .await
67 .map(|response| (response, retry_manager))
68 });
69 Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
70 }
71}
72
73struct RetryManager {
74 context: Arc<dyn AsyncEngineContext>,
75 request: PreprocessedRequest,
76 next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
77 next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
78 retries_left: u32,
79}
80
81impl RetryManager {
82 pub async fn build(
83 context: Arc<dyn AsyncEngineContext>,
84 preprocessed_request: PreprocessedRequest,
85 next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
86 retries_left: u32,
87 ) -> Result<Self> {
88 let mut slf = Self {
89 context,
90 request: preprocessed_request,
91 next_generate: next,
92 next_stream: None,
93 retries_left: retries_left + 1, };
95 slf.new_stream().await?;
96 Ok(slf)
97 }
98
99 pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> {
100 loop {
101 let response_stream = match self.next_stream.as_mut() {
102 Some(stream) => stream,
103 None => {
104 tracing::error!("next() called with next_stream is None - should not happen");
105 return Some(Annotated::from_err(
106 Error::msg("next_stream is None").into(),
107 ));
108 }
109 };
110 if let Some(response) = response_stream.next().await {
111 if let Some(err) = response.err()
112 && err
113 .chain()
114 .any(|e| e.to_string().starts_with(STREAM_ERR_MSG))
115 {
116 tracing::warn!("Stream disconnected... recreating stream...");
117 if let Err(err) = self.new_stream().await {
118 tracing::warn!("Cannot recreate stream: {:#}", err);
119 } else {
120 continue;
121 }
122 }
123 self.track_response(&response);
124 return Some(response);
125 }
126 return None;
127 }
128 }
129
130 async fn new_stream(&mut self) -> Result<()> {
131 let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
132 while self.retries_left > 0 {
133 self.retries_left -= 1;
134 let request = Context::with_id(self.request.clone(), self.context.id().to_string());
135 self.context.link_child(request.context());
136 response_stream = Some(self.next_generate.generate(request).await);
137 if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
138 && let Some(req_err) = err.downcast_ref::<NatsRequestError>()
139 && matches!(req_err.kind(), NatsNoResponders)
140 {
141 tracing::warn!("Creating new stream... retrying...");
142 continue;
143 }
144 break;
145 }
146 match response_stream {
147 Some(Ok(next_stream)) => {
148 self.next_stream = Some(next_stream);
149 Ok(())
150 }
151 Some(Err(err)) => Err(err), None => Err(Error::msg(
153 "Migration limit exhausted", )),
155 }
156 }
157
158 fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) {
159 if self.retries_left == 0 {
160 return;
161 }
162 let llm_engine_output = match response.data.as_ref() {
163 Some(output) => output,
164 None => return,
165 };
166 if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
167 self.request.stop_conditions.max_tokens =
168 Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
169 }
170 for token_id in llm_engine_output.token_ids.iter() {
171 self.request.token_ids.push(*token_id);
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
180 use dynamo_runtime::pipeline::AsyncEngine;
181 use dynamo_runtime::pipeline::context::Controller;
182 use std::sync::atomic::{AtomicU32, Ordering};
183 use tokio::sync::mpsc;
184
185 fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
187 PreprocessedRequest::builder()
188 .model("mock".to_string())
189 .token_ids(vec![1, 2, 3])
190 .stop_conditions(StopConditions {
191 max_tokens: Some(max_tokens),
192 ..Default::default()
193 })
194 .sampling_options(SamplingOptions::default())
195 .output_options(OutputOptions::default())
196 .eos_token_ids(vec![])
197 .annotations(vec![])
198 .build()
199 .unwrap()
200 }
201
202 fn create_mock_output(token_id: u32) -> Annotated<LLMEngineOutput> {
204 Annotated::from_data(LLMEngineOutput {
205 token_ids: vec![token_id],
206 tokens: None,
207 text: Some(format!("token_{}", token_id)),
208 cum_log_probs: None,
209 log_probs: None,
210 top_logprobs: None,
211 finish_reason: None,
212 index: None,
213 })
214 }
215
216 #[derive(Debug, Clone)]
217 enum MockBehavior {
218 Success,
220 FailThenSuccess,
222 MidStreamFail { fail_after: usize },
224 MidStreamFailAlways { fail_after: usize },
226 MidStreamFailAlwaysStreamError { fail_after: usize },
228 AlwaysFail,
230 }
231
232 struct MockEngine {
234 behavior: MockBehavior,
235 num_responses: usize,
236 token_offset: u32,
237 call_count: Arc<AtomicU32>,
238 context_id: String,
239 }
240
241 impl MockEngine {
242 fn new(
243 behavior: MockBehavior,
244 num_responses: usize,
245 token_offset: u32,
246 context_id: String,
247 ) -> Self {
248 Self {
249 behavior,
250 num_responses,
251 token_offset,
252 call_count: Arc::new(AtomicU32::new(0)),
253 context_id,
254 }
255 }
256 }
257
258 #[async_trait]
259 impl
260 AsyncEngine<
261 SingleIn<PreprocessedRequest>,
262 ManyOut<Annotated<LLMEngineOutput>>,
263 anyhow::Error,
264 > for MockEngine
265 {
266 async fn generate(
267 &self,
268 request: SingleIn<PreprocessedRequest>,
269 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
270 let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
271 let (preprocessed_request, context) = request.transfer(());
272
273 assert_eq!(
275 context.id().to_string(),
276 self.context_id,
277 "Context ID mismatch"
278 );
279
280 let initial_tokens = 3; let responses_already_generated = preprocessed_request
284 .token_ids
285 .len()
286 .saturating_sub(initial_tokens);
287
288 let expected_max_tokens =
290 self.num_responses
291 .saturating_sub(responses_already_generated) as u32;
292 assert_eq!(
293 preprocessed_request.stop_conditions.max_tokens,
294 Some(expected_max_tokens),
295 "max_tokens should be {} but got {:?}",
296 expected_max_tokens,
297 preprocessed_request.stop_conditions.max_tokens
298 );
299
300 match &self.behavior {
301 MockBehavior::Success => {
302 self.send_responses(responses_already_generated, self.num_responses)
304 .await
305 }
306 MockBehavior::FailThenSuccess => {
307 if call_num == 0 {
308 let nats_error: NatsRequestError = NatsNoResponders.into();
310 return Err(nats_error.into());
311 } else {
312 self.send_responses(responses_already_generated, self.num_responses)
314 .await
315 }
316 }
317 MockBehavior::MidStreamFail { fail_after } => {
318 let (tx, rx) = mpsc::channel(1);
319 let token_offset = self.token_offset;
320 let fail_after = *fail_after;
321 let num_responses = self.num_responses;
322
323 if call_num == 0 {
324 tokio::spawn(async move {
326 for i in responses_already_generated..fail_after.min(num_responses) {
328 let response = create_mock_output(token_offset + 1 + i as u32);
329 if tx.send(response).await.is_err() {
330 break;
331 }
332 }
333 let error_response =
335 Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
336 let _ = tx.send(error_response).await;
337 });
338 } else {
339 tokio::spawn(async move {
341 for i in responses_already_generated..num_responses {
342 let response = create_mock_output(token_offset + 1 + i as u32);
343 if tx.send(response).await.is_err() {
344 break;
345 }
346 }
347 });
348 }
349
350 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
351 let ctx = Arc::new(Controller::new(self.context_id.clone()));
352 Ok(dynamo_runtime::pipeline::ResponseStream::new(
353 Box::pin(stream),
354 ctx,
355 ))
356 }
357 MockBehavior::MidStreamFailAlways { fail_after } => {
358 if call_num == 0 {
359 let (tx, rx) = mpsc::channel(1);
361 let token_offset = self.token_offset;
362 let fail_after = *fail_after;
363 let num_responses = self.num_responses;
364
365 tokio::spawn(async move {
366 for i in responses_already_generated..fail_after.min(num_responses) {
368 let response = create_mock_output(token_offset + 1 + i as u32);
369 if tx.send(response).await.is_err() {
370 break;
371 }
372 }
373 let error_response =
375 Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
376 let _ = tx.send(error_response).await;
377 });
378
379 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
380 let ctx = Arc::new(Controller::new(self.context_id.clone()));
381 Ok(dynamo_runtime::pipeline::ResponseStream::new(
382 Box::pin(stream),
383 ctx,
384 ))
385 } else {
386 let nats_error: NatsRequestError = NatsNoResponders.into();
388 Err(nats_error.into())
389 }
390 }
391 MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => {
392 let (tx, rx) = mpsc::channel(1);
393 let token_offset = self.token_offset;
394 let fail_after = *fail_after;
395 let num_responses = self.num_responses;
396
397 if call_num == 0 {
398 tokio::spawn(async move {
400 for i in responses_already_generated..fail_after.min(num_responses) {
402 let response = create_mock_output(token_offset + 1 + i as u32);
403 if tx.send(response).await.is_err() {
404 break;
405 }
406 }
407 let error_response =
409 Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
410 let _ = tx.send(error_response).await;
411 });
412
413 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
414 let ctx = Arc::new(Controller::new(self.context_id.clone()));
415 Ok(dynamo_runtime::pipeline::ResponseStream::new(
416 Box::pin(stream),
417 ctx,
418 ))
419 } else {
420 tokio::spawn(async move {
422 let error_response =
424 Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
425 let _ = tx.send(error_response).await;
426 });
427
428 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
429 let ctx = Arc::new(Controller::new(self.context_id.clone()));
430 Ok(dynamo_runtime::pipeline::ResponseStream::new(
431 Box::pin(stream),
432 ctx,
433 ))
434 }
435 }
436 MockBehavior::AlwaysFail => {
437 let nats_error: NatsRequestError = NatsNoResponders.into();
439 Err(nats_error.into())
440 }
441 }
442 }
443 }
444
445 impl MockEngine {
446 async fn send_responses(
447 &self,
448 start: usize,
449 end: usize,
450 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
451 let (tx, rx) = mpsc::channel(1);
452 let token_offset = self.token_offset;
453
454 tokio::spawn(async move {
455 for i in start..end {
456 let response = create_mock_output(token_offset + 1 + i as u32);
457 if tx.send(response).await.is_err() {
458 break;
459 }
460 }
461 });
462
463 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
464 let ctx = Arc::new(Controller::new(self.context_id.clone()));
465 Ok(dynamo_runtime::pipeline::ResponseStream::new(
466 Box::pin(stream),
467 ctx,
468 ))
469 }
470 }
471
472 #[tokio::test]
477 async fn test_retry_manager_no_migration() {
478 dynamo_runtime::logging::init();
479 let context_id = uuid::Uuid::new_v4().to_string();
480 let request = create_mock_request(10);
481 let mock_engine = Arc::new(MockEngine::new(
482 MockBehavior::Success,
483 10,
484 100,
485 context_id.clone(),
486 ));
487 let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
488 mock_engine;
489
490 let ctx = Arc::new(Controller::new(context_id.clone()));
491 let mut retry_manager = RetryManager::build(ctx, request, next_generate, 0)
492 .await
493 .expect("Failed to build RetryManager");
494
495 let mut responses = Vec::new();
496 while let Some(response) = retry_manager.next().await {
497 responses.push(response);
498 }
499
500 assert_eq!(responses.len(), 10);
501 for (i, response) in responses.iter().enumerate() {
502 assert!(response.err().is_none());
503 if let Some(output) = &response.data {
504 assert_eq!(output.token_ids, vec![101 + i as u32]); }
506 }
507 }
508
509 #[tokio::test]
516 async fn test_retry_manager_new_request_migration() {
517 dynamo_runtime::logging::init();
518 let context_id = uuid::Uuid::new_v4().to_string();
519 let request = create_mock_request(10);
520 let mock_engine = Arc::new(MockEngine::new(
521 MockBehavior::FailThenSuccess,
522 10,
523 100,
524 context_id.clone(),
525 ));
526 let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
527 mock_engine;
528
529 let ctx = Arc::new(Controller::new(context_id.clone()));
530 let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
531 .await
532 .expect("Failed to build RetryManager");
533
534 let mut responses = Vec::new();
535 while let Some(response) = retry_manager.next().await {
536 responses.push(response);
537 }
538
539 assert_eq!(responses.len(), 10);
540 for (i, response) in responses.iter().enumerate() {
541 assert!(response.err().is_none());
542 if let Some(output) = &response.data {
543 assert_eq!(output.token_ids, vec![101 + i as u32]); }
545 }
546 }
547
548 #[tokio::test]
555 async fn test_retry_manager_ongoing_request_migration() {
556 dynamo_runtime::logging::init();
557
558 let context_id = uuid::Uuid::new_v4().to_string();
559 let request = create_mock_request(10);
560 let mock_engine = Arc::new(MockEngine::new(
561 MockBehavior::MidStreamFail { fail_after: 5 },
562 10,
563 100,
564 context_id.clone(),
565 ));
566 let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
567 mock_engine;
568
569 let ctx = Arc::new(Controller::new(context_id.clone()));
570 let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
571 .await
572 .expect("Failed to build RetryManager");
573
574 let mut responses = Vec::new();
575 while let Some(response) = retry_manager.next().await {
576 responses.push(response);
577 }
578
579 assert_eq!(responses.len(), 10);
581
582 for (i, response) in responses.iter().enumerate() {
584 assert!(response.err().is_none());
585 if let Some(output) = &response.data {
586 assert_eq!(output.token_ids, vec![101 + i as u32]); }
588 }
589 }
590
591 #[tokio::test]
596 async fn test_retry_manager_new_request_migration_indefinite_failure() {
597 dynamo_runtime::logging::init();
598 let context_id = uuid::Uuid::new_v4().to_string();
599 let request = create_mock_request(0);
600 let mock_engine = Arc::new(MockEngine::new(
601 MockBehavior::AlwaysFail,
602 0,
603 100,
604 context_id.clone(),
605 ));
606 let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
607 mock_engine;
608
609 let ctx = Arc::new(Controller::new(context_id.clone()));
611 let retry_manager_result = RetryManager::build(ctx, request, next_generate, 3).await;
612
613 assert!(retry_manager_result.is_err());
614 if let Err(error) = retry_manager_result {
615 assert!(error.to_string().contains("no responders"));
616 }
617 }
618
619 #[tokio::test]
624 async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
625 dynamo_runtime::logging::init();
626 let context_id = uuid::Uuid::new_v4().to_string();
627 let request = create_mock_request(10);
628 let mock_engine = Arc::new(MockEngine::new(
629 MockBehavior::MidStreamFailAlways { fail_after: 3 },
630 10,
631 100,
632 context_id.clone(),
633 ));
634 let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
635 mock_engine;
636
637 let ctx = Arc::new(Controller::new(context_id.clone()));
638 let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) .await
640 .expect("Failed to build RetryManager");
641
642 let mut responses = Vec::new();
643
644 while let Some(response) = retry_manager.next().await {
646 responses.push(response);
647 }
648
649 assert_eq!(responses.len(), 4);
651
652 for (i, response) in responses[0..3].iter().enumerate() {
654 assert!(response.err().is_none());
655 if let Some(output) = &response.data {
656 assert_eq!(output.token_ids, vec![101 + i as u32]); }
658 }
659
660 let error_response = &responses[3];
662 assert!(error_response.err().is_some());
663 if let Some(error) = error_response.err() {
664 assert!(error.to_string().contains(STREAM_ERR_MSG));
665 }
666 }
667
668 #[tokio::test]
673 async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
674 dynamo_runtime::logging::init();
675 let context_id = uuid::Uuid::new_v4().to_string();
676 let request = create_mock_request(10);
677 let mock_engine = Arc::new(MockEngine::new(
678 MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
679 10,
680 100,
681 context_id.clone(),
682 ));
683 let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
684 mock_engine;
685
686 let ctx = Arc::new(Controller::new(context_id.clone()));
687 let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) .await
689 .expect("Failed to build RetryManager");
690
691 let mut responses = Vec::new();
692
693 while let Some(response) = retry_manager.next().await {
695 responses.push(response);
696 }
697
698 assert_eq!(responses.len(), 4);
700
701 for (i, response) in responses[0..3].iter().enumerate() {
703 assert!(response.err().is_none());
704 if let Some(output) = &response.data {
705 assert_eq!(output.token_ids, vec![101 + i as u32]); }
707 }
708
709 let error_response = &responses[3];
711 assert!(error_response.err().is_some());
712 if let Some(error) = error_response.err() {
713 assert!(error.to_string().contains(STREAM_ERR_MSG));
714 }
715 }
716}