Skip to main content

agentis_pay/
session.rs

1use anyhow::{Context, Result, anyhow, bail};
2
3use agentis_pay_shared::{
4    BipaClient, ClientConfig, CredentialsStore, proto::pb, unix_timestamp_seconds,
5};
6
7#[derive(serde::Deserialize)]
8struct OAuthTokenResponse {
9    access_token: String,
10    refresh_token: Option<String>,
11    expires_in: Option<u64>,
12    #[serde(default)]
13    token_type: Option<String>,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct EnvSession {
18    jwt: Option<String>,
19    refresh_token: Option<String>,
20    jwt_expires_at: Option<i64>,
21}
22
23impl EnvSession {
24    pub fn from_env() -> Result<Option<Self>> {
25        Self::from_lookup(|name| {
26            std::env::var(name)
27                .ok()
28                .map(|value| value.trim().to_string())
29                .filter(|value| !value.is_empty())
30        })
31    }
32
33    fn from_lookup(mut lookup: impl FnMut(&str) -> Option<String>) -> Result<Option<Self>> {
34        let jwt = lookup("AGENTIS_PAY_JWT");
35        let refresh_token = lookup("AGENTIS_PAY_REFRESH_TOKEN");
36        let jwt_expires_at = match lookup("AGENTIS_PAY_JWT_EXPIRES_AT") {
37            Some(value) => Some(
38                value
39                    .parse::<i64>()
40                    .with_context(|| format!("parse AGENTIS_PAY_JWT_EXPIRES_AT: {value}"))?,
41            ),
42            None => None,
43        };
44
45        if jwt.is_none() {
46            if refresh_token.is_some() || jwt_expires_at.is_some() {
47                bail!(
48                    "AGENTIS_PAY_REFRESH_TOKEN and AGENTIS_PAY_JWT_EXPIRES_AT require AGENTIS_PAY_JWT"
49                );
50            }
51            return Ok(None);
52        }
53
54        Ok(Some(Self {
55            jwt,
56            refresh_token,
57            jwt_expires_at,
58        }))
59    }
60
61    pub fn jwt(&self) -> Option<&str> {
62        self.jwt.as_deref().filter(|value| !value.trim().is_empty())
63    }
64
65    pub fn has_jwt(&self) -> bool {
66        self.jwt().is_some()
67    }
68
69    pub fn jwt_expires_at(&self) -> Option<i64> {
70        self.jwt_expires_at
71    }
72
73    pub fn refresh_token(&self) -> Option<&str> {
74        self.refresh_token
75            .as_deref()
76            .filter(|value| !value.trim().is_empty())
77    }
78
79    pub fn clear_session(&mut self) {
80        self.jwt = None;
81        self.refresh_token = None;
82        self.jwt_expires_at = None;
83    }
84
85    pub fn set_session(&mut self, jwt: String, refresh_token: Option<String>, ttl_seconds: i64) {
86        self.jwt = Some(jwt);
87        self.refresh_token = refresh_token.filter(|value| !value.trim().is_empty());
88        let ttl_seconds = ttl_seconds.max(0);
89        self.jwt_expires_at = Some(unix_timestamp_seconds() + ttl_seconds);
90    }
91}
92
93pub async fn ensure_valid_session(
94    config: &ClientConfig,
95    client: &mut BipaClient,
96    credentials: &mut CredentialsStore,
97) -> Result<()> {
98    if !credentials.has_jwt() {
99        bail!("No active session. Run `agentis-pay login` first.");
100    }
101
102    let jwt = credentials
103        .credentials()
104        .jwt
105        .as_deref()
106        .ok_or_else(|| anyhow::anyhow!("No active session. Run `agentis-pay login` first."))?
107        .to_string();
108    client.set_jwt(jwt.clone());
109
110    if !token_expired(credentials) {
111        return Ok(());
112    }
113
114    if let Some(refresh_token) = oauth_refresh_token(credentials) {
115        let refreshed =
116            match refresh_oauth_access_token(config.oauth_endpoint(), refresh_token).await {
117                Ok(response) => response,
118                Err(error) => {
119                    return invalidate(
120                        credentials,
121                        format!("session expired and oauth refresh failed: {error}"),
122                    );
123                }
124            };
125
126        if refreshed.access_token.trim().is_empty() {
127            return invalidate(credentials, "oauth refresh returned an empty access token");
128        }
129
130        if let Err(error) = validate_bearer_token_type(refreshed.token_type.as_deref()) {
131            credentials.clear_session();
132            credentials.save()?;
133            return Err(error);
134        }
135
136        let ttl_seconds = i64::try_from(refreshed.expires_in.unwrap_or(3600))
137            .context("oauth refresh expires_in exceeds i64")?;
138        credentials.set_session(
139            refreshed.access_token.clone(),
140            refreshed.refresh_token,
141            ttl_seconds,
142        );
143        client.set_jwt(refreshed.access_token);
144        credentials.save()?;
145        return Ok(());
146    }
147
148    let response = match client.refresh_auth().await {
149        Ok(response) => response,
150        Err(error) => {
151            return invalidate(
152                credentials,
153                format!("session expired and refresh failed: {error}"),
154            );
155        }
156    };
157
158    let refreshed = match response.outcome {
159        Some(pb::refresh_auth_response::Outcome::Refreshed(r)) => r,
160        Some(pb::refresh_auth_response::Outcome::Denied(_)) => {
161            return invalidate(
162                credentials,
163                "session expired and refresh was denied by the server",
164            );
165        }
166        None => {
167            return invalidate(credentials, "refresh returned an empty response");
168        }
169    };
170
171    if refreshed.token.trim().is_empty() {
172        return invalidate(credentials, "refresh returned an empty token");
173    }
174
175    let refresh_token = credentials.credentials().refresh_token.clone();
176    let ttl_seconds = i64::from(refreshed.refresh_in);
177    credentials.set_session(refreshed.token.clone(), refresh_token, ttl_seconds);
178    client.set_jwt(refreshed.token);
179    credentials.save()?;
180    Ok(())
181}
182
183pub async fn ensure_valid_env_session(
184    config: &ClientConfig,
185    client: &mut BipaClient,
186    session: &mut EnvSession,
187) -> Result<()> {
188    if !session.has_jwt() {
189        bail!("No active session. Set AGENTIS_PAY_JWT for MCP auth.");
190    }
191
192    let jwt = session
193        .jwt()
194        .ok_or_else(|| anyhow!("No active session. Set AGENTIS_PAY_JWT for MCP auth."))?
195        .to_string();
196    client.set_jwt(jwt);
197
198    if !token_expired_at(session.jwt_expires_at()) {
199        return Ok(());
200    }
201
202    if let Some(refresh_token) = session.refresh_token() {
203        let refreshed =
204            match refresh_oauth_access_token(config.oauth_endpoint(), refresh_token).await {
205                Ok(response) => response,
206                Err(error) => {
207                    return invalidate_env_session(
208                        session,
209                        format!("session expired and oauth refresh failed: {error}"),
210                    );
211                }
212            };
213
214        if refreshed.access_token.trim().is_empty() {
215            return invalidate_env_session(session, "oauth refresh returned an empty access token");
216        }
217
218        if let Err(error) = validate_bearer_token_type(refreshed.token_type.as_deref()) {
219            session.clear_session();
220            return Err(error);
221        }
222
223        let ttl_seconds = i64::try_from(refreshed.expires_in.unwrap_or(3600))
224            .context("oauth refresh expires_in exceeds i64")?;
225        session.set_session(
226            refreshed.access_token.clone(),
227            refreshed.refresh_token,
228            ttl_seconds,
229        );
230        client.set_jwt(refreshed.access_token);
231        return Ok(());
232    }
233
234    let response = match client.refresh_auth().await {
235        Ok(response) => response,
236        Err(error) => {
237            return invalidate_env_session(
238                session,
239                format!("session expired and refresh failed: {error}"),
240            );
241        }
242    };
243
244    let refreshed = match response.outcome {
245        Some(pb::refresh_auth_response::Outcome::Refreshed(r)) => r,
246        Some(pb::refresh_auth_response::Outcome::Denied(_)) => {
247            return invalidate_env_session(
248                session,
249                "session expired and refresh was denied by the server",
250            );
251        }
252        None => {
253            return invalidate_env_session(session, "refresh returned an empty response");
254        }
255    };
256
257    if refreshed.token.trim().is_empty() {
258        return invalidate_env_session(session, "refresh returned an empty token");
259    }
260
261    let refresh_token = session.refresh_token.clone();
262    let ttl_seconds = i64::from(refreshed.refresh_in);
263    session.set_session(refreshed.token.clone(), refresh_token, ttl_seconds);
264    client.set_jwt(refreshed.token);
265    Ok(())
266}
267
268/// Clear the session and bail with an error message.
269fn invalidate(credentials: &mut CredentialsStore, msg: impl std::fmt::Display) -> Result<()> {
270    credentials.clear_session();
271    credentials.save()?;
272    bail!("{msg}");
273}
274
275fn invalidate_env_session(session: &mut EnvSession, msg: impl std::fmt::Display) -> Result<()> {
276    session.clear_session();
277    bail!("{msg}");
278}
279
280fn oauth_refresh_token(credentials: &CredentialsStore) -> Option<&str> {
281    credentials
282        .oauth_client_id()
283        .zip(
284            credentials
285                .credentials()
286                .refresh_token
287                .as_deref()
288                .filter(|value| !value.trim().is_empty()),
289        )
290        .map(|(_, refresh_token)| refresh_token)
291}
292
293pub fn validate_bearer_token_type(token_type: Option<&str>) -> Result<&str> {
294    let token_type =
295        token_type.ok_or_else(|| anyhow::anyhow!("oauth token response missing token_type"))?;
296    if !token_type.eq_ignore_ascii_case("Bearer") {
297        bail!("unsupported oauth token_type: {token_type}");
298    }
299    Ok(token_type)
300}
301
302async fn refresh_oauth_access_token(
303    oauth_endpoint: &str,
304    refresh_token: &str,
305) -> Result<OAuthTokenResponse> {
306    let response = reqwest::Client::new()
307        .post(format!("{oauth_endpoint}/oauth/token"))
308        .form(&[
309            ("grant_type", "refresh_token"),
310            ("refresh_token", refresh_token),
311        ])
312        .send()
313        .await
314        .context("POST /oauth/token")?;
315
316    if !response.status().is_success() {
317        let status = response.status();
318        let text = response.text().await.unwrap_or_default();
319        bail!("oauth token refresh failed ({status}): {text}");
320    }
321
322    response
323        .json()
324        .await
325        .context("parse oauth token refresh response")
326}
327
328fn token_expired(credentials: &CredentialsStore) -> bool {
329    token_expired_at(credentials.jwt_expires_at())
330}
331
332fn token_expired_at(expires_at: Option<i64>) -> bool {
333    match expires_at {
334        None => true,
335        Some(expires_at) => {
336            let now = unix_timestamp_seconds();
337            now.saturating_add(SESSION_REFRESH_SKEW_SECONDS) >= expires_at
338        }
339    }
340}
341
342const SESSION_REFRESH_SKEW_SECONDS: i64 = 60;
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn env_session_from_lookup_parses_expected_fields() {
350        let session = EnvSession::from_lookup(|name| match name {
351            "AGENTIS_PAY_JWT" => Some("jwt-token".to_string()),
352            "AGENTIS_PAY_REFRESH_TOKEN" => Some("refresh-token".to_string()),
353            "AGENTIS_PAY_JWT_EXPIRES_AT" => Some("123".to_string()),
354            _ => None,
355        })
356        .expect("parse env session")
357        .expect("env session must exist");
358
359        assert_eq!(session.jwt(), Some("jwt-token"));
360        assert_eq!(session.refresh_token(), Some("refresh-token"));
361        assert_eq!(session.jwt_expires_at(), Some(123));
362    }
363
364    #[test]
365    fn env_session_rejects_refresh_without_jwt() {
366        let error = EnvSession::from_lookup(|name| match name {
367            "AGENTIS_PAY_REFRESH_TOKEN" => Some("refresh-token".to_string()),
368            _ => None,
369        })
370        .expect_err("missing jwt must fail");
371
372        assert!(error.to_string().contains("require AGENTIS_PAY_JWT"));
373    }
374}