1use crate::error::CdpError;
2use base64::Engine;
3use bon::bon;
4use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
5use reqwest::{Request, Response};
6use reqwest_middleware::{Middleware, Next};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12use uuid::Uuid;
13
14const VERSION: &str = env!("CARGO_PKG_VERSION");
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct Claims {
18 sub: String,
19 iss: String,
20 aud: Vec<String>,
21 exp: u64,
22 iat: u64,
23 nbf: u64,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 uris: Option<Vec<String>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29struct WalletClaims {
30 iat: u64,
31 nbf: u64,
32 jti: String,
33 uris: Vec<String>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 #[serde(rename = "reqHash")]
36 req_hash: Option<String>,
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct WalletAuth {
42 pub api_key_id: String,
44 pub api_key_secret: String,
46 pub wallet_secret: Option<String>,
48 pub debug: bool,
50 pub source: String,
52 pub source_version: Option<String>,
54 pub expires_in: u64,
56}
57
58#[bon]
59impl WalletAuth {
60 #[builder]
61 pub fn new(
62 api_key_id: Option<String>,
64 api_key_secret: Option<String>,
66 wallet_secret: Option<String>,
68 debug: Option<bool>,
70 source: Option<String>,
72 source_version: Option<String>,
74 expires_in: Option<u64>,
76 ) -> Result<Self, CdpError> {
77 use std::env;
78
79 let api_key_id = api_key_id
81 .or_else(|| env::var("CDP_API_KEY_ID").ok())
82 .ok_or_else(|| {
83 CdpError::Config(
84 "Missing required CDP API Key ID configuration.\n\n\
85 You can set them as environment variables:\n\
86 CDP_API_KEY_ID=your-api-key-id\n\
87 CDP_API_KEY_SECRET=your-api-key-secret\n\n\
88 Or pass them directly to the CdpClientOptions."
89 .to_string(),
90 )
91 })?;
92
93 let api_key_secret = api_key_secret
94 .or_else(|| env::var("CDP_API_KEY_SECRET").ok())
95 .ok_or_else(|| {
96 CdpError::Config(
97 "Missing required CDP API Key Secret configuration.\n\n\
98 You can set them as environment variables:\n\
99 CDP_API_KEY_ID=your-api-key-id\n\
100 CDP_API_KEY_SECRET=your-api-key-secret\n\n\
101 Or pass them directly to the CdpClientOptions."
102 .to_string(),
103 )
104 })?;
105
106 let wallet_secret = wallet_secret.or_else(|| env::var("CDP_WALLET_SECRET").ok());
107
108 let debug = debug.unwrap_or(false);
109 let expires_in = expires_in.unwrap_or(120);
110 let source = source.unwrap_or("sdk-auth".to_string());
111
112 Ok(WalletAuth {
113 api_key_id,
114 api_key_secret,
115 wallet_secret,
116 debug,
117 source,
118 source_version,
119 expires_in,
120 })
121 }
122
123 fn generate_jwt(
124 &self,
125 method: &str,
126 host: &str,
127 path: &str,
128 expires_in: u64,
129 ) -> Result<String, CdpError> {
130 let now = SystemTime::now()
131 .duration_since(UNIX_EPOCH)
132 .unwrap()
133 .as_secs();
134
135 let claims = Claims {
136 sub: self.api_key_id.clone(),
137 iss: "cdp".to_string(),
138 aud: vec!["cdp_service".to_string()],
139 exp: now + expires_in,
140 iat: now,
141 nbf: now,
142 uris: Some(vec![format!("{} {}{}", method, host, path)]),
143 };
144
145 let (algorithm, encoding_key) = if is_ec_pem_key(&self.api_key_secret) {
147 let key = EncodingKey::from_ec_pem(self.api_key_secret.as_bytes())
149 .map_err(|e| CdpError::Auth(format!("Failed to parse EC PEM key: {}", e)))?;
150 (Algorithm::ES256, key)
151 } else if is_ed25519_key(&self.api_key_secret) {
152 let decoded = base64::engine::general_purpose::STANDARD
154 .decode(&self.api_key_secret)
155 .map_err(|e| CdpError::Auth(format!("Failed to decode Ed25519 key: {}", e)))?;
156
157 if decoded.len() != 64 {
158 return Err(CdpError::Auth(
159 "Invalid Ed25519 key length, expected 64 bytes".to_string(),
160 ));
161 }
162
163 let seed = &decoded[0..32];
166
167 let mut pkcs8_der = Vec::new();
169 let header = hex::decode("302e020100300506032b657004220420").unwrap();
171 pkcs8_der.extend_from_slice(&header);
172 pkcs8_der.extend_from_slice(seed);
173
174 let pem_content = base64::engine::general_purpose::STANDARD.encode(&pkcs8_der);
176 let pem_formatted = format!(
177 "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----",
178 pem_content
179 .chars()
180 .collect::<Vec<_>>()
181 .chunks(64)
182 .map(|chunk| chunk.iter().collect::<String>())
183 .collect::<Vec<_>>()
184 .join("\n")
185 );
186
187 let key = EncodingKey::from_ed_pem(pem_formatted.as_bytes())
188 .map_err(|e| CdpError::Auth(format!("Failed to parse Ed25519 key: {}", e)))?;
189 (Algorithm::EdDSA, key)
190 } else {
191 return Err(CdpError::Auth(
192 "Invalid key format - must be either PEM EC key or base64 Ed25519 key".to_string(),
193 ));
194 };
195
196 let mut header = Header::new(algorithm);
197 header.kid = Some(self.api_key_id.clone());
198
199 encode(&header, &claims, &encoding_key)
200 .map_err(|e| CdpError::Auth(format!("Failed to encode JWT: {}", e)))
201 }
202
203 pub fn generate_wallet_jwt(
204 &self,
205 method: &str,
206 host: &str,
207 path: &str,
208 body: &[u8],
209 ) -> Result<String, CdpError> {
210 let wallet_secret = self.wallet_secret.as_ref().ok_or_else(|| {
211 CdpError::Auth("Wallet secret required for this operation".to_string())
212 })?;
213
214 let now = SystemTime::now()
215 .duration_since(UNIX_EPOCH)
216 .unwrap()
217 .as_secs();
218
219 let uri = format!("{} {}{}", method, host, path);
220 let jti = format!("{:x}", Uuid::new_v4().simple()); let req_hash = if !body.is_empty() {
224 let body_str = std::str::from_utf8(body)
226 .map_err(|e| CdpError::Auth(format!("Invalid UTF-8 in request body: {}", e)))?;
227
228 if !body_str.trim().is_empty() {
229 let parsed: Value = serde_json::from_str(body_str)
230 .map_err(|e| CdpError::Auth(format!("Failed to parse JSON body: {}", e)))?;
231
232 let sorted = sort_keys(parsed);
233 let sorted_json = serde_json::to_string(&sorted).map_err(|e| {
234 CdpError::Auth(format!("Failed to serialize sorted JSON: {}", e))
235 })?;
236
237 let mut hasher = Sha256::new();
238 hasher.update(sorted_json.as_bytes());
239 Some(format!("{:x}", hasher.finalize()))
240 } else {
241 None
242 }
243 } else {
244 None
245 };
246
247 let claims = WalletClaims {
248 iat: now,
249 nbf: now, jti,
251 uris: vec![uri],
252 req_hash,
253 };
254
255 let header = Header::new(Algorithm::ES256);
256
257 let der_bytes = base64::engine::general_purpose::STANDARD
259 .decode(wallet_secret)
260 .map_err(|e| CdpError::Auth(format!("Failed to decode wallet secret: {}", e)))?;
261
262 let encoding_key = EncodingKey::from_ec_der(&der_bytes);
263
264 encode(&header, &claims, &encoding_key)
265 .map_err(|e| CdpError::Auth(format!("Failed to encode wallet JWT: {}", e)))
266 }
267
268 fn requires_wallet_auth(&self, method: &str, path: &str) -> bool {
269 (path.contains("/accounts") || path.contains("/spend-permissions"))
270 && (method == "POST" || method == "DELETE" || method == "PUT")
271 }
272
273 fn get_correlation_data(&self) -> String {
274 let mut data = HashMap::new();
275
276 data.insert("sdk_version".to_string(), VERSION.to_string());
277 data.insert("sdk_language".to_string(), "rust".to_string());
278 data.insert("source".to_string(), self.source.clone());
279
280 if let Some(ref source_version) = self.source_version {
281 data.insert("source_version".to_string(), source_version.clone());
282 }
283
284 data.into_iter()
285 .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v)))
286 .collect::<Vec<_>>()
287 .join(",")
288 }
289}
290
291#[async_trait::async_trait]
292impl Middleware for WalletAuth {
293 async fn handle(
294 &self,
295 mut req: Request,
296 extensions: &mut http::Extensions,
297 next: Next<'_>,
298 ) -> reqwest_middleware::Result<Response> {
299 let method = req.method().as_str().to_uppercase();
300 let url = req.url().clone();
301 let host = url.host_str().unwrap_or("api.cdp.coinbase.com");
302 let path = url.path();
303
304 let body = if let Some(body) = req.body() {
306 body.as_bytes().unwrap_or_default().to_vec()
307 } else {
308 Vec::new()
309 };
310
311 let expires_in = self.expires_in;
312
313 let jwt = self
315 .generate_jwt(&method, host, path, expires_in)
316 .map_err(reqwest_middleware::Error::middleware)?;
317
318 req.headers_mut()
320 .insert("Authorization", format!("Bearer {}", jwt).parse().unwrap());
321
322 req.headers_mut()
324 .insert("Content-Type", "application/json".parse().unwrap());
325
326 if self.requires_wallet_auth(&method, path)
328 && (!req.headers().contains_key("X-Wallet-Auth")
329 || req
330 .headers()
331 .get("X-Wallet-Auth")
332 .is_none_or(|v| v.is_empty()))
333 {
334 let wallet_jwt = self
335 .generate_wallet_jwt(&method, host, path, &body)
336 .map_err(reqwest_middleware::Error::middleware)?;
337
338 req.headers_mut()
339 .insert("X-Wallet-Auth", wallet_jwt.parse().unwrap());
340 }
341
342 req.headers_mut().insert(
344 "Correlation-Context",
345 self.get_correlation_data().parse().unwrap(),
346 );
347
348 if self.debug {
349 println!("Request: {} {}", method, url);
350 println!("Headers: {:?}", req.headers());
351 }
352
353 let response = next.run(req, extensions).await;
354
355 if self.debug {
356 if let Ok(ref resp) = response {
357 println!(
358 "Response: {} {}",
359 resp.status(),
360 resp.status().canonical_reason().unwrap_or("")
361 );
362 }
363 }
364
365 response
366 }
367}
368
369fn sort_keys(value: Value) -> Value {
370 match value {
371 Value::Object(map) => {
372 let mut sorted_map = serde_json::Map::new();
373 let mut keys: Vec<_> = map.keys().collect();
374 keys.sort();
375 for key in keys {
376 if let Some(val) = map.get(key) {
377 sorted_map.insert(key.clone(), sort_keys(val.clone()));
378 }
379 }
380 Value::Object(sorted_map)
381 }
382 Value::Array(arr) => Value::Array(arr.into_iter().map(sort_keys).collect()),
383 _ => value,
384 }
385}
386
387fn is_ed25519_key(key: &str) -> bool {
388 if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(key) {
390 decoded.len() == 64
391 } else {
392 false
393 }
394}
395
396fn is_ec_pem_key(key: &str) -> bool {
397 key.contains("-----BEGIN")
399 && key.contains("-----END")
400 && (key.contains("EC PRIVATE KEY") || key.contains("PRIVATE KEY"))
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_wallet_auth_builder_with_all_fields() {
409 let auth = WalletAuth::builder()
410 .api_key_id("test_key_id".to_string())
411 .api_key_secret("test_key_secret".to_string())
412 .wallet_secret("test_wallet_secret".to_string())
413 .debug(true)
414 .source("test_source".to_string())
415 .source_version("1.0.0".to_string())
416 .expires_in(300)
417 .build()
418 .unwrap();
419
420 assert_eq!(auth.api_key_id, "test_key_id");
421 assert_eq!(auth.api_key_secret, "test_key_secret");
422 assert_eq!(auth.wallet_secret, Some("test_wallet_secret".to_string()));
423 assert!(auth.debug);
424 assert_eq!(auth.source, "test_source");
425 assert_eq!(auth.source_version, Some("1.0.0".to_string()));
426 assert_eq!(auth.expires_in, 300);
427 }
428
429 #[test]
430 fn test_wallet_auth_builder_with_required_fields_only() {
431 let auth = WalletAuth::builder()
432 .api_key_id("test_key_id".to_string())
433 .api_key_secret("test_key_secret".to_string())
434 .build()
435 .unwrap();
436
437 assert_eq!(auth.api_key_id, "test_key_id");
438 assert_eq!(auth.api_key_secret, "test_key_secret");
439 assert_eq!(auth.wallet_secret, None);
440 assert!(!auth.debug);
441 assert_eq!(auth.source, "sdk-auth");
442 assert_eq!(auth.source_version, None);
443 assert_eq!(auth.expires_in, 120);
444 }
445
446 #[test]
447 fn test_wallet_auth_builder_with_optional_fields() {
448 let auth = WalletAuth::builder()
449 .api_key_id("test_key_id".to_string())
450 .api_key_secret("test_key_secret".to_string())
451 .debug(true)
452 .expires_in(600)
453 .build()
454 .unwrap();
455
456 assert_eq!(auth.api_key_id, "test_key_id");
457 assert_eq!(auth.api_key_secret, "test_key_secret");
458 assert!(auth.debug);
459 assert_eq!(auth.expires_in, 600);
460 assert_eq!(auth.source, "sdk-auth"); }
462
463 #[test]
464 fn test_wallet_auth_builder_missing_api_key_id() {
465 let result = WalletAuth::builder()
466 .api_key_secret("test_key_secret".to_string())
467 .build();
468
469 assert!(result.is_err());
470 if let Err(CdpError::Config(msg)) = result {
471 assert!(msg.contains("Missing required CDP API Key ID configuration"));
472 } else {
473 panic!("Expected Config error for missing api_key_id");
474 }
475 }
476
477 #[test]
478 fn test_wallet_auth_builder_missing_api_key_secret() {
479 let result = WalletAuth::builder()
480 .api_key_id("test_key_id".to_string())
481 .build();
482
483 assert!(result.is_err());
484 if let Err(CdpError::Config(msg)) = result {
485 assert!(msg.contains("Missing required CDP API Key Secret configuration"));
486 } else {
487 panic!("Expected Config error for missing api_key_secret");
488 }
489 }
490
491 #[test]
492 fn test_wallet_auth_builder_custom_source() {
493 let auth = WalletAuth::builder()
494 .api_key_id("test_key_id".to_string())
495 .api_key_secret("test_key_secret".to_string())
496 .source("my-custom-app".to_string())
497 .source_version("2.1.0".to_string())
498 .build()
499 .unwrap();
500
501 assert_eq!(auth.source, "my-custom-app");
502 assert_eq!(auth.source_version, Some("2.1.0".to_string()));
503 }
504
505 #[test]
506 fn test_requires_wallet_auth() {
507 let auth = WalletAuth::builder()
508 .api_key_id("test_key_id".to_string())
509 .api_key_secret("test_key_secret".to_string())
510 .build()
511 .unwrap();
512
513 assert!(auth.requires_wallet_auth("POST", "/v2/evm/accounts"));
515
516 assert!(auth.requires_wallet_auth("PUT", "/v2/evm/accounts/0x123"));
518
519 assert!(auth.requires_wallet_auth("DELETE", "/v2/evm/accounts/0x123"));
521
522 assert!(auth.requires_wallet_auth("POST", "/v2/spend-permissions"));
524
525 assert!(!auth.requires_wallet_auth("GET", "/v2/evm/accounts"));
527
528 assert!(!auth.requires_wallet_auth("POST", "/v2/other/endpoint"));
530 }
531
532 #[test]
533 fn test_is_ed25519_key() {
534 let valid_ed25519 = base64::engine::general_purpose::STANDARD.encode([0u8; 64]);
536 assert!(is_ed25519_key(&valid_ed25519));
537
538 let invalid_key = base64::engine::general_purpose::STANDARD.encode([0u8; 32]);
540 assert!(!is_ed25519_key(&invalid_key));
541
542 assert!(!is_ed25519_key("not-base64"));
544 }
545
546 #[test]
547 fn test_is_ec_pem_key() {
548 let pem_key = "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----";
549 assert!(is_ec_pem_key(pem_key));
550
551 let generic_pem_key = "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----";
552 assert!(is_ec_pem_key(generic_pem_key));
553
554 let not_pem_key = "just-a-string";
555 assert!(!is_ec_pem_key(not_pem_key));
556 }
557}