Skip to main content

mentra_provider/
definition.rs

1use http::HeaderMap;
2use http::HeaderName;
3use http::HeaderValue;
4use http::header;
5use serde::Deserialize;
6use serde::Serialize;
7use std::borrow::Cow;
8use std::collections::HashMap;
9use std::fmt::Display;
10use std::time::Duration;
11use strum::Display as StrumDisplay;
12use strum::IntoStaticStr;
13use url::Url;
14
15use crate::request::SessionRequestOptions;
16
17/// Builtin provider families Mentra can construct from presets.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, StrumDisplay, IntoStaticStr)]
19#[strum(serialize_all = "lowercase")]
20pub enum BuiltinProvider {
21    Anthropic,
22    Gemini,
23    OpenAI,
24    OpenRouter,
25    Ollama,
26    LmStudio,
27}
28
29impl From<BuiltinProvider> for ProviderId {
30    fn from(value: BuiltinProvider) -> Self {
31        Self(Cow::Borrowed(value.into()))
32    }
33}
34
35/// Stable identifier for a registered provider implementation.
36#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
37pub struct ProviderId(Cow<'static, str>);
38
39impl ProviderId {
40    pub fn new(id: impl Into<String>) -> Self {
41        Self(Cow::Owned(id.into()))
42    }
43
44    pub fn as_str(&self) -> &str {
45        self.0.as_ref()
46    }
47}
48
49impl Display for ProviderId {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.write_str(self.as_str())
52    }
53}
54
55impl From<&str> for ProviderId {
56    fn from(value: &str) -> Self {
57        Self::new(value)
58    }
59}
60
61impl From<String> for ProviderId {
62    fn from(value: String) -> Self {
63        Self(Cow::Owned(value))
64    }
65}
66
67impl From<&String> for ProviderId {
68    fn from(value: &String) -> Self {
69        Self::new(value.as_str())
70    }
71}
72
73/// Human-facing metadata about a provider.
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct ProviderDescriptor {
76    pub id: ProviderId,
77    pub display_name: Option<String>,
78    pub description: Option<String>,
79}
80
81impl ProviderDescriptor {
82    pub fn new(id: impl Into<ProviderId>) -> Self {
83        Self {
84            id: id.into(),
85            display_name: None,
86            description: None,
87        }
88    }
89}
90
91/// Capabilities advertised by a provider instance.
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
93pub struct ProviderCapabilities {
94    pub supports_model_listing: bool,
95    pub supports_streaming: bool,
96    pub supports_websockets: bool,
97    pub supports_tool_calls: bool,
98    pub supports_images: bool,
99    pub supports_history_compaction: bool,
100    pub supports_memory_summarization: bool,
101    pub supports_deferred_tools: bool,
102    pub supports_hosted_tool_search: bool,
103    pub supports_hosted_web_search: bool,
104    pub supports_image_generation: bool,
105    pub supports_reasoning_effort: bool,
106    pub reports_reasoning_tokens: bool,
107    pub reports_thoughts_tokens: bool,
108    pub supports_structured_tool_results: bool,
109    pub supports_embeddings: bool,
110}
111
112/// Wire protocol supported by a provider.
113#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(rename_all = "lowercase")]
115pub enum WireApi {
116    #[default]
117    Responses,
118    AnthropicMessages,
119    GeminiGenerateContent,
120}
121
122impl Display for WireApi {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        let value = match self {
125            Self::Responses => "responses",
126            Self::AnthropicMessages => "anthropic_messages",
127            Self::GeminiGenerateContent => "gemini_generate_content",
128        };
129        f.write_str(value)
130    }
131}
132
133/// Retry configuration for provider transport calls.
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135pub struct RetryPolicy {
136    pub max_attempts: u64,
137    pub base_delay: Duration,
138    pub retry_429: bool,
139    pub retry_5xx: bool,
140    pub retry_transport: bool,
141}
142
143impl Default for RetryPolicy {
144    fn default() -> Self {
145        Self {
146            max_attempts: 5,
147            base_delay: Duration::from_millis(200),
148            retry_429: false,
149            retry_5xx: true,
150            retry_transport: true,
151        }
152    }
153}
154
155/// Serializable provider definition used by runtime and adapter layers.
156#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct ProviderDefinition {
158    pub descriptor: ProviderDescriptor,
159    #[serde(default)]
160    pub wire_api: WireApi,
161    #[serde(default)]
162    pub auth_scheme: crate::AuthScheme,
163    #[serde(default)]
164    pub capabilities: ProviderCapabilities,
165    pub base_url: Option<String>,
166    #[serde(default)]
167    pub query_params: Option<HashMap<String, String>>,
168    #[serde(default)]
169    pub headers: Option<HashMap<String, String>>,
170    #[serde(default)]
171    pub retry: RetryPolicy,
172    #[serde(default = "default_stream_idle_timeout")]
173    pub stream_idle_timeout: Duration,
174    #[serde(default = "default_websocket_connect_timeout")]
175    pub websocket_connect_timeout: Duration,
176}
177
178fn default_stream_idle_timeout() -> Duration {
179    Duration::from_millis(300_000)
180}
181
182fn default_websocket_connect_timeout() -> Duration {
183    Duration::from_millis(15_000)
184}
185
186impl ProviderDefinition {
187    pub fn new(id: impl Into<ProviderId>) -> Self {
188        Self {
189            descriptor: ProviderDescriptor::new(id),
190            wire_api: WireApi::default(),
191            auth_scheme: crate::AuthScheme::default(),
192            capabilities: ProviderCapabilities {
193                supports_model_listing: true,
194                supports_streaming: true,
195                supports_websockets: false,
196                supports_tool_calls: true,
197                supports_images: true,
198                supports_history_compaction: false,
199                supports_memory_summarization: false,
200                supports_deferred_tools: false,
201                supports_hosted_tool_search: false,
202                supports_hosted_web_search: false,
203                supports_image_generation: false,
204                supports_reasoning_effort: false,
205                reports_reasoning_tokens: false,
206                reports_thoughts_tokens: false,
207                supports_structured_tool_results: false,
208                supports_embeddings: false,
209            },
210            base_url: None,
211            query_params: None,
212            headers: None,
213            retry: RetryPolicy::default(),
214            stream_idle_timeout: default_stream_idle_timeout(),
215            websocket_connect_timeout: default_websocket_connect_timeout(),
216        }
217    }
218
219    pub fn descriptor(&self) -> ProviderDescriptor {
220        self.descriptor.clone()
221    }
222
223    pub fn provider_id(&self) -> &ProviderId {
224        &self.descriptor.id
225    }
226
227    pub fn url_for_path(&self, path: &str) -> String {
228        let base = self
229            .base_url
230            .as_deref()
231            .unwrap_or_default()
232            .trim_end_matches('/');
233        let path = path.trim_start_matches('/');
234        let mut url = if path.is_empty() {
235            base.to_string()
236        } else {
237            format!("{base}/{path}")
238        };
239
240        if let Some(params) = &self.query_params
241            && !params.is_empty()
242        {
243            let qs = params
244                .iter()
245                .map(|(key, value)| format!("{key}={value}"))
246                .collect::<Vec<_>>()
247                .join("&");
248            url.push('?');
249            url.push_str(&qs);
250        }
251
252        url
253    }
254
255    pub fn build_headers(
256        &self,
257        credentials: &crate::ProviderCredentials,
258    ) -> Result<HeaderMap, crate::ProviderError> {
259        let mut headers = HeaderMap::new();
260
261        if let Some(configured_headers) = &self.headers {
262            for (name, value) in configured_headers {
263                insert_header(&mut headers, name, value)?;
264            }
265        }
266
267        for (name, value) in &credentials.headers {
268            insert_header(&mut headers, name, value)?;
269        }
270
271        match &self.auth_scheme {
272            crate::AuthScheme::None | crate::AuthScheme::QueryParam { .. } => {}
273            crate::AuthScheme::BearerToken => {
274                let token = required_auth_value(credentials)?;
275                let auth_value =
276                    HeaderValue::from_str(&format!("Bearer {token}")).map_err(|error| {
277                        crate::ProviderError::InvalidRequest(format!(
278                            "invalid bearer token header: {error}"
279                        ))
280                    })?;
281                headers.insert(header::AUTHORIZATION, auth_value);
282            }
283            crate::AuthScheme::Header { name } => {
284                let token = required_auth_value(credentials)?;
285                insert_header(&mut headers, name, token)?;
286            }
287        }
288
289        Ok(headers)
290    }
291
292    pub fn build_headers_for_session(
293        &self,
294        credentials: &crate::ProviderCredentials,
295        session: Option<&SessionRequestOptions>,
296        fallback_turn_state: Option<&str>,
297    ) -> Result<HeaderMap, crate::ProviderError> {
298        let mut headers = self.build_headers(credentials)?;
299
300        if let Some(turn_state) = session
301            .and_then(|session| session.sticky_turn_state.as_deref())
302            .or(fallback_turn_state)
303            && let Ok(value) = HeaderValue::from_str(turn_state)
304        {
305            headers.insert("x-mentra-turn-state", value.clone());
306            headers.insert("x-codex-turn-state", value);
307        }
308        if let Some(value) = session.and_then(|session| session.turn_metadata.as_deref())
309            && let Ok(value) = HeaderValue::from_str(value)
310        {
311            headers.insert("x-mentra-turn-metadata", value.clone());
312            headers.insert("x-codex-turn-metadata", value);
313        }
314        if let Some(value) = session.and_then(|session| session.session_affinity.as_deref())
315            && let Ok(value) = HeaderValue::from_str(value)
316        {
317            headers.insert("x-mentra-session-affinity", value);
318        }
319        if let Some(prefer_connection_reuse) =
320            session.and_then(|session| session.prefer_connection_reuse)
321        {
322            headers.insert(
323                "x-mentra-connection-reuse",
324                HeaderValue::from_static(if prefer_connection_reuse {
325                    "prefer-reuse"
326                } else {
327                    "prefer-fresh"
328                }),
329            );
330        }
331        if let Some(value) = session.and_then(|session| session.subagent.as_deref())
332            && let Ok(value) = HeaderValue::from_str(value)
333        {
334            headers.insert("x-openai-subagent", value);
335        }
336        if let Some(extra_headers) = session.map(|session| &session.extra_headers) {
337            for (name, value) in extra_headers {
338                if let (Ok(name), Ok(value)) = (
339                    name.parse::<http::HeaderName>(),
340                    HeaderValue::from_str(value),
341                ) {
342                    headers.insert(name, value);
343                }
344            }
345        }
346
347        Ok(headers)
348    }
349
350    pub fn request_url_with_auth_for_path(
351        &self,
352        path: &str,
353        credentials: &crate::ProviderCredentials,
354    ) -> Result<Url, crate::ProviderError> {
355        let mut url = Url::parse(&self.url_for_path(path))
356            .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
357
358        if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
359            let token = required_auth_value(credentials)?;
360            url.query_pairs_mut().append_pair(name, token);
361        }
362
363        Ok(url)
364    }
365
366    pub fn websocket_url_for_path(&self, path: &str) -> Result<Url, url::ParseError> {
367        let mut url = Url::parse(&self.url_for_path(path))?;
368
369        let scheme = match url.scheme() {
370            "http" => "ws",
371            "https" => "wss",
372            "ws" | "wss" => return Ok(url),
373            _ => return Ok(url),
374        };
375        let _ = url.set_scheme(scheme);
376        Ok(url)
377    }
378
379    pub fn websocket_url_with_auth_for_path(
380        &self,
381        path: &str,
382        credentials: &crate::ProviderCredentials,
383    ) -> Result<Url, crate::ProviderError> {
384        let mut url = self
385            .websocket_url_for_path(path)
386            .map_err(|error| crate::ProviderError::InvalidRequest(error.to_string()))?;
387
388        if let crate::AuthScheme::QueryParam { name } = &self.auth_scheme {
389            let token = required_auth_value(credentials)?;
390            url.query_pairs_mut().append_pair(name, token);
391        }
392
393        Ok(url)
394    }
395}
396
397fn insert_header(
398    headers: &mut HeaderMap,
399    name: &str,
400    value: &str,
401) -> Result<(), crate::ProviderError> {
402    let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
403        crate::ProviderError::InvalidRequest(format!(
404            "invalid provider header name {name:?}: {error}"
405        ))
406    })?;
407    let header_value = HeaderValue::from_str(value).map_err(|error| {
408        crate::ProviderError::InvalidRequest(format!(
409            "invalid provider header value for {name:?}: {error}"
410        ))
411    })?;
412    headers.insert(header_name, header_value);
413    Ok(())
414}
415
416fn required_auth_value(
417    credentials: &crate::ProviderCredentials,
418) -> Result<&str, crate::ProviderError> {
419    credentials.bearer_token.as_deref().ok_or_else(|| {
420        crate::ProviderError::InvalidRequest("missing provider auth credential".to_string())
421    })
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn build_headers_applies_bearer_auth_and_static_headers() {
430        let mut definition = ProviderDefinition::new("test");
431        definition.auth_scheme = crate::AuthScheme::BearerToken;
432        definition.headers = Some(HashMap::from([(
433            "x-provider-header".to_string(),
434            "static".to_string(),
435        )]));
436
437        let headers = definition
438            .build_headers(&crate::ProviderCredentials {
439                bearer_token: Some("secret".to_string()),
440                account_id: None,
441                headers: HashMap::from([("x-runtime-header".to_string(), "dynamic".to_string())]),
442            })
443            .expect("headers should build");
444
445        assert_eq!(headers["x-provider-header"], "static");
446        assert_eq!(headers["x-runtime-header"], "dynamic");
447        assert_eq!(headers[header::AUTHORIZATION], "Bearer secret");
448    }
449
450    #[test]
451    fn request_url_with_auth_appends_query_param_auth() {
452        let mut definition = ProviderDefinition::new("test");
453        definition.base_url = Some("https://example.com/v1".to_string());
454        definition.query_params = Some(HashMap::from([(
455            "api-version".to_string(),
456            "2026".to_string(),
457        )]));
458        definition.auth_scheme = crate::AuthScheme::QueryParam {
459            name: "api-key".to_string(),
460        };
461
462        let url = definition
463            .request_url_with_auth_for_path(
464                "responses",
465                &crate::ProviderCredentials {
466                    bearer_token: Some("secret".to_string()),
467                    account_id: None,
468                    headers: HashMap::new(),
469                },
470            )
471            .expect("url should build");
472
473        assert_eq!(
474            url.as_str(),
475            "https://example.com/v1/responses?api-version=2026&api-key=secret"
476        );
477    }
478}