cdp_sdk/
auth.rs

1use crate::error::CdpError;
2use base64::Engine;
3use bon::bon;
4use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
5use reqwest::{Request, Response};
6use reqwest_middleware::{Middleware, Next};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12use uuid::Uuid;
13
14const VERSION: &str = env!("CARGO_PKG_VERSION");
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct Claims {
18    sub: String,
19    iss: String,
20    aud: Vec<String>,
21    exp: u64,
22    iat: u64,
23    nbf: u64,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    uris: Option<Vec<String>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29struct WalletClaims {
30    iat: u64,
31    nbf: u64,
32    jti: String,
33    uris: Vec<String>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    #[serde(rename = "reqHash")]
36    req_hash: Option<String>,
37}
38
39/// Configuration options for the CDP Wallet Auth client
40#[derive(Debug, Clone, Default)]
41pub struct WalletAuth {
42    /// The API key ID
43    pub api_key_id: String,
44    /// The API key secret
45    pub api_key_secret: String,
46    /// The wallet secret
47    pub wallet_secret: Option<String>,
48    /// Whether to enable debugging
49    pub debug: bool,
50    /// The source identifier for requests
51    pub source: String,
52    /// The version of the source making requests
53    pub source_version: Option<String>,
54    /// JWT expiration time in seconds
55    pub expires_in: u64,
56}
57
58#[bon]
59impl WalletAuth {
60    #[builder]
61    pub fn new(
62        /// The API key ID
63        api_key_id: Option<String>,
64        /// The API key secret
65        api_key_secret: Option<String>,
66        /// The wallet secret
67        wallet_secret: Option<String>,
68        /// Whether to enable debugging
69        debug: Option<bool>,
70        /// The source identifier for requests
71        source: Option<String>,
72        /// The version of the source making requests
73        source_version: Option<String>,
74        /// JWT expiration time in seconds
75        expires_in: Option<u64>,
76    ) -> Result<Self, CdpError> {
77        use std::env;
78
79        // Get configuration from options or environment variables
80        let api_key_id = api_key_id
81            .or_else(|| env::var("CDP_API_KEY_ID").ok())
82            .ok_or_else(|| {
83                CdpError::Config(
84                    "Missing required CDP API Key ID configuration.\n\n\
85                        You can set them as environment variables:\n\
86                        CDP_API_KEY_ID=your-api-key-id\n\
87                        CDP_API_KEY_SECRET=your-api-key-secret\n\n\
88                        Or pass them directly to the CdpClientOptions."
89                        .to_string(),
90                )
91            })?;
92
93        let api_key_secret = api_key_secret
94            .or_else(|| env::var("CDP_API_KEY_SECRET").ok())
95            .ok_or_else(|| {
96                CdpError::Config(
97                    "Missing required CDP API Key Secret configuration.\n\n\
98                        You can set them as environment variables:\n\
99                        CDP_API_KEY_ID=your-api-key-id\n\
100                        CDP_API_KEY_SECRET=your-api-key-secret\n\n\
101                        Or pass them directly to the CdpClientOptions."
102                        .to_string(),
103                )
104            })?;
105
106        let wallet_secret = wallet_secret.or_else(|| env::var("CDP_WALLET_SECRET").ok());
107
108        let debug = debug.unwrap_or(false);
109        let expires_in = expires_in.unwrap_or(120);
110        let source = source.unwrap_or("sdk-auth".to_string());
111
112        Ok(WalletAuth {
113            api_key_id,
114            api_key_secret,
115            wallet_secret,
116            debug,
117            source,
118            source_version,
119            expires_in,
120        })
121    }
122
123    fn generate_jwt(
124        &self,
125        method: &str,
126        host: &str,
127        path: &str,
128        expires_in: u64,
129    ) -> Result<String, CdpError> {
130        let now = SystemTime::now()
131            .duration_since(UNIX_EPOCH)
132            .unwrap()
133            .as_secs();
134
135        let claims = Claims {
136            sub: self.api_key_id.clone(),
137            iss: "cdp".to_string(),
138            aud: vec!["cdp_service".to_string()],
139            exp: now + expires_in,
140            iat: now,
141            nbf: now,
142            uris: Some(vec![format!("{} {}{}", method, host, path)]),
143        };
144
145        // Determine key format and algorithm
146        let (algorithm, encoding_key) = if is_ec_pem_key(&self.api_key_secret) {
147            // PEM format EC key - use ES256
148            let key = EncodingKey::from_ec_pem(self.api_key_secret.as_bytes())
149                .map_err(|e| CdpError::Auth(format!("Failed to parse EC PEM key: {}", e)))?;
150            (Algorithm::ES256, key)
151        } else if is_ed25519_key(&self.api_key_secret) {
152            // Base64 Ed25519 key - use EdDSA
153            let decoded = base64::engine::general_purpose::STANDARD
154                .decode(&self.api_key_secret)
155                .map_err(|e| CdpError::Auth(format!("Failed to decode Ed25519 key: {}", e)))?;
156
157            if decoded.len() != 64 {
158                return Err(CdpError::Auth(
159                    "Invalid Ed25519 key length, expected 64 bytes".to_string(),
160                ));
161            }
162
163            // For Ed25519 keys, we need to create a proper PKCS#8 DER format
164            // Extract the seed (first 32 bytes)
165            let seed = &decoded[0..32];
166
167            // Create PKCS#8 DER format for Ed25519 private key
168            let mut pkcs8_der = Vec::new();
169            // PKCS#8 header for Ed25519
170            let header = hex::decode("302e020100300506032b657004220420").unwrap();
171            pkcs8_der.extend_from_slice(&header);
172            pkcs8_der.extend_from_slice(seed);
173
174            // Convert to PEM format
175            let pem_content = base64::engine::general_purpose::STANDARD.encode(&pkcs8_der);
176            let pem_formatted = format!(
177                "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----",
178                pem_content
179                    .chars()
180                    .collect::<Vec<_>>()
181                    .chunks(64)
182                    .map(|chunk| chunk.iter().collect::<String>())
183                    .collect::<Vec<_>>()
184                    .join("\n")
185            );
186
187            let key = EncodingKey::from_ed_pem(pem_formatted.as_bytes())
188                .map_err(|e| CdpError::Auth(format!("Failed to parse Ed25519 key: {}", e)))?;
189            (Algorithm::EdDSA, key)
190        } else {
191            return Err(CdpError::Auth(
192                "Invalid key format - must be either PEM EC key or base64 Ed25519 key".to_string(),
193            ));
194        };
195
196        let mut header = Header::new(algorithm);
197        header.kid = Some(self.api_key_id.clone());
198
199        encode(&header, &claims, &encoding_key)
200            .map_err(|e| CdpError::Auth(format!("Failed to encode JWT: {}", e)))
201    }
202
203    pub fn generate_wallet_jwt(
204        &self,
205        method: &str,
206        host: &str,
207        path: &str,
208        body: &[u8],
209    ) -> Result<String, CdpError> {
210        let wallet_secret = self.wallet_secret.as_ref().ok_or_else(|| {
211            CdpError::Auth("Wallet secret required for this operation".to_string())
212        })?;
213
214        let now = SystemTime::now()
215            .duration_since(UNIX_EPOCH)
216            .unwrap()
217            .as_secs();
218
219        let uri = format!("{} {}{}", method, host, path);
220        let jti = format!("{:x}", Uuid::new_v4().simple()); // Use hex format like JavaScript
221
222        // Calculate reqHash only if body is not empty, using hex format like JavaScript
223        let req_hash = if !body.is_empty() {
224            // Parse body as JSON and sort keys
225            let body_str = std::str::from_utf8(body)
226                .map_err(|e| CdpError::Auth(format!("Invalid UTF-8 in request body: {}", e)))?;
227
228            if !body_str.trim().is_empty() {
229                let parsed: Value = serde_json::from_str(body_str)
230                    .map_err(|e| CdpError::Auth(format!("Failed to parse JSON body: {}", e)))?;
231
232                let sorted = sort_keys(parsed);
233                let sorted_json = serde_json::to_string(&sorted).map_err(|e| {
234                    CdpError::Auth(format!("Failed to serialize sorted JSON: {}", e))
235                })?;
236
237                let mut hasher = Sha256::new();
238                hasher.update(sorted_json.as_bytes());
239                Some(format!("{:x}", hasher.finalize()))
240            } else {
241                None
242            }
243        } else {
244            None
245        };
246
247        let claims = WalletClaims {
248            iat: now,
249            nbf: now, // Add nbf like JavaScript
250            jti,
251            uris: vec![uri],
252            req_hash,
253        };
254
255        let header = Header::new(Algorithm::ES256);
256
257        // Decode the base64 wallet secret
258        let der_bytes = base64::engine::general_purpose::STANDARD
259            .decode(wallet_secret)
260            .map_err(|e| CdpError::Auth(format!("Failed to decode wallet secret: {}", e)))?;
261
262        let encoding_key = EncodingKey::from_ec_der(&der_bytes);
263
264        encode(&header, &claims, &encoding_key)
265            .map_err(|e| CdpError::Auth(format!("Failed to encode wallet JWT: {}", e)))
266    }
267
268    fn requires_wallet_auth(&self, method: &str, path: &str) -> bool {
269        (path.contains("/accounts") || path.contains("/spend-permissions"))
270            && (method == "POST" || method == "DELETE" || method == "PUT")
271    }
272
273    fn get_correlation_data(&self) -> String {
274        let mut data = HashMap::new();
275
276        data.insert("sdk_version".to_string(), VERSION.to_string());
277        data.insert("sdk_language".to_string(), "rust".to_string());
278        data.insert("source".to_string(), self.source.clone());
279
280        if let Some(ref source_version) = self.source_version {
281            data.insert("source_version".to_string(), source_version.clone());
282        }
283
284        data.into_iter()
285            .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v)))
286            .collect::<Vec<_>>()
287            .join(",")
288    }
289}
290
291#[async_trait::async_trait]
292impl Middleware for WalletAuth {
293    async fn handle(
294        &self,
295        mut req: Request,
296        extensions: &mut http::Extensions,
297        next: Next<'_>,
298    ) -> reqwest_middleware::Result<Response> {
299        let method = req.method().as_str().to_uppercase();
300        let url = req.url().clone();
301        let host = url.host_str().unwrap_or("api.cdp.coinbase.com");
302        let path = url.path();
303
304        // Get request body for wallet auth
305        let body = if let Some(body) = req.body() {
306            body.as_bytes().unwrap_or_default().to_vec()
307        } else {
308            Vec::new()
309        };
310
311        let expires_in = self.expires_in;
312
313        // Generate main JWT
314        let jwt = self
315            .generate_jwt(&method, host, path, expires_in)
316            .map_err(reqwest_middleware::Error::middleware)?;
317
318        // Add authorization header
319        req.headers_mut()
320            .insert("Authorization", format!("Bearer {}", jwt).parse().unwrap());
321
322        // Add content type
323        req.headers_mut()
324            .insert("Content-Type", "application/json".parse().unwrap());
325
326        // Add wallet auth if needed, and not already provided or if empty
327        if self.requires_wallet_auth(&method, path)
328            && (!req.headers().contains_key("X-Wallet-Auth")
329                || req
330                    .headers()
331                    .get("X-Wallet-Auth")
332                    .is_none_or(|v| v.is_empty()))
333        {
334            let wallet_jwt = self
335                .generate_wallet_jwt(&method, host, path, &body)
336                .map_err(reqwest_middleware::Error::middleware)?;
337
338            req.headers_mut()
339                .insert("X-Wallet-Auth", wallet_jwt.parse().unwrap());
340        }
341
342        // Add correlation data
343        req.headers_mut().insert(
344            "Correlation-Context",
345            self.get_correlation_data().parse().unwrap(),
346        );
347
348        if self.debug {
349            println!("Request: {} {}", method, url);
350            println!("Headers: {:?}", req.headers());
351        }
352
353        let response = next.run(req, extensions).await;
354
355        if self.debug {
356            if let Ok(ref resp) = response {
357                println!(
358                    "Response: {} {}",
359                    resp.status(),
360                    resp.status().canonical_reason().unwrap_or("")
361                );
362            }
363        }
364
365        response
366    }
367}
368
369fn sort_keys(value: Value) -> Value {
370    match value {
371        Value::Object(map) => {
372            let mut sorted_map = serde_json::Map::new();
373            let mut keys: Vec<_> = map.keys().collect();
374            keys.sort();
375            for key in keys {
376                if let Some(val) = map.get(key) {
377                    sorted_map.insert(key.clone(), sort_keys(val.clone()));
378                }
379            }
380            Value::Object(sorted_map)
381        }
382        Value::Array(arr) => Value::Array(arr.into_iter().map(sort_keys).collect()),
383        _ => value,
384    }
385}
386
387fn is_ed25519_key(key: &str) -> bool {
388    // Try to decode as base64 and check if it's 64 bytes (Ed25519 format)
389    if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(key) {
390        decoded.len() == 64
391    } else {
392        false
393    }
394}
395
396fn is_ec_pem_key(key: &str) -> bool {
397    // Check if the key looks like a PEM format EC key
398    key.contains("-----BEGIN")
399        && key.contains("-----END")
400        && (key.contains("EC PRIVATE KEY") || key.contains("PRIVATE KEY"))
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_wallet_auth_builder_with_all_fields() {
409        let auth = WalletAuth::builder()
410            .api_key_id("test_key_id".to_string())
411            .api_key_secret("test_key_secret".to_string())
412            .wallet_secret("test_wallet_secret".to_string())
413            .debug(true)
414            .source("test_source".to_string())
415            .source_version("1.0.0".to_string())
416            .expires_in(300)
417            .build()
418            .unwrap();
419
420        assert_eq!(auth.api_key_id, "test_key_id");
421        assert_eq!(auth.api_key_secret, "test_key_secret");
422        assert_eq!(auth.wallet_secret, Some("test_wallet_secret".to_string()));
423        assert!(auth.debug);
424        assert_eq!(auth.source, "test_source");
425        assert_eq!(auth.source_version, Some("1.0.0".to_string()));
426        assert_eq!(auth.expires_in, 300);
427    }
428
429    #[test]
430    fn test_wallet_auth_builder_with_required_fields_only() {
431        let auth = WalletAuth::builder()
432            .api_key_id("test_key_id".to_string())
433            .api_key_secret("test_key_secret".to_string())
434            .build()
435            .unwrap();
436
437        assert_eq!(auth.api_key_id, "test_key_id");
438        assert_eq!(auth.api_key_secret, "test_key_secret");
439        assert_eq!(auth.wallet_secret, None);
440        assert!(!auth.debug);
441        assert_eq!(auth.source, "sdk-auth");
442        assert_eq!(auth.source_version, None);
443        assert_eq!(auth.expires_in, 120);
444    }
445
446    #[test]
447    fn test_wallet_auth_builder_with_optional_fields() {
448        let auth = WalletAuth::builder()
449            .api_key_id("test_key_id".to_string())
450            .api_key_secret("test_key_secret".to_string())
451            .debug(true)
452            .expires_in(600)
453            .build()
454            .unwrap();
455
456        assert_eq!(auth.api_key_id, "test_key_id");
457        assert_eq!(auth.api_key_secret, "test_key_secret");
458        assert!(auth.debug);
459        assert_eq!(auth.expires_in, 600);
460        assert_eq!(auth.source, "sdk-auth"); // default value
461    }
462
463    #[test]
464    fn test_wallet_auth_builder_missing_api_key_id() {
465        let result = WalletAuth::builder()
466            .api_key_secret("test_key_secret".to_string())
467            .build();
468
469        assert!(result.is_err());
470        if let Err(CdpError::Config(msg)) = result {
471            assert!(msg.contains("Missing required CDP API Key ID configuration"));
472        } else {
473            panic!("Expected Config error for missing api_key_id");
474        }
475    }
476
477    #[test]
478    fn test_wallet_auth_builder_missing_api_key_secret() {
479        let result = WalletAuth::builder()
480            .api_key_id("test_key_id".to_string())
481            .build();
482
483        assert!(result.is_err());
484        if let Err(CdpError::Config(msg)) = result {
485            assert!(msg.contains("Missing required CDP API Key Secret configuration"));
486        } else {
487            panic!("Expected Config error for missing api_key_secret");
488        }
489    }
490
491    #[test]
492    fn test_wallet_auth_builder_custom_source() {
493        let auth = WalletAuth::builder()
494            .api_key_id("test_key_id".to_string())
495            .api_key_secret("test_key_secret".to_string())
496            .source("my-custom-app".to_string())
497            .source_version("2.1.0".to_string())
498            .build()
499            .unwrap();
500
501        assert_eq!(auth.source, "my-custom-app");
502        assert_eq!(auth.source_version, Some("2.1.0".to_string()));
503    }
504
505    #[test]
506    fn test_requires_wallet_auth() {
507        let auth = WalletAuth::builder()
508            .api_key_id("test_key_id".to_string())
509            .api_key_secret("test_key_secret".to_string())
510            .build()
511            .unwrap();
512
513        // Should require wallet auth for POST to accounts
514        assert!(auth.requires_wallet_auth("POST", "/v2/evm/accounts"));
515
516        // Should require wallet auth for PUT to accounts
517        assert!(auth.requires_wallet_auth("PUT", "/v2/evm/accounts/0x123"));
518
519        // Should require wallet auth for DELETE to accounts
520        assert!(auth.requires_wallet_auth("DELETE", "/v2/evm/accounts/0x123"));
521
522        // Should require wallet auth for spend-permissions
523        assert!(auth.requires_wallet_auth("POST", "/v2/spend-permissions"));
524
525        // Should NOT require wallet auth for GET requests
526        assert!(!auth.requires_wallet_auth("GET", "/v2/evm/accounts"));
527
528        // Should NOT require wallet auth for non-account endpoints
529        assert!(!auth.requires_wallet_auth("POST", "/v2/other/endpoint"));
530    }
531
532    #[test]
533    fn test_is_ed25519_key() {
534        // Valid base64 encoded 64-byte key
535        let valid_ed25519 = base64::engine::general_purpose::STANDARD.encode([0u8; 64]);
536        assert!(is_ed25519_key(&valid_ed25519));
537
538        // Invalid key (wrong length)
539        let invalid_key = base64::engine::general_purpose::STANDARD.encode([0u8; 32]);
540        assert!(!is_ed25519_key(&invalid_key));
541
542        // Not base64
543        assert!(!is_ed25519_key("not-base64"));
544    }
545
546    #[test]
547    fn test_is_ec_pem_key() {
548        let pem_key = "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----";
549        assert!(is_ec_pem_key(pem_key));
550
551        let generic_pem_key = "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----";
552        assert!(is_ec_pem_key(generic_pem_key));
553
554        let not_pem_key = "just-a-string";
555        assert!(!is_ec_pem_key(not_pem_key));
556    }
557}