dynamo_llm/
migration.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use anyhow::{Error, Result};
7use futures::{stream, stream::StreamExt};
8
9use async_nats::client::{
10    RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
11};
12
13use crate::{
14    model_card::ModelDeploymentCard,
15    protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
16};
17
18use dynamo_runtime::{
19    pipeline::{
20        AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
21        ServerStreamingEngine, SingleIn, async_trait, network::STREAM_ERR_MSG,
22    },
23    protocols::{annotated::Annotated, maybe_error::MaybeError},
24};
25
26pub struct Migration {
27    migration_limit: u32,
28}
29
30impl Migration {
31    pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
32        tracing::debug!(
33            "model {} migration limit {}",
34            mdc.display_name,
35            mdc.migration_limit
36        );
37        Arc::new(Self {
38            migration_limit: mdc.migration_limit,
39        })
40    }
41}
42
43#[async_trait]
44impl
45    Operator<
46        SingleIn<PreprocessedRequest>,
47        ManyOut<Annotated<LLMEngineOutput>>,
48        SingleIn<PreprocessedRequest>,
49        ManyOut<Annotated<LLMEngineOutput>>,
50    > for Migration
51{
52    async fn generate(
53        &self,
54        request: SingleIn<PreprocessedRequest>,
55        next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
56    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
57        let (preprocessed_request, context) = request.transfer(());
58        let engine_ctx = context.context();
59        let engine_ctx_ = engine_ctx.clone();
60        let retry_manager =
61            RetryManager::build(engine_ctx, preprocessed_request, next, self.migration_limit)
62                .await?;
63        let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
64            retry_manager
65                .next()
66                .await
67                .map(|response| (response, retry_manager))
68        });
69        Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
70    }
71}
72
73struct RetryManager {
74    context: Arc<dyn AsyncEngineContext>,
75    request: PreprocessedRequest,
76    next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
77    next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
78    retries_left: u32,
79}
80
81impl RetryManager {
82    pub async fn build(
83        context: Arc<dyn AsyncEngineContext>,
84        preprocessed_request: PreprocessedRequest,
85        next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
86        retries_left: u32,
87    ) -> Result<Self> {
88        let mut slf = Self {
89            context,
90            request: preprocessed_request,
91            next_generate: next,
92            next_stream: None,
93            retries_left: retries_left + 1, // +1 to account for the initial attempt
94        };
95        slf.new_stream().await?;
96        Ok(slf)
97    }
98
99    pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> {
100        loop {
101            let response_stream = match self.next_stream.as_mut() {
102                Some(stream) => stream,
103                None => {
104                    tracing::error!("next() called with next_stream is None - should not happen");
105                    return Some(Annotated::from_err(
106                        Error::msg("next_stream is None").into(),
107                    ));
108                }
109            };
110            if let Some(response) = response_stream.next().await {
111                if let Some(err) = response.err()
112                    && err
113                        .chain()
114                        .any(|e| e.to_string().starts_with(STREAM_ERR_MSG))
115                {
116                    tracing::warn!("Stream disconnected... recreating stream...");
117                    if let Err(err) = self.new_stream().await {
118                        tracing::warn!("Cannot recreate stream: {:#}", err);
119                    } else {
120                        continue;
121                    }
122                }
123                self.track_response(&response);
124                return Some(response);
125            }
126            return None;
127        }
128    }
129
130    async fn new_stream(&mut self) -> Result<()> {
131        let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
132        while self.retries_left > 0 {
133            self.retries_left -= 1;
134            let request = Context::with_id(self.request.clone(), self.context.id().to_string());
135            self.context.link_child(request.context());
136            response_stream = Some(self.next_generate.generate(request).await);
137            if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
138                && let Some(req_err) = err.downcast_ref::<NatsRequestError>()
139                && matches!(req_err.kind(), NatsNoResponders)
140            {
141                tracing::warn!("Creating new stream... retrying...");
142                continue;
143            }
144            break;
145        }
146        match response_stream {
147            Some(Ok(next_stream)) => {
148                self.next_stream = Some(next_stream);
149                Ok(())
150            }
151            Some(Err(err)) => Err(err), // should propagate original error if any
152            None => Err(Error::msg(
153                "Migration limit exhausted", // should propagate original error if any
154            )),
155        }
156    }
157
158    fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) {
159        if self.retries_left == 0 {
160            return;
161        }
162        let llm_engine_output = match response.data.as_ref() {
163            Some(output) => output,
164            None => return,
165        };
166        if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
167            self.request.stop_conditions.max_tokens =
168                Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
169        }
170        for token_id in llm_engine_output.token_ids.iter() {
171            self.request.token_ids.push(*token_id);
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
180    use dynamo_runtime::pipeline::AsyncEngine;
181    use dynamo_runtime::pipeline::context::Controller;
182    use std::sync::atomic::{AtomicU32, Ordering};
183    use tokio::sync::mpsc;
184
185    // Helper to create a mock preprocessed request
186    fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
187        PreprocessedRequest::builder()
188            .model("mock".to_string())
189            .token_ids(vec![1, 2, 3])
190            .stop_conditions(StopConditions {
191                max_tokens: Some(max_tokens),
192                ..Default::default()
193            })
194            .sampling_options(SamplingOptions::default())
195            .output_options(OutputOptions::default())
196            .eos_token_ids(vec![])
197            .annotations(vec![])
198            .build()
199            .unwrap()
200    }
201
202    // Helper to create mock LLM engine output
203    fn create_mock_output(token_id: u32) -> Annotated<LLMEngineOutput> {
204        Annotated::from_data(LLMEngineOutput {
205            token_ids: vec![token_id],
206            tokens: None,
207            text: Some(format!("token_{}", token_id)),
208            cum_log_probs: None,
209            log_probs: None,
210            top_logprobs: None,
211            finish_reason: None,
212            index: None,
213        })
214    }
215
216    #[derive(Debug, Clone)]
217    enum MockBehavior {
218        /// Always succeeds with all responses
219        Success,
220        /// Fails on first call with NoResponders error, then succeeds on subsequent calls
221        FailThenSuccess,
222        /// Succeeds initially, fails mid-stream with specific error, then succeeds on retry
223        MidStreamFail { fail_after: usize },
224        /// Succeeds initially, fails mid-stream with specific error, then always fails on retry attempts
225        MidStreamFailAlways { fail_after: usize },
226        /// Succeeds initially, fails mid-stream, then always fails with stream error on retry attempts
227        MidStreamFailAlwaysStreamError { fail_after: usize },
228        /// Always fails with NoResponders error (same as FailThenSuccess first call)
229        AlwaysFail,
230    }
231
232    // Unified mock server streaming engine that can simulate different scenarios
233    struct MockEngine {
234        behavior: MockBehavior,
235        num_responses: usize,
236        token_offset: u32,
237        call_count: Arc<AtomicU32>,
238        context_id: String,
239    }
240
241    impl MockEngine {
242        fn new(
243            behavior: MockBehavior,
244            num_responses: usize,
245            token_offset: u32,
246            context_id: String,
247        ) -> Self {
248            Self {
249                behavior,
250                num_responses,
251                token_offset,
252                call_count: Arc::new(AtomicU32::new(0)),
253                context_id,
254            }
255        }
256    }
257
258    #[async_trait]
259    impl
260        AsyncEngine<
261            SingleIn<PreprocessedRequest>,
262            ManyOut<Annotated<LLMEngineOutput>>,
263            anyhow::Error,
264        > for MockEngine
265    {
266        async fn generate(
267            &self,
268            request: SingleIn<PreprocessedRequest>,
269        ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
270            let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
271            let (preprocessed_request, context) = request.transfer(());
272
273            // Assert that the context_id matches the expected one
274            assert_eq!(
275                context.id().to_string(),
276                self.context_id,
277                "Context ID mismatch"
278            );
279
280            // Calculate how many responses we've already generated based on request token_ids
281            // Initial request has [1, 2, 3], so anything beyond that are generated responses
282            let initial_tokens = 3; // [1, 2, 3]
283            let responses_already_generated = preprocessed_request
284                .token_ids
285                .len()
286                .saturating_sub(initial_tokens);
287
288            // Assert that max_tokens reflects the expected remaining tokens
289            let expected_max_tokens =
290                self.num_responses
291                    .saturating_sub(responses_already_generated) as u32;
292            assert_eq!(
293                preprocessed_request.stop_conditions.max_tokens,
294                Some(expected_max_tokens),
295                "max_tokens should be {} but got {:?}",
296                expected_max_tokens,
297                preprocessed_request.stop_conditions.max_tokens
298            );
299
300            match &self.behavior {
301                MockBehavior::Success => {
302                    // Always succeed with remaining responses
303                    self.send_responses(responses_already_generated, self.num_responses)
304                        .await
305                }
306                MockBehavior::FailThenSuccess => {
307                    if call_num == 0 {
308                        // First call - return "No responders available" error to trigger retry
309                        let nats_error: NatsRequestError = NatsNoResponders.into();
310                        return Err(nats_error.into());
311                    } else {
312                        // Subsequent calls - succeed with remaining responses
313                        self.send_responses(responses_already_generated, self.num_responses)
314                            .await
315                    }
316                }
317                MockBehavior::MidStreamFail { fail_after } => {
318                    let (tx, rx) = mpsc::channel(1);
319                    let token_offset = self.token_offset;
320                    let fail_after = *fail_after;
321                    let num_responses = self.num_responses;
322
323                    if call_num == 0 {
324                        // First call - send some responses then an error to simulate disconnection
325                        tokio::spawn(async move {
326                            // Send responses from current position to fail_after
327                            for i in responses_already_generated..fail_after.min(num_responses) {
328                                let response = create_mock_output(token_offset + 1 + i as u32);
329                                if tx.send(response).await.is_err() {
330                                    break;
331                                }
332                            }
333                            // Send the specific error that triggers retry logic
334                            let error_response =
335                                Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
336                            let _ = tx.send(error_response).await;
337                        });
338                    } else {
339                        // Second call - send remaining responses from where we left off
340                        tokio::spawn(async move {
341                            for i in responses_already_generated..num_responses {
342                                let response = create_mock_output(token_offset + 1 + i as u32);
343                                if tx.send(response).await.is_err() {
344                                    break;
345                                }
346                            }
347                        });
348                    }
349
350                    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
351                    let ctx = Arc::new(Controller::new(self.context_id.clone()));
352                    Ok(dynamo_runtime::pipeline::ResponseStream::new(
353                        Box::pin(stream),
354                        ctx,
355                    ))
356                }
357                MockBehavior::MidStreamFailAlways { fail_after } => {
358                    if call_num == 0 {
359                        // First call - send some responses then an error to simulate disconnection
360                        let (tx, rx) = mpsc::channel(1);
361                        let token_offset = self.token_offset;
362                        let fail_after = *fail_after;
363                        let num_responses = self.num_responses;
364
365                        tokio::spawn(async move {
366                            // Send responses from current position to fail_after
367                            for i in responses_already_generated..fail_after.min(num_responses) {
368                                let response = create_mock_output(token_offset + 1 + i as u32);
369                                if tx.send(response).await.is_err() {
370                                    break;
371                                }
372                            }
373                            // Send the specific error that triggers retry logic
374                            let error_response =
375                                Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
376                            let _ = tx.send(error_response).await;
377                        });
378
379                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
380                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
381                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
382                            Box::pin(stream),
383                            ctx,
384                        ))
385                    } else {
386                        // Subsequent calls - always fail with NoResponders error (same as AlwaysFail)
387                        let nats_error: NatsRequestError = NatsNoResponders.into();
388                        Err(nats_error.into())
389                    }
390                }
391                MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => {
392                    let (tx, rx) = mpsc::channel(1);
393                    let token_offset = self.token_offset;
394                    let fail_after = *fail_after;
395                    let num_responses = self.num_responses;
396
397                    if call_num == 0 {
398                        // First call - send some responses then an error to simulate disconnection
399                        tokio::spawn(async move {
400                            // Send responses from current position to fail_after
401                            for i in responses_already_generated..fail_after.min(num_responses) {
402                                let response = create_mock_output(token_offset + 1 + i as u32);
403                                if tx.send(response).await.is_err() {
404                                    break;
405                                }
406                            }
407                            // Send the specific error that triggers retry logic
408                            let error_response =
409                                Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
410                            let _ = tx.send(error_response).await;
411                        });
412
413                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
414                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
415                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
416                            Box::pin(stream),
417                            ctx,
418                        ))
419                    } else {
420                        // Subsequent calls - immediately send stream error (no successful responses)
421                        tokio::spawn(async move {
422                            // Send the stream error immediately
423                            let error_response =
424                                Annotated::from_err(anyhow::Error::msg(STREAM_ERR_MSG).into());
425                            let _ = tx.send(error_response).await;
426                        });
427
428                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
429                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
430                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
431                            Box::pin(stream),
432                            ctx,
433                        ))
434                    }
435                }
436                MockBehavior::AlwaysFail => {
437                    // Always fail with NoResponders error (same as FailThenSuccess first call)
438                    let nats_error: NatsRequestError = NatsNoResponders.into();
439                    Err(nats_error.into())
440                }
441            }
442        }
443    }
444
445    impl MockEngine {
446        async fn send_responses(
447            &self,
448            start: usize,
449            end: usize,
450        ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
451            let (tx, rx) = mpsc::channel(1);
452            let token_offset = self.token_offset;
453
454            tokio::spawn(async move {
455                for i in start..end {
456                    let response = create_mock_output(token_offset + 1 + i as u32);
457                    if tx.send(response).await.is_err() {
458                        break;
459                    }
460                }
461            });
462
463            let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
464            let ctx = Arc::new(Controller::new(self.context_id.clone()));
465            Ok(dynamo_runtime::pipeline::ResponseStream::new(
466                Box::pin(stream),
467                ctx,
468            ))
469        }
470    }
471
472    /// Test case 1: No migration needed
473    /// Tests the normal case where the RetryManager successfully processes all responses
474    /// from a single stream without any failures or need for retries/migration.
475    /// Expected behavior: All 10 responses should be received successfully.
476    #[tokio::test]
477    async fn test_retry_manager_no_migration() {
478        dynamo_runtime::logging::init();
479        let context_id = uuid::Uuid::new_v4().to_string();
480        let request = create_mock_request(10);
481        let mock_engine = Arc::new(MockEngine::new(
482            MockBehavior::Success,
483            10,
484            100,
485            context_id.clone(),
486        ));
487        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
488            mock_engine;
489
490        let ctx = Arc::new(Controller::new(context_id.clone()));
491        let mut retry_manager = RetryManager::build(ctx, request, next_generate, 0)
492            .await
493            .expect("Failed to build RetryManager");
494
495        let mut responses = Vec::new();
496        while let Some(response) = retry_manager.next().await {
497            responses.push(response);
498        }
499
500        assert_eq!(responses.len(), 10);
501        for (i, response) in responses.iter().enumerate() {
502            assert!(response.err().is_none());
503            if let Some(output) = &response.data {
504                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
505            }
506        }
507    }
508
509    /// Test case 2: New request migration
510    /// Tests the scenario where a worker becomes unreachable for new requests initially,
511    /// triggering the RetryManager to retry the request. The MockEngine with FailThenSuccess
512    /// fails on the first call with a "No responders available" error, then succeeds
513    /// on subsequent calls, simulating a worker becoming available after initial failure.
514    /// Expected behavior: All 10 responses should be received successfully after retry.
515    #[tokio::test]
516    async fn test_retry_manager_new_request_migration() {
517        dynamo_runtime::logging::init();
518        let context_id = uuid::Uuid::new_v4().to_string();
519        let request = create_mock_request(10);
520        let mock_engine = Arc::new(MockEngine::new(
521            MockBehavior::FailThenSuccess,
522            10,
523            100,
524            context_id.clone(),
525        ));
526        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
527            mock_engine;
528
529        let ctx = Arc::new(Controller::new(context_id.clone()));
530        let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
531            .await
532            .expect("Failed to build RetryManager");
533
534        let mut responses = Vec::new();
535        while let Some(response) = retry_manager.next().await {
536            responses.push(response);
537        }
538
539        assert_eq!(responses.len(), 10);
540        for (i, response) in responses.iter().enumerate() {
541            assert!(response.err().is_none());
542            if let Some(output) = &response.data {
543                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
544            }
545        }
546    }
547
548    /// Test case 3: Ongoing request migration
549    /// Tests the scenario where a worker fails mid-stream during an ongoing request.
550    /// This simulates a connection being lost after partial response delivery, requiring
551    /// the RetryManager to detect the failure (via "Stream ended before generation completed" error),
552    /// create a new stream, and continue from where it left off.
553    /// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
554    #[tokio::test]
555    async fn test_retry_manager_ongoing_request_migration() {
556        dynamo_runtime::logging::init();
557
558        let context_id = uuid::Uuid::new_v4().to_string();
559        let request = create_mock_request(10);
560        let mock_engine = Arc::new(MockEngine::new(
561            MockBehavior::MidStreamFail { fail_after: 5 },
562            10,
563            100,
564            context_id.clone(),
565        ));
566        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
567            mock_engine;
568
569        let ctx = Arc::new(Controller::new(context_id.clone()));
570        let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3)
571            .await
572            .expect("Failed to build RetryManager");
573
574        let mut responses = Vec::new();
575        while let Some(response) = retry_manager.next().await {
576            responses.push(response);
577        }
578
579        // Should have received all 10 responses (5 from first stream + 5 from second stream)
580        assert_eq!(responses.len(), 10);
581
582        // Check that we received responses from both streams
583        for (i, response) in responses.iter().enumerate() {
584            assert!(response.err().is_none());
585            if let Some(output) = &response.data {
586                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
587            }
588        }
589    }
590
591    /// Test case 4: New request migration - indefinite failure
592    /// Tests the scenario where a worker becomes unreachable for new requests indefinitely.
593    /// The RetryManager should exhaust all retries and return the original error from the first attempt.
594    /// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
595    #[tokio::test]
596    async fn test_retry_manager_new_request_migration_indefinite_failure() {
597        dynamo_runtime::logging::init();
598        let context_id = uuid::Uuid::new_v4().to_string();
599        let request = create_mock_request(0);
600        let mock_engine = Arc::new(MockEngine::new(
601            MockBehavior::AlwaysFail,
602            0,
603            100,
604            context_id.clone(),
605        ));
606        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
607            mock_engine;
608
609        // Should fail to build due to initial stream creation failure after exhausting all 3 retries
610        let ctx = Arc::new(Controller::new(context_id.clone()));
611        let retry_manager_result = RetryManager::build(ctx, request, next_generate, 3).await;
612
613        assert!(retry_manager_result.is_err());
614        if let Err(error) = retry_manager_result {
615            assert!(error.to_string().contains("no responders"));
616        }
617    }
618
619    /// Test case 5: Ongoing request migration - indefinite failure
620    /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests.
621    /// The RetryManager should exhaust all retries and return the original stream disconnection error.
622    /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
623    #[tokio::test]
624    async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
625        dynamo_runtime::logging::init();
626        let context_id = uuid::Uuid::new_v4().to_string();
627        let request = create_mock_request(10);
628        let mock_engine = Arc::new(MockEngine::new(
629            MockBehavior::MidStreamFailAlways { fail_after: 3 },
630            10,
631            100,
632            context_id.clone(),
633        ));
634        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
635            mock_engine;
636
637        let ctx = Arc::new(Controller::new(context_id.clone()));
638        let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
639            .await
640            .expect("Failed to build RetryManager");
641
642        let mut responses = Vec::new();
643
644        // Collect all responses (both successful and error responses)
645        while let Some(response) = retry_manager.next().await {
646            responses.push(response);
647        }
648
649        // Should have received 4 total responses: 3 successful + 1 error
650        assert_eq!(responses.len(), 4);
651
652        // First 3 responses should be successful with tokens 101, 102, 103
653        for (i, response) in responses[0..3].iter().enumerate() {
654            assert!(response.err().is_none());
655            if let Some(output) = &response.data {
656                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
657            }
658        }
659
660        // 4th response should be an error after retries are exhausted
661        let error_response = &responses[3];
662        assert!(error_response.err().is_some());
663        if let Some(error) = error_response.err() {
664            assert!(error.to_string().contains(STREAM_ERR_MSG));
665        }
666    }
667
668    /// Test case 6: Ongoing request migration - indefinite failure with stream errors
669    /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests,
670    /// and all retry attempts also fail with stream errors instead of NATS errors.
671    /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
672    #[tokio::test]
673    async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
674        dynamo_runtime::logging::init();
675        let context_id = uuid::Uuid::new_v4().to_string();
676        let request = create_mock_request(10);
677        let mock_engine = Arc::new(MockEngine::new(
678            MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
679            10,
680            100,
681            context_id.clone(),
682        ));
683        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
684            mock_engine;
685
686        let ctx = Arc::new(Controller::new(context_id.clone()));
687        let mut retry_manager = RetryManager::build(ctx, request, next_generate, 3) // 3 retries
688            .await
689            .expect("Failed to build RetryManager");
690
691        let mut responses = Vec::new();
692
693        // Collect all responses (both successful and error responses)
694        while let Some(response) = retry_manager.next().await {
695            responses.push(response);
696        }
697
698        // Should have received 4 total responses: 3 successful + 1 error
699        assert_eq!(responses.len(), 4);
700
701        // First 3 responses should be successful with tokens 101, 102, 103
702        for (i, response) in responses[0..3].iter().enumerate() {
703            assert!(response.err().is_none());
704            if let Some(output) = &response.data {
705                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
706            }
707        }
708
709        // 4th response should be an error after retries are exhausted
710        let error_response = &responses[3];
711        assert!(error_response.err().is_some());
712        if let Some(error) = error_response.err() {
713            assert!(error.to_string().contains(STREAM_ERR_MSG));
714        }
715    }
716}