mentra 0.6.0

An agent runtime for tool-using LLM applications
Documentation
use std::{collections::BTreeMap, net::SocketAddr, time::Duration as StdDuration};

use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use rand::RngCore;
use reqwest::StatusCode;
use ring::digest::{SHA256, digest};
use serde::{Deserialize, Serialize};
use time::{Duration, OffsetDateTime};
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    net::TcpListener,
    time::timeout,
};
use url::Url;

pub const DEFAULT_CLIENT_ID: &str = "T19P7LJMcLZgUbhBzA85goHf";
pub const DEFAULT_AUTH_URL: &str = "https://auth.openai.com/oauth/authorize";
pub const DEFAULT_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
pub const DEFAULT_SCOPE: &str = "openid profile email offline_access";

const CALLBACK_PATH: &str = "/callback";
const CALLBACK_TIMEOUT: StdDuration = StdDuration::from_secs(300);

#[derive(Debug, thiserror::Error)]
pub enum OpenAIOAuthError {
    #[error("http transport error: {0}")]
    Transport(#[from] reqwest::Error),
    #[error("failed to parse url: {0}")]
    Url(#[from] url::ParseError),
    #[error("failed to encode or decode json: {0}")]
    Json(#[from] serde_json::Error),
    #[error("failed to bind loopback callback listener: {0}")]
    Io(#[from] std::io::Error),
    #[error("oauth endpoint returned HTTP {status}: {body}")]
    Http { status: StatusCode, body: String },
    #[error("callback timed out waiting for browser redirect")]
    CallbackTimeout,
    #[error("callback did not include an authorization code")]
    MissingCode,
    #[error("callback state mismatch")]
    StateMismatch,
    #[error("oauth endpoint did not return an API key")]
    MissingApiKey,
    #[error("no stored OAuth tokens found")]
    MissingStoredTokens,
    #[error("token store is unsupported on this platform: {0}")]
    UnsupportedStore(&'static str),
    #[error("credential store command `{command}` failed with status {status}: {stderr}")]
    CredentialCommand {
        command: &'static str,
        status: i32,
        stderr: String,
    },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAITokenSet {
    pub access_token: String,
    pub refresh_token: String,
    pub id_token: Option<String>,
    pub api_key: Option<String>,
    pub expires_at: OffsetDateTime,
}

impl OpenAITokenSet {
    pub fn is_expired(&self, refresh_skew: Duration) -> bool {
        self.expires_at <= OffsetDateTime::now_utc() + refresh_skew
    }

    pub fn require_api_key(&self) -> Result<&str, OpenAIOAuthError> {
        self.api_key
            .as_deref()
            .ok_or(OpenAIOAuthError::MissingApiKey)
    }
}

#[derive(Debug, Deserialize)]
struct TokenResponse {
    access_token: String,
    refresh_token: String,
    #[serde(default)]
    id_token: Option<String>,
    #[serde(default)]
    api_key: Option<String>,
    expires_in_seconds: i64,
}

impl TokenResponse {
    fn into_tokens(self) -> OpenAITokenSet {
        OpenAITokenSet {
            access_token: self.access_token,
            refresh_token: self.refresh_token,
            id_token: self.id_token,
            api_key: self.api_key,
            expires_at: OffsetDateTime::now_utc() + Duration::seconds(self.expires_in_seconds),
        }
    }
}

#[derive(Debug, Clone)]
pub struct OpenAIOAuthClient {
    client: reqwest::Client,
    client_id: String,
    auth_url: Url,
    token_url: Url,
}

impl Default for OpenAIOAuthClient {
    fn default() -> Self {
        Self::new(DEFAULT_CLIENT_ID)
    }
}

impl OpenAIOAuthClient {
    pub fn new(client_id: impl Into<String>) -> Self {
        Self {
            client: reqwest::Client::builder()
                .build()
                .expect("Failed to build OpenAI OAuth client"),
            client_id: client_id.into(),
            auth_url: Url::parse(DEFAULT_AUTH_URL).expect("Failed to parse auth url"),
            token_url: Url::parse(DEFAULT_TOKEN_URL).expect("Failed to parse token url"),
        }
    }

    pub async fn start_authorization(&self) -> Result<PendingAuthorization, OpenAIOAuthError> {
        let listener = TcpListener::bind("127.0.0.1:0").await?;
        let redirect_addr = listener.local_addr()?;
        let redirect_uri = loopback_redirect_uri(redirect_addr)?;
        let code_verifier = random_base64_url(32);
        let code_challenge = pkce_s256(&code_verifier);
        let state = random_base64_url(32);

        let mut authorize_url = self.auth_url.clone();
        authorize_url
            .query_pairs_mut()
            .append_pair("response_type", "code")
            .append_pair("code_challenge_method", "S256")
            .append_pair("client_id", &self.client_id)
            .append_pair("redirect_uri", redirect_uri.as_str())
            .append_pair("code_challenge", &code_challenge)
            .append_pair("scope", DEFAULT_SCOPE)
            .append_pair("state", &state);

        Ok(PendingAuthorization {
            authorize_url,
            redirect_uri,
            state,
            code_verifier,
            listener,
        })
    }

    pub async fn exchange_code(
        &self,
        code: &str,
        redirect_uri: &Url,
        code_verifier: &str,
    ) -> Result<OpenAITokenSet, OpenAIOAuthError> {
        let body = self
            .client
            .post(self.token_url.clone())
            .form(&BTreeMap::from([
                ("grant_type", "authorization_code"),
                ("client_id", self.client_id.as_str()),
                ("redirect_uri", redirect_uri.as_str()),
                ("code", code),
                ("code_verifier", code_verifier),
            ]))
            .send()
            .await?;

        parse_token_response(body).await
    }

    pub async fn refresh_tokens(
        &self,
        refresh_token: &str,
    ) -> Result<OpenAITokenSet, OpenAIOAuthError> {
        let body = self
            .client
            .post(self.token_url.clone())
            .form(&BTreeMap::from([
                ("grant_type", "refresh_token"),
                ("client_id", self.client_id.as_str()),
                ("refresh_token", refresh_token),
            ]))
            .send()
            .await?;

        parse_token_response(body).await
    }
}

pub struct PendingAuthorization {
    authorize_url: Url,
    redirect_uri: Url,
    state: String,
    code_verifier: String,
    listener: TcpListener,
}

impl PendingAuthorization {
    pub fn authorize_url(&self) -> &Url {
        &self.authorize_url
    }

    pub fn redirect_uri(&self) -> &Url {
        &self.redirect_uri
    }

    pub async fn complete(
        self,
        client: &OpenAIOAuthClient,
    ) -> Result<OpenAITokenSet, OpenAIOAuthError> {
        let code = timeout(CALLBACK_TIMEOUT, receive_code(self.listener, &self.state))
            .await
            .map_err(|_| OpenAIOAuthError::CallbackTimeout)??;
        client
            .exchange_code(&code, &self.redirect_uri, &self.code_verifier)
            .await
    }
}

async fn parse_token_response(
    response: reqwest::Response,
) -> Result<OpenAITokenSet, OpenAIOAuthError> {
    if !response.status().is_success() {
        return Err(OpenAIOAuthError::Http {
            status: response.status(),
            body: response.text().await.unwrap_or_default(),
        });
    }

    let body = response.json::<TokenResponse>().await?;
    Ok(body.into_tokens())
}

async fn receive_code(
    listener: TcpListener,
    expected_state: &str,
) -> Result<String, OpenAIOAuthError> {
    let (mut stream, _) = listener.accept().await?;
    let mut buffer = [0_u8; 8192];
    let bytes_read = stream.read(&mut buffer).await?;
    let request = String::from_utf8_lossy(&buffer[..bytes_read]);
    let path = request
        .lines()
        .next()
        .and_then(|line| line.split_whitespace().nth(1))
        .unwrap_or("/");

    let callback_url = Url::parse(&format!("http://localhost{path}"))?;
    let params: BTreeMap<_, _> = callback_url.query_pairs().into_owned().collect();

    let response = if params
        .get("state")
        .is_some_and(|state| state == expected_state)
    {
        success_response().to_string()
    } else {
        error_response("State mismatch")
    };
    stream.write_all(response.as_bytes()).await?;
    stream.shutdown().await?;

    if params
        .get("state")
        .is_none_or(|state| state != expected_state)
    {
        return Err(OpenAIOAuthError::StateMismatch);
    }

    params
        .get("code")
        .cloned()
        .ok_or(OpenAIOAuthError::MissingCode)
}

fn loopback_redirect_uri(addr: SocketAddr) -> Result<Url, OpenAIOAuthError> {
    Url::parse(&format!(
        "http://{}:{}{CALLBACK_PATH}",
        addr.ip(),
        addr.port()
    ))
    .map_err(Into::into)
}

fn pkce_s256(verifier: &str) -> String {
    URL_SAFE_NO_PAD.encode(digest(&SHA256, verifier.as_bytes()))
}

fn random_base64_url(len: usize) -> String {
    let mut bytes = vec![0_u8; len];
    rand::rng().fill_bytes(&mut bytes);
    URL_SAFE_NO_PAD.encode(bytes)
}

fn success_response() -> &'static str {
    "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\n\r\n<!doctype html><html><body><h1>Authorization complete</h1><p>You can return to Mentra.</p></body></html>"
}

fn error_response(message: &str) -> String {
    format!(
        "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\n\r\n<!doctype html><html><body><h1>Authorization failed</h1><p>{message}</p></body></html>"
    )
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn token_expiry_uses_refresh_skew() {
        let tokens = OpenAITokenSet {
            access_token: "access".into(),
            refresh_token: "refresh".into(),
            id_token: None,
            api_key: Some("api".into()),
            expires_at: OffsetDateTime::now_utc() + Duration::seconds(30),
        };

        assert!(tokens.is_expired(Duration::seconds(60)));
        assert!(!tokens.is_expired(Duration::seconds(5)));
    }

    #[test]
    fn pkce_challenge_is_url_safe() {
        let challenge = pkce_s256("test-verifier");
        assert!(!challenge.contains('='));
        assert!(!challenge.contains('+'));
        assert!(!challenge.contains('/'));
    }
}