pub mod anthropic;
pub mod defaults;
pub mod failover;
pub mod gemini;
pub mod model_defaults;
pub mod openai;
pub mod registry;
use std::pin::Pin;
use anyhow::Result;
pub(crate) const DEFAULT_USER_AGENT: &str = "rsclaw/1.0";
pub(crate) fn http_client() -> reqwest::Client {
http_client_with_ua(None)
}
pub(crate) async fn send_with_transport_retry(
builder: reqwest::RequestBuilder,
) -> reqwest::Result<reqwest::Response> {
let retryable = |e: &reqwest::Error| -> bool {
use std::error::Error;
if e.is_connect() {
return true;
}
let mut src: Option<&dyn Error> = e.source();
while let Some(s) = src {
let msg = s.to_string();
if msg.contains("closed before message completed")
|| msg.contains("Connection reset")
|| msg.contains("Connection refused")
|| msg.contains("connection closed")
{
return true;
}
src = s.source();
}
false
};
let Some(retry_builder) = builder.try_clone() else {
return builder.send().await;
};
match builder.send().await {
Ok(resp) => Ok(resp),
Err(e) if retryable(&e) => {
tracing::debug!(error = %e, "http: retrying once after transport error");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
retry_builder.send().await
}
Err(e) => Err(e),
}
}
pub(crate) fn http_client_with_ua(user_agent: Option<&str>) -> reqwest::Client {
reqwest::Client::builder()
.user_agent(user_agent.unwrap_or(DEFAULT_USER_AGENT))
.connect_timeout(std::time::Duration::from_secs(20))
.pool_idle_timeout(std::time::Duration::from_secs(10))
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()
.expect("failed to build HTTP client")
}
use futures::{Stream, future::BoxFuture};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
Image {
url: String,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
is_error: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct LlmRequest {
pub model: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolDef>,
pub system: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub frequency_penalty: Option<f32>,
pub thinking_budget: Option<u32>,
pub kv_cache_mode: u8,
pub session_key: Option<String>,
}
pub fn json_f32(v: f32) -> serde_json::Value {
let rounded = (f64::from(v) * 100.0).round() / 100.0;
serde_json::json!(rounded)
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta(String),
ReasoningDelta(String),
ToolCall {
id: String,
name: String,
input: serde_json::Value,
},
Done { usage: Option<TokenUsage> },
Error(String),
}
#[derive(Debug, Clone)]
pub struct TokenUsage {
pub input: u32,
pub output: u32,
}
pub type LlmStream = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>;
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &str;
fn stream(&self, req: LlmRequest) -> BoxFuture<'_, Result<LlmStream>>;
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(default)]
pub struct RetryConfig {
pub attempts: u32, pub min_delay_ms: u64, pub max_delay_ms: u64, pub jitter: f64, }
impl Default for RetryConfig {
fn default() -> Self {
Self {
attempts: 3,
min_delay_ms: 400,
max_delay_ms: 30_000,
jitter: 0.1,
}
}
}
pub fn backoff_delay(attempt: u32, config: &RetryConfig) -> std::time::Duration {
let base = config.min_delay_ms as f64 * 2f64.powi(attempt as i32);
let clamped = base.min(config.max_delay_ms as f64);
let jitter = clamped * config.jitter * (attempt as f64 * 0.31 % 1.0);
std::time::Duration::from_millis((clamped + jitter) as u64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_increases_with_attempt() {
let cfg = RetryConfig::default();
let d0 = backoff_delay(0, &cfg);
let d1 = backoff_delay(1, &cfg);
let d2 = backoff_delay(2, &cfg);
assert!(
d0 < d1,
"attempt 0 ({d0:?}) should be less than attempt 1 ({d1:?})"
);
assert!(
d1 < d2,
"attempt 1 ({d1:?}) should be less than attempt 2 ({d2:?})"
);
}
#[test]
fn backoff_clamped_at_max() {
let cfg = RetryConfig::default();
let d = backoff_delay(20, &cfg);
let max_with_jitter = (cfg.max_delay_ms as f64 * (1.0 + cfg.jitter)) as u64;
assert!(
d.as_millis() as u64 <= max_with_jitter,
"delay {d:?} exceeds max+jitter bound ({max_with_jitter} ms)"
);
}
}