use std::sync::Arc;
use std::time::Duration;
use camel_api::Body;
use camel_component_api::NetworkRetryPolicy;
use tower::Service;
use crate::LlmEndpointConfig;
use crate::config::LlmOperation;
use crate::error::LlmError;
use crate::producer::LlmProducer;
use crate::provider::LlmProvider;
use crate::provider::mock::{MockMode, MockProvider};
use super::producer_test_helpers::{
make_exchange, make_producer_with_concurrency_and_retry, make_producer_with_retry,
make_producer_with_timeout_and_retry,
};
#[tokio::test]
async fn retry_succeeds_after_transient_failure() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into()))
.with_fail_after(1, LlmError::Network("boom".into())),
);
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
enabled: true,
max_attempts: 3,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
};
let mut producer = make_producer_with_retry(Arc::clone(&provider), false, Some(policy));
let out = producer
.call(make_exchange(Body::Text("x".into())))
.await
.unwrap();
assert!(matches!(out.input.body, Body::Text(_)));
assert_eq!(mock.call_count(), 2);
}
#[tokio::test]
async fn retry_honors_retry_after_over_backoff() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into()))
.with_rate_limit(Some(Duration::from_millis(60))),
);
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
enabled: true,
max_attempts: 2,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
};
let mut producer = make_producer_with_retry(Arc::clone(&provider), false, Some(policy));
let start = std::time::Instant::now();
let _ = producer.call(make_exchange(Body::Text("x".into()))).await;
assert!(start.elapsed() >= Duration::from_millis(55));
}
#[tokio::test]
async fn no_retry_in_streaming_mode() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into()))
.with_fail_after(1, LlmError::Network("boom".into())),
);
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
enabled: true,
max_attempts: 3,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
};
let mut producer = make_producer_with_retry(Arc::clone(&provider), true, Some(policy));
let _ = producer.call(make_exchange(Body::Text("x".into()))).await;
assert_eq!(mock.call_count(), 1, "streaming must not retry");
}
#[tokio::test]
async fn no_retry_after_content_start() {
let mock = Arc::new(MockProvider::new(
"t",
MockMode::Error(LlmError::Network("boom".into())),
));
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
enabled: true,
max_attempts: 3,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
};
let mut producer = make_producer_with_retry(Arc::clone(&provider), false, Some(policy));
let _ = producer.call(make_exchange(Body::Text("x".into()))).await;
assert_eq!(mock.call_count(), 1, "must not retry after content-started");
}
#[tokio::test]
async fn total_timeout_fires_during_retry_backoff() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into()))
.with_rate_limit(Some(Duration::from_millis(200))),
);
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
max_attempts: 10,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
enabled: true,
};
let mut producer = make_producer_with_timeout_and_retry(
provider,
false,
Duration::from_millis(50), Some(policy),
);
let start = std::time::Instant::now();
let result = producer.call(make_exchange(Body::Text("x".into()))).await;
let elapsed = start.elapsed();
assert!(result.is_err(), "must error with timeout");
assert!(
elapsed < Duration::from_millis(150),
"total deadline must cut backoff short, elapsed: {elapsed:?}"
);
assert!(
mock.call_count() <= 2,
"total deadline must prevent excessive retries, got: {}",
mock.call_count()
);
}
#[tokio::test]
async fn permit_released_during_retry_backoff() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into()))
.with_delay(Duration::from_millis(40))
.with_fail_after(1, LlmError::Network("boom".into()))
.with_concurrent_tracker(),
);
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
max_attempts: 3,
initial_delay: Duration::from_millis(80), multiplier: 1.0,
max_delay: Duration::from_millis(200),
jitter_factor: 0.0,
enabled: true,
};
let producer = make_producer_with_concurrency_and_retry(provider, 1, Some(policy));
let p = Arc::new(producer);
let start = std::time::Instant::now();
let mut handles = vec![];
for _ in 0..2 {
let p = p.clone();
handles.push(tokio::spawn(async move {
let mut prod = (*p).clone();
prod.call(make_exchange(Body::Text("x".into()))).await
}));
}
for h in handles {
let _ = h.await;
}
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(180),
"permit must be released during backoff — total elapsed {elapsed:?} suggests sequential execution",
);
}
#[tokio::test]
async fn retry_exhaustion_surfaces_last_error() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into())).with_rate_limit(None), );
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
enabled: true,
max_attempts: 3,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
};
let mut producer = make_producer_with_retry(Arc::clone(&provider), false, Some(policy));
let result = producer.call(make_exchange(Body::Text("x".into()))).await;
assert!(result.is_err(), "retry exhaustion must produce an error");
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.to_lowercase().contains("rate limited"),
"surfaced error must mention rate limit, got: {msg}"
);
assert_eq!(
mock.call_count(),
3,
"provider must be called exactly max_attempts times"
);
}
#[tokio::test]
async fn embed_retries_on_transient_failure() {
let mock = Arc::new(
MockProvider::new("t", MockMode::Fixed("ok".into()))
.with_fail_after(1, LlmError::Network("boom".into())),
);
let provider = mock.clone() as Arc<dyn LlmProvider>;
let policy = NetworkRetryPolicy {
max_attempts: 3,
initial_delay: Duration::from_millis(1),
multiplier: 1.0,
max_delay: Duration::from_millis(5),
jitter_factor: 0.0,
enabled: true,
};
let config = LlmEndpointConfig {
operation: LlmOperation::Embed,
stream: false,
..Default::default()
};
let mut producer = LlmProducer::new(config, provider, 32768, "test-route".into())
.with_retry(Some(policy))
.build();
let _ = producer.call(make_exchange(Body::Text("x".into()))).await;
assert_eq!(mock.call_count(), 2, "embed must retry transient failures");
}