oxify_connect_llm/
timeout.rs1use crate::{
4 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5 LlmResponse, LlmStream, Result, StreamingLlmProvider,
6};
7use async_trait::async_trait;
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct TimeoutConfig {
13 pub request_timeout: Duration,
15
16 pub stream_timeout: Duration,
18
19 pub embedding_timeout: Duration,
21}
22
23impl Default for TimeoutConfig {
24 fn default() -> Self {
25 Self {
26 request_timeout: Duration::from_secs(60),
27 stream_timeout: Duration::from_secs(120),
28 embedding_timeout: Duration::from_secs(30),
29 }
30 }
31}
32
33impl TimeoutConfig {
34 pub fn uniform(timeout: Duration) -> Self {
36 Self {
37 request_timeout: timeout,
38 stream_timeout: timeout,
39 embedding_timeout: timeout,
40 }
41 }
42
43 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
45 self.request_timeout = timeout;
46 self
47 }
48
49 pub fn with_stream_timeout(mut self, timeout: Duration) -> Self {
51 self.stream_timeout = timeout;
52 self
53 }
54
55 pub fn with_embedding_timeout(mut self, timeout: Duration) -> Self {
57 self.embedding_timeout = timeout;
58 self
59 }
60}
61
62pub struct TimeoutProvider<P> {
64 inner: P,
65 config: TimeoutConfig,
66}
67
68impl<P> TimeoutProvider<P> {
69 pub fn new(provider: P) -> Self {
71 Self {
72 inner: provider,
73 config: TimeoutConfig::default(),
74 }
75 }
76
77 pub fn with_config(provider: P, config: TimeoutConfig) -> Self {
79 Self {
80 inner: provider,
81 config,
82 }
83 }
84
85 pub fn with_timeout(provider: P, timeout: Duration) -> Self {
87 Self {
88 inner: provider,
89 config: TimeoutConfig::uniform(timeout),
90 }
91 }
92
93 pub fn inner(&self) -> &P {
95 &self.inner
96 }
97
98 pub fn inner_mut(&mut self) -> &mut P {
100 &mut self.inner
101 }
102
103 pub fn config(&self) -> &TimeoutConfig {
105 &self.config
106 }
107}
108
109#[async_trait]
110impl<P: LlmProvider> LlmProvider for TimeoutProvider<P> {
111 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
112 let timeout = self.config.request_timeout;
113
114 match tokio::time::timeout(timeout, self.inner.complete(request)).await {
115 Ok(result) => result,
116 Err(_) => {
117 tracing::warn!(timeout_ms = timeout.as_millis(), "LLM request timed out");
118 Err(LlmError::Timeout(timeout))
119 }
120 }
121 }
122}
123
124#[async_trait]
125impl<P: StreamingLlmProvider> StreamingLlmProvider for TimeoutProvider<P> {
126 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
127 let timeout = self.config.stream_timeout;
128
129 match tokio::time::timeout(timeout, self.inner.complete_stream(request)).await {
130 Ok(result) => result,
131 Err(_) => {
132 tracing::warn!(
133 timeout_ms = timeout.as_millis(),
134 "LLM stream request timed out"
135 );
136 Err(LlmError::Timeout(timeout))
137 }
138 }
139 }
140}
141
142#[async_trait]
143impl<P: EmbeddingProvider> EmbeddingProvider for TimeoutProvider<P> {
144 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
145 let timeout = self.config.embedding_timeout;
146
147 match tokio::time::timeout(timeout, self.inner.embed(request)).await {
148 Ok(result) => result,
149 Err(_) => {
150 tracing::warn!(
151 timeout_ms = timeout.as_millis(),
152 "Embedding request timed out"
153 );
154 Err(LlmError::Timeout(timeout))
155 }
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_timeout_config_default() {
166 let config = TimeoutConfig::default();
167 assert_eq!(config.request_timeout, Duration::from_secs(60));
168 assert_eq!(config.stream_timeout, Duration::from_secs(120));
169 assert_eq!(config.embedding_timeout, Duration::from_secs(30));
170 }
171
172 #[test]
173 fn test_timeout_config_uniform() {
174 let config = TimeoutConfig::uniform(Duration::from_secs(45));
175 assert_eq!(config.request_timeout, Duration::from_secs(45));
176 assert_eq!(config.stream_timeout, Duration::from_secs(45));
177 assert_eq!(config.embedding_timeout, Duration::from_secs(45));
178 }
179
180 #[test]
181 fn test_timeout_config_builder() {
182 let config = TimeoutConfig::default()
183 .with_request_timeout(Duration::from_secs(90))
184 .with_stream_timeout(Duration::from_secs(180))
185 .with_embedding_timeout(Duration::from_secs(15));
186
187 assert_eq!(config.request_timeout, Duration::from_secs(90));
188 assert_eq!(config.stream_timeout, Duration::from_secs(180));
189 assert_eq!(config.embedding_timeout, Duration::from_secs(15));
190 }
191}