use super::support::*;
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct LlmTimeouts {
pub request_timeout: Option<Duration>,
pub chunk_timeout: Duration,
}
impl Default for LlmTimeouts {
fn default() -> Self {
Self {
request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestTimeout {
Disabled,
Millis(u64),
}
impl Serialize for RequestTimeout {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Disabled => serializer.serialize_bool(false),
Self::Millis(value) => serializer.serialize_u64(*value),
}
}
}
impl<'de> Deserialize<'de> for RequestTimeout {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct RequestTimeoutVisitor;
impl Visitor<'_> for RequestTimeoutVisitor {
type Value = RequestTimeout;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a positive timeout in milliseconds or false")
}
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
if value {
return Err(E::custom("timeout must be a positive integer or false"));
}
Ok(RequestTimeout::Disabled)
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
if value == 0 {
return Err(E::custom("timeout must be greater than 0"));
}
Ok(RequestTimeout::Millis(value))
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
if value <= 0 {
return Err(E::custom("timeout must be greater than 0"));
}
Ok(RequestTimeout::Millis(value as u64))
}
}
deserializer.deserialize_any(RequestTimeoutVisitor)
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum CacheRetention {
None,
#[default]
Short,
Long,
}
impl CacheRetention {
pub fn is_default(&self) -> bool {
matches!(self, CacheRetention::Short)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ProviderOptions {
#[serde(default)]
pub reliability: ProviderReliability,
#[serde(default, skip_serializing_if = "ProviderThinkingPolicy::is_default")]
pub thinking: ProviderThinkingPolicy,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "CacheRetention::is_default")]
pub cache_retention: CacheRetention,
}
impl ProviderOptions {
pub fn is_default(&self) -> bool {
self.reliability == ProviderReliability::default_llm()
&& self.thinking == ProviderThinkingPolicy::default()
&& self.max_output_tokens.is_none()
&& self.cache_retention.is_default()
}
pub fn llm_timeouts(&self) -> LlmTimeouts {
self.reliability.timeouts.llm_timeouts()
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ProviderThinkingPolicy {
#[serde(default)]
pub expose: bool,
}
impl ProviderThinkingPolicy {
pub fn is_default(&self) -> bool {
!self.expose
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ResolvedGenerationPolicy<TThinking> {
pub max_output_tokens: u64,
pub cache_retention: CacheRetention,
pub expose_thinking: bool,
pub thinking: TThinking,
}
pub fn resolve_generation_policy<TThinking>(
generation: &crate::GenerationOptions,
options: &ProviderOptions,
provider_default_max_output_tokens: u64,
thinking: TThinking,
) -> ResolvedGenerationPolicy<TThinking> {
let max_output_tokens = generation
.output_token_cap_u64()
.or(options.max_output_tokens)
.unwrap_or(provider_default_max_output_tokens);
ResolvedGenerationPolicy {
max_output_tokens,
cache_retention: options.cache_retention,
expose_thinking: options.thinking.expose,
thinking,
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProviderReliability {
#[serde(default)]
pub timeouts: ProviderTimeoutPolicy,
#[serde(default)]
pub retry: ProviderRetryPolicy,
#[serde(default)]
pub rate_limits: ProviderRateLimitPolicy,
}
impl ProviderReliability {
pub fn default_llm() -> Self {
Self {
timeouts: ProviderTimeoutPolicy::default(),
retry: ProviderRetryPolicy::default(),
rate_limits: ProviderRateLimitPolicy::default(),
}
}
pub fn codex() -> Self {
Self {
retry: ProviderRetryPolicy {
max_attempts: 4,
base_delay_ms: 1_000,
max_delay_ms: 4_000,
jitter_ms: 0,
retry_after_cap_ms: Some(60_000),
enabled: true,
},
..Self::default_llm()
}
}
pub fn disabled() -> Self {
Self {
retry: ProviderRetryPolicy::disabled(),
rate_limits: ProviderRateLimitPolicy::default(),
timeouts: ProviderTimeoutPolicy::default(),
}
}
pub fn builder() -> ProviderReliabilityBuilder {
ProviderReliabilityBuilder {
reliability: Self::default_llm(),
}
}
}
impl Default for ProviderReliability {
fn default() -> Self {
Self::default_llm()
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct ProviderTimeoutPolicy {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_timeout: Option<RequestTimeout>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chunk_timeout: Option<u64>,
}
impl ProviderTimeoutPolicy {
pub fn llm_timeouts(&self) -> LlmTimeouts {
let request_timeout = match self.request_timeout {
Some(RequestTimeout::Disabled) => None,
Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
};
let chunk_timeout_ms = self
.chunk_timeout
.filter(|value| *value > 0)
.unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
LlmTimeouts {
request_timeout,
chunk_timeout: Duration::from_millis(chunk_timeout_ms),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProviderRetryPolicy {
pub enabled: bool,
pub max_attempts: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub jitter_ms: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retry_after_cap_ms: Option<u64>,
}
impl Default for ProviderRetryPolicy {
fn default() -> Self {
Self {
enabled: true,
max_attempts: 4,
base_delay_ms: 2_000,
max_delay_ms: 10_000,
jitter_ms: 0,
retry_after_cap_ms: Some(60_000),
}
}
}
impl ProviderRetryPolicy {
pub fn disabled() -> Self {
Self {
enabled: false,
max_attempts: 1,
base_delay_ms: 0,
max_delay_ms: 0,
jitter_ms: 0,
retry_after_cap_ms: None,
}
}
pub(crate) fn attempts(&self) -> u32 {
if self.enabled {
self.max_attempts.max(1)
} else {
1
}
}
pub(crate) fn delay_for_attempt(
&self,
retry_index: u32,
retry_after: Option<Duration>,
) -> Duration {
if let Some(retry_after) = retry_after {
return self
.retry_after_cap_ms
.map(Duration::from_millis)
.map(|cap| retry_after.min(cap))
.unwrap_or(retry_after);
}
let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
let delay_ms = self
.base_delay_ms
.saturating_mul(multiplier)
.min(self.max_delay_ms);
Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct ProviderRateLimitPolicy {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_concurrency: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub requests_per_window: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_window_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokens_per_window: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_window_ms: Option<u64>,
}
pub struct ProviderReliabilityBuilder {
reliability: ProviderReliability,
}
impl ProviderReliabilityBuilder {
pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
self.reliability.timeouts.request_timeout = timeout;
self
}
pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
self.reliability.timeouts.chunk_timeout = timeout_ms;
self
}
pub fn max_attempts(mut self, attempts: u32) -> Self {
self.reliability.retry.max_attempts = attempts.max(1);
self
}
pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
self.reliability.retry.base_delay_ms = delay_ms;
self
}
pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
self.reliability.retry.max_delay_ms = delay_ms;
self
}
pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
self.reliability.retry.retry_after_cap_ms = cap_ms;
self
}
pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
self.reliability.rate_limits.max_concurrency = value;
self
}
pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
self.reliability.rate_limits.requests_per_window = requests;
self.reliability.rate_limits.request_window_ms = window_ms;
self
}
pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
self.reliability.rate_limits.tokens_per_window = tokens;
self.reliability.rate_limits.token_window_ms = window_ms;
self
}
pub fn build(self) -> ProviderReliability {
self.reliability
}
}