Skip to main content

atomcode_core/coding_plan/
client.rs

1// crates/atomcode-core/src/coding_plan/client.rs
2//
3// Blocking HTTP client for the three CodingPlan REST endpoints. Reuses the
4// OAuth token already on disk (from `crate::auth`) — the token authenticates
5// both `atomgit.com` and `api.gitcode.com` (same backend, different front
6// domains). Every request carries `ATOMCODE_USER_AGENT` so AtomGit's
7// API gateway sees a consistent identifier.
8//
9// Blocking (not async) is deliberate: the coding-plan flow runs synchronously
10// before / outside the agent event loop. Async would force tokio::block_on
11// or channel plumbing at every call site for zero concurrency benefit.
12
13use anyhow::{anyhow, Context, Result};
14
15use super::types::{ClaimResponse, ModelEntry, PlanType, StatusResponse};
16use crate::auth;
17
18/// Default CodingPlan REST API base URL.
19/// Override with the `ATOMCODE_CODINGPLAN_API_BASE` environment variable.
20const DEFAULT_API_BASE: &str = "https://api.gitcode.com/api/v5";
21
22/// Return the CodingPlan REST API base URL, reading
23/// `ATOMCODE_CODINGPLAN_API_BASE` once at first call and caching the
24/// result for the process lifetime.
25///
26/// Read order:
27///   1. `ATOMCODE_CODINGPLAN_API_BASE` env var (trimmed, trailing `/`
28///      stripped, empty value treated as unset).
29///   2. [`DEFAULT_API_BASE`].
30pub fn api_base_url() -> String {
31    use std::sync::OnceLock;
32    static BASE: OnceLock<String> = OnceLock::new();
33    BASE.get_or_init(|| {
34        std::env::var("ATOMCODE_CODINGPLAN_API_BASE")
35            .ok()
36            .map(|v| v.trim().trim_end_matches('/').to_string())
37            .filter(|v| !v.is_empty())
38            .unwrap_or_else(|| DEFAULT_API_BASE.to_string())
39    })
40    .clone()
41}
42
43/// Typed error surfaced when the API rejects the bearer token (401/403).
44///
45/// Carried inside `anyhow::Error` by every `Client` method so the
46/// orchestrator can `downcast_ref::<AuthExpired>()` and decide to
47/// re-run OAuth instead of just printing the failure. Before this
48/// existed `/codingplan` would emit "already logged in" + "claim failed
49/// — run `atomcode login` again" and leave the user to do it manually,
50/// even though `/login` would have fixed it in one step.
51///
52/// The Display text matches the legacy error string verbatim so
53/// rendered reports stay byte-identical when no recovery happens
54/// (e.g. running `atomcode` against a server we never reach for
55/// re-auth, or a non-interactive scripted invocation).
56#[derive(Debug)]
57pub struct AuthExpired {
58    pub status: u16,
59}
60
61impl std::fmt::Display for AuthExpired {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(
64            f,
65            "authentication failed ({}) — run `atomcode login` again",
66            self.status
67        )
68    }
69}
70
71impl std::error::Error for AuthExpired {}
72
73/// True iff `err` (or any error in its cause chain) is an `AuthExpired`.
74/// Centralised so the orchestrator and shell callers agree on what
75/// "stale token" looks like — anywhere we want to decide "rerun OAuth?"
76/// goes through here.
77pub fn is_auth_expired(err: &anyhow::Error) -> bool {
78    err.chain().any(|e| e.is::<AuthExpired>())
79}
80
81/// Token-authenticated blocking REST client for CodingPlan endpoints.
82pub struct Client {
83    http: reqwest::blocking::Client,
84    token: String,
85}
86
87impl Client {
88    /// Build a client using the currently-stored OAuth token. Refreshes
89    /// the token if expired. Errors with a user-facing message if the
90    /// user isn't logged in.
91    pub fn from_stored_auth() -> Result<Self> {
92        if !auth::is_logged_in() {
93            return Err(anyhow!(
94                "not logged in — run `atomcode login` (or the codingplan flow) first"
95            ));
96        }
97        // If the local access token can't be made valid (expired and the
98        // broker refused our refresh_token, or the file is malformed),
99        // there's no way to proceed without a fresh OAuth round-trip.
100        // Surface as `AuthExpired` so the orchestrator triggers the same
101        // recovery path it uses for an API-side 401, instead of bailing
102        // with a generic "build client" error that callers can't act on.
103        let token = match auth::get_valid_token() {
104            Ok(t) => t,
105            Err(e) => {
106                return Err(anyhow::Error::new(AuthExpired { status: 401 })
107                    .context(format!("local token unusable: {:#}", e)));
108            }
109        };
110        // Timeouts are critical here: `/status` and the background drift
111        // monitor both call these endpoints synchronously from the TUI
112        // event loop, and without a cap a slow / unreachable gateway
113        // hangs the entire UI until the OS eventually gives up (minutes
114        // on a VPN flap). 5s connect + 10s total covers every realistic
115        // latency for a healthy path and fails fast otherwise — the
116        // error surfaces as a benign "status fetch failed" line next to
117        // the rest of the status report.
118        let http = reqwest::blocking::Client::builder()
119            .connect_timeout(std::time::Duration::from_secs(5))
120            .timeout(std::time::Duration::from_secs(10))
121            .user_agent(crate::ATOMCODE_USER_AGENT)
122            .build()
123            .unwrap_or_else(|_| reqwest::blocking::Client::new());
124        Ok(Self { http, token })
125    }
126
127    /// `POST /coding-plan/claim-v2` — claim a specific CodingPlan tier.
128    /// Server reports `duplicate=true` when the user already holds the
129    /// tier (or a higher one); callers should treat that as success and
130    /// stop the cascade rather than retrying lower tiers — those would
131    /// either also report duplicate or unnecessarily downgrade.
132    ///
133    /// Body shape: `{"plan_type": "Max" | "Pro" | "Lite"}`. The user
134    /// asked us to start at Max and walk down, so the orchestrator
135    /// (`step_claim`) calls this in `PlanType::CASCADE_ORDER`.
136    pub fn claim_v2(&self, plan_type: PlanType) -> Result<ClaimResponse> {
137        let url = format!("{}/coding-plan/claim-v2", api_base_url());
138        let body_str = format!(r#"{{"plan_type":"{}"}}"#, plan_type.as_str());
139        let resp = self
140            .http
141            .post(&url)
142            .bearer_auth(&self.token)
143            .header("Content-Type", "application/json")
144            .body(body_str)
145            .send()
146            .with_context(|| format!("POST {} failed", url))?;
147
148        let status = resp.status();
149        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
150            return Err(anyhow::Error::new(AuthExpired {
151                status: status.as_u16(),
152            }));
153        }
154        let body = resp.text().unwrap_or_default();
155        if !status.is_success() {
156            return Err(anyhow!("{}", format_api_error("claim-v2", status, &body)));
157        }
158        serde_json::from_str::<ClaimResponse>(&body).with_context(|| {
159            format!(
160                "parse claim-v2 response (body: {})",
161                truncate_for_error(&body, 200)
162            )
163        })
164    }
165
166    /// `GET /coding-plan/models-v2?plan_type=<tier>` — model catalogue
167    /// from the v2 endpoint. Every entry now carries `plan_available`
168    /// telling the caller whether the user's tier covers that model;
169    /// the renderer uses it to apply strikethrough on locked entries.
170    /// Empty list is a legitimate return when the entitlement hasn't
171    /// been provisioned yet.
172    pub fn list_models_v2(&self, plan_type: PlanType) -> Result<Vec<ModelEntry>> {
173        let url = format!(
174            "{}/coding-plan/models-v2?plan_type={}",
175            api_base_url(),
176            plan_type.as_str()
177        );
178        let resp = self
179            .http
180            .get(&url)
181            .bearer_auth(&self.token)
182            .send()
183            .with_context(|| format!("GET {} failed", url))?;
184
185        let status = resp.status();
186        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
187            return Err(anyhow::Error::new(AuthExpired {
188                status: status.as_u16(),
189            }));
190        }
191        let body = resp.text().unwrap_or_default();
192        if !status.is_success() {
193            return Err(anyhow!("{}", format_api_error("models-v2", status, &body)));
194        }
195        serde_json::from_str::<Vec<ModelEntry>>(&body).with_context(|| {
196            format!(
197                "parse models-v2 response (body: {})",
198                truncate_for_error(&body, 200)
199            )
200        })
201    }
202
203    /// `GET /coding-plan/status-v2` — audit/quota/expiry snapshot. Same
204    /// response envelope as v1; only the path changed under the v2
205    /// rollout, so the parser type stays put.
206    pub fn status_v2(&self) -> Result<StatusResponse> {
207        let url = format!("{}/coding-plan/status-v2", api_base_url());
208        let resp = self
209            .http
210            .get(&url)
211            .bearer_auth(&self.token)
212            .send()
213            .with_context(|| format!("GET {} failed", url))?;
214
215        let status = resp.status();
216        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
217            return Err(anyhow::Error::new(AuthExpired {
218                status: status.as_u16(),
219            }));
220        }
221        let body = resp.text().unwrap_or_default();
222        if !status.is_success() {
223            return Err(anyhow!("{}", format_api_error("status-v2", status, &body)));
224        }
225        serde_json::from_str::<StatusResponse>(&body).with_context(|| {
226            format!(
227                "parse status-v2 response (body: {})",
228                truncate_for_error(&body, 200)
229            )
230        })
231    }
232}
233
234fn truncate_for_error(s: &str, max_chars: usize) -> String {
235    if s.chars().count() <= max_chars {
236        s.to_string()
237    } else {
238        let head: String = s.chars().take(max_chars).collect();
239        format!("{}…", head)
240    }
241}
242
243/// Format a non-2xx response from any CodingPlan endpoint into a
244/// user-facing error string. Tries three body shapes in priority order:
245///
246///   1. Product payload `{"message": "..."}` (non-empty) — shown verbatim
247///      (e.g. `全平台日限额已满` from a 429).
248///   2. Spring default error body `{"timestamp":..,"status":..,"error":..,
249///      "path":".."}` — rendered as `HTTP <code> — 接口暂不可用 (<path>)`
250///      so the user sees which endpoint 404'd without staring at raw JSON.
251///   3. Raw text fallback — `CodingPlan <descriptor> returned <status> — <body>`
252///      with a 200-char body cap.
253///
254/// `descriptor` is a short name for the caller (`claim` / `models` /
255/// `status`) used only in shape-3 fallback where no structured info exists.
256fn format_api_error(descriptor: &str, status: reqwest::StatusCode, body: &str) -> String {
257    if let Ok(val) = serde_json::from_str::<serde_json::Value>(body) {
258        // Shape 1: product payload with non-empty `message`.
259        if let Some(msg) = val.get("message").and_then(|v| v.as_str()) {
260            if !msg.is_empty() {
261                return msg.to_string();
262            }
263        }
264        // Shape 2: Spring error body with `path` (and no usable message).
265        if let Some(path) = val.get("path").and_then(|v| v.as_str()) {
266            return format!("HTTP {} — 接口暂不可用 ({})", status.as_u16(), path);
267        }
268    }
269    // Shape 3: raw text fallback.
270    format!(
271        "CodingPlan {} returned {} — {}",
272        descriptor,
273        status,
274        truncate_for_error(body, 200)
275    )
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    /// `AuthExpired` must Display identically to the legacy
283    /// `anyhow!("authentication failed (NNN) — run `atomcode login` again")`
284    /// string so existing renderers / log scrapers / users that grep
285    /// for the hint don't see a stealth wording change.
286    #[test]
287    fn auth_expired_display_matches_legacy_string() {
288        let e = AuthExpired { status: 401 };
289        assert_eq!(
290            e.to_string(),
291            "authentication failed (401) — run `atomcode login` again"
292        );
293        let e = AuthExpired { status: 403 };
294        assert_eq!(
295            e.to_string(),
296            "authentication failed (403) — run `atomcode login` again"
297        );
298    }
299
300    /// `is_auth_expired` finds the marker on a direct `AuthExpired`
301    /// error AND through the `with_context` chain that callers wrap
302    /// around it ("build client: ...", "list models-v2: ..."). Without
303    /// walking the chain we'd miss it the moment any step layers a
304    /// context onto the original anyhow::Error.
305    #[test]
306    fn is_auth_expired_walks_cause_chain() {
307        let raw = anyhow::Error::new(AuthExpired { status: 401 });
308        assert!(is_auth_expired(&raw));
309
310        let wrapped: anyhow::Error =
311            Err::<(), _>(anyhow::Error::new(AuthExpired { status: 401 }))
312                .context("list models-v2")
313                .unwrap_err();
314        assert!(is_auth_expired(&wrapped));
315
316        let unrelated = anyhow!("some other failure");
317        assert!(!is_auth_expired(&unrelated));
318    }
319
320    #[test]
321    fn format_api_error_extracts_message_from_product_payload() {
322        let body = r#"{"success":false,"duplicate":false,"message":"全平台日限额已满"}"#;
323        let msg = format_api_error("claim", reqwest::StatusCode::TOO_MANY_REQUESTS, body);
324        assert_eq!(msg, "全平台日限额已满");
325    }
326
327    #[test]
328    fn format_api_error_uses_path_from_spring_error_body() {
329        let body = r#"{"timestamp":"2026-04-23T06:44:11.638+00:00","status":404,"error":"Not Found","path":"/api/v5/coding-plan/claim"}"#;
330        let msg = format_api_error("claim", reqwest::StatusCode::NOT_FOUND, body);
331        assert_eq!(msg, "HTTP 404 — 接口暂不可用 (/api/v5/coding-plan/claim)");
332    }
333
334    #[test]
335    fn format_api_error_path_takes_precedence_when_message_empty() {
336        let body = r#"{"message":"","path":"/api/v5/coding-plan/models"}"#;
337        let msg = format_api_error("models", reqwest::StatusCode::NOT_FOUND, body);
338        assert_eq!(msg, "HTTP 404 — 接口暂不可用 (/api/v5/coding-plan/models)");
339    }
340
341    #[test]
342    fn format_api_error_falls_back_on_non_json_body() {
343        let body = "<html>502 Bad Gateway</html>";
344        let msg = format_api_error("status", reqwest::StatusCode::BAD_GATEWAY, body);
345        assert!(msg.contains("502"), "status code missing: {}", msg);
346        assert!(
347            msg.contains("CodingPlan status"),
348            "descriptor missing: {}",
349            msg
350        );
351        assert!(
352            msg.contains("502 Bad Gateway"),
353            "body should be echoed: {}",
354            msg
355        );
356    }
357
358    #[test]
359    fn format_api_error_falls_back_on_json_with_no_known_fields() {
360        let body = r#"{"foo":"bar"}"#;
361        let msg = format_api_error("claim", reqwest::StatusCode::INTERNAL_SERVER_ERROR, body);
362        assert!(msg.contains("500"), "status code missing: {}", msg);
363        assert!(
364            msg.contains("CodingPlan claim"),
365            "descriptor missing: {}",
366            msg
367        );
368    }
369
370    #[test]
371    fn format_api_error_message_wins_over_path_when_both_present() {
372        let body = r#"{"message":"全平台日限额已满","path":"/api/v5/coding-plan/claim"}"#;
373        let msg = format_api_error("claim", reqwest::StatusCode::TOO_MANY_REQUESTS, body);
374        assert_eq!(msg, "全平台日限额已满");
375    }
376}