pub mod anthropic;
pub mod defaults;
pub mod failover;
pub mod gemini;
pub mod health;
pub mod model_defaults;
pub mod openai;
pub mod registry;
pub mod rsclaw;
pub mod rsclaw_http;
use std::pin::Pin;
use anyhow::Result;
pub const DEFAULT_USER_AGENT: &str = concat!("rsclaw/", env!("CARGO_PKG_VERSION"));
#[cfg(test)]
#[ctor::ctor]
fn init_test_crypto() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
}
pub(crate) fn warn_unsupported_kv_cache_mode_2(provider: &str, req: &LlmRequest) {
if req.kv_cache_mode < 2 {
return;
}
use std::{
collections::HashSet,
sync::{Mutex, OnceLock},
};
static SEEN: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
let session = req.session_key.as_deref().unwrap_or("<no-session>");
let key = format!("{provider}:{session}");
let seen = SEEN.get_or_init(|| Mutex::new(HashSet::new()));
let mut guard = match seen.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
if !guard.insert(key) {
return;
}
drop(guard);
tracing::warn!(
provider,
session = session,
"kv_cache_mode=2 requested but {} provider does not support it; \
degrading to mode 0 — route mode 2 traffic through the rsclaw \
provider (RSCLAW_KEY/RSCLAW_URL) for incremental session caching",
provider,
);
}
pub(crate) fn http_client() -> reqwest::Client {
http_client_with_ua(None)
}
pub(crate) async fn send_with_transport_retry(
builder: reqwest::RequestBuilder,
) -> reqwest::Result<reqwest::Response> {
let retryable = |e: &reqwest::Error| -> bool {
use std::error::Error;
if e.is_connect() {
return true;
}
let mut src: Option<&dyn Error> = e.source();
while let Some(s) = src {
let msg = s.to_string();
if msg.contains("closed before message completed")
|| msg.contains("Connection reset")
|| msg.contains("Connection refused")
|| msg.contains("connection closed")
{
return true;
}
src = s.source();
}
false
};
let Some(retry_builder) = builder.try_clone() else {
return builder.send().await;
};
match builder.send().await {
Ok(resp) => Ok(resp),
Err(e) if retryable(&e) => {
tracing::debug!(error = %e, "http: retrying once after transport error");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
retry_builder.send().await
}
Err(e) => Err(e),
}
}
pub(crate) fn http_client_with_ua(user_agent: Option<&str>) -> reqwest::Client {
reqwest::Client::builder()
.user_agent(user_agent.unwrap_or(DEFAULT_USER_AGENT))
.connect_timeout(std::time::Duration::from_secs(20))
.pool_idle_timeout(std::time::Duration::from_secs(10))
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()
.expect("failed to build HTTP client")
}
use futures::{Stream, future::BoxFuture};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rsclaw_hidden: Option<RsclawHidden>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RsclawHidden {
pub recall_context: String,
pub recall_format: String,
pub recall_mode: String,
pub recall_doc_ids: Vec<String>,
pub recall_hash: String,
pub recall_truncated: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub recall_input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub recall_trace_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
Image {
url: String,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
is_error: Option<bool>,
},
Reasoning {
text: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum AgentEndpoint {
#[default]
Primary,
Flash,
Vision,
}
#[derive(Debug, Clone, Default)]
pub struct LlmRequest {
pub model: String,
pub fallback_models: Vec<String>,
pub messages: Vec<Message>,
pub tools: Vec<ToolDef>,
pub system: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub frequency_penalty: Option<f32>,
pub thinking_budget: Option<u32>,
pub endpoint: AgentEndpoint,
pub kv_cache_mode: u8,
pub session_key: Option<String>,
pub system_shared: Option<String>,
pub user_system: Option<String>,
pub recall: Option<RecallBundle>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct RecallBundle {
pub context: String,
pub metadata: RecallMetadata,
}
impl RecallBundle {
pub fn to_rsclaw_hidden(&self) -> Option<RsclawHidden> {
let context = self.context.trim();
if context.is_empty() || self.metadata.mode != "committed" {
return None;
}
Some(RsclawHidden {
recall_context: self.context.clone(),
recall_format: self.metadata.format.clone(),
recall_mode: self.metadata.mode.clone(),
recall_doc_ids: self.metadata.doc_ids.clone(),
recall_hash: self.metadata.hash.clone(),
recall_truncated: self.metadata.truncated,
recall_input_tokens: None,
recall_trace_id: self.metadata.trace_id.clone(),
})
}
}
pub fn redact_rsclaw_hidden_value(mut value: serde_json::Value) -> serde_json::Value {
if let Some(obj) = value.as_object_mut() {
obj.remove("rsclaw_hidden");
}
value
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RecallMetadata {
pub mode: String,
pub format: String,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub trace_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
pub doc_ids: Vec<String>,
pub hash: String,
pub truncated: bool,
}
impl Default for RecallMetadata {
fn default() -> Self {
Self {
mode: "committed".to_owned(),
format: "xml".to_owned(),
source: "server".to_owned(),
trace_id: None,
max_tokens: None,
doc_ids: Vec::new(),
hash: String::new(),
truncated: false,
}
}
}
pub fn json_f32(v: f32) -> serde_json::Value {
let rounded = (f64::from(v) * 100.0).round() / 100.0;
serde_json::json!(rounded)
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta(String),
ReasoningDelta(String),
ToolCall {
id: String,
name: String,
input: serde_json::Value,
},
Done { usage: Option<TokenUsage> },
Error(String),
}
#[derive(Debug, Clone, Default)]
pub struct TokenUsage {
pub input: u64,
pub output: u64,
pub cache_creation: u64,
pub cache_read: u64,
pub recall_tokens: u64,
pub recall_doc_ids: Vec<String>,
pub recall_hash: Option<String>,
pub recall_truncated: bool,
}
pub type LlmStream = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>;
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &str;
fn stream(&self, req: LlmRequest) -> BoxFuture<'_, Result<LlmStream>>;
#[allow(unused_variables)]
fn compact_splice<'a>(
&'a self,
session_key: &'a str,
keep_head_messages: usize,
summary: &'a str,
keep_tail_messages: usize,
expected_msgs_count: Option<usize>,
) -> BoxFuture<'a, Result<usize>> {
let name = self.name().to_owned();
Box::pin(async move { anyhow::bail!("compact splice not supported by provider {name}") })
}
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(default)]
pub struct RetryConfig {
pub attempts: u32, pub min_delay_ms: u64, pub max_delay_ms: u64, pub jitter: f64, }
impl Default for RetryConfig {
fn default() -> Self {
Self {
attempts: 3,
min_delay_ms: 400,
max_delay_ms: 30_000,
jitter: 0.1,
}
}
}
pub fn backoff_delay(attempt: u32, config: &RetryConfig) -> std::time::Duration {
let base = config.min_delay_ms as f64 * 2f64.powi(attempt as i32);
let clamped = base.min(config.max_delay_ms as f64);
let jitter = clamped * config.jitter * (attempt as f64 * 0.31 % 1.0);
std::time::Duration::from_millis((clamped + jitter) as u64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_increases_with_attempt() {
let cfg = RetryConfig::default();
let d0 = backoff_delay(0, &cfg);
let d1 = backoff_delay(1, &cfg);
let d2 = backoff_delay(2, &cfg);
assert!(
d0 < d1,
"attempt 0 ({d0:?}) should be less than attempt 1 ({d1:?})"
);
assert!(
d1 < d2,
"attempt 1 ({d1:?}) should be less than attempt 2 ({d2:?})"
);
}
#[test]
fn backoff_clamped_at_max() {
let cfg = RetryConfig::default();
let d = backoff_delay(20, &cfg);
let max_with_jitter = (cfg.max_delay_ms as f64 * (1.0 + cfg.jitter)) as u64;
assert!(
d.as_millis() as u64 <= max_with_jitter,
"delay {d:?} exceeds max+jitter bound ({max_with_jitter} ms)"
);
}
}
pub mod build;