pub mod anthropic;
pub mod circuit;
const ERROR_BODY_MAX_BYTES: usize = 8 << 10;
pub(crate) const STREAM_MAX_TEXT_BYTES: usize = 16 << 20;
pub(crate) const STREAM_MAX_TOOL_ARGS_BYTES: usize = 1 << 20;
pub(crate) const STREAM_MAX_TOOL_CALLS: usize = 256;
pub(crate) async fn api_error_from_response(response: reqwest::Response) -> Error {
use futures::TryStreamExt;
let status = response.status().as_u16();
let message = if status == 401 || status == 403 {
format!("authentication failed (HTTP {status})")
} else {
let mut buf: Vec<u8> = Vec::with_capacity(2048);
let mut stream = response.bytes_stream();
let mut overflowed = false;
loop {
match stream.try_next().await {
Ok(Some(chunk)) => {
let remaining = ERROR_BODY_MAX_BYTES.saturating_sub(buf.len());
if remaining == 0 {
overflowed = true;
break;
}
let take = chunk.len().min(remaining);
buf.extend_from_slice(&chunk[..take]);
if take < chunk.len() {
overflowed = true;
break;
}
}
Ok(None) => break,
Err(e) => {
return Error::Api {
status,
message: format!("<body read error: {e}>"),
};
}
}
}
let mut text = String::from_utf8_lossy(&buf).to_string();
text.retain(|c| c == '\t' || (!c.is_control() && c != '\u{1b}'));
if overflowed {
text.push_str("…[truncated]");
}
text
};
Error::Api { status, message }
}
pub mod cascade;
pub mod error_class;
pub mod gemini;
pub mod openai_compat;
pub mod openrouter;
pub mod pricing;
pub mod registry;
pub mod retry;
pub mod types;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::Error;
use crate::llm::types::{CompletionRequest, CompletionResponse};
pub type OnText = dyn Fn(&str) + Send + Sync;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApprovalDecision {
Allow,
Deny,
AlwaysAllow,
AlwaysDeny,
}
impl ApprovalDecision {
pub fn is_allowed(self) -> bool {
matches!(self, Self::Allow | Self::AlwaysAllow)
}
pub fn is_persistent(self) -> bool {
matches!(self, Self::AlwaysAllow | Self::AlwaysDeny)
}
}
impl From<bool> for ApprovalDecision {
fn from(allowed: bool) -> Self {
if allowed { Self::Allow } else { Self::Deny }
}
}
pub type OnApproval = dyn Fn(&[crate::llm::types::ToolCall]) -> ApprovalDecision + Send + Sync;
pub trait LlmProvider: Send + Sync {
fn complete(
&self,
request: CompletionRequest,
) -> impl Future<Output = Result<CompletionResponse, Error>> + Send;
fn stream_complete(
&self,
request: CompletionRequest,
on_text: &OnText,
) -> impl Future<Output = Result<CompletionResponse, Error>> + Send {
let _ = on_text;
self.complete(request)
}
fn model_name(&self) -> Option<&str> {
None
}
}
pub trait DynLlmProvider: Send + Sync {
fn complete<'a>(
&'a self,
request: CompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>>;
fn stream_complete<'a>(
&'a self,
request: CompletionRequest,
on_text: &'a OnText,
) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>>;
fn model_name(&self) -> Option<&str>;
}
impl<P: LlmProvider> DynLlmProvider for P {
fn complete<'a>(
&'a self,
request: CompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>> {
Box::pin(LlmProvider::complete(self, request))
}
fn stream_complete<'a>(
&'a self,
request: CompletionRequest,
on_text: &'a OnText,
) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>> {
Box::pin(LlmProvider::stream_complete(self, request, on_text))
}
fn model_name(&self) -> Option<&str> {
LlmProvider::model_name(self)
}
}
pub struct BoxedProvider(Box<dyn DynLlmProvider>);
impl BoxedProvider {
pub fn new<P: LlmProvider + 'static>(provider: P) -> Self {
Self(Box::new(provider))
}
pub fn from_arc<P: LlmProvider + 'static>(provider: Arc<P>) -> Self {
struct ArcAdapter<P>(Arc<P>);
impl<P: LlmProvider> LlmProvider for ArcAdapter<P> {
async fn complete(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
self.0.complete(request).await
}
async fn stream_complete(
&self,
request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
self.0.stream_complete(request, on_text).await
}
fn model_name(&self) -> Option<&str> {
self.0.model_name()
}
}
Self(Box::new(ArcAdapter(provider)))
}
}
impl LlmProvider for BoxedProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
self.0.complete(request).await
}
async fn stream_complete(
&self,
request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
self.0.stream_complete(request, on_text).await
}
fn model_name(&self) -> Option<&str> {
self.0.model_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::{ContentBlock, Message, StopReason, TokenUsage};
use std::sync::{Arc, Mutex};
struct FakeProvider;
impl LlmProvider for FakeProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
Ok(CompletionResponse {
content: vec![ContentBlock::Text {
text: "fake".into(),
}],
stop_reason: StopReason::EndTurn,
usage: TokenUsage::default(),
model: None,
})
}
}
struct StreamingFakeProvider;
impl LlmProvider for StreamingFakeProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
panic!("should call stream_complete, not complete");
}
async fn stream_complete(
&self,
_request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
on_text("hello");
on_text(" world");
Ok(CompletionResponse {
content: vec![ContentBlock::Text {
text: "hello world".into(),
}],
stop_reason: StopReason::EndTurn,
usage: TokenUsage::default(),
model: None,
})
}
}
fn test_request() -> CompletionRequest {
CompletionRequest {
system: String::new(),
messages: vec![Message::user("test")],
tools: vec![],
max_tokens: 100,
tool_choice: None,
reasoning_effort: None,
}
}
#[test]
fn dyn_llm_provider_wraps_provider() {
let provider = FakeProvider;
let dyn_provider: &dyn DynLlmProvider = &provider;
let _ = dyn_provider;
}
#[tokio::test]
async fn boxed_provider_delegates_complete() {
let provider = BoxedProvider::new(FakeProvider);
let response = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(response.text(), "fake");
}
#[tokio::test]
async fn boxed_provider_delegates_stream_complete() {
let provider = BoxedProvider::new(StreamingFakeProvider);
let received = Arc::new(Mutex::new(Vec::<String>::new()));
let received_clone = received.clone();
let on_text: &OnText = &move |text: &str| {
received_clone
.lock()
.expect("test lock")
.push(text.to_string());
};
let response = LlmProvider::stream_complete(&provider, test_request(), on_text)
.await
.unwrap();
assert_eq!(response.text(), "hello world");
let texts = received.lock().expect("test lock");
assert_eq!(*texts, vec!["hello", " world"]);
}
#[test]
fn boxed_provider_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BoxedProvider>();
}
#[tokio::test]
async fn boxed_provider_default_stream_falls_back_to_complete() {
let provider = BoxedProvider::new(FakeProvider);
let on_text: &OnText = &|_| {};
let response = LlmProvider::stream_complete(&provider, test_request(), on_text)
.await
.unwrap();
assert_eq!(response.text(), "fake");
}
#[tokio::test]
async fn boxed_provider_from_arc_delegates_complete() {
let provider = Arc::new(FakeProvider);
let boxed = BoxedProvider::from_arc(provider);
let response = LlmProvider::complete(&boxed, test_request()).await.unwrap();
assert_eq!(response.text(), "fake");
}
#[tokio::test]
async fn boxed_provider_from_arc_delegates_stream_complete() {
let provider = Arc::new(StreamingFakeProvider);
let boxed = BoxedProvider::from_arc(provider);
let received = Arc::new(Mutex::new(Vec::<String>::new()));
let received_clone = received.clone();
let on_text: &OnText = &move |text: &str| {
received_clone
.lock()
.expect("test lock")
.push(text.to_string());
};
let response = LlmProvider::stream_complete(&boxed, test_request(), on_text)
.await
.unwrap();
assert_eq!(response.text(), "hello world");
let texts = received.lock().expect("test lock");
assert_eq!(*texts, vec!["hello", " world"]);
}
#[test]
fn model_name_default_is_none() {
let provider = FakeProvider;
assert!(LlmProvider::model_name(&provider).is_none());
}
#[test]
fn boxed_provider_preserves_model_name() {
struct NamedProvider;
impl LlmProvider for NamedProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
unimplemented!()
}
fn model_name(&self) -> Option<&str> {
Some("test-model")
}
}
let boxed = BoxedProvider::new(NamedProvider);
assert_eq!(LlmProvider::model_name(&boxed), Some("test-model"));
}
#[test]
fn boxed_provider_from_arc_preserves_model_name() {
struct NamedProvider;
impl LlmProvider for NamedProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
unimplemented!()
}
fn model_name(&self) -> Option<&str> {
Some("arc-model")
}
}
let boxed = BoxedProvider::from_arc(Arc::new(NamedProvider));
assert_eq!(LlmProvider::model_name(&boxed), Some("arc-model"));
}
#[tokio::test]
async fn boxed_provider_from_arc_shares_underlying_provider() {
let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
struct CountingProvider(Arc<std::sync::atomic::AtomicUsize>);
impl LlmProvider for CountingProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, crate::error::Error> {
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(CompletionResponse {
content: vec![ContentBlock::Text {
text: "counted".into(),
}],
stop_reason: StopReason::EndTurn,
usage: TokenUsage::default(),
model: None,
})
}
}
let inner = Arc::new(CountingProvider(call_count.clone()));
let boxed1 = BoxedProvider::from_arc(inner.clone());
let boxed2 = BoxedProvider::from_arc(inner);
LlmProvider::complete(&boxed1, test_request())
.await
.unwrap();
LlmProvider::complete(&boxed2, test_request())
.await
.unwrap();
assert_eq!(
call_count.load(std::sync::atomic::Ordering::Relaxed),
2,
"both boxed providers should share the same underlying provider"
);
}
#[test]
fn approval_decision_from_true() {
let decision = ApprovalDecision::from(true);
assert_eq!(decision, ApprovalDecision::Allow);
assert!(decision.is_allowed());
assert!(!decision.is_persistent());
}
#[test]
fn approval_decision_from_false() {
let decision = ApprovalDecision::from(false);
assert_eq!(decision, ApprovalDecision::Deny);
assert!(!decision.is_allowed());
assert!(!decision.is_persistent());
}
#[test]
fn approval_decision_always_allow() {
let decision = ApprovalDecision::AlwaysAllow;
assert!(decision.is_allowed());
assert!(decision.is_persistent());
}
#[test]
fn approval_decision_always_deny() {
let decision = ApprovalDecision::AlwaysDeny;
assert!(!decision.is_allowed());
assert!(decision.is_persistent());
}
}