oxify_connect_llm/
timeout.rs

1//! Timeout handling for LLM providers
2
3use crate::{
4    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5    LlmResponse, LlmStream, Result, StreamingLlmProvider,
6};
7use async_trait::async_trait;
8use std::time::Duration;
9
10/// Configuration for timeout behavior
11#[derive(Debug, Clone)]
12pub struct TimeoutConfig {
13    /// Timeout for completion requests (default: 60s)
14    pub request_timeout: Duration,
15
16    /// Timeout for streaming requests (default: 120s)
17    pub stream_timeout: Duration,
18
19    /// Timeout for embedding requests (default: 30s)
20    pub embedding_timeout: Duration,
21}
22
23impl Default for TimeoutConfig {
24    fn default() -> Self {
25        Self {
26            request_timeout: Duration::from_secs(60),
27            stream_timeout: Duration::from_secs(120),
28            embedding_timeout: Duration::from_secs(30),
29        }
30    }
31}
32
33impl TimeoutConfig {
34    /// Create a new config with a uniform timeout for all request types
35    pub fn uniform(timeout: Duration) -> Self {
36        Self {
37            request_timeout: timeout,
38            stream_timeout: timeout,
39            embedding_timeout: timeout,
40        }
41    }
42
43    /// Set the request timeout
44    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
45        self.request_timeout = timeout;
46        self
47    }
48
49    /// Set the stream timeout
50    pub fn with_stream_timeout(mut self, timeout: Duration) -> Self {
51        self.stream_timeout = timeout;
52        self
53    }
54
55    /// Set the embedding timeout
56    pub fn with_embedding_timeout(mut self, timeout: Duration) -> Self {
57        self.embedding_timeout = timeout;
58        self
59    }
60}
61
62/// A wrapper that adds timeout functionality to any LLM provider
63pub struct TimeoutProvider<P> {
64    inner: P,
65    config: TimeoutConfig,
66}
67
68impl<P> TimeoutProvider<P> {
69    /// Create a new TimeoutProvider with default timeout settings
70    pub fn new(provider: P) -> Self {
71        Self {
72            inner: provider,
73            config: TimeoutConfig::default(),
74        }
75    }
76
77    /// Create a new TimeoutProvider with custom timeout configuration
78    pub fn with_config(provider: P, config: TimeoutConfig) -> Self {
79        Self {
80            inner: provider,
81            config,
82        }
83    }
84
85    /// Create a new TimeoutProvider with a uniform timeout
86    pub fn with_timeout(provider: P, timeout: Duration) -> Self {
87        Self {
88            inner: provider,
89            config: TimeoutConfig::uniform(timeout),
90        }
91    }
92
93    /// Get a reference to the inner provider
94    pub fn inner(&self) -> &P {
95        &self.inner
96    }
97
98    /// Get a mutable reference to the inner provider
99    pub fn inner_mut(&mut self) -> &mut P {
100        &mut self.inner
101    }
102
103    /// Get the timeout configuration
104    pub fn config(&self) -> &TimeoutConfig {
105        &self.config
106    }
107}
108
109#[async_trait]
110impl<P: LlmProvider> LlmProvider for TimeoutProvider<P> {
111    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
112        let timeout = self.config.request_timeout;
113
114        match tokio::time::timeout(timeout, self.inner.complete(request)).await {
115            Ok(result) => result,
116            Err(_) => {
117                tracing::warn!(timeout_ms = timeout.as_millis(), "LLM request timed out");
118                Err(LlmError::Timeout(timeout))
119            }
120        }
121    }
122}
123
124#[async_trait]
125impl<P: StreamingLlmProvider> StreamingLlmProvider for TimeoutProvider<P> {
126    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
127        let timeout = self.config.stream_timeout;
128
129        match tokio::time::timeout(timeout, self.inner.complete_stream(request)).await {
130            Ok(result) => result,
131            Err(_) => {
132                tracing::warn!(
133                    timeout_ms = timeout.as_millis(),
134                    "LLM stream request timed out"
135                );
136                Err(LlmError::Timeout(timeout))
137            }
138        }
139    }
140}
141
142#[async_trait]
143impl<P: EmbeddingProvider> EmbeddingProvider for TimeoutProvider<P> {
144    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
145        let timeout = self.config.embedding_timeout;
146
147        match tokio::time::timeout(timeout, self.inner.embed(request)).await {
148            Ok(result) => result,
149            Err(_) => {
150                tracing::warn!(
151                    timeout_ms = timeout.as_millis(),
152                    "Embedding request timed out"
153                );
154                Err(LlmError::Timeout(timeout))
155            }
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_timeout_config_default() {
166        let config = TimeoutConfig::default();
167        assert_eq!(config.request_timeout, Duration::from_secs(60));
168        assert_eq!(config.stream_timeout, Duration::from_secs(120));
169        assert_eq!(config.embedding_timeout, Duration::from_secs(30));
170    }
171
172    #[test]
173    fn test_timeout_config_uniform() {
174        let config = TimeoutConfig::uniform(Duration::from_secs(45));
175        assert_eq!(config.request_timeout, Duration::from_secs(45));
176        assert_eq!(config.stream_timeout, Duration::from_secs(45));
177        assert_eq!(config.embedding_timeout, Duration::from_secs(45));
178    }
179
180    #[test]
181    fn test_timeout_config_builder() {
182        let config = TimeoutConfig::default()
183            .with_request_timeout(Duration::from_secs(90))
184            .with_stream_timeout(Duration::from_secs(180))
185            .with_embedding_timeout(Duration::from_secs(15));
186
187        assert_eq!(config.request_timeout, Duration::from_secs(90));
188        assert_eq!(config.stream_timeout, Duration::from_secs(180));
189        assert_eq!(config.embedding_timeout, Duration::from_secs(15));
190    }
191}