Skip to main content

mentra_provider/
definition.rs

1use http::HeaderMap;
2use http::HeaderName;
3use http::HeaderValue;
4use http::header;
5use serde::Deserialize;
6use serde::Serialize;
7use std::borrow::Cow;
8use std::collections::HashMap;
9use std::fmt::Display;
10use std::time::Duration;
11use strum::Display as StrumDisplay;
12use strum::IntoStaticStr;
13use url::Url;
14
15use crate::request::SessionRequestOptions;
16
17/// Builtin provider families Mentra can construct from presets.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, StrumDisplay, IntoStaticStr)]
19#[strum(serialize_all = "lowercase")]
20pub enum BuiltinProvider {
21    Anthropic,
22    Gemini,
23    OpenAI,
24    OpenRouter,
25    Ollama,
26    LmStudio,
27}
28
29impl From<BuiltinProvider> for ProviderId {
30    fn from(value: BuiltinProvider) -> Self {
31        Self(Cow::Borrowed(value.into()))
32    }
33}
34
35/// Stable identifier for a registered provider implementation.
36#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
37pub struct ProviderId(Cow<'static, str>);
38
39impl ProviderId {
40    pub fn new(id: impl Into<String>) -> Self {
41        Self(Cow::Owned(id.into()))
42    }
43
44    pub fn as_str(&self) -> &str {
45        self.0.as_ref()
46    }
47}
48
49impl Display for ProviderId {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.write_str(self.as_str())
52    }
53}
54
55impl From<&str> for ProviderId {
56    fn from(value: &str) -> Self {
57        Self::new(value)
58    }
59}
60
61impl From<String> for ProviderId {
62    fn from(value: String) -> Self {
63        Self(Cow::Owned(value))
64    }
65}
66
67impl From<&String> for ProviderId {
68    fn from(value: &String) -> Self {
69        Self::new(value.as_str())
70    }
71}
72
73/// Human-facing metadata about a provider.
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct ProviderDescriptor {
76    pub id: ProviderId,
77    pub display_name: Option<String>,
78    pub description: Option<String>,
79}
80
81impl ProviderDescriptor {
82    pub fn new(id: impl Into<ProviderId>) -> Self {
83        Self {
84            id: id.into(),
85            display_name: None,
86            description: None,
87        }
88    }
89}
90
91/// Capabilities advertised by a provider instance.
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
93pub struct ProviderCapabilities {
94    pub supports_model_listing: bool,
95    pub supports_streaming: bool,
96    pub supports_websockets: bool,
97    pub supports_tool_calls: bool,
98    pub supports_images: bool,
99    pub supports_history_compaction: bool,
100    pub supports_memory_summarization: bool,
101    pub supports_deferred_tools: bool,
102    pub supports_hosted_tool_search: bool,
103    pub supports_hosted_web_search: bool,
104    pub supports_image_generation: bool,
105    pub supports_reasoning_effort: bool,
106    pub reports_reasoning_tokens: bool,
107    pub reports_thoughts_tokens: bool,
108    pub supports_structured_tool_results: bool,
109}
110
111/// Wire protocol supported by a provider.
112#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114pub enum WireApi {
115    #[default]
116    Responses,
117    AnthropicMessages,
118    GeminiGenerateContent,
119}
120
121impl Display for WireApi {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        let value = match self {
124            Self::Responses => "responses",
125            Self::AnthropicMessages => "anthropic_messages",
126            Self::GeminiGenerateContent => "gemini_generate_content",
127        };
128        f.write_str(value)
129    }
130}
131
132/// Retry configuration for provider transport calls.
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
134pub struct RetryPolicy {
135    pub max_attempts: u64,
136    pub base_delay: Duration,
137    pub retry_429: bool,
138    pub retry_5xx: bool,
139    pub retry_transport: bool,
140}
141
142impl Default for RetryPolicy {
143    fn default() -> Self {
144        Self {
145            max_attempts: 5,
146            base_delay: Duration::from_millis(200),
147            retry_429: false,
148            retry_5xx: true,
149            retry_transport: true,
150        }
151    }
152}
153
154/// Serializable provider definition used by runtime and adapter layers.
155#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
156pub struct ProviderDefinition {
157    pub descriptor: ProviderDescriptor,
158    #[serde(default)]
159    pub wire_api: WireApi,
160    #[serde(default)]
161    pub auth_scheme: crate::AuthScheme,
162    #[serde(default)]
163    pub capabilities: ProviderCapabilities,
164    pub base_url: Option<String>,
165    #[serde(default)]
166    pub query_params: Option<HashMap<String, String>>,
167    #[serde(default)]
168    pub headers: Option<HashMap<String, String>>,
169    #[serde(default)]
170    pub retry: RetryPolicy,
171    #[serde(default = "default_stream_idle_timeout")]
172    pub stream_idle_timeout: Duration,
173    #[serde(default = "default_websocket_connect_timeout")]
174    pub websocket_connect_timeout: Duration,
175}
176
177fn default_stream_idle_timeout() -> Duration {
178    Duration::from_millis(300_000)
179}
180
181fn default_websocket_connect_timeout() -> Duration {
182    Duration::from_millis(15_000)
183}
184
185impl ProviderDefinition {
186    pub fn new(id: impl Into<ProviderId>) -> Self {
187        Self {
188            descriptor: ProviderDescriptor::new(id),
189            wire_api: WireApi::default(),
190            auth_scheme: crate::AuthScheme::default(),
191            capabilities: ProviderCapabilities {
192                supports_model_listing: true,
193                supports_streaming: true,
194                supports_websockets: false,
195                supports_tool_calls: true,
196                supports_images: true,
197                supports_history_compaction: false,
198                supports_memory_summarization: false,
199                supports_deferred_tools: false,
200                supports_hosted_tool_search: false,
201                supports_hosted_web_search: false,
202                supports_image_generation: false,
203                supports_reasoning_effort: false,
204                reports_reasoning_tokens: false,
205                reports_thoughts_tokens: false,
206                supports_structured_tool_results: false,
207            },
208            base_url: None,
209            query_params: None,
210            headers: None,
211            retry: RetryPolicy::default(),
212            stream_idle_timeout: default_stream_idle_timeout(),
213            websocket_connect_timeout: default_websocket_connect_timeout(),
214        }
215    }
216
217    pub fn descriptor(&self) -> ProviderDescriptor {
218        self.descriptor.clone()
219    }
220
221    pub fn provider_id(&self) -> &ProviderId {
222        &self.descriptor.id
223    }
224
225    pub fn url_for_path(&self, path: &str) -> String {
226        let base = self
227            .base_url
228            .as_deref()
229            .unwrap_or_default()
230            .trim_end_matches('/');
231        let path = path.trim_start_matches('/');
232        let mut url = if path.is_empty() {
233            base.to_string()
234        } else {
235            format!("{base}/{path}")
236        };
237
238        if let Some(params) = &self.query_params
239            && !params.is_empty()
240        {
241            let qs = params
242                .iter()
243                .map(|(key, value)| format!("{key}={value}"))
244                .collect::<Vec<_>>()
245                .join("&");
246            url.push('?');
247            url.push_str(&qs);
248        }
249
250        url
251    }
252
253    pub fn build_headers(
254        &self,
255        credentials: &crate::ProviderCredentials,
256    ) -> Result<HeaderMap, crate::ProviderError> {
257        let mut headers = HeaderMap::new();
258
259        if let Some(configured_headers) = &self.headers {
260            for (name, value) in configured_headers {
261                insert_header(&mut headers, name, value)?;
262            }
263        }
264
265        for (name, value) in &credentials.headers {
266            insert_header(&mut headers, name, value)?;
267        }
268
269        match &self.auth_scheme {
270            crate::AuthScheme::None | crate::AuthScheme::QueryParam { .. } => {}
271            crate::AuthScheme::BearerToken => {
272                let token = required_auth_value(credentials)?;
273                let auth_value =
274                    HeaderValue::from_str(&format!("Bearer {token}")).map_err(|error| {
275                        crate::ProviderError::InvalidRequest(format!(
276                            "invalid bearer token header: {error}"
277                        ))
278                    })?;
279                headers.insert(header::AUTHORIZATION, auth_value);
280            }
281            crate::AuthScheme::Header { name } => {
282                let token = required_auth_value(credentials)?;
283                insert_header(&mut headers, name, token)?;
284            }
285        }
286
287        Ok(headers)
288    }
289
290    pub fn build_headers_for_session(
291        &self,
292        credentials: &crate::ProviderCredentials,
293        session: Option<&SessionRequestOptions>,
294        fallback_turn_state: Option<&str>,
295    ) -> Result<HeaderMap, crate::ProviderError> {
296        let mut headers = self.build_headers(credentials)?;
297
298        if let Some(turn_state) = session
299            .and_then(|session| session.sticky_turn_state.as_deref())
300            .or(fallback_turn_state)
301            && let Ok(value) = HeaderValue::from_str(turn_state)
302        {
303            headers.insert("x-mentra-turn-state", value.clone());
304            headers.insert("x-codex-turn-state", value);
305        }
306        if let Some(value) = session.and_then(|session| session.turn_metadata.as_deref())
307            && let Ok(value) = HeaderValue::from_str(value)
308        {
309            headers.insert("x-mentra-turn-metadata", value.clone());
310            headers.insert("x-codex-turn-metadata", value);
311        }
312        if let Some(value) = session.and_then(|session| session.session_affinity.as_deref())
313            && let Ok(value) = HeaderValue::from_str(value)
314        {
315            headers.insert("x-mentra-session-affinity", value);
316        }
317        if let Some(prefer_connection_reuse) =
318            session.and_then(|session| session.prefer_connection_reuse)
319        {
320            headers.insert(
321                "x-mentra-connection-reuse",
322                HeaderValue::from_static(if prefer_connection_reuse {
323                    "prefer-reuse"
324                } else {
325                    "prefer-fresh"
326                }),
327            );
328        }
329        if let Some(value) = session.and_then(|session| session.subagent.as_deref())
330            && let Ok(value) = HeaderValue::from_str(value)
331        {
332            headers.insert("x-openai-subagent", value);
333        }
334        if let Some(extra_headers) = session.map(|session| &session.extra_headers) {
335            for (name, value) in extra_headers {
336                if let (Ok(name), Ok(value)) = (
337                    name.parse::<http::HeaderName>(),
338                    HeaderValue::from_str(value),
339                ) {
340                    headers.insert(name, value);
341                }
342            }
343        }
344
345        Ok(headers)
346    }
347
348    pub fn request_url_with_auth_for_path(
349        &self,
350        path: &str,
351        credentials: &crate::ProviderCredentials,
352    ) -> Result<Url, crate::ProviderError> {
353        let mut url = Url::parse(&self.url_for_path(path))
354            .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
355
356        if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
357            let token = required_auth_value(credentials)?;
358            url.query_pairs_mut().append_pair(name, token);
359        }
360
361        Ok(url)
362    }
363
364    pub fn websocket_url_for_path(&self, path: &str) -> Result<Url, url::ParseError> {
365        let mut url = Url::parse(&self.url_for_path(path))?;
366
367        let scheme = match url.scheme() {
368            "http" => "ws",
369            "https" => "wss",
370            "ws" | "wss" => return Ok(url),
371            _ => return Ok(url),
372        };
373        let _ = url.set_scheme(scheme);
374        Ok(url)
375    }
376
377    pub fn websocket_url_with_auth_for_path(
378        &self,
379        path: &str,
380        credentials: &crate::ProviderCredentials,
381    ) -> Result<Url, crate::ProviderError> {
382        let mut url = self
383            .websocket_url_for_path(path)
384            .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
385
386        if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
387            let token = required_auth_value(credentials)?;
388            url.query_pairs_mut().append_pair(name, token);
389        }
390
391        Ok(url)
392    }
393}
394
395fn insert_header(
396    headers: &mut HeaderMap,
397    name: &str,
398    value: &str,
399) -> Result<(), crate::ProviderError> {
400    let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
401        crate::ProviderError::InvalidRequest(format!(
402            "invalid provider header name {name:?}: {error}"
403        ))
404    })?;
405    let header_value = HeaderValue::from_str(value).map_err(|error| {
406        crate::ProviderError::InvalidRequest(format!(
407            "invalid provider header value for {name:?}: {error}"
408        ))
409    })?;
410    headers.insert(header_name, header_value);
411    Ok(())
412}
413
414fn required_auth_value(
415    credentials: &crate::ProviderCredentials,
416) -> Result<&str, crate::ProviderError> {
417    credentials.bearer_token.as_deref().ok_or_else(|| {
418        crate::ProviderError::InvalidRequest("missing provider auth credential".to_string())
419    })
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn build_headers_applies_bearer_auth_and_static_headers() {
428        let mut definition = ProviderDefinition::new("test");
429        definition.auth_scheme = crate::AuthScheme::BearerToken;
430        definition.headers = Some(HashMap::from([(
431            "x-provider-header".to_string(),
432            "static".to_string(),
433        )]));
434
435        let headers = definition
436            .build_headers(&crate::ProviderCredentials {
437                bearer_token: Some("secret".to_string()),
438                account_id: None,
439                headers: HashMap::from([("x-runtime-header".to_string(), "dynamic".to_string())]),
440            })
441            .expect("headers should build");
442
443        assert_eq!(headers["x-provider-header"], "static");
444        assert_eq!(headers["x-runtime-header"], "dynamic");
445        assert_eq!(headers[header::AUTHORIZATION], "Bearer secret");
446    }
447
448    #[test]
449    fn request_url_with_auth_appends_query_param_auth() {
450        let mut definition = ProviderDefinition::new("test");
451        definition.base_url = Some("https://example.com/v1".to_string());
452        definition.query_params = Some(HashMap::from([(
453            "api-version".to_string(),
454            "2026".to_string(),
455        )]));
456        definition.auth_scheme = crate::AuthScheme::QueryParam {
457            name: "api-key".to_string(),
458        };
459
460        let url = definition
461            .request_url_with_auth_for_path(
462                "responses",
463                &crate::ProviderCredentials {
464                    bearer_token: Some("secret".to_string()),
465                    account_id: None,
466                    headers: HashMap::new(),
467                },
468            )
469            .expect("url should build");
470
471        assert_eq!(
472            url.as_str(),
473            "https://example.com/v1/responses?api-version=2026&api-key=secret"
474        );
475    }
476}