1use std::fmt;
8use std::io;
9use std::time::{Duration, Instant};
10
11use crate::runtime::RedDBRuntime;
12
13pub const CONFIG_POOL_SIZE: &str = "runtime.ai.transport_pool_size";
14pub const CONFIG_TIMEOUT_MS: &str = "runtime.ai.transport_timeout_ms";
15pub const CONFIG_RETRY_MAX_ATTEMPTS: &str = "runtime.ai.transport_retry_max_attempts";
16pub const CONFIG_RETRY_BASE_MS: &str = "runtime.ai.transport_retry_base_ms";
17
18pub const DEFAULT_POOL_SIZE: usize = 16;
19pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
20pub const DEFAULT_RETRY_MAX_ATTEMPTS: u32 = 3;
21pub const DEFAULT_RETRY_BASE_MS: u64 = 500;
22pub const DEFAULT_RETRY_CAP_MS: u64 = 10_000;
23
24#[derive(Debug, Clone)]
25pub struct AiTransportConfig {
26 pub pool_size: usize,
27 pub timeout: Duration,
28 pub retry: AiRetryConfig,
29}
30
31impl Default for AiTransportConfig {
32 fn default() -> Self {
33 Self {
34 pool_size: DEFAULT_POOL_SIZE,
35 timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
36 retry: AiRetryConfig::default(),
37 }
38 }
39}
40
41impl AiTransportConfig {
42 pub fn from_runtime(runtime: &RedDBRuntime) -> Self {
43 let defaults = Self::default();
44 Self {
45 pool_size: runtime.config_u64(CONFIG_POOL_SIZE, defaults.pool_size as u64) as usize,
46 timeout: Duration::from_millis(
47 runtime.config_u64(CONFIG_TIMEOUT_MS, DEFAULT_TIMEOUT_MS),
48 ),
49 retry: AiRetryConfig {
50 max_attempts: runtime
51 .config_u64(CONFIG_RETRY_MAX_ATTEMPTS, DEFAULT_RETRY_MAX_ATTEMPTS as u64)
52 as u32,
53 base_delay: Duration::from_millis(
54 runtime.config_u64(CONFIG_RETRY_BASE_MS, DEFAULT_RETRY_BASE_MS),
55 ),
56 max_delay: defaults.retry.max_delay,
57 },
58 }
59 .normalized()
60 }
61
62 pub fn normalized(mut self) -> Self {
63 self.pool_size = self.pool_size.max(1);
64 if self.timeout.is_zero() {
65 self.timeout = Duration::from_millis(DEFAULT_TIMEOUT_MS);
66 }
67 self.retry = self.retry.normalized();
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct AiRetryConfig {
74 pub max_attempts: u32,
75 pub base_delay: Duration,
76 pub max_delay: Duration,
77}
78
79impl Default for AiRetryConfig {
80 fn default() -> Self {
81 Self {
82 max_attempts: DEFAULT_RETRY_MAX_ATTEMPTS,
83 base_delay: Duration::from_millis(DEFAULT_RETRY_BASE_MS),
84 max_delay: Duration::from_millis(DEFAULT_RETRY_CAP_MS),
85 }
86 }
87}
88
89impl AiRetryConfig {
90 pub fn normalized(mut self) -> Self {
91 self.max_attempts = self.max_attempts.max(1);
92 if self.base_delay.is_zero() {
93 self.base_delay = Duration::from_millis(DEFAULT_RETRY_BASE_MS);
94 }
95 if self.max_delay < self.base_delay {
96 self.max_delay = self.base_delay;
97 }
98 self
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum AiHttpMethod {
104 Get,
105 Post,
106}
107
108#[derive(Debug, Clone)]
109pub struct AiHttpRequest {
110 pub provider: String,
111 pub model: Option<String>,
112 pub method: AiHttpMethod,
113 pub url: String,
114 pub headers: Vec<(String, String)>,
115 pub body: Option<String>,
116}
117
118impl AiHttpRequest {
119 pub fn post_json(provider: impl Into<String>, url: impl Into<String>, body: String) -> Self {
120 Self {
121 provider: provider.into(),
122 model: None,
123 method: AiHttpMethod::Post,
124 url: url.into(),
125 headers: vec![
126 ("content-type".to_string(), "application/json".to_string()),
127 ("accept".to_string(), "application/json".to_string()),
128 ],
129 body: Some(body),
130 }
131 }
132
133 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
134 self.headers.push((name.into(), value.into()));
135 self
136 }
137
138 pub fn model(mut self, model: impl Into<String>) -> Self {
139 self.model = Some(model.into());
140 self
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct AiHttpResponse {
146 pub status_code: u16,
147 pub body: String,
148 pub attempt_count: u32,
149 pub total_wait_ms: u64,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
153pub struct AiTransportError {
154 pub provider: String,
155 pub status_code: Option<u16>,
156 pub attempt_count: u32,
157 pub total_wait_ms: u64,
158 pub message: String,
159}
160
161impl fmt::Display for AiTransportError {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 write!(
164 f,
165 "AI transport error provider={} status_code={} attempt_count={} total_wait_ms={}: {}",
166 self.provider,
167 self.status_code
168 .map(|status| status.to_string())
169 .unwrap_or_else(|| "none".to_string()),
170 self.attempt_count,
171 self.total_wait_ms,
172 self.message
173 )
174 }
175}
176
177impl std::error::Error for AiTransportError {}
178
179#[derive(Clone)]
180pub struct AiTransport {
181 agent: ureq::Agent,
182 config: AiTransportConfig,
183}
184
185impl AiTransport {
186 pub fn new(config: AiTransportConfig) -> Self {
187 let config = config.normalized();
188 let agent: ureq::Agent = ureq::Agent::config_builder()
189 .max_idle_connections(config.pool_size)
190 .max_idle_connections_per_host(config.pool_size)
191 .timeout_global(Some(config.timeout))
192 .http_status_as_error(false)
193 .build()
194 .into();
195 Self { agent, config }
196 }
197
198 pub fn from_runtime(runtime: &RedDBRuntime) -> Self {
199 Self::new(AiTransportConfig::from_runtime(runtime))
200 }
201
202 pub fn config(&self) -> &AiTransportConfig {
203 &self.config
204 }
205
206 pub async fn request(
207 &self,
208 request: AiHttpRequest,
209 ) -> Result<AiHttpResponse, AiTransportError> {
210 let mut attempt = 0;
211 let mut total_wait = Duration::ZERO;
212 let started = Instant::now();
213 let provider = request.provider.clone();
214 let model = request
215 .model
216 .as_deref()
217 .filter(|model| !model.trim().is_empty())
218 .unwrap_or("unknown")
219 .to_string();
220
221 loop {
222 attempt += 1;
223 match self.try_request_once(request.clone()).await {
224 Ok(mut response) if response.status_code < 400 => {
225 let duration_ms = millis_u64(started.elapsed());
226 crate::runtime::ai::metrics::record_provider_request(
227 &provider,
228 &model,
229 "ok",
230 duration_ms,
231 );
232 response.attempt_count = attempt;
233 response.total_wait_ms = millis_u64(total_wait);
234 return Ok(response);
235 }
236 Ok(response) => {
237 let status_code = Some(response.status_code);
238 let message = format!("HTTP status {}", response.status_code);
239 let retryable = is_retryable_status(response.status_code);
240 let error = AiTransportError {
241 provider: request.provider.clone(),
242 status_code,
243 attempt_count: attempt,
244 total_wait_ms: millis_u64(total_wait),
245 message,
246 };
247 if !retryable || attempt >= self.config.retry.max_attempts {
248 let status = http_status_label(response.status_code);
249 crate::runtime::ai::metrics::record_provider_request(
250 &provider,
251 &model,
252 status,
253 millis_u64(started.elapsed()),
254 );
255 tracing::warn!(
256 target: "reddb::developer",
257 provider = %provider,
258 model = %model,
259 status_code = response.status_code,
260 attempt_count = attempt,
261 total_wait_ms = millis_u64(total_wait),
262 "ai provider request failed"
263 );
264 return Err(error);
265 }
266 let reason = retry_reason_for_status(response.status_code);
267 crate::runtime::ai::metrics::record_provider_retry(&provider, reason);
268 tracing::debug!(
269 target: "reddb::developer",
270 provider = %provider,
271 model = %model,
272 status_code = response.status_code,
273 attempt_count = attempt,
274 reason = reason,
275 "ai provider request retry scheduled"
276 );
277 }
278 Err(error) => {
279 let retryable = error.retryable;
280 let error = AiTransportError {
281 provider: request.provider.clone(),
282 status_code: None,
283 attempt_count: attempt,
284 total_wait_ms: millis_u64(total_wait),
285 message: error.message,
286 };
287 if !retryable || attempt >= self.config.retry.max_attempts {
288 crate::runtime::ai::metrics::record_provider_request(
289 &provider,
290 &model,
291 "transport_error",
292 millis_u64(started.elapsed()),
293 );
294 tracing::warn!(
295 target: "reddb::developer",
296 provider = %provider,
297 model = %model,
298 status_code = tracing::field::Empty,
299 attempt_count = attempt,
300 total_wait_ms = millis_u64(total_wait),
301 "ai provider request failed"
302 );
303 return Err(error);
304 }
305 crate::runtime::ai::metrics::record_provider_retry(
306 &provider,
307 "transport_error",
308 );
309 tracing::debug!(
310 target: "reddb::developer",
311 provider = %provider,
312 model = %model,
313 attempt_count = attempt,
314 reason = "transport_error",
315 "ai provider request retry scheduled"
316 );
317 }
318 }
319
320 let delay = backoff_delay(&self.config.retry, attempt);
321 total_wait += delay;
322 tokio::time::sleep(delay).await;
323 }
324 }
325
326 async fn try_request_once(
327 &self,
328 request: AiHttpRequest,
329 ) -> Result<AiHttpResponse, TransportAttemptError> {
330 let agent = self.agent.clone();
331 tokio::task::spawn_blocking(move || send_blocking(agent, request))
332 .await
333 .map_err(|err| TransportAttemptError {
334 retryable: false,
335 message: format!("request worker failed: {err}"),
336 })?
337 }
338}
339
340#[derive(Debug)]
341struct TransportAttemptError {
342 retryable: bool,
343 message: String,
344}
345
346fn send_blocking(
347 agent: ureq::Agent,
348 request: AiHttpRequest,
349) -> Result<AiHttpResponse, TransportAttemptError> {
350 let result = match request.method {
351 AiHttpMethod::Get => {
352 let mut builder = agent.get(&request.url);
353 for (name, value) in &request.headers {
354 builder = builder.header(name, value);
355 }
356 builder.call()
357 }
358 AiHttpMethod::Post => {
359 let mut builder = agent.post(&request.url);
360 for (name, value) in &request.headers {
361 builder = builder.header(name, value);
362 }
363 builder.send(request.body.unwrap_or_default())
364 }
365 };
366
367 match result {
368 Ok(mut response) => {
369 let status_code = response.status().as_u16();
370 let body =
371 response
372 .body_mut()
373 .read_to_string()
374 .map_err(|err| TransportAttemptError {
375 retryable: is_retryable_ureq_error(&err),
376 message: format!("failed to read response body: {err}"),
377 })?;
378 Ok(AiHttpResponse {
379 status_code,
380 body,
381 attempt_count: 1,
382 total_wait_ms: 0,
383 })
384 }
385 Err(err) => Err(TransportAttemptError {
386 retryable: is_retryable_ureq_error(&err),
387 message: err.to_string(),
388 }),
389 }
390}
391
392fn backoff_delay(config: &AiRetryConfig, attempt: u32) -> Duration {
393 let shift = attempt.saturating_sub(1).min(31);
394 let multiplier = 1u32 << shift;
395 config
396 .base_delay
397 .saturating_mul(multiplier)
398 .min(config.max_delay)
399}
400
401fn is_retryable_status(status: u16) -> bool {
402 status == 429 || (500..=599).contains(&status)
403}
404
405fn retry_reason_for_status(status: u16) -> &'static str {
406 if status == 429 {
407 "http_429"
408 } else if (500..=599).contains(&status) {
409 "http_5xx"
410 } else {
411 "http_error"
412 }
413}
414
415fn http_status_label(status: u16) -> &'static str {
416 if status == 429 {
417 "http_429"
418 } else if (400..=499).contains(&status) {
419 "http_4xx"
420 } else if (500..=599).contains(&status) {
421 "http_5xx"
422 } else {
423 "http_error"
424 }
425}
426
427fn is_retryable_ureq_error(err: &ureq::Error) -> bool {
428 match err {
429 ureq::Error::Timeout(_) | ureq::Error::ConnectionFailed => true,
430 ureq::Error::Io(err) => is_retryable_io_error(err),
431 _ => false,
432 }
433}
434
435fn is_retryable_io_error(err: &io::Error) -> bool {
436 matches!(
437 err.kind(),
438 io::ErrorKind::ConnectionRefused
439 | io::ErrorKind::ConnectionReset
440 | io::ErrorKind::ConnectionAborted
441 | io::ErrorKind::TimedOut
442 | io::ErrorKind::UnexpectedEof
443 )
444}
445
446fn millis_u64(duration: Duration) -> u64 {
447 duration.as_millis().min(u128::from(u64::MAX)) as u64
448}