1use anyhow::{Context, Result, anyhow, bail};
2
3use agentis_pay_shared::{
4 BipaClient, ClientConfig, CredentialsStore, proto::pb, unix_timestamp_seconds,
5};
6
7#[derive(serde::Deserialize)]
8struct OAuthTokenResponse {
9 access_token: String,
10 refresh_token: Option<String>,
11 expires_in: Option<u64>,
12 #[serde(default)]
13 token_type: Option<String>,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct EnvSession {
18 jwt: Option<String>,
19 refresh_token: Option<String>,
20 jwt_expires_at: Option<i64>,
21}
22
23impl EnvSession {
24 pub fn from_env() -> Result<Option<Self>> {
25 Self::from_lookup(|name| {
26 std::env::var(name)
27 .ok()
28 .map(|value| value.trim().to_string())
29 .filter(|value| !value.is_empty())
30 })
31 }
32
33 fn from_lookup(mut lookup: impl FnMut(&str) -> Option<String>) -> Result<Option<Self>> {
34 let jwt = lookup("AGENTIS_PAY_JWT");
35 let refresh_token = lookup("AGENTIS_PAY_REFRESH_TOKEN");
36 let jwt_expires_at = match lookup("AGENTIS_PAY_JWT_EXPIRES_AT") {
37 Some(value) => Some(
38 value
39 .parse::<i64>()
40 .with_context(|| format!("parse AGENTIS_PAY_JWT_EXPIRES_AT: {value}"))?,
41 ),
42 None => None,
43 };
44
45 if jwt.is_none() {
46 if refresh_token.is_some() || jwt_expires_at.is_some() {
47 bail!(
48 "AGENTIS_PAY_REFRESH_TOKEN and AGENTIS_PAY_JWT_EXPIRES_AT require AGENTIS_PAY_JWT"
49 );
50 }
51 return Ok(None);
52 }
53
54 Ok(Some(Self {
55 jwt,
56 refresh_token,
57 jwt_expires_at,
58 }))
59 }
60
61 pub fn jwt(&self) -> Option<&str> {
62 self.jwt.as_deref().filter(|value| !value.trim().is_empty())
63 }
64
65 pub fn has_jwt(&self) -> bool {
66 self.jwt().is_some()
67 }
68
69 pub fn jwt_expires_at(&self) -> Option<i64> {
70 self.jwt_expires_at
71 }
72
73 pub fn refresh_token(&self) -> Option<&str> {
74 self.refresh_token
75 .as_deref()
76 .filter(|value| !value.trim().is_empty())
77 }
78
79 pub fn clear_session(&mut self) {
80 self.jwt = None;
81 self.refresh_token = None;
82 self.jwt_expires_at = None;
83 }
84
85 pub fn set_session(&mut self, jwt: String, refresh_token: Option<String>, ttl_seconds: i64) {
86 self.jwt = Some(jwt);
87 self.refresh_token = refresh_token.filter(|value| !value.trim().is_empty());
88 let ttl_seconds = ttl_seconds.max(0);
89 self.jwt_expires_at = Some(unix_timestamp_seconds() + ttl_seconds);
90 }
91}
92
93pub async fn ensure_valid_session(
94 config: &ClientConfig,
95 client: &mut BipaClient,
96 credentials: &mut CredentialsStore,
97) -> Result<()> {
98 if !credentials.has_jwt() {
99 bail!("No active session. Run `agentis-pay login` first.");
100 }
101
102 let jwt = credentials
103 .credentials()
104 .jwt
105 .as_deref()
106 .ok_or_else(|| anyhow::anyhow!("No active session. Run `agentis-pay login` first."))?
107 .to_string();
108 client.set_jwt(jwt.clone());
109
110 if !token_expired(credentials) {
111 return Ok(());
112 }
113
114 if let Some(refresh_token) = oauth_refresh_token(credentials) {
115 let refreshed =
116 match refresh_oauth_access_token(config.oauth_endpoint(), refresh_token).await {
117 Ok(response) => response,
118 Err(error) => {
119 return invalidate(
120 credentials,
121 format!("session expired and oauth refresh failed: {error}"),
122 );
123 }
124 };
125
126 if refreshed.access_token.trim().is_empty() {
127 return invalidate(credentials, "oauth refresh returned an empty access token");
128 }
129
130 if let Err(error) = validate_bearer_token_type(refreshed.token_type.as_deref()) {
131 credentials.clear_session();
132 credentials.save()?;
133 return Err(error);
134 }
135
136 let ttl_seconds = i64::try_from(refreshed.expires_in.unwrap_or(3600))
137 .context("oauth refresh expires_in exceeds i64")?;
138 credentials.set_session(
139 refreshed.access_token.clone(),
140 refreshed.refresh_token,
141 ttl_seconds,
142 );
143 client.set_jwt(refreshed.access_token);
144 credentials.save()?;
145 return Ok(());
146 }
147
148 let response = match client.refresh_auth().await {
149 Ok(response) => response,
150 Err(error) => {
151 return invalidate(
152 credentials,
153 format!("session expired and refresh failed: {error}"),
154 );
155 }
156 };
157
158 let refreshed = match response.outcome {
159 Some(pb::refresh_auth_response::Outcome::Refreshed(r)) => r,
160 Some(pb::refresh_auth_response::Outcome::Denied(_)) => {
161 return invalidate(
162 credentials,
163 "session expired and refresh was denied by the server",
164 );
165 }
166 None => {
167 return invalidate(credentials, "refresh returned an empty response");
168 }
169 };
170
171 if refreshed.token.trim().is_empty() {
172 return invalidate(credentials, "refresh returned an empty token");
173 }
174
175 let refresh_token = credentials.credentials().refresh_token.clone();
176 let ttl_seconds = i64::from(refreshed.refresh_in);
177 credentials.set_session(refreshed.token.clone(), refresh_token, ttl_seconds);
178 client.set_jwt(refreshed.token);
179 credentials.save()?;
180 Ok(())
181}
182
183pub async fn ensure_valid_env_session(
184 config: &ClientConfig,
185 client: &mut BipaClient,
186 session: &mut EnvSession,
187) -> Result<()> {
188 if !session.has_jwt() {
189 bail!("No active session. Set AGENTIS_PAY_JWT for MCP auth.");
190 }
191
192 let jwt = session
193 .jwt()
194 .ok_or_else(|| anyhow!("No active session. Set AGENTIS_PAY_JWT for MCP auth."))?
195 .to_string();
196 client.set_jwt(jwt);
197
198 if !token_expired_at(session.jwt_expires_at()) {
199 return Ok(());
200 }
201
202 if let Some(refresh_token) = session.refresh_token() {
203 let refreshed =
204 match refresh_oauth_access_token(config.oauth_endpoint(), refresh_token).await {
205 Ok(response) => response,
206 Err(error) => {
207 return invalidate_env_session(
208 session,
209 format!("session expired and oauth refresh failed: {error}"),
210 );
211 }
212 };
213
214 if refreshed.access_token.trim().is_empty() {
215 return invalidate_env_session(session, "oauth refresh returned an empty access token");
216 }
217
218 if let Err(error) = validate_bearer_token_type(refreshed.token_type.as_deref()) {
219 session.clear_session();
220 return Err(error);
221 }
222
223 let ttl_seconds = i64::try_from(refreshed.expires_in.unwrap_or(3600))
224 .context("oauth refresh expires_in exceeds i64")?;
225 session.set_session(
226 refreshed.access_token.clone(),
227 refreshed.refresh_token,
228 ttl_seconds,
229 );
230 client.set_jwt(refreshed.access_token);
231 return Ok(());
232 }
233
234 let response = match client.refresh_auth().await {
235 Ok(response) => response,
236 Err(error) => {
237 return invalidate_env_session(
238 session,
239 format!("session expired and refresh failed: {error}"),
240 );
241 }
242 };
243
244 let refreshed = match response.outcome {
245 Some(pb::refresh_auth_response::Outcome::Refreshed(r)) => r,
246 Some(pb::refresh_auth_response::Outcome::Denied(_)) => {
247 return invalidate_env_session(
248 session,
249 "session expired and refresh was denied by the server",
250 );
251 }
252 None => {
253 return invalidate_env_session(session, "refresh returned an empty response");
254 }
255 };
256
257 if refreshed.token.trim().is_empty() {
258 return invalidate_env_session(session, "refresh returned an empty token");
259 }
260
261 let refresh_token = session.refresh_token.clone();
262 let ttl_seconds = i64::from(refreshed.refresh_in);
263 session.set_session(refreshed.token.clone(), refresh_token, ttl_seconds);
264 client.set_jwt(refreshed.token);
265 Ok(())
266}
267
268fn invalidate(credentials: &mut CredentialsStore, msg: impl std::fmt::Display) -> Result<()> {
270 credentials.clear_session();
271 credentials.save()?;
272 bail!("{msg}");
273}
274
275fn invalidate_env_session(session: &mut EnvSession, msg: impl std::fmt::Display) -> Result<()> {
276 session.clear_session();
277 bail!("{msg}");
278}
279
280fn oauth_refresh_token(credentials: &CredentialsStore) -> Option<&str> {
281 credentials
282 .oauth_client_id()
283 .zip(
284 credentials
285 .credentials()
286 .refresh_token
287 .as_deref()
288 .filter(|value| !value.trim().is_empty()),
289 )
290 .map(|(_, refresh_token)| refresh_token)
291}
292
293pub fn validate_bearer_token_type(token_type: Option<&str>) -> Result<&str> {
294 let token_type =
295 token_type.ok_or_else(|| anyhow::anyhow!("oauth token response missing token_type"))?;
296 if !token_type.eq_ignore_ascii_case("Bearer") {
297 bail!("unsupported oauth token_type: {token_type}");
298 }
299 Ok(token_type)
300}
301
302async fn refresh_oauth_access_token(
303 oauth_endpoint: &str,
304 refresh_token: &str,
305) -> Result<OAuthTokenResponse> {
306 let response = reqwest::Client::new()
307 .post(format!("{oauth_endpoint}/oauth/token"))
308 .form(&[
309 ("grant_type", "refresh_token"),
310 ("refresh_token", refresh_token),
311 ])
312 .send()
313 .await
314 .context("POST /oauth/token")?;
315
316 if !response.status().is_success() {
317 let status = response.status();
318 let text = response.text().await.unwrap_or_default();
319 bail!("oauth token refresh failed ({status}): {text}");
320 }
321
322 response
323 .json()
324 .await
325 .context("parse oauth token refresh response")
326}
327
328fn token_expired(credentials: &CredentialsStore) -> bool {
329 token_expired_at(credentials.jwt_expires_at())
330}
331
332fn token_expired_at(expires_at: Option<i64>) -> bool {
333 match expires_at {
334 None => true,
335 Some(expires_at) => {
336 let now = unix_timestamp_seconds();
337 now.saturating_add(SESSION_REFRESH_SKEW_SECONDS) >= expires_at
338 }
339 }
340}
341
342const SESSION_REFRESH_SKEW_SECONDS: i64 = 60;
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn env_session_from_lookup_parses_expected_fields() {
350 let session = EnvSession::from_lookup(|name| match name {
351 "AGENTIS_PAY_JWT" => Some("jwt-token".to_string()),
352 "AGENTIS_PAY_REFRESH_TOKEN" => Some("refresh-token".to_string()),
353 "AGENTIS_PAY_JWT_EXPIRES_AT" => Some("123".to_string()),
354 _ => None,
355 })
356 .expect("parse env session")
357 .expect("env session must exist");
358
359 assert_eq!(session.jwt(), Some("jwt-token"));
360 assert_eq!(session.refresh_token(), Some("refresh-token"));
361 assert_eq!(session.jwt_expires_at(), Some(123));
362 }
363
364 #[test]
365 fn env_session_rejects_refresh_without_jwt() {
366 let error = EnvSession::from_lookup(|name| match name {
367 "AGENTIS_PAY_REFRESH_TOKEN" => Some("refresh-token".to_string()),
368 _ => None,
369 })
370 .expect_err("missing jwt must fail");
371
372 assert!(error.to_string().contains("require AGENTIS_PAY_JWT"));
373 }
374}