Skip to main content

haystack_client/
auth.rs

1//! Client-side SCRAM SHA-256 authentication handshake.
2//!
3//! Performs the three-phase Haystack auth handshake (HELLO, SCRAM, BEARER)
4//! against a Haystack server, returning the auth token on success.
5
6use base64::Engine;
7use base64::engine::general_purpose::STANDARD as BASE64;
8use reqwest::Client;
9
10use crate::error::ClientError;
11use haystack_core::auth;
12
13/// Perform SCRAM SHA-256 authentication handshake against a Haystack server.
14///
15/// Executes the three-phase handshake:
16/// 1. HELLO: sends username, receives SCRAM challenge
17/// 2. SCRAM: sends client proof, receives auth token and server signature
18/// 3. Returns the auth token for subsequent Bearer authentication
19///
20/// # Arguments
21/// * `client` - The reqwest HTTP client to use
22/// * `base_url` - The server API root (e.g. `http://localhost:8080/api`)
23/// * `username` - The username to authenticate as
24/// * `password` - The user's plaintext password
25///
26/// # Returns
27/// The auth token string on success.
28pub async fn authenticate(
29    client: &Client,
30    base_url: &str,
31    username: &str,
32    password: &str,
33) -> Result<String, ClientError> {
34    let base_url = base_url.trim_end_matches('/');
35    let about_url = format!("{}/about", base_url);
36
37    // -----------------------------------------------------------------------
38    // Phase 1: HELLO
39    // -----------------------------------------------------------------------
40    // Send GET /api/about with Authorization: HELLO username=<base64(username)>
41    let username_b64 = BASE64.encode(username.as_bytes());
42    let hello_header = format!("HELLO username={}", username_b64);
43
44    let (client_nonce, _client_first_b64) = auth::client_first_message(username);
45
46    let hello_resp = client
47        .get(&about_url)
48        .header("Authorization", &hello_header)
49        .send()
50        .await
51        .map_err(|e| ClientError::Transport(e.to_string()))?;
52
53    if hello_resp.status() != reqwest::StatusCode::UNAUTHORIZED {
54        return Err(ClientError::AuthFailed(format!(
55            "expected 401 from HELLO, got {}",
56            hello_resp.status()
57        )));
58    }
59
60    // Parse WWW-Authenticate header
61    // Expected: SCRAM handshakeToken=..., hash=SHA-256, data=<server_first_b64>
62    let www_auth = hello_resp
63        .headers()
64        .get("www-authenticate")
65        .and_then(|v| v.to_str().ok())
66        .ok_or_else(|| {
67            ClientError::AuthFailed("missing WWW-Authenticate header in 401 response".to_string())
68        })?
69        .to_string();
70
71    let (handshake_token, server_first_b64) = parse_www_authenticate(&www_auth)?;
72
73    // -----------------------------------------------------------------------
74    // Phase 2: SCRAM
75    // -----------------------------------------------------------------------
76    // Compute client-final-message from server-first-message
77    let (client_final_b64, expected_server_sig) =
78        auth::client_final_message(password, &client_nonce, &server_first_b64, username)
79            .map_err(|e| ClientError::AuthFailed(e.to_string()))?;
80
81    // Send GET /api/about with Authorization: SCRAM handshakeToken=..., data=<client_final>
82    let scram_header = format!(
83        "SCRAM handshakeToken={}, data={}",
84        handshake_token, client_final_b64
85    );
86
87    let scram_resp = client
88        .get(&about_url)
89        .header("Authorization", &scram_header)
90        .send()
91        .await
92        .map_err(|e| ClientError::Transport(e.to_string()))?;
93
94    if !scram_resp.status().is_success() {
95        return Err(ClientError::AuthFailed(format!(
96            "SCRAM phase failed with status {}",
97            scram_resp.status()
98        )));
99    }
100
101    // -----------------------------------------------------------------------
102    // Phase 3: Extract auth token
103    // -----------------------------------------------------------------------
104    // Parse Authentication-Info header: authToken=..., data=<server_final_b64>
105    let auth_info = scram_resp
106        .headers()
107        .get("authentication-info")
108        .and_then(|v| v.to_str().ok())
109        .ok_or_else(|| {
110            ClientError::AuthFailed(
111                "missing Authentication-Info header in SCRAM response".to_string(),
112            )
113        })?
114        .to_string();
115
116    let (auth_token, server_final_b64) = parse_auth_info(&auth_info)?;
117
118    // Verify the server signature from the server-final-message
119    let server_final_bytes = BASE64.decode(&server_final_b64).map_err(|e| {
120        ClientError::AuthFailed(format!("invalid base64 in server-final data: {}", e))
121    })?;
122    let server_final_msg = String::from_utf8(server_final_bytes).map_err(|e| {
123        ClientError::AuthFailed(format!("invalid UTF-8 in server-final data: {}", e))
124    })?;
125    let server_sig_b64 = server_final_msg.strip_prefix("v=").ok_or_else(|| {
126        ClientError::AuthFailed("server-final message missing v= prefix".to_string())
127    })?;
128    let received_server_sig = BASE64.decode(server_sig_b64).map_err(|e| {
129        ClientError::AuthFailed(format!("invalid base64 in server signature: {}", e))
130    })?;
131
132    if received_server_sig != expected_server_sig {
133        return Err(ClientError::AuthFailed(
134            "server signature verification failed".to_string(),
135        ));
136    }
137
138    Ok(auth_token)
139}
140
141/// Parse the WWW-Authenticate header from a SCRAM challenge response.
142///
143/// Expected format: `SCRAM handshakeToken=<token>, hash=SHA-256, data=<b64>`
144///
145/// Returns `(handshake_token, server_first_data_b64)`.
146fn parse_www_authenticate(header: &str) -> Result<(String, String), ClientError> {
147    let rest = header
148        .trim()
149        .strip_prefix("SCRAM ")
150        .ok_or_else(|| ClientError::AuthFailed("WWW-Authenticate not SCRAM scheme".to_string()))?;
151
152    let mut handshake_token = None;
153    let mut data = None;
154
155    for part in rest.split(',') {
156        let part = part.trim();
157        if let Some(val) = part.strip_prefix("handshakeToken=") {
158            handshake_token = Some(val.trim().to_string());
159        } else if let Some(val) = part.strip_prefix("data=") {
160            data = Some(val.trim().to_string());
161        }
162        // hash= is informational; we always use SHA-256
163    }
164
165    let handshake_token = handshake_token.ok_or_else(|| {
166        ClientError::AuthFailed("missing handshakeToken in WWW-Authenticate".to_string())
167    })?;
168    let data = data
169        .ok_or_else(|| ClientError::AuthFailed("missing data in WWW-Authenticate".to_string()))?;
170
171    Ok((handshake_token, data))
172}
173
174/// Parse the Authentication-Info header to extract the auth token and server-final data.
175///
176/// Expected format: `authToken=<token>, data=<b64>`
177///
178/// Returns `(auth_token, server_final_data_b64)`.
179fn parse_auth_info(header: &str) -> Result<(String, String), ClientError> {
180    let mut auth_token = None;
181    let mut data = None;
182
183    for part in header.split(',') {
184        let part = part.trim();
185        if let Some(val) = part.strip_prefix("authToken=") {
186            auth_token = Some(val.trim().to_string());
187        } else if let Some(val) = part.strip_prefix("data=") {
188            data = Some(val.trim().to_string());
189        }
190    }
191
192    let auth_token = auth_token.ok_or_else(|| {
193        ClientError::AuthFailed("missing authToken in Authentication-Info header".to_string())
194    })?;
195    let data = data.ok_or_else(|| {
196        ClientError::AuthFailed("missing data in Authentication-Info header".to_string())
197    })?;
198
199    Ok((auth_token, data))
200}