tt_provider_compat/compat.rs
1//! Shared base for OpenAI-compatible provider adapters.
2//!
3//! Many inference providers (Mistral, Groq, Together AI, OpenRouter) expose an
4//! endpoint that is wire-compatible with OpenAI's `POST /chat/completions` API.
5//! Rather than duplicating HTTP plumbing in each adapter crate, this module
6//! provides [`OpenAICompatibleProvider`] — a single generic implementation that
7//! each adapter instantiates with its own [`CompatConfig`].
8//!
9//! # Billing note
10//!
11//! [`CompatConfig::fee_multiplier`] is stored but **not applied at request time**.
12//! Token counts and raw per-token costs flow through the response unchanged;
13//! the billing layer (in `tt-core`) multiplies by `fee_multiplier` when it
14//! computes the final USD charge displayed on the dashboard. This is intentional:
15//! the adapter should not alter usage numbers, only report them faithfully.
16//! (Tracked as a follow-up in the cost-accounting work item.)
17//!
18//! # Usage
19//!
20//! ```rust,no_run
21//! use std::collections::HashMap;
22//! use tt_provider_compat::{CompatConfig, OpenAICompatibleProvider, ClientConfig};
23//! use tt_shared::pricing::{Capability, ModelInfo, ModelPricing};
24//! use chrono::Utc;
25//!
26//! let cfg = CompatConfig {
27//! id: "my-provider",
28//! default_base_url: "https://api.example.com/v1".to_string(),
29//! models: vec![],
30//! pricing_table: HashMap::new(),
31//! fee_multiplier: 1.0,
32//! allow_local: false,
33//! };
34//! let provider = OpenAICompatibleProvider::new(ClientConfig::default(), cfg);
35//! ```
36
37use std::collections::HashMap;
38
39use async_trait::async_trait;
40use futures::stream::BoxStream;
41use reqwest::Client;
42use tracing::instrument;
43use tt_shared::{
44 filter_extra_headers, validate_provider_url, ChatCompletionChunk, ChatCompletionRequest,
45 ChatCompletionResponse, EmbeddingsRequest, EmbeddingsResponse, ModelInfo, ModelPricing,
46 Provider, ProviderError, RequestContext,
47};
48
49use crate::client::{build_client, ClientConfig};
50use crate::errors::{map_reqwest_error, map_response_error};
51use crate::{stream, translate};
52
53// ---------------------------------------------------------------------------
54// Configuration
55// ---------------------------------------------------------------------------
56
57/// Per-provider configuration for [`OpenAICompatibleProvider`].
58///
59/// Construct this once at startup (or in a lazy static) and pass it to
60/// [`OpenAICompatibleProvider::new`].
61pub struct CompatConfig {
62 /// Stable, lower-case identifier used by the routing and telemetry layers.
63 ///
64 /// Examples: `"mistral"`, `"groq"`, `"together"`, `"openrouter"`.
65 pub id: &'static str,
66
67 /// Default base URL, used when the caller's [`RequestContext`] does not
68 /// supply a `base_url` override in its credentials.
69 pub default_base_url: String,
70
71 /// All models exposed by this provider configuration.
72 pub models: Vec<ModelInfo>,
73
74 /// Pricing keyed by model ID string, mirroring the per-provider tables in
75 /// the OpenAI adapter's `pricing.rs`.
76 pub pricing_table: HashMap<String, ModelPricing>,
77
78 /// Optional fee multiplier stored for the billing layer (e.g. `1.05` for a
79 /// 5% BYOK fee on OpenRouter).
80 ///
81 /// **This value is NOT applied to usage at request time.** The adapter
82 /// faithfully reports raw token counts; the dashboard billing pass applies
83 /// the multiplier when computing the final USD charge. Default: `1.0`.
84 pub fee_multiplier: f64,
85
86 /// When `true`, skip SSRF URL validation for private/loopback addresses.
87 ///
88 /// Set to `true` only for local providers (Ollama, vLLM, LM Studio) that
89 /// legitimately target `http://localhost` or `http://127.0.0.1`. All hosted
90 /// providers must use `false`.
91 pub allow_local: bool,
92}
93
94// ---------------------------------------------------------------------------
95// Provider struct
96// ---------------------------------------------------------------------------
97
98/// Generic OpenAI-compatible chat-completion adapter.
99///
100/// Holds an HTTP client and a [`CompatConfig`] that varies per provider.
101/// All four thin adapter crates (Mistral, Groq, Together, OpenRouter) wrap
102/// this struct and forward every [`Provider`] method to it.
103pub struct OpenAICompatibleProvider {
104 client: Client,
105 cfg: CompatConfig,
106}
107
108impl OpenAICompatibleProvider {
109 /// Construct a new adapter from the given HTTP client configuration and
110 /// provider-specific configuration.
111 ///
112 /// # Panics
113 ///
114 /// Panics if the underlying [`reqwest::Client`] cannot be constructed (very
115 /// rare — only happens with invalid TLS configuration).
116 pub fn new(client_cfg: ClientConfig, cfg: CompatConfig) -> Self {
117 let client = build_client(&client_cfg)
118 .unwrap_or_else(|e| panic!("failed to build HTTP client for {}: {e}", cfg.id));
119 Self { client, cfg }
120 }
121
122 /// The fee multiplier stored in this provider's config.
123 ///
124 /// Exposed so that the billing layer can retrieve it without accessing
125 /// private fields. See [`CompatConfig::fee_multiplier`] for semantics.
126 pub fn fee_multiplier(&self) -> f64 {
127 self.cfg.fee_multiplier
128 }
129
130 /// Resolve the base URL: prefer the credential override, fall back to the
131 /// compiled-in default.
132 fn base_url<'a>(&'a self, ctx: &'a RequestContext) -> &'a str {
133 ctx.credentials
134 .base_url
135 .as_deref()
136 .unwrap_or(self.cfg.default_base_url.as_str())
137 }
138}
139
140// ---------------------------------------------------------------------------
141// Provider trait implementation
142// ---------------------------------------------------------------------------
143
144#[async_trait]
145impl Provider for OpenAICompatibleProvider {
146 fn id(&self) -> &'static str {
147 self.cfg.id
148 }
149
150 fn models(&self) -> Vec<ModelInfo> {
151 self.cfg.models.clone()
152 }
153
154 fn pricing(&self, model: &str) -> Option<ModelPricing> {
155 self.cfg.pricing_table.get(model).cloned()
156 }
157
158 fn dropped_params(&self, req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
159 crate::translate::dropped_params(req)
160 }
161
162 /// Non-streaming chat completion via `POST /chat/completions`.
163 ///
164 /// Translates the canonical request, sends it to the provider's endpoint
165 /// (resolved from credentials or the default base URL), and maps any HTTP
166 /// error to the appropriate [`ProviderError`] variant.
167 #[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
168 async fn chat_completion(
169 &self,
170 req: ChatCompletionRequest,
171 ctx: &RequestContext,
172 ) -> Result<ChatCompletionResponse, ProviderError> {
173 let base_url = self.base_url(ctx);
174 validate_provider_url(base_url, self.cfg.allow_local)
175 .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
176
177 let url = format!("{base_url}/chat/completions");
178 let body = translate::translate_request(req)?;
179
180 let mut rb = self
181 .client
182 .post(&url)
183 .header(
184 "Authorization",
185 format!("Bearer {}", ctx.credentials.api_key.expose()),
186 )
187 .header("Content-Type", "application/json")
188 .json(&body);
189
190 for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
191 rb = rb.header(name, value);
192 }
193
194 let response = rb.send().await.map_err(map_reqwest_error)?;
195
196 let status = response.status().as_u16();
197 let retry_after = response
198 .headers()
199 .get("retry-after")
200 .and_then(|v| v.to_str().ok())
201 .map(|s| s.to_string());
202
203 let response_text = response.text().await.map_err(map_reqwest_error)?;
204
205 if status >= 400 {
206 return Err(map_response_error(
207 status,
208 &response_text,
209 retry_after.as_deref(),
210 ));
211 }
212
213 translate::deserialize_response(&response_text)
214 }
215
216 /// Streaming chat completion via `POST /chat/completions` with `stream: true`.
217 ///
218 /// Returns a [`BoxStream`] that yields [`ChatCompletionChunk`] values parsed
219 /// from OpenAI-compatible SSE events. HTTP errors before the first byte are
220 /// surfaced as `Err` before any chunk is produced.
221 #[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
222 async fn chat_completion_stream(
223 &self,
224 req: ChatCompletionRequest,
225 ctx: &RequestContext,
226 ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, ProviderError>>, ProviderError> {
227 let base_url = self.base_url(ctx);
228 validate_provider_url(base_url, self.cfg.allow_local)
229 .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
230
231 let base_url = base_url.to_string();
232 let client = self.client.clone();
233 stream::stream_chat_completion(client, &base_url, req, ctx).await
234 }
235
236 /// Embeddings via `POST /embeddings`.
237 ///
238 /// Sends the canonical [`EmbeddingsRequest`] to the provider's `/embeddings`
239 /// endpoint (resolved from credentials or the default base URL). All
240 /// OpenAI-compatible providers (Mistral, Together, etc.) expose the same
241 /// `/embeddings` path and wire format, so no translation is needed beyond
242 /// what [`translate::translate_embeddings_request`] provides.
243 #[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
244 async fn embeddings(
245 &self,
246 req: EmbeddingsRequest,
247 ctx: &RequestContext,
248 ) -> Result<EmbeddingsResponse, ProviderError> {
249 let base_url = self.base_url(ctx);
250 validate_provider_url(base_url, self.cfg.allow_local)
251 .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
252
253 let url = format!("{base_url}/embeddings");
254 let body = translate::translate_embeddings_request(req)?;
255
256 let mut rb = self
257 .client
258 .post(&url)
259 .header(
260 "Authorization",
261 format!("Bearer {}", ctx.credentials.api_key.expose()),
262 )
263 .header("Content-Type", "application/json")
264 .json(&body);
265
266 for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
267 rb = rb.header(name, value);
268 }
269
270 let response = rb.send().await.map_err(map_reqwest_error)?;
271
272 let status = response.status().as_u16();
273 let retry_after = response
274 .headers()
275 .get("retry-after")
276 .and_then(|v| v.to_str().ok())
277 .map(|s| s.to_string());
278
279 let response_text = response.text().await.map_err(map_reqwest_error)?;
280
281 if status >= 400 {
282 return Err(map_response_error(
283 status,
284 &response_text,
285 retry_after.as_deref(),
286 ));
287 }
288
289 translate::deserialize_embeddings_response(&response_text)
290 }
291}