#![allow(dead_code)]
use std::collections::HashMap;
use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
pub mod anthropic;
pub mod error;
pub mod gemini;
pub mod glm;
pub mod kimi;
pub mod locked_model;
pub mod observability;
pub mod ollama;
pub mod openai;
pub mod openai_compat;
pub mod openrouter;
pub mod retry;
pub mod tokens;
pub(crate) mod transport;
pub use anthropic::AnthropicBackend;
pub use error::{categorise_http, BackendError};
pub use gemini::GeminiBackend;
pub use glm::GLMBackend;
pub use kimi::KimiBackend;
pub use ollama::OllamaBackend;
pub use openai::OpenAIBackend;
pub use openai_compat::{OpenAICompatConfig, OpenAICompatibleBackend};
pub use openrouter::OpenRouterBackend;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Self::System => "system",
Self::User => "user",
Self::Assistant => "assistant",
Self::Tool => "tool",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Message {
pub role: Role,
pub content: String,
pub tool_call_id: Option<String>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self { role: Role::User, content: content.into(), tool_call_id: None }
}
pub fn assistant(content: impl Into<String>) -> Self {
Self { role: Role::Assistant, content: content.into(), tool_call_id: None }
}
pub fn system(content: impl Into<String>) -> Self {
Self { role: Role::System, content: content.into(), tool_call_id: None }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ToolSpec {
pub name: String,
pub description: String,
pub parameters_json: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Capability {
Streaming,
ToolUse,
Vision,
PromptCaching,
SafetySettings,
StructuredOutput,
LockedParams,
}
#[derive(Debug, Clone, Default)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
pub system: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub tools: Vec<ToolSpec>,
pub stream: bool,
pub trace_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FinishReason {
Stop,
Length,
ToolUse,
SafetyBreach,
Other(String),
}
impl FinishReason {
pub fn from_provider(provider: &str, raw: &str) -> Self {
let lc = raw.to_ascii_lowercase();
match (provider, lc.as_str()) {
("anthropic", "end_turn") => Self::Stop,
("anthropic", "max_tokens") => Self::Length,
("anthropic", "tool_use") => Self::ToolUse,
("anthropic", "stop_sequence") => Self::Stop,
(_, "stop") => Self::Stop,
(_, "length") => Self::Length,
(_, "tool_calls") | (_, "function_call") => Self::ToolUse,
(_, "content_filter") => Self::SafetyBreach,
(_, "max_tokens") => Self::Length,
(_, "safety") => Self::SafetyBreach,
(_, "") => Self::Other(String::new()),
_ => Self::Other(raw.to_string()),
}
}
pub fn is_safety_breach(&self) -> bool {
matches!(self, Self::SafetyBreach)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_creation_tokens: u32,
pub reasoning_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: String,
pub model_name: String,
pub provider_name: String,
pub finish_reason: FinishReason,
pub usage: Usage,
pub retry_count: u32,
pub trace_id: String,
}
#[derive(Debug, Clone, Default)]
pub struct ChatChunk {
pub delta: String,
pub finish_reason: Option<FinishReason>,
pub usage: Option<Usage>,
}
pub type ChatStream =
Pin<Box<dyn Stream<Item = Result<ChatChunk, BackendError>> + Send>>;
#[async_trait]
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn default_model(&self) -> &str;
async fn complete(&self, request: ChatRequest) -> Result<ChatResponse, BackendError>;
async fn stream(&self, request: ChatRequest) -> Result<ChatStream, BackendError>;
fn count_tokens(&self, model: &str, text: &str) -> usize {
tokens::count_tokens(model, text).count
}
#[allow(unused_variables)]
fn supports(&self, capability: Capability, model: &str) -> bool {
false
}
}
pub struct Registry {
backends: HashMap<String, Box<dyn Backend>>,
}
impl Registry {
pub fn empty() -> Self {
Self { backends: HashMap::new() }
}
pub fn production() -> Self {
let mut registry = Self::empty();
registry.register(Box::new(anthropic::AnthropicBackend::from_env()));
registry.register(Box::new(gemini::GeminiBackend::from_env()));
registry.register(Box::new(glm::GLMBackend::from_env()));
registry.register(Box::new(kimi::KimiBackend::from_env()));
registry.register(Box::new(ollama::OllamaBackend::from_env()));
registry.register(Box::new(openai::OpenAIBackend::from_env()));
registry.register(Box::new(openrouter::OpenRouterBackend::from_env()));
registry
}
pub fn register(&mut self, backend: Box<dyn Backend>) {
self.backends.insert(backend.name().to_string(), backend);
}
pub fn get(&self, name: &str) -> Option<&dyn Backend> {
self.backends.get(name).map(|b| b.as_ref())
}
pub fn provider_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.backends.keys().cloned().collect();
names.sort();
names
}
pub fn len(&self) -> usize {
self.backends.len()
}
pub fn is_empty(&self) -> bool {
self.backends.is_empty()
}
}
impl Default for Registry {
fn default() -> Self {
Self::production()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
struct StubBackend {
name: String,
}
#[async_trait]
impl Backend for StubBackend {
fn name(&self) -> &str {
&self.name
}
fn default_model(&self) -> &str {
"stub-model"
}
async fn complete(
&self,
_request: ChatRequest,
) -> Result<ChatResponse, BackendError> {
Ok(ChatResponse {
content: "stubbed".into(),
model_name: "stub-model".into(),
provider_name: self.name.clone(),
finish_reason: FinishReason::Stop,
usage: Usage::default(),
retry_count: 0,
trace_id: "stub".into(),
})
}
async fn stream(
&self,
_request: ChatRequest,
) -> Result<ChatStream, BackendError> {
let chunks = vec![
Ok(ChatChunk { delta: "hi ".into(), ..Default::default() }),
Ok(ChatChunk {
delta: "world".into(),
finish_reason: Some(FinishReason::Stop),
usage: Some(Usage { input_tokens: 1, output_tokens: 2, total_tokens: 3, ..Default::default() }),
}),
];
Ok(Box::pin(futures::stream::iter(chunks)))
}
fn supports(&self, capability: Capability, _model: &str) -> bool {
matches!(capability, Capability::Streaming)
}
}
fn stub(name: &str) -> Box<dyn Backend> {
Box::new(StubBackend { name: name.to_string() })
}
#[test]
fn role_round_trips_via_as_str() {
for r in [Role::System, Role::User, Role::Assistant, Role::Tool] {
assert!(!r.as_str().is_empty());
}
assert_eq!(Role::User.as_str(), "user");
}
#[test]
fn message_helpers_set_role() {
assert_eq!(Message::user("a").role, Role::User);
assert_eq!(Message::assistant("b").role, Role::Assistant);
assert_eq!(Message::system("c").role, Role::System);
}
#[test]
fn chat_request_default_is_empty() {
let r = ChatRequest::default();
assert!(r.model.is_empty());
assert!(r.messages.is_empty());
assert!(r.tools.is_empty());
assert!(!r.stream);
}
#[test]
fn finish_reason_anthropic_mapping() {
assert_eq!(FinishReason::from_provider("anthropic", "end_turn"), FinishReason::Stop);
assert_eq!(FinishReason::from_provider("anthropic", "max_tokens"), FinishReason::Length);
assert_eq!(FinishReason::from_provider("anthropic", "tool_use"), FinishReason::ToolUse);
assert_eq!(FinishReason::from_provider("anthropic", "stop_sequence"), FinishReason::Stop);
}
#[test]
fn finish_reason_openai_mapping() {
assert_eq!(FinishReason::from_provider("openai", "stop"), FinishReason::Stop);
assert_eq!(FinishReason::from_provider("openai", "length"), FinishReason::Length);
assert_eq!(FinishReason::from_provider("openai", "tool_calls"), FinishReason::ToolUse);
assert_eq!(FinishReason::from_provider("openai", "content_filter"), FinishReason::SafetyBreach);
}
#[test]
fn finish_reason_gemini_mapping_uppercase() {
assert_eq!(FinishReason::from_provider("gemini", "STOP"), FinishReason::Stop);
assert_eq!(FinishReason::from_provider("gemini", "MAX_TOKENS"), FinishReason::Length);
assert_eq!(FinishReason::from_provider("gemini", "SAFETY"), FinishReason::SafetyBreach);
}
#[test]
fn finish_reason_unknown_preserves_raw() {
let r = FinishReason::from_provider("openai", "weird_signal");
assert_eq!(r, FinishReason::Other("weird_signal".into()));
}
#[test]
fn finish_reason_safety_breach_predicate() {
assert!(FinishReason::SafetyBreach.is_safety_breach());
assert!(!FinishReason::Stop.is_safety_breach());
assert!(!FinishReason::Other("anything".into()).is_safety_breach());
}
#[test]
fn registry_empty_then_register() {
let mut r = Registry::empty();
assert_eq!(r.len(), 0);
r.register(stub("anthropic"));
assert_eq!(r.len(), 1);
assert!(r.get("anthropic").is_some());
assert!(r.get("openai").is_none());
}
#[test]
fn registry_provider_names_sorted() {
let mut r = Registry::empty();
r.register(stub("openai"));
r.register(stub("anthropic"));
r.register(stub("gemini"));
assert_eq!(
r.provider_names(),
vec!["anthropic".to_string(), "gemini".to_string(), "openai".to_string()]
);
}
#[test]
fn registry_replace_on_duplicate_register() {
let mut r = Registry::empty();
r.register(stub("anthropic"));
r.register(stub("anthropic"));
assert_eq!(r.len(), 1); }
#[tokio::test]
async fn stub_complete_returns_response() {
let b = StubBackend { name: "stub".into() };
let resp = b.complete(ChatRequest::default()).await.unwrap();
assert_eq!(resp.content, "stubbed");
assert_eq!(resp.provider_name, "stub");
assert_eq!(resp.finish_reason, FinishReason::Stop);
}
#[tokio::test]
async fn stub_stream_yields_chunks() {
let b = StubBackend { name: "stub".into() };
let stream = b.stream(ChatRequest::default()).await.unwrap();
let chunks: Vec<_> = stream.collect().await;
assert_eq!(chunks.len(), 2);
let first = chunks[0].as_ref().unwrap();
assert_eq!(first.delta, "hi ");
assert!(first.finish_reason.is_none());
let last = chunks[1].as_ref().unwrap();
assert_eq!(last.delta, "world");
assert!(matches!(last.finish_reason, Some(FinishReason::Stop)));
let usage = last.usage.as_ref().unwrap();
assert_eq!(usage.total_tokens, 3);
}
#[tokio::test]
async fn registry_dispatches_to_correct_backend() {
let mut r = Registry::empty();
r.register(stub("anthropic"));
r.register(stub("openai"));
let b = r.get("openai").expect("openai registered");
let resp = b.complete(ChatRequest::default()).await.unwrap();
assert_eq!(resp.provider_name, "openai");
}
#[test]
fn supports_capability_default_false() {
struct DefaultBackend;
#[async_trait]
impl Backend for DefaultBackend {
fn name(&self) -> &str {
"default"
}
fn default_model(&self) -> &str {
""
}
async fn complete(
&self,
_r: ChatRequest,
) -> Result<ChatResponse, BackendError> {
unreachable!()
}
async fn stream(
&self,
_r: ChatRequest,
) -> Result<ChatStream, BackendError> {
unreachable!()
}
}
let b = DefaultBackend;
assert!(!b.supports(Capability::Streaming, "anything"));
assert!(!b.supports(Capability::ToolUse, "anything"));
}
#[test]
fn count_tokens_default_uses_unified_dispatch() {
let b = StubBackend { name: "stub".into() };
let n = b.count_tokens("gpt-4o-mini", "hello world");
assert!(n > 0);
}
}