use std::collections::BTreeMap;
use std::time::{SystemTime, UNIX_EPOCH};
use base64::{Engine as _, engine::general_purpose};
use rand::{Rng, distr::Alphanumeric, rng};
use rsa::{
RsaPrivateKey,
pkcs1v15::SigningKey,
pkcs8::DecodePrivateKey,
signature::{SignatureEncoding, Signer},
};
use sha1::Sha1;
use crate::{Error, Result};
pub fn generate_oauth_header(
method: &str,
url: &str,
consumer_key: &str,
private_key_pem: &str,
token: &str,
_token_secret: &str, ) -> Result<String> {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs()
.to_string();
let nonce = generate_nonce();
let mut oauth_params = BTreeMap::new();
oauth_params.insert("oauth_consumer_key", consumer_key.to_string());
oauth_params.insert("oauth_nonce", nonce);
oauth_params.insert("oauth_signature_method", "RSA-SHA1".to_string());
oauth_params.insert("oauth_timestamp", timestamp);
oauth_params.insert("oauth_token", token.to_string());
oauth_params.insert("oauth_version", "1.0".to_string());
let (base_url, query_params) = parse_url(url);
let signature_base = build_signature_base(method, &base_url, &oauth_params, &query_params);
let signature = sign_rsa_sha1(private_key_pem, &signature_base)?;
oauth_params.insert("oauth_signature", signature);
let auth_header = oauth_params
.iter()
.map(|(k, v)| format!("{}=\"{}\"", k, percent_encode(v)))
.collect::<Vec<_>>()
.join(", ");
Ok(format!("OAuth {}", auth_header))
}
fn generate_nonce() -> String {
rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect()
}
fn parse_url(url: &str) -> (String, BTreeMap<String, String>) {
let mut query_params = BTreeMap::new();
if let Some(question_mark) = url.find('?') {
let base_url = url[..question_mark].to_string();
let query_string = &url[question_mark + 1..];
for pair in query_string.split('&') {
if let Some(equals) = pair.find('=') {
let key = &pair[..equals];
let value = &pair[equals + 1..];
query_params.insert(key.to_string(), value.to_string());
}
}
(base_url, query_params)
} else {
(url.to_string(), query_params)
}
}
fn build_signature_base(
method: &str,
base_url: &str,
oauth_params: &BTreeMap<&str, String>,
query_params: &BTreeMap<String, String>,
) -> String {
let mut all_params = BTreeMap::new();
for (k, v) in oauth_params {
if *k != "oauth_signature" {
all_params.insert(k.to_string(), v.clone());
}
}
for (k, v) in query_params {
all_params.insert(k.clone(), v.clone());
}
let normalized_params = all_params
.iter()
.map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
.collect::<Vec<_>>()
.join("&");
format!(
"{}&{}&{}",
method.to_uppercase(),
percent_encode(base_url),
percent_encode(&normalized_params)
)
}
fn sign_rsa_sha1(private_key_pem: &str, data: &str) -> Result<String> {
let private_key =
RsaPrivateKey::from_pkcs8_pem(private_key_pem).map_err(|e| Error::OAuthError {
message: format!("Failed to parse RSA private key: {}", e),
})?;
let signing_key = SigningKey::<Sha1>::new_unprefixed(private_key);
let signature = signing_key
.try_sign(data.as_bytes())
.map_err(|e| Error::OAuthError {
message: format!("Failed to sign data: {}", e),
})?;
Ok(general_purpose::STANDARD.encode(signature.to_bytes()))
}
fn percent_encode(input: &str) -> String {
input
.bytes()
.map(|byte| match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
(byte as char).to_string()
}
_ => format!("%{:02X}", byte),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn percent_decode(input: &str) -> String {
let mut result = String::new();
let mut chars = input.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '%' {
let hex: String = chars.by_ref().take(2).collect();
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
}
} else {
result.push(ch);
}
}
result
}
#[test]
fn test_percent_encode() {
assert_eq!(percent_encode("hello world"), "hello%20world");
assert_eq!(percent_encode("a-b.c_d~e"), "a-b.c_d~e");
assert_eq!(percent_encode("special!@#"), "special%21%40%23");
}
#[test]
fn test_generate_nonce() {
let nonce1 = generate_nonce();
let nonce2 = generate_nonce();
assert_eq!(nonce1.len(), 32);
assert_eq!(nonce2.len(), 32);
assert_ne!(nonce1, nonce2); }
#[test]
fn test_parse_url() {
let (base, params) = parse_url("https://example.com/api/search?q=test&limit=10");
assert_eq!(base, "https://example.com/api/search");
assert_eq!(params.get("q"), Some(&"test".to_string()));
assert_eq!(params.get("limit"), Some(&"10".to_string()));
}
#[test]
fn test_parse_url_no_query() {
let (base, params) = parse_url("https://example.com/api/search");
assert_eq!(base, "https://example.com/api/search");
assert!(params.is_empty());
}
#[test]
fn test_build_signature_base() {
let method = "GET";
let base_url = "https://example.com/api/search";
let mut oauth_params = BTreeMap::new();
oauth_params.insert("oauth_consumer_key", "key".to_string());
oauth_params.insert("oauth_nonce", "nonce".to_string());
let mut query_params = BTreeMap::new();
query_params.insert("q".to_string(), "test".to_string());
let base = build_signature_base(method, base_url, &oauth_params, &query_params);
assert!(base.starts_with("GET&"));
assert!(base.contains("https%3A%2F%2Fexample.com%2Fapi%2Fsearch"));
assert!(base.contains("oauth_consumer_key"));
assert!(base.contains("q%3Dtest"));
}
#[test]
fn test_parse_url_with_multiple_params() {
let (base, params) = parse_url(
"https://jira.example.com/rest/api/2/search?jql=project=TEST&maxResults=50&startAt=0",
);
assert_eq!(base, "https://jira.example.com/rest/api/2/search");
assert_eq!(params.get("jql"), Some(&"project=TEST".to_string()));
assert_eq!(params.get("maxResults"), Some(&"50".to_string()));
assert_eq!(params.get("startAt"), Some(&"0".to_string()));
assert_eq!(params.len(), 3);
}
#[test]
fn test_percent_encode_unreserved_chars() {
assert_eq!(
percent_encode("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"),
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
);
}
#[test]
fn test_percent_encode_reserved_chars() {
assert_eq!(
percent_encode(":/?#[]@!$&'()*+,;="),
"%3A%2F%3F%23%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D"
);
}
#[test]
fn test_sign_rsa_sha1_invalid_key() {
let invalid_key = "-----BEGIN PRIVATE KEY-----\nINVALID\n-----END PRIVATE KEY-----";
let result = sign_rsa_sha1(invalid_key, "test data");
assert!(result.is_err());
}
#[test]
fn test_oauth_header_invalid_key_returns_error() {
let invalid_key = "-----BEGIN PRIVATE KEY-----\nNOT_A_VALID_KEY\n-----END PRIVATE KEY-----";
let result = generate_oauth_header(
"GET",
"https://jira.example.com/rest/api/2/myself",
"consumer-key",
invalid_key,
"access-token",
"access-token-secret",
);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("OAuth"));
}
}
#[test]
fn test_build_signature_base_excludes_oauth_signature() {
let method = "POST";
let base_url = "https://jira.example.com/rest/api/2/issue";
let mut oauth_params = BTreeMap::new();
oauth_params.insert("oauth_consumer_key", "key".to_string());
oauth_params.insert("oauth_signature", "should-be-excluded".to_string());
oauth_params.insert("oauth_nonce", "nonce".to_string());
let query_params = BTreeMap::new();
let base = build_signature_base(method, base_url, &oauth_params, &query_params);
assert!(!base.contains("should-be-excluded"));
assert!(!base.contains("oauth_signature"));
}
#[test]
fn test_build_signature_base_parameter_ordering() {
let method = "GET";
let base_url = "https://example.com/api";
let mut oauth_params = BTreeMap::new();
oauth_params.insert("oauth_version", "1.0".to_string());
oauth_params.insert("oauth_consumer_key", "key".to_string());
oauth_params.insert("oauth_nonce", "nonce".to_string());
let query_params = BTreeMap::new();
let base = build_signature_base(method, base_url, &oauth_params, &query_params);
let params_part = base.split('&').nth(2).unwrap();
let decoded = percent_decode(params_part);
let consumer_pos = decoded.find("oauth_consumer_key").unwrap();
let nonce_pos = decoded.find("oauth_nonce").unwrap();
let version_pos = decoded.find("oauth_version").unwrap();
assert!(consumer_pos < nonce_pos);
assert!(nonce_pos < version_pos);
}
}