1use http::HeaderMap;
2use http::HeaderName;
3use http::HeaderValue;
4use http::header;
5use serde::Deserialize;
6use serde::Serialize;
7use std::borrow::Cow;
8use std::collections::HashMap;
9use std::fmt::Display;
10use std::time::Duration;
11use strum::Display as StrumDisplay;
12use strum::IntoStaticStr;
13use url::Url;
14
15use crate::request::SessionRequestOptions;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, StrumDisplay, IntoStaticStr)]
19#[strum(serialize_all = "lowercase")]
20pub enum BuiltinProvider {
21 Anthropic,
22 Gemini,
23 OpenAI,
24 OpenRouter,
25 Ollama,
26 LmStudio,
27}
28
29impl From<BuiltinProvider> for ProviderId {
30 fn from(value: BuiltinProvider) -> Self {
31 Self(Cow::Borrowed(value.into()))
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
37pub struct ProviderId(Cow<'static, str>);
38
39impl ProviderId {
40 pub fn new(id: impl Into<String>) -> Self {
41 Self(Cow::Owned(id.into()))
42 }
43
44 pub fn as_str(&self) -> &str {
45 self.0.as_ref()
46 }
47}
48
49impl Display for ProviderId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.write_str(self.as_str())
52 }
53}
54
55impl From<&str> for ProviderId {
56 fn from(value: &str) -> Self {
57 Self::new(value)
58 }
59}
60
61impl From<String> for ProviderId {
62 fn from(value: String) -> Self {
63 Self(Cow::Owned(value))
64 }
65}
66
67impl From<&String> for ProviderId {
68 fn from(value: &String) -> Self {
69 Self::new(value.as_str())
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct ProviderDescriptor {
76 pub id: ProviderId,
77 pub display_name: Option<String>,
78 pub description: Option<String>,
79}
80
81impl ProviderDescriptor {
82 pub fn new(id: impl Into<ProviderId>) -> Self {
83 Self {
84 id: id.into(),
85 display_name: None,
86 description: None,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
93pub struct ProviderCapabilities {
94 pub supports_model_listing: bool,
95 pub supports_streaming: bool,
96 pub supports_websockets: bool,
97 pub supports_tool_calls: bool,
98 pub supports_images: bool,
99 pub supports_history_compaction: bool,
100 pub supports_memory_summarization: bool,
101 pub supports_deferred_tools: bool,
102 pub supports_hosted_tool_search: bool,
103 pub supports_hosted_web_search: bool,
104 pub supports_image_generation: bool,
105 pub supports_reasoning_effort: bool,
106 pub reports_reasoning_tokens: bool,
107 pub reports_thoughts_tokens: bool,
108 pub supports_structured_tool_results: bool,
109}
110
111#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114pub enum WireApi {
115 #[default]
116 Responses,
117 AnthropicMessages,
118 GeminiGenerateContent,
119}
120
121impl Display for WireApi {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 let value = match self {
124 Self::Responses => "responses",
125 Self::AnthropicMessages => "anthropic_messages",
126 Self::GeminiGenerateContent => "gemini_generate_content",
127 };
128 f.write_str(value)
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
134pub struct RetryPolicy {
135 pub max_attempts: u64,
136 pub base_delay: Duration,
137 pub retry_429: bool,
138 pub retry_5xx: bool,
139 pub retry_transport: bool,
140}
141
142impl Default for RetryPolicy {
143 fn default() -> Self {
144 Self {
145 max_attempts: 5,
146 base_delay: Duration::from_millis(200),
147 retry_429: false,
148 retry_5xx: true,
149 retry_transport: true,
150 }
151 }
152}
153
154#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
156pub struct ProviderDefinition {
157 pub descriptor: ProviderDescriptor,
158 #[serde(default)]
159 pub wire_api: WireApi,
160 #[serde(default)]
161 pub auth_scheme: crate::AuthScheme,
162 #[serde(default)]
163 pub capabilities: ProviderCapabilities,
164 pub base_url: Option<String>,
165 #[serde(default)]
166 pub query_params: Option<HashMap<String, String>>,
167 #[serde(default)]
168 pub headers: Option<HashMap<String, String>>,
169 #[serde(default)]
170 pub retry: RetryPolicy,
171 #[serde(default = "default_stream_idle_timeout")]
172 pub stream_idle_timeout: Duration,
173 #[serde(default = "default_websocket_connect_timeout")]
174 pub websocket_connect_timeout: Duration,
175}
176
177fn default_stream_idle_timeout() -> Duration {
178 Duration::from_millis(300_000)
179}
180
181fn default_websocket_connect_timeout() -> Duration {
182 Duration::from_millis(15_000)
183}
184
185impl ProviderDefinition {
186 pub fn new(id: impl Into<ProviderId>) -> Self {
187 Self {
188 descriptor: ProviderDescriptor::new(id),
189 wire_api: WireApi::default(),
190 auth_scheme: crate::AuthScheme::default(),
191 capabilities: ProviderCapabilities {
192 supports_model_listing: true,
193 supports_streaming: true,
194 supports_websockets: false,
195 supports_tool_calls: true,
196 supports_images: true,
197 supports_history_compaction: false,
198 supports_memory_summarization: false,
199 supports_deferred_tools: false,
200 supports_hosted_tool_search: false,
201 supports_hosted_web_search: false,
202 supports_image_generation: false,
203 supports_reasoning_effort: false,
204 reports_reasoning_tokens: false,
205 reports_thoughts_tokens: false,
206 supports_structured_tool_results: false,
207 },
208 base_url: None,
209 query_params: None,
210 headers: None,
211 retry: RetryPolicy::default(),
212 stream_idle_timeout: default_stream_idle_timeout(),
213 websocket_connect_timeout: default_websocket_connect_timeout(),
214 }
215 }
216
217 pub fn descriptor(&self) -> ProviderDescriptor {
218 self.descriptor.clone()
219 }
220
221 pub fn provider_id(&self) -> &ProviderId {
222 &self.descriptor.id
223 }
224
225 pub fn url_for_path(&self, path: &str) -> String {
226 let base = self
227 .base_url
228 .as_deref()
229 .unwrap_or_default()
230 .trim_end_matches('/');
231 let path = path.trim_start_matches('/');
232 let mut url = if path.is_empty() {
233 base.to_string()
234 } else {
235 format!("{base}/{path}")
236 };
237
238 if let Some(params) = &self.query_params
239 && !params.is_empty()
240 {
241 let qs = params
242 .iter()
243 .map(|(key, value)| format!("{key}={value}"))
244 .collect::<Vec<_>>()
245 .join("&");
246 url.push('?');
247 url.push_str(&qs);
248 }
249
250 url
251 }
252
253 pub fn build_headers(
254 &self,
255 credentials: &crate::ProviderCredentials,
256 ) -> Result<HeaderMap, crate::ProviderError> {
257 let mut headers = HeaderMap::new();
258
259 if let Some(configured_headers) = &self.headers {
260 for (name, value) in configured_headers {
261 insert_header(&mut headers, name, value)?;
262 }
263 }
264
265 for (name, value) in &credentials.headers {
266 insert_header(&mut headers, name, value)?;
267 }
268
269 match &self.auth_scheme {
270 crate::AuthScheme::None | crate::AuthScheme::QueryParam { .. } => {}
271 crate::AuthScheme::BearerToken => {
272 let token = required_auth_value(credentials)?;
273 let auth_value =
274 HeaderValue::from_str(&format!("Bearer {token}")).map_err(|error| {
275 crate::ProviderError::InvalidRequest(format!(
276 "invalid bearer token header: {error}"
277 ))
278 })?;
279 headers.insert(header::AUTHORIZATION, auth_value);
280 }
281 crate::AuthScheme::Header { name } => {
282 let token = required_auth_value(credentials)?;
283 insert_header(&mut headers, name, token)?;
284 }
285 }
286
287 Ok(headers)
288 }
289
290 pub fn build_headers_for_session(
291 &self,
292 credentials: &crate::ProviderCredentials,
293 session: Option<&SessionRequestOptions>,
294 fallback_turn_state: Option<&str>,
295 ) -> Result<HeaderMap, crate::ProviderError> {
296 let mut headers = self.build_headers(credentials)?;
297
298 if let Some(turn_state) = session
299 .and_then(|session| session.sticky_turn_state.as_deref())
300 .or(fallback_turn_state)
301 && let Ok(value) = HeaderValue::from_str(turn_state)
302 {
303 headers.insert("x-mentra-turn-state", value.clone());
304 headers.insert("x-codex-turn-state", value);
305 }
306 if let Some(value) = session.and_then(|session| session.turn_metadata.as_deref())
307 && let Ok(value) = HeaderValue::from_str(value)
308 {
309 headers.insert("x-mentra-turn-metadata", value.clone());
310 headers.insert("x-codex-turn-metadata", value);
311 }
312 if let Some(value) = session.and_then(|session| session.session_affinity.as_deref())
313 && let Ok(value) = HeaderValue::from_str(value)
314 {
315 headers.insert("x-mentra-session-affinity", value);
316 }
317 if let Some(prefer_connection_reuse) =
318 session.and_then(|session| session.prefer_connection_reuse)
319 {
320 headers.insert(
321 "x-mentra-connection-reuse",
322 HeaderValue::from_static(if prefer_connection_reuse {
323 "prefer-reuse"
324 } else {
325 "prefer-fresh"
326 }),
327 );
328 }
329 if let Some(value) = session.and_then(|session| session.subagent.as_deref())
330 && let Ok(value) = HeaderValue::from_str(value)
331 {
332 headers.insert("x-openai-subagent", value);
333 }
334 if let Some(extra_headers) = session.map(|session| &session.extra_headers) {
335 for (name, value) in extra_headers {
336 if let (Ok(name), Ok(value)) = (
337 name.parse::<http::HeaderName>(),
338 HeaderValue::from_str(value),
339 ) {
340 headers.insert(name, value);
341 }
342 }
343 }
344
345 Ok(headers)
346 }
347
348 pub fn request_url_with_auth_for_path(
349 &self,
350 path: &str,
351 credentials: &crate::ProviderCredentials,
352 ) -> Result<Url, crate::ProviderError> {
353 let mut url = Url::parse(&self.url_for_path(path))
354 .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
355
356 if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
357 let token = required_auth_value(credentials)?;
358 url.query_pairs_mut().append_pair(name, token);
359 }
360
361 Ok(url)
362 }
363
364 pub fn websocket_url_for_path(&self, path: &str) -> Result<Url, url::ParseError> {
365 let mut url = Url::parse(&self.url_for_path(path))?;
366
367 let scheme = match url.scheme() {
368 "http" => "ws",
369 "https" => "wss",
370 "ws" | "wss" => return Ok(url),
371 _ => return Ok(url),
372 };
373 let _ = url.set_scheme(scheme);
374 Ok(url)
375 }
376
377 pub fn websocket_url_with_auth_for_path(
378 &self,
379 path: &str,
380 credentials: &crate::ProviderCredentials,
381 ) -> Result<Url, crate::ProviderError> {
382 let mut url = self
383 .websocket_url_for_path(path)
384 .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
385
386 if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
387 let token = required_auth_value(credentials)?;
388 url.query_pairs_mut().append_pair(name, token);
389 }
390
391 Ok(url)
392 }
393}
394
395fn insert_header(
396 headers: &mut HeaderMap,
397 name: &str,
398 value: &str,
399) -> Result<(), crate::ProviderError> {
400 let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
401 crate::ProviderError::InvalidRequest(format!(
402 "invalid provider header name {name:?}: {error}"
403 ))
404 })?;
405 let header_value = HeaderValue::from_str(value).map_err(|error| {
406 crate::ProviderError::InvalidRequest(format!(
407 "invalid provider header value for {name:?}: {error}"
408 ))
409 })?;
410 headers.insert(header_name, header_value);
411 Ok(())
412}
413
414fn required_auth_value(
415 credentials: &crate::ProviderCredentials,
416) -> Result<&str, crate::ProviderError> {
417 credentials.bearer_token.as_deref().ok_or_else(|| {
418 crate::ProviderError::InvalidRequest("missing provider auth credential".to_string())
419 })
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn build_headers_applies_bearer_auth_and_static_headers() {
428 let mut definition = ProviderDefinition::new("test");
429 definition.auth_scheme = crate::AuthScheme::BearerToken;
430 definition.headers = Some(HashMap::from([(
431 "x-provider-header".to_string(),
432 "static".to_string(),
433 )]));
434
435 let headers = definition
436 .build_headers(&crate::ProviderCredentials {
437 bearer_token: Some("secret".to_string()),
438 account_id: None,
439 headers: HashMap::from([("x-runtime-header".to_string(), "dynamic".to_string())]),
440 })
441 .expect("headers should build");
442
443 assert_eq!(headers["x-provider-header"], "static");
444 assert_eq!(headers["x-runtime-header"], "dynamic");
445 assert_eq!(headers[header::AUTHORIZATION], "Bearer secret");
446 }
447
448 #[test]
449 fn request_url_with_auth_appends_query_param_auth() {
450 let mut definition = ProviderDefinition::new("test");
451 definition.base_url = Some("https://example.com/v1".to_string());
452 definition.query_params = Some(HashMap::from([(
453 "api-version".to_string(),
454 "2026".to_string(),
455 )]));
456 definition.auth_scheme = crate::AuthScheme::QueryParam {
457 name: "api-key".to_string(),
458 };
459
460 let url = definition
461 .request_url_with_auth_for_path(
462 "responses",
463 &crate::ProviderCredentials {
464 bearer_token: Some("secret".to_string()),
465 account_id: None,
466 headers: HashMap::new(),
467 },
468 )
469 .expect("url should build");
470
471 assert_eq!(
472 url.as_str(),
473 "https://example.com/v1/responses?api-version=2026&api-key=secret"
474 );
475 }
476}