1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::{Stream, StreamExt};
9use schemars::JsonSchema;
10use secrecy::ExposeSecret;
11use serde::de::DeserializeOwned;
12
13use crate::anthropic::{AnthropicTransport, ReqwestAnthropic};
14use crate::config::{validate_base_url, Config, Provider};
15use crate::error::{redact, AiError};
16use crate::message::{ContentBlock, Message, Usage};
17use crate::thinking::ThinkingMode;
18
19#[derive(Debug, Clone, Default)]
21pub struct ChatRequest {
22 pub system: Option<String>,
26 pub messages: Vec<Message>,
29 pub temperature: Option<f32>,
31 pub max_tokens: Option<u32>,
33 pub cache_control: bool,
36 pub thinking: Option<ThinkingMode>,
39}
40
41#[derive(Debug, Clone)]
43pub struct ChatResponse {
44 pub message: Message,
46 pub usage: Usage,
48 pub citations: Vec<crate::message::Citation>,
51}
52
53#[derive(Debug, Clone)]
55#[non_exhaustive]
56pub enum ChatStreamEvent {
57 Token(String),
59 ThinkingToken(String),
62 Done(Usage),
64 Error(AiError),
66}
67
68pub struct ChatStream {
72 inner: Pin<Box<dyn Stream<Item = ChatStreamEvent> + Send>>,
73}
74
75impl ChatStream {
76 pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = ChatStreamEvent> + Send>>) -> Self {
77 Self { inner: stream }
78 }
79}
80
81impl Stream for ChatStream {
82 type Item = ChatStreamEvent;
83 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
84 self.inner.poll_next_unpin(cx)
85 }
86}
87
88impl std::fmt::Debug for ChatStream {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("ChatStream").finish_non_exhaustive()
91 }
92}
93
94enum Backend {
99 Anthropic(Arc<dyn AnthropicTransport>),
100 Genai(genai::Client),
101}
102
103impl std::fmt::Debug for Backend {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::Anthropic(_) => f.debug_struct("Backend::Anthropic").finish_non_exhaustive(),
107 Self::Genai(_) => f.debug_struct("Backend::Genai").finish_non_exhaustive(),
108 }
109 }
110}
111
112#[derive(Debug)]
114pub struct AiClient {
115 config: Config,
116 backend: Backend,
117}
118
119impl AiClient {
120 pub fn new(config: Config) -> Result<Self, AiError> {
129 Self::validate(&config)?;
130 let backend = if config.provider.is_anthropic() {
131 let client = build_reqwest_client(&config)?;
132 tracing::info!(
133 provider = ?config.provider,
134 host = %backend_host(&config),
135 "rtb-ai: AiClient ready (anthropic-direct)",
136 );
137 Backend::Anthropic(Arc::new(ReqwestAnthropic::new(Arc::new(client))))
138 } else {
139 genai_set_key(&config);
145 tracing::info!(
146 provider = ?config.provider,
147 host = %backend_host(&config),
148 "rtb-ai: AiClient ready (genai)",
149 );
150 Backend::Genai(genai::Client::default())
151 };
152 Ok(Self { config, backend })
153 }
154
155 fn validate(config: &Config) -> Result<(), AiError> {
156 if config.api_key.expose_secret().is_empty() {
157 return Err(AiError::InvalidConfig("api_key must not be empty".into()));
158 }
159 if config.model.is_empty() {
160 return Err(AiError::InvalidConfig("model must not be empty".into()));
161 }
162 if let Some(url) = &config.base_url {
163 validate_base_url(url, config.allow_insecure_base_url)?;
164 }
165 Ok(())
166 }
167
168 pub async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, AiError> {
175 match &self.backend {
176 Backend::Anthropic(t) => t.chat(&self.config, req).await,
177 Backend::Genai(c) => genai_chat(c, &self.config, req).await,
178 }
179 }
180
181 pub async fn chat_stream(&self, req: ChatRequest) -> Result<ChatStream, AiError> {
189 match &self.backend {
190 Backend::Anthropic(t) => t.chat_stream(&self.config, req).await,
191 Backend::Genai(c) => genai_chat_stream(c, &self.config, req).await,
192 }
193 }
194
195 pub async fn chat_structured<T>(&self, req: ChatRequest) -> Result<T, AiError>
209 where
210 T: DeserializeOwned + JsonSchema,
211 {
212 let schema = serde_json::to_value(schemars::schema_for!(T))
213 .map_err(|e| AiError::InvalidConfig(redact(&e.to_string())))?;
214 let augmented = augment_request_for_schema(req, &schema);
215 let resp = self.chat(augmented).await?;
216 let body =
217 resp.message.content.iter().filter_map(ContentBlock::as_text).collect::<String>();
218 let parsed: serde_json::Value = serde_json::from_str(&body)
219 .map_err(|e| AiError::Deserialize(redact(&e.to_string())))?;
220 let validator = jsonschema::validator_for(&schema)
221 .map_err(|e| AiError::SchemaValidation(redact(&e.to_string())))?;
222 if let Err(err) = validator.validate(&parsed) {
223 return Err(AiError::SchemaValidation(redact(&err.to_string())));
224 }
225 serde_json::from_value::<T>(parsed)
226 .map_err(|e| AiError::Deserialize(redact(&e.to_string())))
227 }
228}
229
230fn build_reqwest_client(config: &Config) -> Result<reqwest::Client, AiError> {
231 let mut builder = reqwest::Client::builder()
232 .https_only(!config.allow_insecure_base_url)
233 .timeout(config.timeout)
234 .user_agent(concat!("rtb-ai/", env!("CARGO_PKG_VERSION")));
235 if config.allow_insecure_base_url {
236 builder = builder.https_only(false);
240 }
241 builder.build().map_err(|e| AiError::InvalidConfig(redact(&e.to_string())))
242}
243
244fn backend_host(config: &Config) -> String {
245 config.base_url.as_ref().and_then(|u| u.host_str().map(String::from)).unwrap_or_else(|| {
246 match config.provider {
247 Provider::Anthropic | Provider::AnthropicLocal => "api.anthropic.com".into(),
248 Provider::OpenAi => "api.openai.com".into(),
249 Provider::Gemini => "generativelanguage.googleapis.com".into(),
250 Provider::Ollama => "localhost".into(),
251 Provider::OpenAiCompatible => "openai-compatible".into(),
252 }
253 })
254}
255
256fn augment_request_for_schema(mut req: ChatRequest, schema: &serde_json::Value) -> ChatRequest {
257 let instructions = format!(
258 "You MUST respond with a single JSON value matching this schema. \
259 No prose, no code fences:\n{schema}",
260 );
261 req.system = match req.system.take() {
262 Some(prefix) => Some(format!("{prefix}\n\n{instructions}")),
263 None => Some(instructions),
264 };
265 req
266}
267
268fn genai_set_key(config: &Config) {
273 let var = match config.provider {
279 Provider::OpenAi | Provider::OpenAiCompatible => "OPENAI_API_KEY",
280 Provider::Gemini => "GEMINI_API_KEY",
281 Provider::Ollama | Provider::Anthropic | Provider::AnthropicLocal => return,
284 };
285 #[allow(unsafe_code)]
288 unsafe {
289 std::env::set_var(var, config.api_key.expose_secret());
290 }
291}
292
293async fn genai_chat(
294 client: &genai::Client,
295 config: &Config,
296 req: ChatRequest,
297) -> Result<ChatResponse, AiError> {
298 let chat_req = build_genai_request(&req);
299 let resp = client
300 .exec_chat(&config.model, chat_req, None)
301 .await
302 .map_err(|e| AiError::Provider(redact(&e.to_string())))?;
303 let text = resp.first_text().unwrap_or_default().to_string();
304 let usage = genai_usage(&resp);
305 Ok(ChatResponse { message: Message::assistant(text), usage, citations: Vec::new() })
306}
307
308async fn genai_chat_stream(
309 client: &genai::Client,
310 config: &Config,
311 req: ChatRequest,
312) -> Result<ChatStream, AiError> {
313 let chat_req = build_genai_request(&req);
314 let resp = client
315 .exec_chat_stream(&config.model, chat_req, None)
316 .await
317 .map_err(|e| AiError::Provider(redact(&e.to_string())))?;
318 let stream = futures_util::StreamExt::map(resp.stream, |event| {
319 use genai::chat::ChatStreamEvent as G;
320 match event {
321 Ok(G::Chunk(chunk)) => ChatStreamEvent::Token(chunk.content),
322 Ok(G::ReasoningChunk(chunk)) => ChatStreamEvent::ThinkingToken(chunk.content),
323 Ok(G::End(end)) => ChatStreamEvent::Done(genai_usage_from_end(&end)),
324 Ok(G::Start | G::ToolCallChunk(_) | G::ThoughtSignatureChunk(_)) => {
328 ChatStreamEvent::Token(String::new())
329 }
330 Err(e) => ChatStreamEvent::Error(AiError::Provider(redact(&e.to_string()))),
331 }
332 });
333 let stream = futures_util::StreamExt::filter(stream, |e| {
335 let keep = !matches!(e, ChatStreamEvent::Token(t) if t.is_empty());
336 std::future::ready(keep)
337 });
338 Ok(ChatStream::new(Box::pin(stream)))
339}
340
341fn build_genai_request(req: &ChatRequest) -> genai::chat::ChatRequest {
342 let mut chat = genai::chat::ChatRequest::default();
343 if let Some(system) = &req.system {
344 chat = chat.with_system(system.clone());
345 }
346 for msg in &req.messages {
347 let text =
348 msg.content.iter().filter_map(ContentBlock::as_text).collect::<Vec<_>>().join("\n");
349 match msg.role {
350 crate::message::Role::User => {
351 chat = chat.append_message(genai::chat::ChatMessage::user(text));
352 }
353 crate::message::Role::Assistant => {
354 chat = chat.append_message(genai::chat::ChatMessage::assistant(text));
355 }
356 crate::message::Role::System => {
357 chat = chat.with_system(text);
358 }
359 }
360 }
361 chat
362}
363
364fn genai_usage(resp: &genai::chat::ChatResponse) -> Usage {
365 let u = &resp.usage;
366 Usage {
367 input_tokens: u.prompt_tokens.unwrap_or(0) as u32,
368 output_tokens: u.completion_tokens.unwrap_or(0) as u32,
369 cache_creation_input_tokens: 0,
370 cache_read_input_tokens: 0,
371 }
372}
373
374fn genai_usage_from_end(end: &genai::chat::StreamEnd) -> Usage {
375 end.captured_usage.as_ref().map_or_else(Usage::default, |u| Usage {
376 input_tokens: u.prompt_tokens.unwrap_or(0) as u32,
377 output_tokens: u.completion_tokens.unwrap_or(0) as u32,
378 cache_creation_input_tokens: 0,
379 cache_read_input_tokens: 0,
380 })
381}