Skip to main content

oracledb_protocol/tls/
dn.rs

1//! Oracle server-certificate DN / name matching.
2//!
3//! python-oracledb thin disables rustls's standard hostname verification and
4//! instead runs its own check after the TLS handshake completes
5//! (`impl/thin/crypto.pyx::check_server_dn`). This module is a faithful,
6//! sans-I/O port of that algorithm operating on already-extracted certificate
7//! fields (subject DN string, SAN DNS names, common names).
8//!
9//! Two modes, mirroring the reference exactly:
10//!
11//! * **Explicit DN** (`ssl_server_cert_dn` is set): parse the expected DN and
12//!   the server's subject DN into `{ATTR: value}` maps and require the maps to
13//!   be equal. Order-independent; exact (no wildcards).
14//! * **Name match** (no `ssl_server_cert_dn`): match the expected host against
15//!   the certificate's SAN DNS names first, then its common names, with
16//!   wildcard support (`_name_matches`).
17
18/// Outcome of a DN / name check, kept distinct so the driver can surface the
19/// reference's two distinct errors (ERR_INVALID_SERVER_CERT_DN vs
20/// ERR_INVALID_SERVER_NAME).
21#[derive(Debug, Clone, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum DnMatchError {
24    /// `ssl_server_cert_dn` was supplied but did not equal the server's
25    /// subject DN (reference ERR_INVALID_SERVER_CERT_DN).
26    CertDnMismatch { expected_dn: String },
27    /// No `ssl_server_cert_dn`; the host matched neither a SAN DNS name nor a
28    /// common name (reference ERR_INVALID_SERVER_NAME).
29    NameMismatch { expected_name: String },
30}
31
32impl core::fmt::Display for DnMatchError {
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        match self {
35            Self::CertDnMismatch { expected_dn } => write!(
36                f,
37                "the distinguished name (DN) on the server certificate does not match \
38                 the expected value \"{expected_dn}\""
39            ),
40            Self::NameMismatch { expected_name } => write!(
41                f,
42                "the server name \"{expected_name}\" does not match the names in the \
43                 server certificate"
44            ),
45        }
46    }
47}
48
49impl std::error::Error for DnMatchError {}
50
51/// Parse a distinguished-name string into a map of `ATTR -> value`, mirroring
52/// python-oracledb's `DN_REGEX` semantics:
53///
54/// `(?:^|,\s?)(?:(?P<name>[A-Z]+)=(?P<val>"(?:[^"]|"")+"|[^,]+))+`
55///
56/// i.e. comma-separated `ATTR=value` pairs where the attribute name is one or
57/// more uppercase ASCII letters and the value is either a double-quoted string
58/// (in which `""` is a literal quote) or a run of non-comma characters. The
59/// separator may be a comma optionally followed by a single space.
60///
61/// Returned as a sorted `Vec` of `(attr, value)` so two DNs can be compared
62/// order-independently (the reference compares Python dicts).
63#[must_use]
64pub fn parse_dn(dn: &str) -> Vec<(String, String)> {
65    let mut out: Vec<(String, String)> = Vec::new();
66    let bytes: Vec<char> = dn.chars().collect();
67    let mut i = 0usize;
68    let n = bytes.len();
69    while i < n {
70        // Skip a leading separator: optional comma + optional single space.
71        if bytes[i] == ',' {
72            i += 1;
73            if i < n && bytes[i] == ' ' {
74                i += 1;
75            }
76        }
77        // Skip any other incidental whitespace at the start of a pair.
78        while i < n && bytes[i] == ' ' {
79            i += 1;
80        }
81        if i >= n {
82            break;
83        }
84        // Attribute name: one or more uppercase ASCII letters.
85        let name_start = i;
86        while i < n && bytes[i].is_ascii_uppercase() {
87            i += 1;
88        }
89        if i == name_start || i >= n || bytes[i] != '=' {
90            // Not a well-formed pair; skip to the next comma to stay in sync.
91            while i < n && bytes[i] != ',' {
92                i += 1;
93            }
94            continue;
95        }
96        let name: String = bytes[name_start..i].iter().collect();
97        i += 1; // consume '='
98
99        // Value: quoted ("" => literal quote) or a run of non-comma chars.
100        let value = if i < n && bytes[i] == '"' {
101            i += 1; // opening quote
102            let mut val = String::new();
103            while i < n {
104                if bytes[i] == '"' {
105                    if i + 1 < n && bytes[i + 1] == '"' {
106                        // Escaped quote.
107                        val.push('"');
108                        i += 2;
109                    } else {
110                        i += 1; // closing quote
111                        break;
112                    }
113                } else {
114                    val.push(bytes[i]);
115                    i += 1;
116                }
117            }
118            val
119        } else {
120            let val_start = i;
121            while i < n && bytes[i] != ',' {
122                i += 1;
123            }
124            let mut val: String = bytes[val_start..i].iter().collect();
125            // The non-quoted branch in the reference regex ([^,]+) does not
126            // trim, but a trailing space before the next ", " separator is part
127            // of the value only if no space-separator follows. To match the
128            // reference's "comma + optional single space" separator we leave
129            // the value as-is; callers compare verbatim. We do trim a single
130            // trailing space that would otherwise belong to the separator.
131            if val.ends_with(' ') {
132                // Only strip if the next char is a comma-less end / separator.
133                val = val.trim_end_matches(' ').to_string();
134            }
135            val
136        };
137        out.push((name, value));
138    }
139    out.sort();
140    out
141}
142
143/// Compare an expected DN against the server's subject DN for equality, the
144/// `expected_dn is not None` branch of `check_server_dn`.
145///
146/// # Errors
147/// Returns [`DnMatchError::CertDnMismatch`] when the parsed attribute maps
148/// differ.
149pub fn check_cert_dn(expected_dn: &str, server_subject_dn: &str) -> Result<(), DnMatchError> {
150    let expected = parse_dn(expected_dn);
151    let server = parse_dn(server_subject_dn);
152    if expected == server {
153        Ok(())
154    } else {
155        Err(DnMatchError::CertDnMismatch {
156            expected_dn: expected_dn.to_string(),
157        })
158    }
159}
160
161/// Returns whether `name_to_check` matches `cert_name`, where `cert_name` may
162/// contain a wildcard (`*`). Faithful port of python-oracledb's
163/// `crypto.pyx::_name_matches` (case-insensitive).
164#[must_use]
165pub fn name_matches(name_to_check: &str, cert_name: &str) -> bool {
166    let cert_name = cert_name.to_lowercase();
167    let name_to_check = name_to_check.to_lowercase();
168
169    // Full match.
170    if name_to_check == cert_name {
171        return true;
172    }
173
174    // Both must have more than one label.
175    let check_pos = name_to_check.find('.');
176    let cert_pos = cert_name.find('.');
177    let (Some(check_pos), Some(cert_pos)) = (check_pos, cert_pos) else {
178        return false;
179    };
180    if check_pos == 0 || cert_pos == 0 {
181        return false;
182    }
183
184    // Right-hand labels (from the first dot onward) must match.
185    if name_to_check[check_pos..] != cert_name[cert_pos..] {
186        return false;
187    }
188
189    // Wildcard matching on the left-most label.
190    let cert_label = &cert_name[..cert_pos];
191    let check_label = &name_to_check[..check_pos];
192    if cert_label == "*" {
193        return true;
194    } else if let Some(suffix) = cert_label.strip_prefix('*') {
195        return check_label.ends_with(suffix);
196    } else if let Some(prefix) = cert_label.strip_suffix('*') {
197        return check_label.starts_with(prefix);
198    }
199    // Wildcard somewhere in the middle.
200    match cert_name.find('*') {
201        None => false,
202        Some(_) => {
203            // The reference uses the wildcard position within the *full*
204            // cert_name to slice cert_name (not cert_label). Replicate that.
205            let wildcard_pos = cert_name.find('*').unwrap_or(0);
206            let pre = &cert_name[..wildcard_pos];
207            let post_start = wildcard_pos + 1;
208            // cert_name[wildcard_pos + 1:] in the reference is sliced from the
209            // full cert_name, but `_name_matches` only reaches here for the
210            // left label, so post is the remainder of cert_label.
211            let post = if post_start <= cert_label.len() {
212                &cert_label[post_start..]
213            } else {
214                ""
215            };
216            check_label.starts_with(pre) && check_label.ends_with(post)
217        }
218    }
219}
220
221/// Match the expected host name against the certificate's SAN DNS names and
222/// then its common names — the `expected_dn is None` branch of
223/// `check_server_dn`.
224///
225/// # Errors
226/// Returns [`DnMatchError::NameMismatch`] when no SAN DNS name and no common
227/// name matches `expected_name`.
228pub fn check_server_name(
229    expected_name: &str,
230    san_dns_names: &[String],
231    common_names: &[String],
232) -> Result<(), DnMatchError> {
233    for name in san_dns_names {
234        if name_matches(expected_name, name) {
235            return Ok(());
236        }
237    }
238    for name in common_names {
239        if name_matches(expected_name, name) {
240            return Ok(());
241        }
242    }
243    Err(DnMatchError::NameMismatch {
244        expected_name: expected_name.to_string(),
245    })
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn parse_dn_simple() {
254        let parsed = parse_dn("CN=db.example.com,O=Example,C=US");
255        assert_eq!(
256            parsed,
257            vec![
258                ("C".to_string(), "US".to_string()),
259                ("CN".to_string(), "db.example.com".to_string()),
260                ("O".to_string(), "Example".to_string()),
261            ]
262        );
263    }
264
265    #[test]
266    fn parse_dn_order_independent_equality() {
267        let a = parse_dn("CN=x,O=y");
268        let b = parse_dn("O=y,CN=x");
269        assert_eq!(a, b);
270    }
271
272    #[test]
273    fn parse_dn_comma_space_separator() {
274        let a = parse_dn("CN=x, O=y, C=Z");
275        assert_eq!(
276            a,
277            vec![
278                ("C".to_string(), "Z".to_string()),
279                ("CN".to_string(), "x".to_string()),
280                ("O".to_string(), "y".to_string()),
281            ]
282        );
283    }
284
285    #[test]
286    fn parse_dn_quoted_value() {
287        let a = parse_dn(r#"CN="Acme, Inc.",C=US"#);
288        // The quoted value contains a comma that must NOT split the pair.
289        assert!(a.contains(&("CN".to_string(), "Acme, Inc.".to_string())));
290        assert!(a.contains(&("C".to_string(), "US".to_string())));
291    }
292
293    #[test]
294    fn check_cert_dn_accept_exact() {
295        assert!(check_cert_dn("CN=x,O=y", "O=y,CN=x").is_ok());
296    }
297
298    #[test]
299    fn check_cert_dn_reject_diff() {
300        let err = check_cert_dn("CN=x,O=y", "CN=z,O=y").unwrap_err();
301        assert!(matches!(err, DnMatchError::CertDnMismatch { .. }));
302    }
303
304    #[test]
305    fn check_cert_dn_reject_extra_attr() {
306        let err = check_cert_dn("CN=x", "CN=x,O=y").unwrap_err();
307        assert!(matches!(err, DnMatchError::CertDnMismatch { .. }));
308    }
309
310    #[test]
311    fn name_matches_full_case_insensitive() {
312        assert!(name_matches("DB.example.com", "db.example.COM"));
313    }
314
315    #[test]
316    fn name_matches_leading_wildcard() {
317        assert!(name_matches("host.example.com", "*.example.com"));
318        assert!(!name_matches("host.sub.example.com", "*.example.com"));
319    }
320
321    #[test]
322    fn name_matches_prefix_wildcard_label() {
323        // cert "web*.example.com" matches "webserver.example.com"
324        assert!(name_matches("webserver.example.com", "web*.example.com"));
325        assert!(!name_matches("appserver.example.com", "web*.example.com"));
326    }
327
328    #[test]
329    fn name_matches_suffix_wildcard_label() {
330        assert!(name_matches("serverweb.example.com", "*web.example.com"));
331    }
332
333    #[test]
334    fn name_matches_rejects_single_label() {
335        assert!(!name_matches("localhost", "*"));
336    }
337
338    #[test]
339    fn check_server_name_san_first() {
340        assert!(check_server_name("db.example.com", &["db.example.com".to_string()], &[]).is_ok());
341    }
342
343    #[test]
344    fn check_server_name_falls_back_to_cn() {
345        assert!(check_server_name("db.example.com", &[], &["db.example.com".to_string()]).is_ok());
346    }
347
348    #[test]
349    fn check_server_name_rejects_unknown() {
350        let err = check_server_name("evil.example.com", &["db.example.com".to_string()], &[])
351            .unwrap_err();
352        assert!(matches!(err, DnMatchError::NameMismatch { .. }));
353    }
354}