Skip to main content

ai_agent/bridge/
work_secret.rs

1//! Work secret handling and session ID utilities.
2//!
3//! Translated from openclaudecode/src/bridge/workSecret.ts
4
5use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use serde::{Deserialize, Serialize};
7
8#[cfg(feature = "reqwest")]
9use reqwest;
10
11/// Work secret structure decoded from base64url-encoded JSON.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct WorkSecret {
14    pub version: u32,
15    pub session_ingress_token: String,
16    pub api_base_url: String,
17    pub sources: Vec<WorkSource>,
18    pub auth: Vec<WorkAuth>,
19    #[serde(default)]
20    pub claude_code_args: Option<std::collections::HashMap<String, String>>,
21    #[serde(default)]
22    pub mcp_config: Option<serde_json::Value>,
23    #[serde(default)]
24    pub environment_variables: Option<std::collections::HashMap<String, String>>,
25    /// Server-driven CCR v2 selector. Set when the session was created
26    /// via the v2 compat layer.
27    #[serde(default)]
28    pub use_code_sessions: Option<bool>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct WorkSource {
33    #[serde(rename = "type")]
34    pub source_type: String,
35    #[serde(default)]
36    pub git_info: Option<GitInfo>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GitInfo {
41    #[serde(rename = "type")]
42    pub git_type: String,
43    pub repo: String,
44    #[serde(default)]
45    pub r#ref: Option<String>,
46    #[serde(default)]
47    pub token: Option<String>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct WorkAuth {
52    #[serde(rename = "type")]
53    pub auth_type: String,
54    pub token: String,
55}
56
57/// Decode a base64url-encoded work secret and validate its version.
58pub fn decode_work_secret(secret: &str) -> Result<WorkSecret, String> {
59    let json = URL_SAFE_NO_PAD
60        .decode(secret)
61        .map_err(|e| format!("Failed to decode base64url: {}", e))?;
62
63    let parsed: serde_json::Value =
64        serde_json::from_slice(&json).map_err(|e| format!("Failed to parse JSON: {}", e))?;
65
66    if let Some(obj) = parsed.as_object() {
67        let version = obj.get("version").and_then(|v| v.as_u64()).unwrap_or(0);
68
69        if version != 1 {
70            return Err(format!(
71                "Unsupported work secret version: {}",
72                obj.get("version")
73                    .map(|v| v.to_string())
74                    .unwrap_or_else(|| "unknown".to_string())
75            ));
76        }
77
78        // Validate required fields
79        let session_ingress_token = obj
80            .get("session_ingress_token")
81            .and_then(|v| v.as_str())
82            .filter(|s| !s.is_empty())
83            .ok_or("Invalid work secret: missing or empty session_ingress_token")?;
84
85        let api_base_url = obj
86            .get("api_base_url")
87            .and_then(|v| v.as_str())
88            .ok_or("Invalid work secret: missing api_base_url")?;
89
90        let work_secret = WorkSecret {
91            version: version as u32,
92            session_ingress_token: session_ingress_token.to_string(),
93            api_base_url: api_base_url.to_string(),
94            sources: obj
95                .get("sources")
96                .and_then(|v| serde_json::from_value(v.clone()).ok())
97                .unwrap_or_default(),
98            auth: obj
99                .get("auth")
100                .and_then(|v| serde_json::from_value(v.clone()).ok())
101                .unwrap_or_default(),
102            claude_code_args: obj
103                .get("claude_code_args")
104                .and_then(|v| serde_json::from_value(v.clone()).ok()),
105            mcp_config: obj.get("mcp_config").cloned(),
106            environment_variables: obj
107                .get("environment_variables")
108                .and_then(|v| serde_json::from_value(v.clone()).ok()),
109            use_code_sessions: obj.get("use_code_sessions").and_then(|v| v.as_bool()),
110        };
111
112        Ok(work_secret)
113    } else {
114        Err("Invalid work secret: not an object".to_string())
115    }
116}
117
118/// Build a WebSocket SDK URL from the API base URL and session ID.
119/// Strips the HTTP(S) protocol and constructs a ws(s):// ingress URL.
120///
121/// Uses /v2/ for localhost (direct to session-ingress, no Envoy rewrite)
122/// and /v1/ for production (Envoy rewrites /v1/ -> /v2/).
123pub fn build_sdk_url(api_base_url: &str, session_id: &str) -> String {
124    let is_localhost = api_base_url.contains("localhost") || api_base_url.contains("127.0.0.1");
125    let protocol = if is_localhost { "ws" } else { "wss" };
126    let version = if is_localhost { "v2" } else { "v1" };
127    let host = api_base_url
128        .trim_start_matches("https://")
129        .trim_start_matches("http://")
130        .trim_end_matches('/');
131
132    format!(
133        "{}://{}/{}/session_ingress/ws/{}",
134        protocol, host, version, session_id
135    )
136}
137
138/// Compare two session IDs regardless of their tagged-ID prefix.
139///
140/// Tagged IDs have the form {tag}_{body} or {tag}_staging_{body}, where the
141/// body encodes a UUID. CCR v2's compat layer returns `session_*` to v1 API
142/// clients but the infrastructure layer uses `cse_*`. Both have the same
143/// underlying UUID.
144pub fn same_session_id(a: &str, b: &str) -> bool {
145    if a == b {
146        return true;
147    }
148
149    // The body is everything after the last underscore — this handles both
150    // `{tag}_{body}` and `{tag}_staging_{body}`.
151    let a_body = a.split('_').last().unwrap_or("");
152    let b_body = b.split('_').last().unwrap_or("");
153
154    // Guard against IDs with no underscore (bare UUIDs).
155    // Require a minimum length to avoid accidental matches on short suffixes.
156    a_body.len() >= 4 && a_body == b_body
157}
158
159/// Build a CCR v2 session URL from the API base URL and session ID.
160/// Returns an HTTP(S) URL (not ws://) and points at /v1/code/sessions/{id}.
161pub fn build_ccr_v2_sdk_url(api_base_url: &str, session_id: &str) -> String {
162    let base = api_base_url.trim_end_matches('/');
163    format!("{}/v1/code/sessions/{}", base, session_id)
164}
165
166/// Register this bridge as the worker for a CCR v2 session.
167/// Returns the worker_epoch, which must be passed to the child CC process.
168pub async fn register_worker(session_url: &str, access_token: &str) -> Result<u64, String> {
169    let client = reqwest::Client::new();
170
171    let response = client
172        .post(&format!("{}/worker/register", session_url))
173        .header("Authorization", format!("Bearer {}", access_token))
174        .header("Content-Type", "application/json")
175        .header("anthropic-version", "2023-06-01")
176        .timeout(std::time::Duration::from_secs(10))
177        .send()
178        .await
179        .map_err(|e| format!("Request failed: {}", e))?;
180
181    let data: serde_json::Value = response
182        .json()
183        .await
184        .map_err(|e| format!("Failed to parse response: {}", e))?;
185
186    let raw = data.get("worker_epoch");
187
188    let epoch = match raw {
189        Some(v) if v.is_number() => v.as_u64(),
190        Some(v) if v.is_string() => v.as_str().and_then(|s| s.parse().ok()),
191        _ => None,
192    };
193
194    epoch.ok_or_else(|| {
195        format!(
196            "register_worker: invalid worker_epoch in response: {}",
197            data
198        )
199    })
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_build_sdk_url() {
208        // Production
209        assert_eq!(
210            build_sdk_url("https://api.anthropic.com", "session_abc"),
211            "wss://api.anthropic.com/v1/session_ingress/ws/session_abc"
212        );
213
214        // Localhost
215        assert_eq!(
216            build_sdk_url("http://localhost:8080", "session_abc"),
217            "ws://localhost:8080/v2/session_ingress/ws/session_abc"
218        );
219    }
220
221    #[test]
222    fn test_same_session_id() {
223        // Same ID
224        assert!(same_session_id("session_abc123", "session_abc123"));
225
226        // Same UUID with different tags
227        assert!(same_session_id("cse_abc123", "session_abc123"));
228
229        // Different UUIDs
230        assert!(!same_session_id("session_abc123", "session_xyz789"));
231
232        // Staging format
233        assert!(same_session_id(
234            "cse_staging_abc123",
235            "session_staging_abc123"
236        ));
237    }
238
239    #[test]
240    fn test_build_ccr_v2_sdk_url() {
241        assert_eq!(
242            build_ccr_v2_sdk_url("https://api.anthropic.com", "session_abc"),
243            "https://api.anthropic.com/v1/code/sessions/session_abc"
244        );
245    }
246
247    #[test]
248    fn test_decode_work_secret() {
249        let secret = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(
250            r#"{"version":1,"session_ingress_token":"tok123","api_base_url":"https://api.example.com","sources":[],"auth":[]}"#
251        );
252
253        let decoded = decode_work_secret(&secret).unwrap();
254        assert_eq!(decoded.version, 1);
255        assert_eq!(decoded.session_ingress_token, "tok123");
256        assert_eq!(decoded.api_base_url, "https://api.example.com");
257    }
258}