use std::collections::HashMap;
use sha2::{Sha512, Digest};
use hmac::{Hmac, Mac};
use uuid::Uuid;
use crate::core::{
Credentials, ExchangeResult, ExchangeError,
};
type HmacSha512 = Hmac<Sha512>;
#[derive(Clone)]
pub struct UpbitAuth {
access_key: String,
secret_key: String,
}
impl UpbitAuth {
pub fn new(credentials: &Credentials) -> ExchangeResult<Self> {
Ok(Self {
access_key: credentials.api_key.clone(),
secret_key: credentials.api_secret.clone(),
})
}
pub fn create_jwt_token(&self, query_string: Option<&str>) -> ExchangeResult<String> {
let nonce = Uuid::new_v4().to_string();
let mut payload = serde_json::json!({
"access_key": self.access_key,
"nonce": nonce,
});
if let Some(qs) = query_string {
if !qs.is_empty() {
let query_hash = self.sha512_hex(qs.as_bytes());
payload["query_hash"] = serde_json::json!(query_hash);
payload["query_hash_alg"] = serde_json::json!("SHA512");
}
}
let header = serde_json::json!({
"alg": "HS512",
"typ": "JWT"
});
let header_json = serde_json::to_string(&header)
.map_err(|e| ExchangeError::Auth(format!("Failed to serialize header: {}", e)))?;
let payload_json = serde_json::to_string(&payload)
.map_err(|e| ExchangeError::Auth(format!("Failed to serialize payload: {}", e)))?;
let header_b64 = self.base64url_encode(header_json.as_bytes());
let payload_b64 = self.base64url_encode(payload_json.as_bytes());
let message = format!("{}.{}", header_b64, payload_b64);
let signature = self.hmac_sha512(self.secret_key.as_bytes(), message.as_bytes())?;
let signature_b64 = self.base64url_encode(&signature);
Ok(format!("{}.{}.{}", header_b64, payload_b64, signature_b64))
}
pub fn sign_request(
&self,
_method: &str,
_endpoint: &str,
query_string: Option<&str>,
) -> ExchangeResult<HashMap<String, String>> {
let token = self.create_jwt_token(query_string)?;
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), format!("Bearer {}", token));
headers.insert("Content-Type".to_string(), "application/json; charset=utf-8".to_string());
Ok(headers)
}
fn sha512_hex(&self, data: &[u8]) -> String {
let mut hasher = Sha512::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
fn hmac_sha512(&self, key: &[u8], message: &[u8]) -> ExchangeResult<Vec<u8>> {
let mut mac = HmacSha512::new_from_slice(key)
.map_err(|e| ExchangeError::Auth(format!("HMAC init failed: {}", e)))?;
mac.update(message);
Ok(mac.finalize().into_bytes().to_vec())
}
fn base64url_encode(&self, data: &[u8]) -> String {
use base64::{Engine as _, engine::general_purpose};
general_purpose::URL_SAFE_NO_PAD.encode(data)
}
pub fn access_key(&self) -> &str {
&self.access_key
}
}
pub fn json_to_query_string(json_body: &str) -> ExchangeResult<String> {
if json_body.is_empty() {
return Ok(String::new());
}
let value: serde_json::Value = serde_json::from_str(json_body)
.map_err(|e| ExchangeError::Auth(format!("Failed to parse JSON: {}", e)))?;
let obj = value.as_object()
.ok_or_else(|| ExchangeError::Auth("JSON body is not an object".to_string()))?;
let mut pairs: Vec<(String, String)> = obj.iter()
.map(|(k, v): (&String, &serde_json::Value)| {
let value_str = match v {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => v.to_string(),
};
(k.clone(), value_str)
})
.collect();
pairs.sort_by(|a, b| a.0.cmp(&b.0));
let query_string = url::form_urlencoded::Serializer::new(String::new())
.extend_pairs(pairs)
.finish();
Ok(query_string)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_jwt_token() {
let credentials = Credentials::new("test_access_key", "test_secret_key");
let auth = UpbitAuth::new(&credentials).unwrap();
let token = auth.create_jwt_token(None).unwrap();
assert!(token.contains('.'));
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
let token = auth.create_jwt_token(Some("market=SGD-BTC&state=wait")).unwrap();
assert!(token.contains('.'));
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
}
#[test]
fn test_sign_request() {
let credentials = Credentials::new("test_access_key", "test_secret_key");
let auth = UpbitAuth::new(&credentials).unwrap();
let headers = auth.sign_request("GET", "/v1/balances", None).unwrap();
assert!(headers.contains_key("Authorization"));
assert!(headers.get("Authorization").unwrap().starts_with("Bearer "));
assert_eq!(headers.get("Content-Type"), Some(&"application/json; charset=utf-8".to_string()));
}
#[test]
fn test_json_to_query_string() {
let json = r#"{"market":"SGD-BTC","side":"bid","volume":"0.1"}"#;
let qs = json_to_query_string(json).unwrap();
assert!(qs.contains("market=SGD-BTC"));
assert!(qs.contains("side=bid"));
assert!(qs.contains("volume=0.1"));
let market_pos = qs.find("market").unwrap();
let side_pos = qs.find("side").unwrap();
let volume_pos = qs.find("volume").unwrap();
assert!(market_pos < side_pos);
assert!(side_pos < volume_pos);
}
#[test]
fn test_sha512_hex() {
let credentials = Credentials::new("test", "test");
let auth = UpbitAuth::new(&credentials).unwrap();
let hash = auth.sha512_hex(b"test");
assert_eq!(hash.len(), 128); }
}