Skip to main content

rtb_ai/
client.rs

1//! [`AiClient`] — typed façade over `genai` + the Anthropic-direct
2//! path.
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::{Stream, StreamExt};
9use schemars::JsonSchema;
10use secrecy::ExposeSecret;
11use serde::de::DeserializeOwned;
12
13use crate::anthropic::{AnthropicTransport, ReqwestAnthropic};
14use crate::config::{validate_base_url, Config, Provider};
15use crate::error::{redact, AiError};
16use crate::message::{ContentBlock, Message, Usage};
17use crate::thinking::ThinkingMode;
18
19/// One-shot or streaming chat request.
20#[derive(Debug, Clone, Default)]
21pub struct ChatRequest {
22    /// Optional system prompt. Goes to Anthropic's top-level
23    /// `system` field; for `genai`-backed providers it lands in the
24    /// first message with role `system`.
25    pub system: Option<String>,
26    /// Conversation history + the current user message. Last item
27    /// is conventionally the user's turn.
28    pub messages: Vec<Message>,
29    /// Sampling temperature.
30    pub temperature: Option<f32>,
31    /// Hard cap on the assistant's reply.
32    pub max_tokens: Option<u32>,
33    /// Anthropic-only: enables prompt caching at the system prompt
34    /// + first user message. Silently ignored on other providers.
35    pub cache_control: bool,
36    /// Anthropic-only: enables extended-thinking with the supplied
37    /// budget. Silently ignored on other providers.
38    pub thinking: Option<ThinkingMode>,
39}
40
41/// Non-streaming chat response.
42#[derive(Debug, Clone)]
43pub struct ChatResponse {
44    /// Assistant's reply.
45    pub message: Message,
46    /// Token counts the provider reported.
47    pub usage: Usage,
48    /// Citations, populated only on the Anthropic-direct path when
49    /// the model emits them. Empty otherwise.
50    pub citations: Vec<crate::message::Citation>,
51}
52
53/// One event from the streaming chat path.
54#[derive(Debug, Clone)]
55#[non_exhaustive]
56pub enum ChatStreamEvent {
57    /// Regular assistant token.
58    Token(String),
59    /// Anthropic-only: a token from the extended-thinking stream.
60    /// Other providers never emit this.
61    ThinkingToken(String),
62    /// Final event, carrying the cumulative usage.
63    Done(Usage),
64    /// Stream-level error; ends the stream.
65    Error(AiError),
66}
67
68/// Async stream of [`ChatStreamEvent`]s. The stream is `!Sync` to
69/// avoid pinning trade-offs in callers; it is `Send` so it can move
70/// across `tokio::spawn` boundaries.
71pub struct ChatStream {
72    inner: Pin<Box<dyn Stream<Item = ChatStreamEvent> + Send>>,
73}
74
75impl ChatStream {
76    pub(crate) fn new(stream: Pin<Box<dyn Stream<Item = ChatStreamEvent> + Send>>) -> Self {
77        Self { inner: stream }
78    }
79}
80
81impl Stream for ChatStream {
82    type Item = ChatStreamEvent;
83    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
84        self.inner.poll_next_unpin(cx)
85    }
86}
87
88impl std::fmt::Debug for ChatStream {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("ChatStream").finish_non_exhaustive()
91    }
92}
93
94// ---------------------------------------------------------------------
95// AiClient
96// ---------------------------------------------------------------------
97
98enum Backend {
99    Anthropic(Arc<dyn AnthropicTransport>),
100    Genai(genai::Client),
101}
102
103impl std::fmt::Debug for Backend {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            Self::Anthropic(_) => f.debug_struct("Backend::Anthropic").finish_non_exhaustive(),
107            Self::Genai(_) => f.debug_struct("Backend::Genai").finish_non_exhaustive(),
108        }
109    }
110}
111
112/// Typed AI client. Construct via [`AiClient::new`].
113#[derive(Debug)]
114pub struct AiClient {
115    config: Config,
116    backend: Backend,
117}
118
119impl AiClient {
120    /// Build a client. Validates `config.base_url`, builds the
121    /// underlying HTTP client, and stamps the appropriate backend
122    /// (Anthropic-direct or `genai`).
123    ///
124    /// # Errors
125    ///
126    /// [`AiError::InvalidConfig`] on a bad base URL, an empty API
127    /// key, or a `reqwest::Client` build failure.
128    pub fn new(config: Config) -> Result<Self, AiError> {
129        Self::validate(&config)?;
130        let backend = if config.provider.is_anthropic() {
131            let client = build_reqwest_client(&config)?;
132            tracing::info!(
133                provider = ?config.provider,
134                host = %backend_host(&config),
135                "rtb-ai: AiClient ready (anthropic-direct)",
136            );
137            Backend::Anthropic(Arc::new(ReqwestAnthropic::new(Arc::new(client))))
138        } else {
139            // For genai-backed providers we let genai create the
140            // underlying HTTP client. The API key is supplied via
141            // env var of the relevant provider; genai resolves it
142            // internally. We set the variable for the duration of
143            // the constructor — see `genai_set_key` below.
144            genai_set_key(&config);
145            tracing::info!(
146                provider = ?config.provider,
147                host = %backend_host(&config),
148                "rtb-ai: AiClient ready (genai)",
149            );
150            Backend::Genai(genai::Client::default())
151        };
152        Ok(Self { config, backend })
153    }
154
155    fn validate(config: &Config) -> Result<(), AiError> {
156        if config.api_key.expose_secret().is_empty() {
157            return Err(AiError::InvalidConfig("api_key must not be empty".into()));
158        }
159        if config.model.is_empty() {
160            return Err(AiError::InvalidConfig("model must not be empty".into()));
161        }
162        if let Some(url) = &config.base_url {
163            validate_base_url(url, config.allow_insecure_base_url)?;
164        }
165        Ok(())
166    }
167
168    /// One-shot chat completion.
169    ///
170    /// # Errors
171    ///
172    /// Any [`AiError`] variant — provider errors (4xx / 5xx), HTTP
173    /// transport failures, rate-limit responses.
174    pub async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, AiError> {
175        match &self.backend {
176            Backend::Anthropic(t) => t.chat(&self.config, req).await,
177            Backend::Genai(c) => genai_chat(c, &self.config, req).await,
178        }
179    }
180
181    /// Streaming chat completion.
182    ///
183    /// # Errors
184    ///
185    /// Connection-time errors surface synchronously; per-event
186    /// errors surface as [`ChatStreamEvent::Error`] inside the
187    /// returned stream.
188    pub async fn chat_stream(&self, req: ChatRequest) -> Result<ChatStream, AiError> {
189        match &self.backend {
190            Backend::Anthropic(t) => t.chat_stream(&self.config, req).await,
191            Backend::Genai(c) => genai_chat_stream(c, &self.config, req).await,
192        }
193    }
194
195    /// Structured output: validates the response against `T`'s
196    /// `JsonSchema` before deserialising.
197    ///
198    /// The request is augmented to instruct the model to emit JSON
199    /// matching the schema — see [`ChatRequest::system`]; the
200    /// caller's system prompt (if any) is prepended.
201    ///
202    /// # Errors
203    ///
204    /// [`AiError::SchemaValidation`] when the response doesn't match
205    /// the schema; [`AiError::Deserialize`] when it matches the
206    /// schema but `serde::Deserialize` for `T` rejects it; any
207    /// underlying [`AiError`] from the chat call.
208    pub async fn chat_structured<T>(&self, req: ChatRequest) -> Result<T, AiError>
209    where
210        T: DeserializeOwned + JsonSchema,
211    {
212        let schema = serde_json::to_value(schemars::schema_for!(T))
213            .map_err(|e| AiError::InvalidConfig(redact(&e.to_string())))?;
214        let augmented = augment_request_for_schema(req, &schema);
215        let resp = self.chat(augmented).await?;
216        let body =
217            resp.message.content.iter().filter_map(ContentBlock::as_text).collect::<String>();
218        let parsed: serde_json::Value = serde_json::from_str(&body)
219            .map_err(|e| AiError::Deserialize(redact(&e.to_string())))?;
220        let validator = jsonschema::validator_for(&schema)
221            .map_err(|e| AiError::SchemaValidation(redact(&e.to_string())))?;
222        if let Err(err) = validator.validate(&parsed) {
223            return Err(AiError::SchemaValidation(redact(&err.to_string())));
224        }
225        serde_json::from_value::<T>(parsed)
226            .map_err(|e| AiError::Deserialize(redact(&e.to_string())))
227    }
228}
229
230fn build_reqwest_client(config: &Config) -> Result<reqwest::Client, AiError> {
231    let mut builder = reqwest::Client::builder()
232        .https_only(!config.allow_insecure_base_url)
233        .timeout(config.timeout)
234        .user_agent(concat!("rtb-ai/", env!("CARGO_PKG_VERSION")));
235    if config.allow_insecure_base_url {
236        // `https_only(false)` already accepts http; nothing more
237        // needed but the explicit reset keeps the call self-
238        // documenting.
239        builder = builder.https_only(false);
240    }
241    builder.build().map_err(|e| AiError::InvalidConfig(redact(&e.to_string())))
242}
243
244fn backend_host(config: &Config) -> String {
245    config.base_url.as_ref().and_then(|u| u.host_str().map(String::from)).unwrap_or_else(|| {
246        match config.provider {
247            Provider::Anthropic | Provider::AnthropicLocal => "api.anthropic.com".into(),
248            Provider::OpenAi => "api.openai.com".into(),
249            Provider::Gemini => "generativelanguage.googleapis.com".into(),
250            Provider::Ollama => "localhost".into(),
251            Provider::OpenAiCompatible => "openai-compatible".into(),
252        }
253    })
254}
255
256fn augment_request_for_schema(mut req: ChatRequest, schema: &serde_json::Value) -> ChatRequest {
257    let instructions = format!(
258        "You MUST respond with a single JSON value matching this schema. \
259         No prose, no code fences:\n{schema}",
260    );
261    req.system = match req.system.take() {
262        Some(prefix) => Some(format!("{prefix}\n\n{instructions}")),
263        None => Some(instructions),
264    };
265    req
266}
267
268// ---------------------------------------------------------------------
269// genai-backed path
270// ---------------------------------------------------------------------
271
272fn genai_set_key(config: &Config) {
273    // genai reads provider keys from environment variables. Setting
274    // the var here makes the key reachable to genai's lazy client
275    // builder. SAFETY: env var mutation is a known footgun under
276    // racy multi-threaded constructor calls; rtb-ai callers
277    // construct one client per process (the typical pattern).
278    let var = match config.provider {
279        Provider::OpenAi | Provider::OpenAiCompatible => "OPENAI_API_KEY",
280        Provider::Gemini => "GEMINI_API_KEY",
281        // Local inference + the Anthropic-direct path don't go
282        // through genai's env-var key resolution.
283        Provider::Ollama | Provider::Anthropic | Provider::AnthropicLocal => return,
284    };
285    // Safety: same-process env mutation. Documented above; matches
286    // the rtb-config tests' rationale.
287    #[allow(unsafe_code)]
288    unsafe {
289        std::env::set_var(var, config.api_key.expose_secret());
290    }
291}
292
293async fn genai_chat(
294    client: &genai::Client,
295    config: &Config,
296    req: ChatRequest,
297) -> Result<ChatResponse, AiError> {
298    let chat_req = build_genai_request(&req);
299    let resp = client
300        .exec_chat(&config.model, chat_req, None)
301        .await
302        .map_err(|e| AiError::Provider(redact(&e.to_string())))?;
303    let text = resp.first_text().unwrap_or_default().to_string();
304    let usage = genai_usage(&resp);
305    Ok(ChatResponse { message: Message::assistant(text), usage, citations: Vec::new() })
306}
307
308async fn genai_chat_stream(
309    client: &genai::Client,
310    config: &Config,
311    req: ChatRequest,
312) -> Result<ChatStream, AiError> {
313    let chat_req = build_genai_request(&req);
314    let resp = client
315        .exec_chat_stream(&config.model, chat_req, None)
316        .await
317        .map_err(|e| AiError::Provider(redact(&e.to_string())))?;
318    let stream = futures_util::StreamExt::map(resp.stream, |event| {
319        use genai::chat::ChatStreamEvent as G;
320        match event {
321            Ok(G::Chunk(chunk)) => ChatStreamEvent::Token(chunk.content),
322            Ok(G::ReasoningChunk(chunk)) => ChatStreamEvent::ThinkingToken(chunk.content),
323            Ok(G::End(end)) => ChatStreamEvent::Done(genai_usage_from_end(&end)),
324            // Filtered out below — emitted as empty `Token`s and
325            // dropped by the filter step. Keeps the match exhaustive
326            // for future genai event variants.
327            Ok(G::Start | G::ToolCallChunk(_) | G::ThoughtSignatureChunk(_)) => {
328                ChatStreamEvent::Token(String::new())
329            }
330            Err(e) => ChatStreamEvent::Error(AiError::Provider(redact(&e.to_string()))),
331        }
332    });
333    // Filter out empty `Start` / tool-chunk emits.
334    let stream = futures_util::StreamExt::filter(stream, |e| {
335        let keep = !matches!(e, ChatStreamEvent::Token(t) if t.is_empty());
336        std::future::ready(keep)
337    });
338    Ok(ChatStream::new(Box::pin(stream)))
339}
340
341fn build_genai_request(req: &ChatRequest) -> genai::chat::ChatRequest {
342    let mut chat = genai::chat::ChatRequest::default();
343    if let Some(system) = &req.system {
344        chat = chat.with_system(system.clone());
345    }
346    for msg in &req.messages {
347        let text =
348            msg.content.iter().filter_map(ContentBlock::as_text).collect::<Vec<_>>().join("\n");
349        match msg.role {
350            crate::message::Role::User => {
351                chat = chat.append_message(genai::chat::ChatMessage::user(text));
352            }
353            crate::message::Role::Assistant => {
354                chat = chat.append_message(genai::chat::ChatMessage::assistant(text));
355            }
356            crate::message::Role::System => {
357                chat = chat.with_system(text);
358            }
359        }
360    }
361    chat
362}
363
364fn genai_usage(resp: &genai::chat::ChatResponse) -> Usage {
365    let u = &resp.usage;
366    Usage {
367        input_tokens: u.prompt_tokens.unwrap_or(0) as u32,
368        output_tokens: u.completion_tokens.unwrap_or(0) as u32,
369        cache_creation_input_tokens: 0,
370        cache_read_input_tokens: 0,
371    }
372}
373
374fn genai_usage_from_end(end: &genai::chat::StreamEnd) -> Usage {
375    end.captured_usage.as_ref().map_or_else(Usage::default, |u| Usage {
376        input_tokens: u.prompt_tokens.unwrap_or(0) as u32,
377        output_tokens: u.completion_tokens.unwrap_or(0) as u32,
378        cache_creation_input_tokens: 0,
379        cache_read_input_tokens: 0,
380    })
381}