use std::time::Duration;
use async_trait::async_trait;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;
pub struct TimeoutMiddleware {
duration: Duration,
}
impl TimeoutMiddleware {
pub fn new(duration: Duration) -> Self {
Self { duration }
}
}
#[async_trait]
impl ProviderMiddleware for TimeoutMiddleware {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError> {
tokio::time::timeout(self.duration, next.complete(messages, tools))
.await
.map_err(|_| PeError::Timeout {
seconds: self.duration.as_secs_f64(),
})?
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_provider::MockProvider;
#[tokio::test]
async fn test_timeout_fast_call_succeeds() {
let timeout = TimeoutMiddleware::new(Duration::from_secs(5));
let provider = MockProvider::new().respond_with("fast");
let resp = timeout.wrap_complete(&[], &[], &provider).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("fast"));
}
#[tokio::test]
async fn test_timeout_slow_call_returns_error() {
let timeout = TimeoutMiddleware::new(Duration::from_millis(10));
struct SlowProvider;
impl LlmProvider for SlowProvider {
fn complete(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<crate::llm::LlmResponse, PeError>>
+ Send
+ '_,
>,
> {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(10)).await;
unreachable!("should have timed out")
})
}
fn stream(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
) -> crate::llm::StreamFuture<'_> {
unimplemented!()
}
fn embed(
&self,
_text: &str,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>,
> {
unimplemented!()
}
fn provider_name(&self) -> &'static str {
"slow"
}
}
let err = timeout
.wrap_complete(&[], &[], &SlowProvider)
.await
.unwrap_err();
match err {
PeError::Timeout { seconds } => {
assert!(
(seconds - 0.01).abs() < 0.001,
"expected ~0.01s, got {seconds}"
);
}
other => panic!("expected Timeout, got {other:?}"),
}
}
}