pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Timeout middleware — enforces a per-call time limit on LLM completions.
//!
//! Wraps each `complete()` call with [`tokio::time::timeout`]. If the call
//! exceeds the deadline, returns [`PeError::Timeout`].
//!
//! # Example
//!
//! ```ignore
//! use std::time::Duration;
//! use pe_core::timeout_middleware::TimeoutMiddleware;
//!
//! let timeout = TimeoutMiddleware::new(Duration::from_secs(30));
//! let stack = MiddlewareStack::new(provider).with(timeout);
//! ```

use std::time::Duration;

use async_trait::async_trait;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;

/// Enforces a per-call time limit on LLM completion requests.
///
/// Returns [`PeError::Timeout`] if the inner call exceeds the configured duration.
pub struct TimeoutMiddleware {
    duration: Duration,
}

impl TimeoutMiddleware {
    /// Create a timeout middleware with the given deadline per call.
    pub fn new(duration: Duration) -> Self {
        Self { duration }
    }
}

#[async_trait]
impl ProviderMiddleware for TimeoutMiddleware {
    async fn wrap_complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
        next: &dyn LlmProvider,
    ) -> Result<LlmResponse, PeError> {
        tokio::time::timeout(self.duration, next.complete(messages, tools))
            .await
            .map_err(|_| PeError::Timeout {
                seconds: self.duration.as_secs_f64(),
            })?
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mock_provider::MockProvider;

    #[tokio::test]
    async fn test_timeout_fast_call_succeeds() {
        let timeout = TimeoutMiddleware::new(Duration::from_secs(5));
        let provider = MockProvider::new().respond_with("fast");

        let resp = timeout.wrap_complete(&[], &[], &provider).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("fast"));
    }

    #[tokio::test]
    async fn test_timeout_slow_call_returns_error() {
        let timeout = TimeoutMiddleware::new(Duration::from_millis(10));

        // Use a provider that sleeps longer than the timeout
        struct SlowProvider;
        impl LlmProvider for SlowProvider {
            fn complete(
                &self,
                _messages: &[Message],
                _tools: &[ToolSchema],
            ) -> std::pin::Pin<
                Box<
                    dyn std::future::Future<Output = Result<crate::llm::LlmResponse, PeError>>
                        + Send
                        + '_,
                >,
            > {
                Box::pin(async {
                    tokio::time::sleep(Duration::from_secs(10)).await;
                    unreachable!("should have timed out")
                })
            }

            fn stream(
                &self,
                _messages: &[Message],
                _tools: &[ToolSchema],
            ) -> crate::llm::StreamFuture<'_> {
                unimplemented!()
            }

            fn embed(
                &self,
                _text: &str,
            ) -> std::pin::Pin<
                Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>,
            > {
                unimplemented!()
            }

            fn provider_name(&self) -> &'static str {
                "slow"
            }
        }

        let err = timeout
            .wrap_complete(&[], &[], &SlowProvider)
            .await
            .unwrap_err();
        match err {
            PeError::Timeout { seconds } => {
                assert!(
                    (seconds - 0.01).abs() < 0.001,
                    "expected ~0.01s, got {seconds}"
                );
            }
            other => panic!("expected Timeout, got {other:?}"),
        }
    }
}