1use http::HeaderMap;
2use http::HeaderName;
3use http::HeaderValue;
4use http::header;
5use serde::Deserialize;
6use serde::Serialize;
7use std::borrow::Cow;
8use std::collections::HashMap;
9use std::fmt::Display;
10use std::time::Duration;
11use strum::Display as StrumDisplay;
12use strum::IntoStaticStr;
13use url::Url;
14
15use crate::request::SessionRequestOptions;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, StrumDisplay, IntoStaticStr)]
19#[strum(serialize_all = "lowercase")]
20pub enum BuiltinProvider {
21 Anthropic,
22 Gemini,
23 OpenAI,
24 OpenRouter,
25 Ollama,
26 LmStudio,
27}
28
29impl From<BuiltinProvider> for ProviderId {
30 fn from(value: BuiltinProvider) -> Self {
31 Self(Cow::Borrowed(value.into()))
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
37pub struct ProviderId(Cow<'static, str>);
38
39impl ProviderId {
40 pub fn new(id: impl Into<String>) -> Self {
41 Self(Cow::Owned(id.into()))
42 }
43
44 pub fn as_str(&self) -> &str {
45 self.0.as_ref()
46 }
47}
48
49impl Display for ProviderId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.write_str(self.as_str())
52 }
53}
54
55impl From<&str> for ProviderId {
56 fn from(value: &str) -> Self {
57 Self::new(value)
58 }
59}
60
61impl From<String> for ProviderId {
62 fn from(value: String) -> Self {
63 Self(Cow::Owned(value))
64 }
65}
66
67impl From<&String> for ProviderId {
68 fn from(value: &String) -> Self {
69 Self::new(value.as_str())
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct ProviderDescriptor {
76 pub id: ProviderId,
77 pub display_name: Option<String>,
78 pub description: Option<String>,
79}
80
81impl ProviderDescriptor {
82 pub fn new(id: impl Into<ProviderId>) -> Self {
83 Self {
84 id: id.into(),
85 display_name: None,
86 description: None,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
93pub struct ProviderCapabilities {
94 pub supports_model_listing: bool,
95 pub supports_streaming: bool,
96 pub supports_websockets: bool,
97 pub supports_tool_calls: bool,
98 pub supports_images: bool,
99 pub supports_history_compaction: bool,
100 pub supports_memory_summarization: bool,
101 pub supports_deferred_tools: bool,
102 pub supports_hosted_tool_search: bool,
103 pub supports_hosted_web_search: bool,
104 pub supports_image_generation: bool,
105 pub supports_reasoning_effort: bool,
106 pub reports_reasoning_tokens: bool,
107 pub reports_thoughts_tokens: bool,
108 pub supports_structured_tool_results: bool,
109 pub supports_embeddings: bool,
110}
111
112#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(rename_all = "lowercase")]
115pub enum WireApi {
116 #[default]
117 Responses,
118 AnthropicMessages,
119 GeminiGenerateContent,
120}
121
122impl Display for WireApi {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 let value = match self {
125 Self::Responses => "responses",
126 Self::AnthropicMessages => "anthropic_messages",
127 Self::GeminiGenerateContent => "gemini_generate_content",
128 };
129 f.write_str(value)
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135pub struct RetryPolicy {
136 pub max_attempts: u64,
137 pub base_delay: Duration,
138 pub retry_429: bool,
139 pub retry_5xx: bool,
140 pub retry_transport: bool,
141}
142
143impl Default for RetryPolicy {
144 fn default() -> Self {
145 Self {
146 max_attempts: 5,
147 base_delay: Duration::from_millis(200),
148 retry_429: false,
149 retry_5xx: true,
150 retry_transport: true,
151 }
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct ProviderDefinition {
158 pub descriptor: ProviderDescriptor,
159 #[serde(default)]
160 pub wire_api: WireApi,
161 #[serde(default)]
162 pub auth_scheme: crate::AuthScheme,
163 #[serde(default)]
164 pub capabilities: ProviderCapabilities,
165 pub base_url: Option<String>,
166 #[serde(default)]
167 pub query_params: Option<HashMap<String, String>>,
168 #[serde(default)]
169 pub headers: Option<HashMap<String, String>>,
170 #[serde(default)]
171 pub retry: RetryPolicy,
172 #[serde(default = "default_stream_idle_timeout")]
173 pub stream_idle_timeout: Duration,
174 #[serde(default = "default_websocket_connect_timeout")]
175 pub websocket_connect_timeout: Duration,
176}
177
178fn default_stream_idle_timeout() -> Duration {
179 Duration::from_millis(300_000)
180}
181
182fn default_websocket_connect_timeout() -> Duration {
183 Duration::from_millis(15_000)
184}
185
186impl ProviderDefinition {
187 pub fn new(id: impl Into<ProviderId>) -> Self {
188 Self {
189 descriptor: ProviderDescriptor::new(id),
190 wire_api: WireApi::default(),
191 auth_scheme: crate::AuthScheme::default(),
192 capabilities: ProviderCapabilities {
193 supports_model_listing: true,
194 supports_streaming: true,
195 supports_websockets: false,
196 supports_tool_calls: true,
197 supports_images: true,
198 supports_history_compaction: false,
199 supports_memory_summarization: false,
200 supports_deferred_tools: false,
201 supports_hosted_tool_search: false,
202 supports_hosted_web_search: false,
203 supports_image_generation: false,
204 supports_reasoning_effort: false,
205 reports_reasoning_tokens: false,
206 reports_thoughts_tokens: false,
207 supports_structured_tool_results: false,
208 supports_embeddings: false,
209 },
210 base_url: None,
211 query_params: None,
212 headers: None,
213 retry: RetryPolicy::default(),
214 stream_idle_timeout: default_stream_idle_timeout(),
215 websocket_connect_timeout: default_websocket_connect_timeout(),
216 }
217 }
218
219 pub fn descriptor(&self) -> ProviderDescriptor {
220 self.descriptor.clone()
221 }
222
223 pub fn provider_id(&self) -> &ProviderId {
224 &self.descriptor.id
225 }
226
227 pub fn url_for_path(&self, path: &str) -> String {
228 let base = self
229 .base_url
230 .as_deref()
231 .unwrap_or_default()
232 .trim_end_matches('/');
233 let path = path.trim_start_matches('/');
234 let mut url = if path.is_empty() {
235 base.to_string()
236 } else {
237 format!("{base}/{path}")
238 };
239
240 if let Some(params) = &self.query_params
241 && !params.is_empty()
242 {
243 let qs = params
244 .iter()
245 .map(|(key, value)| format!("{key}={value}"))
246 .collect::<Vec<_>>()
247 .join("&");
248 url.push('?');
249 url.push_str(&qs);
250 }
251
252 url
253 }
254
255 pub fn build_headers(
256 &self,
257 credentials: &crate::ProviderCredentials,
258 ) -> Result<HeaderMap, crate::ProviderError> {
259 let mut headers = HeaderMap::new();
260
261 if let Some(configured_headers) = &self.headers {
262 for (name, value) in configured_headers {
263 insert_header(&mut headers, name, value)?;
264 }
265 }
266
267 for (name, value) in &credentials.headers {
268 insert_header(&mut headers, name, value)?;
269 }
270
271 match &self.auth_scheme {
272 crate::AuthScheme::None | crate::AuthScheme::QueryParam { .. } => {}
273 crate::AuthScheme::BearerToken => {
274 let token = required_auth_value(credentials)?;
275 let auth_value =
276 HeaderValue::from_str(&format!("Bearer {token}")).map_err(|error| {
277 crate::ProviderError::InvalidRequest(format!(
278 "invalid bearer token header: {error}"
279 ))
280 })?;
281 headers.insert(header::AUTHORIZATION, auth_value);
282 }
283 crate::AuthScheme::Header { name } => {
284 let token = required_auth_value(credentials)?;
285 insert_header(&mut headers, name, token)?;
286 }
287 }
288
289 Ok(headers)
290 }
291
292 pub fn build_headers_for_session(
293 &self,
294 credentials: &crate::ProviderCredentials,
295 session: Option<&SessionRequestOptions>,
296 fallback_turn_state: Option<&str>,
297 ) -> Result<HeaderMap, crate::ProviderError> {
298 let mut headers = self.build_headers(credentials)?;
299
300 if let Some(turn_state) = session
301 .and_then(|session| session.sticky_turn_state.as_deref())
302 .or(fallback_turn_state)
303 && let Ok(value) = HeaderValue::from_str(turn_state)
304 {
305 headers.insert("x-mentra-turn-state", value.clone());
306 headers.insert("x-codex-turn-state", value);
307 }
308 if let Some(value) = session.and_then(|session| session.turn_metadata.as_deref())
309 && let Ok(value) = HeaderValue::from_str(value)
310 {
311 headers.insert("x-mentra-turn-metadata", value.clone());
312 headers.insert("x-codex-turn-metadata", value);
313 }
314 if let Some(value) = session.and_then(|session| session.session_affinity.as_deref())
315 && let Ok(value) = HeaderValue::from_str(value)
316 {
317 headers.insert("x-mentra-session-affinity", value);
318 }
319 if let Some(prefer_connection_reuse) =
320 session.and_then(|session| session.prefer_connection_reuse)
321 {
322 headers.insert(
323 "x-mentra-connection-reuse",
324 HeaderValue::from_static(if prefer_connection_reuse {
325 "prefer-reuse"
326 } else {
327 "prefer-fresh"
328 }),
329 );
330 }
331 if let Some(value) = session.and_then(|session| session.subagent.as_deref())
332 && let Ok(value) = HeaderValue::from_str(value)
333 {
334 headers.insert("x-openai-subagent", value);
335 }
336 if let Some(extra_headers) = session.map(|session| &session.extra_headers) {
337 for (name, value) in extra_headers {
338 if let (Ok(name), Ok(value)) = (
339 name.parse::<http::HeaderName>(),
340 HeaderValue::from_str(value),
341 ) {
342 headers.insert(name, value);
343 }
344 }
345 }
346
347 Ok(headers)
348 }
349
350 pub fn request_url_with_auth_for_path(
351 &self,
352 path: &str,
353 credentials: &crate::ProviderCredentials,
354 ) -> Result<Url, crate::ProviderError> {
355 let mut url = Url::parse(&self.url_for_path(path))
356 .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
357
358 if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
359 let token = required_auth_value(credentials)?;
360 url.query_pairs_mut().append_pair(name, token);
361 }
362
363 Ok(url)
364 }
365
366 pub fn websocket_url_for_path(&self, path: &str) -> Result<Url, url::ParseError> {
367 let mut url = Url::parse(&self.url_for_path(path))?;
368
369 let scheme = match url.scheme() {
370 "http" => "ws",
371 "https" => "wss",
372 "ws" | "wss" => return Ok(url),
373 _ => return Ok(url),
374 };
375 let _ = url.set_scheme(scheme);
376 Ok(url)
377 }
378
379 pub fn websocket_url_with_auth_for_path(
380 &self,
381 path: &str,
382 credentials: &crate::ProviderCredentials,
383 ) -> Result<Url, crate::ProviderError> {
384 let mut url = self
385 .websocket_url_for_path(path)
386 .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
387
388 if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
389 let token = required_auth_value(credentials)?;
390 url.query_pairs_mut().append_pair(name, token);
391 }
392
393 Ok(url)
394 }
395}
396
397fn insert_header(
398 headers: &mut HeaderMap,
399 name: &str,
400 value: &str,
401) -> Result<(), crate::ProviderError> {
402 let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
403 crate::ProviderError::InvalidRequest(format!(
404 "invalid provider header name {name:?}: {error}"
405 ))
406 })?;
407 let header_value = HeaderValue::from_str(value).map_err(|error| {
408 crate::ProviderError::InvalidRequest(format!(
409 "invalid provider header value for {name:?}: {error}"
410 ))
411 })?;
412 headers.insert(header_name, header_value);
413 Ok(())
414}
415
416fn required_auth_value(
417 credentials: &crate::ProviderCredentials,
418) -> Result<&str, crate::ProviderError> {
419 credentials.bearer_token.as_deref().ok_or_else(|| {
420 crate::ProviderError::InvalidRequest("missing provider auth credential".to_string())
421 })
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn build_headers_applies_bearer_auth_and_static_headers() {
430 let mut definition = ProviderDefinition::new("test");
431 definition.auth_scheme = crate::AuthScheme::BearerToken;
432 definition.headers = Some(HashMap::from([(
433 "x-provider-header".to_string(),
434 "static".to_string(),
435 )]));
436
437 let headers = definition
438 .build_headers(&crate::ProviderCredentials {
439 bearer_token: Some("secret".to_string()),
440 account_id: None,
441 headers: HashMap::from([("x-runtime-header".to_string(), "dynamic".to_string())]),
442 })
443 .expect("headers should build");
444
445 assert_eq!(headers["x-provider-header"], "static");
446 assert_eq!(headers["x-runtime-header"], "dynamic");
447 assert_eq!(headers[header::AUTHORIZATION], "Bearer secret");
448 }
449
450 #[test]
451 fn request_url_with_auth_appends_query_param_auth() {
452 let mut definition = ProviderDefinition::new("test");
453 definition.base_url = Some("https://example.com/v1".to_string());
454 definition.query_params = Some(HashMap::from([(
455 "api-version".to_string(),
456 "2026".to_string(),
457 )]));
458 definition.auth_scheme = crate::AuthScheme::QueryParam {
459 name: "api-key".to_string(),
460 };
461
462 let url = definition
463 .request_url_with_auth_for_path(
464 "responses",
465 &crate::ProviderCredentials {
466 bearer_token: Some("secret".to_string()),
467 account_id: None,
468 headers: HashMap::new(),
469 },
470 )
471 .expect("url should build");
472
473 assert_eq!(
474 url.as_str(),
475 "https://example.com/v1/responses?api-version=2026&api-key=secret"
476 );
477 }
478}