Skip to main content

lake_client/
auth.rs

1// Copyright 2025 TiDB Cloud
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
19use reqwest::RequestBuilder;
20use serde::Serialize;
21
22use crate::error::{Error, Result};
23
24pub trait Auth: Sync + Send {
25    fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder>;
26    fn can_reload(&self) -> bool {
27        false
28    }
29    fn username(&self) -> String;
30}
31
32#[derive(Clone)]
33pub struct BasicAuth {
34    username: String,
35    password: SensitiveString,
36}
37
38impl BasicAuth {
39    pub fn new(username: impl ToString, password: impl ToString) -> Self {
40        Self {
41            username: username.to_string(),
42            password: SensitiveString(password.to_string()),
43        }
44    }
45}
46
47impl Auth for BasicAuth {
48    fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
49        Ok(builder.basic_auth(&self.username, Some(self.password.inner())))
50    }
51
52    fn username(&self) -> String {
53        self.username.clone()
54    }
55}
56
57#[derive(Clone)]
58pub struct AccessTokenAuth {
59    token: SensitiveString,
60}
61
62impl AccessTokenAuth {
63    pub fn new(token: impl ToString) -> Self {
64        Self {
65            token: SensitiveString::from(token.to_string()),
66        }
67    }
68}
69
70impl Auth for AccessTokenAuth {
71    fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
72        Ok(builder.bearer_auth(self.token.inner()))
73    }
74
75    fn username(&self) -> String {
76        "token".to_string()
77    }
78}
79
80#[derive(Clone)]
81pub struct AccessTokenFileAuth {
82    token_file: String,
83}
84
85impl AccessTokenFileAuth {
86    pub fn new(token_file: impl ToString) -> Self {
87        let token_file = token_file.to_string();
88        Self { token_file }
89    }
90}
91
92impl Auth for AccessTokenFileAuth {
93    fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
94        let token = std::fs::read_to_string(&self.token_file).map_err(|e| {
95            Error::IO(format!(
96                "cannot read access token from file {}: {}",
97                self.token_file, e
98            ))
99        })?;
100        Ok(builder.bearer_auth(token.trim()))
101    }
102
103    fn can_reload(&self) -> bool {
104        true
105    }
106
107    fn username(&self) -> String {
108        "token".to_string()
109    }
110}
111
112const HEADER_AUTH_METHOD: &str = "X-DATABEND-AUTH-METHOD";
113const KEYPAIR_TOKEN_TTL_SECS: u64 = 60;
114
115#[derive(Serialize)]
116struct KeyPairClaims {
117    sub: String,
118    iat: u64,
119    exp: u64,
120}
121
122#[derive(Clone)]
123pub struct KeyPairAuth {
124    username: String,
125    encoding_key: Arc<EncodingKey>,
126    algorithm: Algorithm,
127}
128
129impl KeyPairAuth {
130    pub fn new(
131        username: impl ToString,
132        private_key_file: &str,
133        passphrase_file: Option<&str>,
134    ) -> Result<Self> {
135        let pem_data = std::fs::read(private_key_file).map_err(|e| {
136            Error::IO(format!(
137                "cannot read private key from file {}: {}",
138                private_key_file, e
139            ))
140        })?;
141
142        let passphrase = match passphrase_file {
143            Some(path) => {
144                let p = std::fs::read_to_string(path).map_err(|e| {
145                    Error::IO(format!("cannot read passphrase from file {}: {}", path, e))
146                })?;
147                Some(Self::strip_line_ending(p))
148            }
149            None => None,
150        };
151
152        let (encoding_key, algorithm) = Self::parse_private_key(&pem_data, passphrase.as_deref())?;
153
154        Ok(Self {
155            username: username.to_string(),
156            encoding_key: Arc::new(encoding_key),
157            algorithm,
158        })
159    }
160
161    fn strip_line_ending(mut value: String) -> String {
162        if value.ends_with('\n') {
163            value.pop();
164            if value.ends_with('\r') {
165                value.pop();
166            }
167        }
168        value
169    }
170
171    fn parse_private_key(
172        pem_data: &[u8],
173        passphrase: Option<&str>,
174    ) -> Result<(EncodingKey, Algorithm)> {
175        let pem_str = std::str::from_utf8(pem_data)
176            .map_err(|e| Error::IO(format!("private key is not valid UTF-8: {e}")))?;
177
178        if let Some(passphrase) = passphrase {
179            let pem = Self::parse_pem_block(pem_data, "ENCRYPTED PRIVATE KEY").map_err(|_| {
180                Error::IO(
181                    "encrypted private keys with passphrase must use PKCS#8 PEM (BEGIN ENCRYPTED PRIVATE KEY)".to_string(),
182                )
183            })?;
184            // Encrypted PKCS#8 key — decrypt using pkcs8 crate to get DER.
185            Self::parse_encrypted_key(pem.contents(), passphrase)
186        } else {
187            // Unencrypted key — detect type and use jsonwebtoken's PEM methods
188            Self::parse_unencrypted_key(pem_data, pem_str)
189        }
190    }
191
192    fn parse_encrypted_key(
193        encrypted_der: &[u8],
194        passphrase: &str,
195    ) -> Result<(EncodingKey, Algorithm)> {
196        let doc = pkcs8::EncryptedPrivateKeyInfoRef::try_from(encrypted_der)
197            .and_then(|encrypted| encrypted.decrypt(passphrase.as_bytes()))
198            .map_err(|e| Error::IO(format!("failed to decrypt private key: {e}")))?;
199
200        let der_bytes = doc.as_bytes();
201
202        // Try each key type with DER
203        // from_*_der returns EncodingKey directly (infallible for the struct construction),
204        // but the underlying parsing may still fail at sign time.
205        // We try RSA first, then EC, then Ed25519 by attempting to parse the key info.
206        // Since from_*_der doesn't validate, we use the OID from the PKCS#8 structure.
207        let private_key_info = pkcs8::PrivateKeyInfoRef::try_from(der_bytes)
208            .map_err(|e| Error::IO(format!("failed to parse PKCS#8 DER: {e}")))?;
209
210        let algorithm_oid = private_key_info.algorithm.oid;
211
212        // RSA: 1.2.840.113549.1.1.1
213        const RSA_OID: pkcs8::ObjectIdentifier =
214            pkcs8::ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
215        // EC: 1.2.840.10045.2.1
216        const EC_OID: pkcs8::ObjectIdentifier =
217            pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
218        // Ed25519: 1.3.101.112
219        const ED25519_OID: pkcs8::ObjectIdentifier =
220            pkcs8::ObjectIdentifier::new_unwrap("1.3.101.112");
221
222        if algorithm_oid == RSA_OID {
223            // Ring's from_der expects PKCS#1 RSAPrivateKey DER, not full PKCS#8.
224            // Extract the inner private key bytes from PKCS#8 PrivateKeyInfo.
225            let rsa_der = private_key_info.private_key.as_bytes();
226            Ok((EncodingKey::from_rsa_der(rsa_der), Algorithm::RS256))
227        } else if algorithm_oid == EC_OID {
228            let (ec_der, algorithm) = Self::rebuild_ec_pkcs8_named_curve(private_key_info)?;
229            Ok((EncodingKey::from_ec_der(&ec_der), algorithm))
230        } else if algorithm_oid == ED25519_OID {
231            Ok((EncodingKey::from_ed_der(der_bytes), Algorithm::EdDSA))
232        } else {
233            Err(Error::IO(format!(
234                "unsupported key algorithm OID: {algorithm_oid}"
235            )))
236        }
237    }
238
239    /// Rebuild a named-curve PKCS#8 DER for supported EC keys.
240    ///
241    /// jsonwebtoken/ring supports ES256 and ES384. Reject explicit parameters or
242    /// unsupported curves instead of silently relabeling them with the wrong JWT alg.
243    fn rebuild_ec_pkcs8_named_curve(pki: pkcs8::PrivateKeyInfoRef) -> Result<(Vec<u8>, Algorithm)> {
244        use pkcs8::der::Encode;
245
246        let (curve_oid, algorithm) = Self::ec_curve_oid_to_algorithm(
247            pki.algorithm
248                .parameters
249                .and_then(|params| params.decode_as::<pkcs8::ObjectIdentifier>().ok()),
250        )?;
251
252        let alg_id = pkcs8::AlgorithmIdentifierRef {
253            oid: pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1"),
254            parameters: Some(pkcs8::der::asn1::AnyRef::from(&curve_oid)),
255        };
256
257        let new_pki = pkcs8::PrivateKeyInfo {
258            algorithm: alg_id,
259            private_key: pki.private_key,
260            public_key: pki.public_key,
261        };
262
263        let der = new_pki
264            .to_der()
265            .map_err(|e| Error::IO(format!("failed to re-encode EC PKCS#8: {e}")))?;
266        Ok((der, algorithm))
267    }
268
269    fn ec_curve_oid_to_algorithm(
270        curve_oid: Option<pkcs8::ObjectIdentifier>,
271    ) -> Result<(pkcs8::ObjectIdentifier, Algorithm)> {
272        const P256_OID: pkcs8::ObjectIdentifier =
273            pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.3.1.7");
274        const P384_OID: pkcs8::ObjectIdentifier =
275            pkcs8::ObjectIdentifier::new_unwrap("1.3.132.0.34");
276
277        match curve_oid {
278            Some(P256_OID) => Ok((P256_OID, Algorithm::ES256)),
279            Some(P384_OID) => Ok((P384_OID, Algorithm::ES384)),
280            Some(curve_oid) => Err(Error::IO(format!(
281                "unsupported EC private key curve OID: {curve_oid}; supported curves are P-256 and P-384"
282            ))),
283            None => Err(Error::IO(
284                "unsupported EC private key parameters: expected named P-256 or P-384 curve"
285                    .to_string(),
286            )),
287        }
288    }
289
290    fn parse_sec1_ec_pem(pem_data: &[u8]) -> Result<(Vec<u8>, Algorithm)> {
291        use sec1::der::Decode;
292
293        let pem = Self::parse_pem_block(pem_data, "EC PRIVATE KEY")?;
294        let ec_key = sec1::EcPrivateKey::from_der(pem.contents())
295            .map_err(|e| Error::IO(format!("failed to parse EC private key: {e}")))?;
296
297        let curve_oid = ec_key
298            .parameters
299            .and_then(|params| params.named_curve())
300            .ok_or_else(|| {
301                Error::IO(
302                    "unsupported EC private key parameters: expected named P-256 or P-384 curve"
303                        .to_string(),
304                )
305            })?;
306
307        let algorithm = match curve_oid.to_string().as_str() {
308            "1.2.840.10045.3.1.7" => Ok(Algorithm::ES256),
309            "1.3.132.0.34" => Ok(Algorithm::ES384),
310            _ => Err(Error::IO(format!(
311                "unsupported EC private key curve OID: {curve_oid}; supported curves are P-256 and P-384"
312            ))),
313        }?;
314
315        Ok((pem.contents().to_vec(), algorithm))
316    }
317
318    fn parse_pem_block(pem_data: &[u8], tag: &str) -> Result<pem::Pem> {
319        pem::parse_many(pem_data)
320            .map_err(|e| Error::IO(format!("failed to parse PEM: {e}")))?
321            .into_iter()
322            .find(|pem| pem.tag() == tag)
323            .ok_or_else(|| Error::IO(format!("failed to find {tag} PEM block")))
324    }
325
326    fn parse_unencrypted_key(pem_data: &[u8], pem_str: &str) -> Result<(EncodingKey, Algorithm)> {
327        if pem_str.contains("RSA PRIVATE KEY") {
328            // PKCS#1 RSA key. Select the key block when PEM bundles include
329            // another block before the private key.
330            let pem = Self::parse_pem_block(pem_data, "RSA PRIVATE KEY")?;
331            let key_pem = pem::encode(&pem);
332            let key = EncodingKey::from_rsa_pem(key_pem.as_bytes())
333                .map_err(|e| Error::IO(format!("failed to parse RSA private key: {e}")))?;
334            return Ok((key, Algorithm::RS256));
335        }
336
337        if pem_str.contains("EC PRIVATE KEY") {
338            use pkcs8::der::Encode;
339
340            // SEC1 EC key. Choose the JWT alg from the curve instead of
341            // advertising the wrong alg for non-P-256 keys.
342            let (ec_private_key_der, algorithm) = Self::parse_sec1_ec_pem(pem_data)?;
343            let pkcs8_der = pkcs8::PrivateKeyInfo {
344                algorithm: pkcs8::AlgorithmIdentifierRef {
345                    oid: pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1"),
346                    parameters: Some(pkcs8::der::asn1::AnyRef::from(&match algorithm {
347                        Algorithm::ES256 => {
348                            pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.3.1.7")
349                        }
350                        Algorithm::ES384 => pkcs8::ObjectIdentifier::new_unwrap("1.3.132.0.34"),
351                        _ => unreachable!(),
352                    })),
353                },
354                private_key: pkcs8::der::asn1::OctetStringRef::new(&ec_private_key_der)
355                    .map_err(|e| Error::IO(format!("failed to wrap EC private key: {e}")))?,
356                public_key: None::<pkcs8::der::asn1::BitStringRef<'_>>,
357            }
358            .to_der()
359            .map_err(|e| Error::IO(format!("failed to re-encode EC PKCS#8: {e}")))?;
360            return Ok((EncodingKey::from_ec_der(&pkcs8_der), algorithm));
361        }
362
363        // PKCS#8 "BEGIN PRIVATE KEY" — parse OID to determine key type,
364        // then use from_*_der with full PKCS#8 DER (from_ec_pem has issues with PKCS#8 EC keys)
365        let pem_parsed = Self::parse_pem_block(pem_data, "PRIVATE KEY")?;
366        let der_bytes = pem_parsed.contents();
367
368        let private_key_info = pkcs8::PrivateKeyInfoRef::try_from(der_bytes)
369            .map_err(|e| Error::IO(format!("failed to parse PKCS#8 DER: {e}")))?;
370
371        let algorithm_oid = private_key_info.algorithm.oid;
372
373        const RSA_OID: pkcs8::ObjectIdentifier =
374            pkcs8::ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
375        const EC_OID: pkcs8::ObjectIdentifier =
376            pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
377        const ED25519_OID: pkcs8::ObjectIdentifier =
378            pkcs8::ObjectIdentifier::new_unwrap("1.3.101.112");
379
380        if algorithm_oid == RSA_OID {
381            // Ring's from_der expects PKCS#1 RSAPrivateKey DER, not full PKCS#8.
382            // Extract the inner private key bytes from PKCS#8 PrivateKeyInfo.
383            let rsa_der = private_key_info.private_key.as_bytes();
384            Ok((EncodingKey::from_rsa_der(rsa_der), Algorithm::RS256))
385        } else if algorithm_oid == EC_OID {
386            let (ec_der, algorithm) = Self::rebuild_ec_pkcs8_named_curve(private_key_info)?;
387            Ok((EncodingKey::from_ec_der(&ec_der), algorithm))
388        } else if algorithm_oid == ED25519_OID {
389            Ok((EncodingKey::from_ed_der(der_bytes), Algorithm::EdDSA))
390        } else {
391            Err(Error::IO(format!(
392                "unsupported key algorithm OID: {algorithm_oid}"
393            )))
394        }
395    }
396
397    fn generate_jwt(&self) -> Result<String> {
398        let now = SystemTime::now()
399            .duration_since(UNIX_EPOCH)
400            .map_err(|e| Error::IO(format!("system time error: {e}")))?
401            .as_secs();
402
403        let claims = KeyPairClaims {
404            sub: self.username.clone(),
405            iat: now,
406            exp: now + KEYPAIR_TOKEN_TTL_SECS,
407        };
408
409        let header = Header::new(self.algorithm);
410        encode(&header, &claims, &self.encoding_key)
411            .map_err(|e| Error::IO(format!("failed to sign JWT: {e}")))
412    }
413}
414
415impl Auth for KeyPairAuth {
416    fn wrap(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
417        let token = self.generate_jwt()?;
418        Ok(builder
419            .bearer_auth(token)
420            .header(HEADER_AUTH_METHOD, "keypair"))
421    }
422
423    fn can_reload(&self) -> bool {
424        true
425    }
426
427    fn username(&self) -> String {
428        self.username.clone()
429    }
430}
431
432#[derive(::serde::Deserialize, ::serde::Serialize)]
433#[serde(from = "String", into = "String")]
434#[derive(Clone, Default, PartialEq, Eq)]
435pub struct SensitiveString(String);
436
437impl From<String> for SensitiveString {
438    fn from(value: String) -> Self {
439        Self(value)
440    }
441}
442
443impl From<&str> for SensitiveString {
444    fn from(value: &str) -> Self {
445        Self(value.to_string())
446    }
447}
448
449impl From<SensitiveString> for String {
450    fn from(value: SensitiveString) -> Self {
451        value.0
452    }
453}
454
455impl std::fmt::Display for SensitiveString {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        write!(f, "**REDACTED**")
458    }
459}
460
461impl std::fmt::Debug for SensitiveString {
462    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463        // we keep the double quotes here to keep the String behavior
464        write!(f, "\"**REDACTED**\"")
465    }
466}
467
468impl SensitiveString {
469    #[must_use]
470    pub fn inner(&self) -> &str {
471        self.0.as_str()
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    #[test]
480    fn serialization() {
481        let json_value = "\"foo\"";
482        let value: SensitiveString = serde_json::from_str(json_value).unwrap();
483        let result: String = serde_json::to_string(&value).unwrap();
484        assert_eq!(result, json_value);
485    }
486
487    #[test]
488    fn hide_content() {
489        let value = SensitiveString("hello world".to_string());
490        let display = format!("{value}");
491        assert_eq!(display, "**REDACTED**");
492        let debug = format!("{value:?}");
493        assert_eq!(debug, "\"**REDACTED**\"");
494    }
495
496    #[test]
497    fn keypair_auth_rsa() {
498        use std::io::Write;
499        use tempfile::NamedTempFile;
500
501        // Generate a test RSA private key in PKCS#8 format using openssl command
502        // Note: `openssl genrsa` outputs PKCS#1 which may have compatibility issues
503        // with some versions of ring. Using genpkey ensures PKCS#8 format.
504        let output = std::process::Command::new("openssl")
505            .args([
506                "genpkey",
507                "-algorithm",
508                "RSA",
509                "-pkeyopt",
510                "rsa_keygen_bits:2048",
511            ])
512            .output();
513        let output = match output {
514            Ok(o) if o.status.success() => o,
515            _ => {
516                // Skip test if openssl is not available
517                return;
518            }
519        };
520
521        let mut key_file = NamedTempFile::new().unwrap();
522        key_file.write_all(&output.stdout).unwrap();
523
524        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
525        assert_eq!(auth.username(), "testuser");
526        assert!(auth.can_reload());
527        assert_eq!(auth.algorithm, Algorithm::RS256);
528
529        // Verify JWT can be generated
530        let token = auth.generate_jwt().unwrap();
531        assert!(!token.is_empty());
532
533        // Verify JWT structure (header.payload.signature)
534        let parts: Vec<&str> = token.split('.').collect();
535        assert_eq!(parts.len(), 3);
536    }
537
538    #[test]
539    fn keypair_auth_rsa_pkcs1_bundle_selects_private_key_block() {
540        use std::io::Write;
541        use tempfile::NamedTempFile;
542
543        let output = std::process::Command::new("openssl")
544            .args(["genrsa", "2048"])
545            .output();
546        let output = match output {
547            Ok(o) if o.status.success() => o,
548            _ => return,
549        };
550
551        let pem_str = String::from_utf8_lossy(&output.stdout);
552        if !pem_str.contains("RSA PRIVATE KEY") {
553            return;
554        }
555
556        let key = prepend_dummy_public_key_pem(&output.stdout);
557        let mut key_file = NamedTempFile::new().unwrap();
558        key_file.write_all(&key).unwrap();
559
560        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
561        assert_eq!(auth.algorithm, Algorithm::RS256);
562
563        let token = auth.generate_jwt().unwrap();
564        let parts: Vec<&str> = token.split('.').collect();
565        assert_eq!(parts.len(), 3);
566    }
567
568    #[test]
569    fn keypair_auth_rsa_pkcs1() {
570        use std::io::Write;
571        use tempfile::NamedTempFile;
572
573        // Generate a PKCS#1 RSA key (BEGIN RSA PRIVATE KEY)
574        let output = std::process::Command::new("openssl")
575            .args(["genrsa", "2048"])
576            .output();
577        let output = match output {
578            Ok(o) if o.status.success() => o,
579            _ => return,
580        };
581
582        let pem_str = String::from_utf8_lossy(&output.stdout);
583        if !pem_str.contains("RSA PRIVATE KEY") {
584            // Skip if openssl outputs PKCS#8 format instead
585            return;
586        }
587
588        let mut key_file = NamedTempFile::new().unwrap();
589        key_file.write_all(&output.stdout).unwrap();
590
591        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
592        assert_eq!(auth.algorithm, Algorithm::RS256);
593
594        let token = auth.generate_jwt().unwrap();
595        let parts: Vec<&str> = token.split('.').collect();
596        assert_eq!(parts.len(), 3);
597    }
598
599    fn prepend_dummy_public_key_pem(key: &[u8]) -> Vec<u8> {
600        let mut bundle = pem::encode(&pem::Pem::new("PUBLIC KEY", vec![1, 2, 3])).into_bytes();
601        bundle.extend_from_slice(key);
602        bundle
603    }
604
605    fn gen_sec1_ec_private_key(curve: &str) -> Option<Vec<u8>> {
606        gen_sec1_ec_private_key_with_params(curve, false)
607    }
608
609    fn gen_sec1_ec_private_key_with_params(curve: &str, include_params: bool) -> Option<Vec<u8>> {
610        let mut args = vec!["ecparam", "-name", curve, "-genkey"];
611        if !include_params {
612            args.push("-noout");
613        }
614        let output = std::process::Command::new("openssl")
615            .args(args)
616            .output()
617            .ok()?;
618        output.status.success().then_some(output.stdout)
619    }
620
621    fn gen_ec_private_key(curve: &str, encrypted: bool) -> Option<Vec<u8>> {
622        let mut args = vec![
623            "genpkey",
624            "-algorithm",
625            "EC",
626            "-pkeyopt",
627            curve,
628            "-pkeyopt",
629            "ec_param_enc:named_curve",
630        ];
631        if encrypted {
632            args.extend([
633                "-aes-256-cbc",
634                "-pass",
635                "pass:testpass",
636                "-v2prf",
637                "hmacWithSHA256",
638            ]);
639        }
640
641        let output = std::process::Command::new("openssl")
642            .args(args)
643            .output()
644            .ok()?;
645        output.status.success().then_some(output.stdout)
646    }
647
648    #[test]
649    fn keypair_auth_pkcs8_bundle_selects_private_key_block() {
650        use std::io::Write;
651        use tempfile::NamedTempFile;
652
653        let Some(key) = gen_ec_private_key("ec_paramgen_curve:P-256", false) else {
654            return;
655        };
656        let key = prepend_dummy_public_key_pem(&key);
657
658        let mut key_file = NamedTempFile::new().unwrap();
659        key_file.write_all(&key).unwrap();
660
661        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
662        assert_eq!(auth.algorithm, Algorithm::ES256);
663
664        let token = auth.generate_jwt().unwrap();
665        let parts: Vec<&str> = token.split('.').collect();
666        assert_eq!(parts.len(), 3);
667    }
668
669    #[test]
670    fn keypair_auth_ec() {
671        use std::io::Write;
672        use tempfile::NamedTempFile;
673
674        let Some(key) = gen_ec_private_key("ec_paramgen_curve:P-256", false) else {
675            return;
676        };
677
678        let mut key_file = NamedTempFile::new().unwrap();
679        key_file.write_all(&key).unwrap();
680
681        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
682        assert_eq!(auth.algorithm, Algorithm::ES256);
683
684        let token = auth.generate_jwt().unwrap();
685        let parts: Vec<&str> = token.split('.').collect();
686        assert_eq!(parts.len(), 3);
687    }
688
689    #[test]
690    fn keypair_auth_ec_p384() {
691        use std::io::Write;
692        use tempfile::NamedTempFile;
693
694        let Some(key) = gen_ec_private_key("ec_paramgen_curve:P-384", false) else {
695            return;
696        };
697
698        let mut key_file = NamedTempFile::new().unwrap();
699        key_file.write_all(&key).unwrap();
700
701        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
702        assert_eq!(auth.algorithm, Algorithm::ES384);
703
704        let token = auth.generate_jwt().unwrap();
705        let parts: Vec<&str> = token.split('.').collect();
706        assert_eq!(parts.len(), 3);
707    }
708
709    #[test]
710    fn keypair_auth_sec1_ec_with_parameters_block() {
711        use std::io::Write;
712        use tempfile::NamedTempFile;
713
714        let Some(key) = gen_sec1_ec_private_key_with_params("prime256v1", true) else {
715            return;
716        };
717
718        let mut key_file = NamedTempFile::new().unwrap();
719        key_file.write_all(&key).unwrap();
720
721        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
722        assert_eq!(auth.algorithm, Algorithm::ES256);
723
724        let token = auth.generate_jwt().unwrap();
725        let parts: Vec<&str> = token.split('.').collect();
726        assert_eq!(parts.len(), 3);
727    }
728
729    #[test]
730    fn keypair_auth_sec1_ec_p384() {
731        use std::io::Write;
732        use tempfile::NamedTempFile;
733
734        let Some(key) = gen_sec1_ec_private_key("secp384r1") else {
735            return;
736        };
737
738        let mut key_file = NamedTempFile::new().unwrap();
739        key_file.write_all(&key).unwrap();
740
741        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
742        assert_eq!(auth.algorithm, Algorithm::ES384);
743
744        let token = auth.generate_jwt().unwrap();
745        let parts: Vec<&str> = token.split('.').collect();
746        assert_eq!(parts.len(), 3);
747    }
748
749    #[test]
750    fn keypair_auth_encrypted_ec_p384() {
751        use std::io::Write;
752        use tempfile::NamedTempFile;
753
754        let Some(key) = gen_ec_private_key("ec_paramgen_curve:P-384", true) else {
755            return;
756        };
757
758        let mut key_file = NamedTempFile::new().unwrap();
759        key_file.write_all(&key).unwrap();
760
761        let mut pass_file = NamedTempFile::new().unwrap();
762        pass_file.write_all(b"testpass\n").unwrap();
763
764        let auth = KeyPairAuth::new(
765            "testuser",
766            key_file.path().to_str().unwrap(),
767            Some(pass_file.path().to_str().unwrap()),
768        )
769        .unwrap();
770        assert_eq!(auth.algorithm, Algorithm::ES384);
771
772        let token = auth.generate_jwt().unwrap();
773        let parts: Vec<&str> = token.split('.').collect();
774        assert_eq!(parts.len(), 3);
775    }
776
777    #[test]
778    fn keypair_auth_rejects_unsupported_ec_curves() {
779        use std::io::Write;
780        use tempfile::NamedTempFile;
781
782        let Some(key) = gen_ec_private_key("ec_paramgen_curve:P-521", false) else {
783            return;
784        };
785
786        let mut key_file = NamedTempFile::new().unwrap();
787        key_file.write_all(&key).unwrap();
788
789        let err = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None)
790            .err()
791            .expect("unsupported EC curve should be rejected");
792        assert!(
793            err.to_string()
794                .contains("unsupported EC private key curve OID"),
795            "unexpected error: {err}"
796        );
797    }
798
799    #[test]
800    fn keypair_auth_ed25519() {
801        use std::io::Write;
802        use tempfile::NamedTempFile;
803
804        // Generate a test Ed25519 private key
805        let output = std::process::Command::new("openssl")
806            .args(["genpkey", "-algorithm", "ed25519"])
807            .output();
808        let output = match output {
809            Ok(o) if o.status.success() => o,
810            _ => return,
811        };
812
813        let mut key_file = NamedTempFile::new().unwrap();
814        key_file.write_all(&output.stdout).unwrap();
815
816        let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap();
817        assert_eq!(auth.algorithm, Algorithm::EdDSA);
818
819        let token = auth.generate_jwt().unwrap();
820        let parts: Vec<&str> = token.split('.').collect();
821        assert_eq!(parts.len(), 3);
822    }
823
824    #[test]
825    fn keypair_auth_encrypted_pkcs8_bundle_selects_private_key_block() {
826        use std::io::Write;
827        use tempfile::NamedTempFile;
828
829        let output = std::process::Command::new("openssl")
830            .args([
831                "genpkey",
832                "-algorithm",
833                "RSA",
834                "-pkeyopt",
835                "rsa_keygen_bits:2048",
836                "-aes-256-cbc",
837                "-pass",
838                "pass: testpass ",
839                "-v2prf",
840                "hmacWithSHA256",
841            ])
842            .output();
843        let output = match output {
844            Ok(o) if o.status.success() => o,
845            _ => return,
846        };
847
848        let key = prepend_dummy_public_key_pem(&output.stdout);
849        let mut key_file = NamedTempFile::new().unwrap();
850        key_file.write_all(&key).unwrap();
851
852        let mut pass_file = NamedTempFile::new().unwrap();
853        pass_file.write_all(b" testpass \n").unwrap();
854
855        let auth = KeyPairAuth::new(
856            "testuser",
857            key_file.path().to_str().unwrap(),
858            Some(pass_file.path().to_str().unwrap()),
859        )
860        .unwrap();
861        assert_eq!(auth.algorithm, Algorithm::RS256);
862
863        let token = auth.generate_jwt().unwrap();
864        let parts: Vec<&str> = token.split('.').collect();
865        assert_eq!(parts.len(), 3);
866    }
867
868    #[test]
869    fn keypair_auth_encrypted_key() {
870        use std::io::Write;
871        use tempfile::NamedTempFile;
872
873        // Generate an encrypted RSA private key with scrypt KDF (supported by pkcs8 crate)
874        let output = std::process::Command::new("openssl")
875            .args([
876                "genpkey",
877                "-algorithm",
878                "RSA",
879                "-pkeyopt",
880                "rsa_keygen_bits:2048",
881                "-aes-256-cbc",
882                "-pass",
883                "pass:testpass",
884                "-v2prf",
885                "hmacWithSHA256",
886            ])
887            .output();
888        let output = match output {
889            Ok(o) if o.status.success() => o,
890            _ => return,
891        };
892
893        // Check if the generated key is actually encrypted
894        let pem_str = String::from_utf8_lossy(&output.stdout);
895        if !pem_str.contains("ENCRYPTED") {
896            return;
897        }
898
899        let mut key_file = NamedTempFile::new().unwrap();
900        key_file.write_all(&output.stdout).unwrap();
901
902        let mut pass_file = NamedTempFile::new().unwrap();
903        pass_file.write_all(b"testpass\n").unwrap();
904
905        let auth = KeyPairAuth::new(
906            "testuser",
907            key_file.path().to_str().unwrap(),
908            Some(pass_file.path().to_str().unwrap()),
909        )
910        .unwrap();
911        assert_eq!(auth.algorithm, Algorithm::RS256);
912
913        let token = auth.generate_jwt().unwrap();
914        let parts: Vec<&str> = token.split('.').collect();
915        assert_eq!(parts.len(), 3);
916    }
917}