alpaca_base/
auth.rs

1use crate::error::{AlpacaError, Result};
2use base64::{Engine as _, engine::general_purpose};
3use chrono::{DateTime, Utc};
4use hmac::{Hmac, Mac};
5use sha2::Sha256;
6use std::collections::HashMap;
7
8type HmacSha256 = Hmac<Sha256>;
9
10/// Authentication credentials for Alpaca API.
11#[derive(Debug, Clone)]
12pub struct Credentials {
13    /// The API key for authentication.
14    pub api_key: String,
15    /// The secret key for authentication.
16    pub secret_key: String,
17}
18
19impl Credentials {
20    /// Create new credentials
21    pub fn new(api_key: String, secret_key: String) -> Self {
22        Self {
23            api_key,
24            secret_key,
25        }
26    }
27
28    /// Create credentials from environment variables.
29    ///
30    /// This method automatically attempts to load a `.env` file if present.
31    ///
32    /// Looks for `ALPACA_API_KEY` and either `ALPACA_API_SECRET` or `ALPACA_SECRET_KEY`.
33    pub fn from_env() -> Result<Self> {
34        dotenv::dotenv().ok();
35        let api_key = std::env::var("ALPACA_API_KEY")
36            .map_err(|_| AlpacaError::Config("ALPACA_API_KEY not found".to_string()))?;
37        let secret_key = std::env::var("ALPACA_API_SECRET")
38            .or_else(|_| std::env::var("ALPACA_SECRET_KEY"))
39            .map_err(|_| {
40                AlpacaError::Config("ALPACA_API_SECRET or ALPACA_SECRET_KEY not found".to_string())
41            })?;
42
43        Ok(Self::new(api_key, secret_key))
44    }
45
46    /// Generate authorization header for HTTP requests
47    pub fn auth_header(&self) -> String {
48        format!(
49            "Basic {}",
50            general_purpose::STANDARD.encode(format!("{}:{}", self.api_key, self.secret_key))
51        )
52    }
53
54    /// Generate HMAC signature for request authentication
55    pub fn sign_request(
56        &self,
57        method: &str,
58        path: &str,
59        body: &str,
60        timestamp: DateTime<Utc>,
61    ) -> Result<String> {
62        let timestamp_str = timestamp.timestamp().to_string();
63        let message = format!("{}{}{}{}", timestamp_str, method, path, body);
64
65        let mut mac = HmacSha256::new_from_slice(self.secret_key.as_bytes())
66            .map_err(|e| AlpacaError::Auth(format!("Invalid secret key: {}", e)))?;
67
68        mac.update(message.as_bytes());
69        let result = mac.finalize();
70
71        Ok(general_purpose::STANDARD.encode(result.into_bytes()))
72    }
73
74    /// Generate headers for authenticated requests
75    pub fn auth_headers(
76        &self,
77        method: &str,
78        path: &str,
79        body: &str,
80    ) -> Result<HashMap<String, String>> {
81        let timestamp = Utc::now();
82        let signature = self.sign_request(method, path, body, timestamp)?;
83
84        let mut headers = HashMap::new();
85        headers.insert("APCA-API-KEY-ID".to_string(), self.api_key.clone());
86        headers.insert("APCA-API-SECRET-KEY".to_string(), self.secret_key.clone());
87        headers.insert(
88            "APCA-API-TIMESTAMP".to_string(),
89            timestamp.timestamp().to_string(),
90        );
91        headers.insert("APCA-API-SIGNATURE".to_string(), signature);
92
93        Ok(headers)
94    }
95}
96
97/// OAuth token for API access.
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub struct OAuthToken {
100    /// The access token string.
101    pub access_token: String,
102    /// Refresh token for obtaining new access tokens.
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub refresh_token: Option<String>,
105    /// The token type (e.g., "Bearer").
106    pub token_type: String,
107    /// Token expiration time in seconds.
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub expires_in: Option<u64>,
110    /// OAuth scope granted.
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub scope: Option<String>,
113}
114
115impl OAuthToken {
116    /// Create authorization header from OAuth token.
117    #[must_use]
118    pub fn auth_header(&self) -> String {
119        format!("{} {}", self.token_type, self.access_token)
120    }
121
122    /// Check if token has a refresh token.
123    #[must_use]
124    pub fn has_refresh_token(&self) -> bool {
125        self.refresh_token.is_some()
126    }
127}