ai_usagebar/openai/
oauth.rs1use serde::{Deserialize, Serialize};
11
12use crate::error::{AppError, Result};
13
14pub const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
15pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16pub const SCOPE: &str = "openid profile email";
17pub const REFRESH_BUFFER_SECS: i64 = 300;
18
19#[derive(Debug, Serialize)]
20struct RefreshRequest<'a> {
21 client_id: &'a str,
22 grant_type: &'a str,
23 refresh_token: &'a str,
24 scope: &'a str,
25}
26
27#[derive(Debug, Deserialize)]
28pub struct RefreshResponse {
29 pub access_token: String,
30 #[serde(default)]
31 pub refresh_token: Option<String>,
32 #[serde(default)]
33 pub id_token: Option<String>,
34 #[serde(default, deserialize_with = "de_expires_in")]
35 pub expires_in: Option<u64>,
36}
37
38fn de_expires_in<'de, D>(d: D) -> std::result::Result<Option<u64>, D::Error>
39where
40 D: serde::Deserializer<'de>,
41{
42 let v = serde_json::Value::deserialize(d)?;
43 Ok(match v {
44 serde_json::Value::Null => None,
45 serde_json::Value::Number(n) => n.as_u64().or_else(|| n.as_f64().map(|f| f as u64)),
46 _ => None,
47 })
48}
49
50pub async fn refresh(
51 client: &reqwest::Client,
52 endpoint: &str,
53 refresh_token: &str,
54) -> Result<RefreshResponse> {
55 let req = RefreshRequest {
56 client_id: CLIENT_ID,
57 grant_type: "refresh_token",
58 refresh_token,
59 scope: SCOPE,
60 };
61
62 let resp = client
63 .post(endpoint)
64 .header("Content-Type", "application/json")
65 .json(&req)
66 .send()
67 .await?;
68
69 let status = resp.status();
70 let body = resp.text().await.unwrap_or_default();
71 if !status.is_success() {
72 let msg = crate::anthropic::oauth::parse_error_body(&body)
73 .unwrap_or_else(|| "Refresh failed".into());
74 return Err(AppError::Http {
75 status: status.as_u16(),
76 body: msg,
77 });
78 }
79 serde_json::from_str(&body)
80 .map_err(|e| AppError::Schema(format!("openai token response: {e}; body: {body}")))
81}
82
83pub fn needs_refresh(expires_at_secs: i64, now_secs: i64) -> bool {
84 expires_at_secs < now_secs + REFRESH_BUFFER_SECS
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn needs_refresh_threshold() {
93 let now = 1_000_000;
94 assert!(needs_refresh(now + 100, now));
95 assert!(!needs_refresh(now + 1000, now));
96 }
97
98 #[tokio::test]
99 async fn refresh_success_parses_three_tokens() {
100 let mut server = mockito::Server::new_async().await;
101 server
102 .mock("POST", "/oauth/token")
103 .with_status(200)
104 .with_body(
105 r#"{"access_token":"new-at","refresh_token":"new-rt","id_token":"new-id","expires_in":3600}"#,
106 )
107 .create_async()
108 .await;
109 let client = reqwest::Client::new();
110 let r = refresh(&client, &format!("{}/oauth/token", server.url()), "old")
111 .await
112 .unwrap();
113 assert_eq!(r.access_token, "new-at");
114 assert_eq!(r.refresh_token.as_deref(), Some("new-rt"));
115 assert_eq!(r.id_token.as_deref(), Some("new-id"));
116 assert_eq!(r.expires_in, Some(3600));
117 }
118
119 #[tokio::test]
120 async fn refresh_400_returns_http_with_description() {
121 let mut server = mockito::Server::new_async().await;
122 server
123 .mock("POST", "/oauth/token")
124 .with_status(400)
125 .with_body(r#"{"error":"invalid_grant","error_description":"Refresh expired"}"#)
126 .create_async()
127 .await;
128 let client = reqwest::Client::new();
129 let err = refresh(&client, &format!("{}/oauth/token", server.url()), "x")
130 .await
131 .unwrap_err();
132 match err {
133 AppError::Http { status, body } => {
134 assert_eq!(status, 400);
135 assert_eq!(body, "Refresh expired");
136 }
137 other => panic!("expected Http error, got {other:?}"),
138 }
139 }
140}