Skip to main content

haystack_client/
auth.rs

1//! Client-side SCRAM SHA-256 authentication handshake.
2//!
3//! Performs the three-phase Haystack auth handshake (HELLO, SCRAM, BEARER)
4//! against a Haystack server, returning the auth token on success.
5
6use base64::Engine;
7use base64::engine::general_purpose::STANDARD as BASE64;
8use reqwest::Client;
9
10use crate::error::ClientError;
11use haystack_core::auth;
12
13/// Perform SCRAM SHA-256 authentication handshake against a Haystack server.
14///
15/// Executes the three-phase handshake:
16/// 1. HELLO: sends username, receives SCRAM challenge
17/// 2. SCRAM: sends client proof, receives auth token and server signature
18/// 3. Returns the auth token for subsequent Bearer authentication
19///
20/// # Arguments
21/// * `client` - The reqwest HTTP client to use
22/// * `base_url` - The server API root (e.g. `http://localhost:8080/api`)
23/// * `username` - The username to authenticate as
24/// * `password` - The user's plaintext password
25///
26/// # Returns
27/// The auth token string on success.
28pub async fn authenticate(
29    client: &Client,
30    base_url: &str,
31    username: &str,
32    password: &str,
33) -> Result<String, ClientError> {
34    let base_url = base_url.trim_end_matches('/');
35    let about_url = format!("{}/about", base_url);
36
37    // -----------------------------------------------------------------------
38    // Phase 1: HELLO
39    // -----------------------------------------------------------------------
40    // Send GET /api/about with Authorization: HELLO username=<base64(username)>, data=<client_first>
41    let username_b64 = BASE64.encode(username.as_bytes());
42    let (client_nonce, client_first_b64) = auth::client_first_message(username);
43    let hello_header = format!("HELLO username={}, data={}", username_b64, client_first_b64);
44
45    let hello_resp = client
46        .get(&about_url)
47        .header("Authorization", &hello_header)
48        .send()
49        .await
50        .map_err(|e| ClientError::Transport(e.to_string()))?;
51
52    if hello_resp.status() != reqwest::StatusCode::UNAUTHORIZED {
53        return Err(ClientError::AuthFailed(format!(
54            "expected 401 from HELLO, got {}",
55            hello_resp.status()
56        )));
57    }
58
59    // Parse WWW-Authenticate header
60    // Expected: SCRAM handshakeToken=..., hash=SHA-256, data=<server_first_b64>
61    let www_auth = hello_resp
62        .headers()
63        .get("www-authenticate")
64        .and_then(|v| v.to_str().ok())
65        .ok_or_else(|| {
66            ClientError::AuthFailed("missing WWW-Authenticate header in 401 response".to_string())
67        })?
68        .to_string();
69
70    let (handshake_token, server_first_b64) = parse_www_authenticate(&www_auth)?;
71
72    // -----------------------------------------------------------------------
73    // Phase 2: SCRAM
74    // -----------------------------------------------------------------------
75    // Compute client-final-message from server-first-message
76    let (client_final_b64, expected_server_sig) =
77        auth::client_final_message(password, &client_nonce, &server_first_b64, username)
78            .map_err(|e| ClientError::AuthFailed(e.to_string()))?;
79
80    // Send GET /api/about with Authorization: SCRAM handshakeToken=..., data=<client_final>
81    let scram_header = format!(
82        "SCRAM handshakeToken={}, data={}",
83        handshake_token, client_final_b64
84    );
85
86    let scram_resp = client
87        .get(&about_url)
88        .header("Authorization", &scram_header)
89        .send()
90        .await
91        .map_err(|e| ClientError::Transport(e.to_string()))?;
92
93    if !scram_resp.status().is_success() {
94        return Err(ClientError::AuthFailed(format!(
95            "SCRAM phase failed with status {}",
96            scram_resp.status()
97        )));
98    }
99
100    // -----------------------------------------------------------------------
101    // Phase 3: Extract auth token
102    // -----------------------------------------------------------------------
103    // Parse Authentication-Info header: authToken=..., data=<server_final_b64>
104    let auth_info = scram_resp
105        .headers()
106        .get("authentication-info")
107        .and_then(|v| v.to_str().ok())
108        .ok_or_else(|| {
109            ClientError::AuthFailed(
110                "missing Authentication-Info header in SCRAM response".to_string(),
111            )
112        })?
113        .to_string();
114
115    let (auth_token, server_final_b64) = parse_auth_info(&auth_info)?;
116
117    // Verify the server signature from the server-final-message
118    let server_final_bytes = BASE64.decode(&server_final_b64).map_err(|e| {
119        ClientError::AuthFailed(format!("invalid base64 in server-final data: {}", e))
120    })?;
121    let server_final_msg = String::from_utf8(server_final_bytes).map_err(|e| {
122        ClientError::AuthFailed(format!("invalid UTF-8 in server-final data: {}", e))
123    })?;
124    let server_sig_b64 = server_final_msg.strip_prefix("v=").ok_or_else(|| {
125        ClientError::AuthFailed("server-final message missing v= prefix".to_string())
126    })?;
127    let received_server_sig = BASE64.decode(server_sig_b64).map_err(|e| {
128        ClientError::AuthFailed(format!("invalid base64 in server signature: {}", e))
129    })?;
130
131    if received_server_sig != expected_server_sig {
132        return Err(ClientError::AuthFailed(
133            "server signature verification failed".to_string(),
134        ));
135    }
136
137    Ok(auth_token)
138}
139
140/// Parse the WWW-Authenticate header from a SCRAM challenge response.
141///
142/// Expected format: `SCRAM handshakeToken=<token>, hash=SHA-256, data=<b64>`
143///
144/// Returns `(handshake_token, server_first_data_b64)`.
145fn parse_www_authenticate(header: &str) -> Result<(String, String), ClientError> {
146    let rest = header
147        .trim()
148        .strip_prefix("SCRAM ")
149        .ok_or_else(|| ClientError::AuthFailed("WWW-Authenticate not SCRAM scheme".to_string()))?;
150
151    let mut handshake_token = None;
152    let mut data = None;
153
154    for part in rest.split(',') {
155        let part = part.trim();
156        if let Some(val) = part.strip_prefix("handshakeToken=") {
157            handshake_token = Some(val.trim().to_string());
158        } else if let Some(val) = part.strip_prefix("data=") {
159            data = Some(val.trim().to_string());
160        }
161        // hash= is informational; we always use SHA-256
162    }
163
164    let handshake_token = handshake_token.ok_or_else(|| {
165        ClientError::AuthFailed("missing handshakeToken in WWW-Authenticate".to_string())
166    })?;
167    let data = data
168        .ok_or_else(|| ClientError::AuthFailed("missing data in WWW-Authenticate".to_string()))?;
169
170    Ok((handshake_token, data))
171}
172
173/// Parse the Authentication-Info header to extract the auth token and server-final data.
174///
175/// Expected format: `authToken=<token>, data=<b64>`
176///
177/// Returns `(auth_token, server_final_data_b64)`.
178fn parse_auth_info(header: &str) -> Result<(String, String), ClientError> {
179    let mut auth_token = None;
180    let mut data = None;
181
182    for part in header.split(',') {
183        let part = part.trim();
184        if let Some(val) = part.strip_prefix("authToken=") {
185            auth_token = Some(val.trim().to_string());
186        } else if let Some(val) = part.strip_prefix("data=") {
187            data = Some(val.trim().to_string());
188        }
189    }
190
191    let auth_token = auth_token.ok_or_else(|| {
192        ClientError::AuthFailed("missing authToken in Authentication-Info header".to_string())
193    })?;
194    let data = data.ok_or_else(|| {
195        ClientError::AuthFailed("missing data in Authentication-Info header".to_string())
196    })?;
197
198    Ok((auth_token, data))
199}