use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::{Stream, StreamExt};
use schemars::JsonSchema;
use secrecy::ExposeSecret;
use serde::de::DeserializeOwned;
use crate::anthropic::{AnthropicTransport, ReqwestAnthropic};
use crate::config::{validate_base_url, Config, Provider};
use crate::error::{redact, AiError};
use crate::message::{ContentBlock, Message, Usage};
use crate::thinking::ThinkingMode;
#[derive(Debug, Clone, Default)]
pub struct ChatRequest {
pub system: Option<String>,
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub cache_control: bool,
pub thinking: Option<ThinkingMode>,
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub message: Message,
pub usage: Usage,
pub citations: Vec<crate::message::Citation>,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ChatStreamEvent {
Token(String),
ThinkingToken(String),
Done(Usage),
Error(AiError),
}
pub struct ChatStream {
inner: Pin<Box<dyn Stream<Item = ChatStreamEvent> + Send>>,
}
impl ChatStream {
pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = ChatStreamEvent> + Send>>) -> Self {
Self { inner: stream }
}
}
impl Stream for ChatStream {
type Item = ChatStreamEvent;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
impl std::fmt::Debug for ChatStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatStream").finish_non_exhaustive()
}
}
enum Backend {
Anthropic(Arc<dyn AnthropicTransport>),
Genai(genai::Client),
}
impl std::fmt::Debug for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Anthropic(_) => f.debug_struct("Backend::Anthropic").finish_non_exhaustive(),
Self::Genai(_) => f.debug_struct("Backend::Genai").finish_non_exhaustive(),
}
}
}
#[derive(Debug)]
pub struct AiClient {
config: Config,
backend: Backend,
}
impl AiClient {
pub fn new(config: Config) -> Result<Self, AiError> {
Self::validate(&config)?;
let backend = if config.provider.is_anthropic() {
let client = build_reqwest_client(&config)?;
tracing::info!(
provider = ?config.provider,
host = %backend_host(&config),
"rtb-ai: AiClient ready (anthropic-direct)",
);
Backend::Anthropic(Arc::new(ReqwestAnthropic::new(Arc::new(client))))
} else {
genai_set_key(&config);
tracing::info!(
provider = ?config.provider,
host = %backend_host(&config),
"rtb-ai: AiClient ready (genai)",
);
Backend::Genai(genai::Client::default())
};
Ok(Self { config, backend })
}
fn validate(config: &Config) -> Result<(), AiError> {
if config.api_key.expose_secret().is_empty() {
return Err(AiError::InvalidConfig("api_key must not be empty".into()));
}
if config.model.is_empty() {
return Err(AiError::InvalidConfig("model must not be empty".into()));
}
if let Some(url) = &config.base_url {
validate_base_url(url, config.allow_insecure_base_url)?;
}
Ok(())
}
pub async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, AiError> {
match &self.backend {
Backend::Anthropic(t) => t.chat(&self.config, req).await,
Backend::Genai(c) => genai_chat(c, &self.config, req).await,
}
}
pub async fn chat_stream(&self, req: ChatRequest) -> Result<ChatStream, AiError> {
match &self.backend {
Backend::Anthropic(t) => t.chat_stream(&self.config, req).await,
Backend::Genai(c) => genai_chat_stream(c, &self.config, req).await,
}
}
pub async fn chat_structured<T>(&self, req: ChatRequest) -> Result<T, AiError>
where
T: DeserializeOwned + JsonSchema,
{
let schema = serde_json::to_value(schemars::schema_for!(T))
.map_err(|e| AiError::InvalidConfig(redact(&e.to_string())))?;
let augmented = augment_request_for_schema(req, &schema);
let resp = self.chat(augmented).await?;
let body =
resp.message.content.iter().filter_map(ContentBlock::as_text).collect::<String>();
let parsed: serde_json::Value = serde_json::from_str(&body)
.map_err(|e| AiError::Deserialize(redact(&e.to_string())))?;
let validator = jsonschema::validator_for(&schema)
.map_err(|e| AiError::SchemaValidation(redact(&e.to_string())))?;
if let Err(err) = validator.validate(&parsed) {
return Err(AiError::SchemaValidation(redact(&err.to_string())));
}
serde_json::from_value::<T>(parsed)
.map_err(|e| AiError::Deserialize(redact(&e.to_string())))
}
}
fn build_reqwest_client(config: &Config) -> Result<reqwest::Client, AiError> {
let mut builder = reqwest::Client::builder()
.https_only(!config.allow_insecure_base_url)
.timeout(config.timeout)
.user_agent(concat!("rtb-ai/", env!("CARGO_PKG_VERSION")));
if config.allow_insecure_base_url {
builder = builder.https_only(false);
}
builder.build().map_err(|e| AiError::InvalidConfig(redact(&e.to_string())))
}
fn backend_host(config: &Config) -> String {
config.base_url.as_ref().and_then(|u| u.host_str().map(String::from)).unwrap_or_else(|| {
match config.provider {
Provider::Anthropic | Provider::AnthropicLocal => "api.anthropic.com".into(),
Provider::OpenAi => "api.openai.com".into(),
Provider::Gemini => "generativelanguage.googleapis.com".into(),
Provider::Ollama => "localhost".into(),
Provider::OpenAiCompatible => "openai-compatible".into(),
}
})
}
fn augment_request_for_schema(mut req: ChatRequest, schema: &serde_json::Value) -> ChatRequest {
let instructions = format!(
"You MUST respond with a single JSON value matching this schema. \
No prose, no code fences:\n{schema}",
);
req.system = match req.system.take() {
Some(prefix) => Some(format!("{prefix}\n\n{instructions}")),
None => Some(instructions),
};
req
}
fn genai_set_key(config: &Config) {
let var = match config.provider {
Provider::OpenAi | Provider::OpenAiCompatible => "OPENAI_API_KEY",
Provider::Gemini => "GEMINI_API_KEY",
Provider::Ollama | Provider::Anthropic | Provider::AnthropicLocal => return,
};
#[allow(unsafe_code)]
unsafe {
std::env::set_var(var, config.api_key.expose_secret());
}
}
async fn genai_chat(
client: &genai::Client,
config: &Config,
req: ChatRequest,
) -> Result<ChatResponse, AiError> {
let chat_req = build_genai_request(&req);
let resp = client
.exec_chat(&config.model, chat_req, None)
.await
.map_err(|e| AiError::Provider(redact(&e.to_string())))?;
let text = resp.first_text().unwrap_or_default().to_string();
let usage = genai_usage(&resp);
Ok(ChatResponse { message: Message::assistant(text), usage, citations: Vec::new() })
}
async fn genai_chat_stream(
client: &genai::Client,
config: &Config,
req: ChatRequest,
) -> Result<ChatStream, AiError> {
let chat_req = build_genai_request(&req);
let resp = client
.exec_chat_stream(&config.model, chat_req, None)
.await
.map_err(|e| AiError::Provider(redact(&e.to_string())))?;
let stream = futures_util::StreamExt::map(resp.stream, |event| {
use genai::chat::ChatStreamEvent as G;
match event {
Ok(G::Chunk(chunk)) => ChatStreamEvent::Token(chunk.content),
Ok(G::ReasoningChunk(chunk)) => ChatStreamEvent::ThinkingToken(chunk.content),
Ok(G::End(end)) => ChatStreamEvent::Done(genai_usage_from_end(&end)),
Ok(G::Start | G::ToolCallChunk(_) | G::ThoughtSignatureChunk(_)) => {
ChatStreamEvent::Token(String::new())
}
Err(e) => ChatStreamEvent::Error(AiError::Provider(redact(&e.to_string()))),
}
});
let stream = futures_util::StreamExt::filter(stream, |e| {
let keep = !matches!(e, ChatStreamEvent::Token(t) if t.is_empty());
std::future::ready(keep)
});
Ok(ChatStream::new(Box::pin(stream)))
}
fn build_genai_request(req: &ChatRequest) -> genai::chat::ChatRequest {
let mut chat = genai::chat::ChatRequest::default();
if let Some(system) = &req.system {
chat = chat.with_system(system.clone());
}
for msg in &req.messages {
let text =
msg.content.iter().filter_map(ContentBlock::as_text).collect::<Vec<_>>().join("\n");
match msg.role {
crate::message::Role::User => {
chat = chat.append_message(genai::chat::ChatMessage::user(text));
}
crate::message::Role::Assistant => {
chat = chat.append_message(genai::chat::ChatMessage::assistant(text));
}
crate::message::Role::System => {
chat = chat.with_system(text);
}
}
}
chat
}
fn genai_usage(resp: &genai::chat::ChatResponse) -> Usage {
let u = &resp.usage;
Usage {
input_tokens: u.prompt_tokens.unwrap_or(0) as u32,
output_tokens: u.completion_tokens.unwrap_or(0) as u32,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
}
fn genai_usage_from_end(end: &genai::chat::StreamEnd) -> Usage {
end.captured_usage.as_ref().map_or_else(Usage::default, |u| Usage {
input_tokens: u.prompt_tokens.unwrap_or(0) as u32,
output_tokens: u.completion_tokens.unwrap_or(0) as u32,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})
}