Skip to main content

dns_update/providers/
oraclecloud.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12#![cfg(any(feature = "ring", feature = "aws-lc-rs"))]
13
14use crate::crypto::sha256_digest;
15use crate::jwt::{parse_rsa_pkcs8_pem, rsa_sha256_sign};
16use crate::utils::txt_chunks_to_text;
17use crate::{DnsRecord, DnsRecordType, Error, IntoFqdn, Result};
18use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
19use chrono::Utc;
20use reqwest::Method;
21use reqwest::header::{HeaderMap, HeaderValue};
22use serde::{Deserialize, Serialize};
23use std::sync::Arc;
24use std::time::Duration;
25
26#[cfg(feature = "ring")]
27use ring::signature::RsaKeyPair;
28
29#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
30use aws_lc_rs::signature::RsaKeyPair;
31
32#[derive(Debug, Clone)]
33pub struct OracleCloudConfig {
34    pub tenancy_ocid: String,
35    pub user_ocid: String,
36    pub fingerprint: String,
37    pub private_key_pem: String,
38    pub private_key_password: Option<String>,
39    pub region: String,
40    pub compartment_ocid: String,
41    pub request_timeout: Option<Duration>,
42}
43
44#[derive(Clone)]
45pub struct OracleCloudProvider {
46    config: OracleCloudConfig,
47    key_pair: Arc<RsaKeyPair>,
48    endpoint: String,
49}
50
51#[derive(Debug, Serialize, Deserialize, Clone)]
52struct OciRecord {
53    domain: String,
54    rtype: String,
55    rdata: String,
56    ttl: u32,
57    #[serde(rename = "isProtected", skip_serializing_if = "Option::is_none")]
58    is_protected: Option<bool>,
59    #[serde(rename = "recordHash", skip_serializing_if = "Option::is_none")]
60    record_hash: Option<String>,
61}
62
63#[derive(Debug, Serialize)]
64struct UpdateRecordsRequest {
65    items: Vec<OciRecord>,
66}
67
68#[derive(Debug, Deserialize)]
69struct RecordCollection {
70    items: Vec<OciRecord>,
71}
72
73#[derive(Debug, Deserialize)]
74struct Zone {
75    name: String,
76    id: String,
77}
78
79impl OracleCloudProvider {
80    pub(crate) fn new(config: OracleCloudConfig) -> Result<Self> {
81        if config.tenancy_ocid.is_empty() {
82            return Err(Error::Client("tenancy_ocid is required".into()));
83        }
84        if config.user_ocid.is_empty() {
85            return Err(Error::Client("user_ocid is required".into()));
86        }
87        if config.fingerprint.is_empty() {
88            return Err(Error::Client("fingerprint is required".into()));
89        }
90        if config.region.is_empty() {
91            return Err(Error::Client("region is required".into()));
92        }
93        if config.compartment_ocid.is_empty() {
94            return Err(Error::Client("compartment_ocid is required".into()));
95        }
96        if config
97            .private_key_password
98            .as_ref()
99            .is_some_and(|p| !p.is_empty())
100        {
101            return Err(Error::Api(
102                "OCI private keys with a passphrase are not supported".into(),
103            ));
104        }
105
106        let key_pair = parse_rsa_pkcs8_pem(&config.private_key_pem).map_err(|e| {
107            Error::Client(format!("Failed to parse OCI private key: {}", e))
108        })?;
109
110        let endpoint = format!("https://dns.{}.oraclecloud.com", config.region);
111
112        Ok(Self {
113            config,
114            key_pair: Arc::new(key_pair),
115            endpoint,
116        })
117    }
118
119    #[cfg(test)]
120    pub(crate) fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
121        self.endpoint = endpoint.into().trim_end_matches('/').to_string();
122        self
123    }
124
125    fn key_id(&self) -> String {
126        format!(
127            "{}/{}/{}",
128            self.config.tenancy_ocid, self.config.user_ocid, self.config.fingerprint
129        )
130    }
131
132    fn sign_request(
133        &self,
134        method: &Method,
135        url: &str,
136        body: Option<&str>,
137    ) -> Result<HeaderMap> {
138        let parsed = reqwest::Url::parse(url)
139            .map_err(|e| Error::Client(format!("Failed to parse URL {}: {}", url, e)))?;
140        let host = parsed
141            .host_str()
142            .ok_or_else(|| Error::Client(format!("URL missing host: {}", url)))?
143            .to_string();
144        let host_header = if let Some(port) = parsed.port() {
145            format!("{}:{}", host, port)
146        } else {
147            host.clone()
148        };
149        let mut path_and_query = parsed.path().to_string();
150        if let Some(q) = parsed.query() {
151            path_and_query.push('?');
152            path_and_query.push_str(q);
153        }
154
155        let method_lower = method.as_str().to_lowercase();
156        let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
157
158        let mut signed_pairs: Vec<(String, String)> = Vec::new();
159        signed_pairs.push((
160            "(request-target)".to_string(),
161            format!("{} {}", method_lower, path_and_query),
162        ));
163        signed_pairs.push(("host".to_string(), host_header.clone()));
164        signed_pairs.push(("date".to_string(), date.clone()));
165
166        let needs_body_headers = matches!(*method, Method::POST | Method::PUT | Method::PATCH);
167        let body_bytes = body.unwrap_or("").as_bytes();
168        let content_sha256 = B64.encode(sha256_digest(body_bytes));
169        let content_length = body_bytes.len().to_string();
170        if needs_body_headers {
171            signed_pairs.push(("x-content-sha256".to_string(), content_sha256.clone()));
172            signed_pairs.push(("content-type".to_string(), "application/json".to_string()));
173            signed_pairs.push(("content-length".to_string(), content_length.clone()));
174        }
175
176        let signing_string = signed_pairs
177            .iter()
178            .map(|(k, v)| format!("{}: {}", k, v))
179            .collect::<Vec<_>>()
180            .join("\n");
181        let signature = rsa_sha256_sign(&self.key_pair, signing_string.as_bytes())
182            .map_err(|e| Error::Client(format!("Failed to sign request: {}", e)))?;
183        let signature_b64 = B64.encode(&signature);
184
185        let headers_list = signed_pairs
186            .iter()
187            .map(|(k, _)| k.as_str())
188            .collect::<Vec<_>>()
189            .join(" ");
190        let authorization = format!(
191            "Signature version=\"1\",keyId=\"{}\",algorithm=\"rsa-sha256\",headers=\"{}\",signature=\"{}\"",
192            self.key_id(),
193            headers_list,
194            signature_b64,
195        );
196
197        let mut headers = HeaderMap::new();
198        headers.insert(
199            "host",
200            HeaderValue::from_str(&host_header)
201                .map_err(|e| Error::Client(format!("Invalid host header: {}", e)))?,
202        );
203        headers.insert(
204            "date",
205            HeaderValue::from_str(&date)
206                .map_err(|e| Error::Client(format!("Invalid date header: {}", e)))?,
207        );
208        headers.insert(
209            "authorization",
210            HeaderValue::from_str(&authorization)
211                .map_err(|e| Error::Client(format!("Invalid authorization header: {}", e)))?,
212        );
213        if needs_body_headers {
214            headers.insert(
215                "x-content-sha256",
216                HeaderValue::from_str(&content_sha256)
217                    .map_err(|e| Error::Client(format!("Invalid x-content-sha256: {}", e)))?,
218            );
219            headers.insert(
220                "content-type",
221                HeaderValue::from_static("application/json"),
222            );
223            headers.insert(
224                "content-length",
225                HeaderValue::from_str(&content_length)
226                    .map_err(|e| Error::Client(format!("Invalid content-length: {}", e)))?,
227            );
228        }
229
230        Ok(headers)
231    }
232
233    async fn send_signed(
234        &self,
235        method: Method,
236        url: &str,
237        body: Option<String>,
238    ) -> Result<(reqwest::StatusCode, String)> {
239        let headers = self.sign_request(&method, url, body.as_deref())?;
240        let client = reqwest::Client::builder()
241            .timeout(self.config.request_timeout.unwrap_or(Duration::from_secs(30)))
242            .build()
243            .map_err(|e| Error::Client(format!("Failed to build HTTP client: {}", e)))?;
244
245        let mut request = client.request(method, url).headers(headers);
246        if let Some(b) = body {
247            request = request.body(b);
248        }
249
250        let response = request
251            .send()
252            .await
253            .map_err(|e| Error::Api(format!("Failed to send request to {}: {}", url, e)))?;
254        let status = response.status();
255        let text = response
256            .text()
257            .await
258            .map_err(|e| Error::Api(format!("Failed to read response body: {}", e)))?;
259        Ok((status, text))
260    }
261
262    fn record_to_rdata(record: &DnsRecord) -> Result<(String, String)> {
263        let (rtype, rdata) = match record {
264            DnsRecord::A(ip) => ("A".to_string(), ip.to_string()),
265            DnsRecord::AAAA(ip) => ("AAAA".to_string(), ip.to_string()),
266            DnsRecord::CNAME(c) => ("CNAME".to_string(), format_target(c)),
267            DnsRecord::NS(n) => ("NS".to_string(), format_target(n)),
268            DnsRecord::MX(mx) => (
269                "MX".to_string(),
270                format!("{} {}", mx.priority, format_target(&mx.exchange)),
271            ),
272            DnsRecord::TXT(txt) => {
273                let mut rdata = String::new();
274                txt_chunks_to_text(&mut rdata, txt, " ");
275                ("TXT".to_string(), rdata)
276            }
277            DnsRecord::SRV(srv) => (
278                "SRV".to_string(),
279                format!(
280                    "{} {} {} {}",
281                    srv.priority,
282                    srv.weight,
283                    srv.port,
284                    format_target(&srv.target)
285                ),
286            ),
287            DnsRecord::CAA(caa) => {
288                let (flags, tag, value) = caa.clone().decompose();
289                ("CAA".to_string(), format!("{} {} \"{}\"", flags, tag, value))
290            }
291            DnsRecord::TLSA(_) => {
292                return Err(Error::Api(
293                    "TLSA records are not supported by Oracle Cloud DNS".into(),
294                ));
295            }
296        };
297        Ok((rtype, rdata))
298    }
299
300    async fn resolve_zone(&self, origin: &str) -> Result<String> {
301        let trimmed = origin.trim_end_matches('.');
302        let url = format!(
303            "{}/20180115/zones?compartmentId={}&name={}",
304            self.endpoint,
305            urlencode(&self.config.compartment_ocid),
306            urlencode(trimmed),
307        );
308        let (status, body) = self.send_signed(Method::GET, &url, None).await?;
309        if !status.is_success() {
310            return Err(map_error(status, &body));
311        }
312        let zones: Vec<Zone> = serde_json::from_str(&body)
313            .map_err(|e| Error::Serialize(format!("Failed to parse zones list: {}", e)))?;
314        zones
315            .into_iter()
316            .find(|z| z.name.trim_end_matches('.') == trimmed)
317            .map(|z| z.id)
318            .ok_or_else(|| Error::Api(format!("Zone not found for {}", origin)))
319    }
320
321    fn records_url(&self, zone_id: &str, domain: &str, rtype: &str) -> String {
322        format!(
323            "{}/20180115/zones/{}/records/{}/{}?compartmentId={}",
324            self.endpoint,
325            urlencode(zone_id),
326            urlencode(domain),
327            urlencode(rtype),
328            urlencode(&self.config.compartment_ocid),
329        )
330    }
331
332    async fn get_records(
333        &self,
334        zone_id: &str,
335        domain: &str,
336        rtype: &str,
337    ) -> Result<Vec<OciRecord>> {
338        let url = self.records_url(zone_id, domain, rtype);
339        let (status, body) = self.send_signed(Method::GET, &url, None).await?;
340        if status.as_u16() == 404 {
341            return Ok(Vec::new());
342        }
343        if !status.is_success() {
344            return Err(map_error(status, &body));
345        }
346        let collection: RecordCollection = serde_json::from_str(&body)
347            .map_err(|e| Error::Serialize(format!("Failed to parse records: {}", e)))?;
348        Ok(collection.items)
349    }
350
351    async fn put_records(
352        &self,
353        zone_id: &str,
354        domain: &str,
355        rtype: &str,
356        items: Vec<OciRecord>,
357    ) -> Result<()> {
358        let url = self.records_url(zone_id, domain, rtype);
359        let request = UpdateRecordsRequest { items };
360        let body = serde_json::to_string(&request)
361            .map_err(|e| Error::Serialize(format!("Failed to serialize request: {}", e)))?;
362        let (status, response_body) = self.send_signed(Method::PUT, &url, Some(body)).await?;
363        if !status.is_success() {
364            return Err(map_error(status, &response_body));
365        }
366        Ok(())
367    }
368
369    pub(crate) async fn create(
370        &self,
371        name: impl IntoFqdn<'_>,
372        record: DnsRecord,
373        ttl: u32,
374        origin: impl IntoFqdn<'_>,
375    ) -> Result<()> {
376        let (rtype, rdata) = Self::record_to_rdata(&record)?;
377        let name = name.into_name().to_string();
378        let origin = origin.into_name().to_string();
379        let zone_id = self.resolve_zone(&origin).await?;
380
381        let mut existing = self.get_records(&zone_id, &name, &rtype).await?;
382        existing.push(OciRecord {
383            domain: name.clone(),
384            rtype: rtype.clone(),
385            rdata,
386            ttl,
387            is_protected: None,
388            record_hash: None,
389        });
390
391        let items = existing
392            .into_iter()
393            .map(|r| OciRecord {
394                domain: r.domain,
395                rtype: r.rtype,
396                rdata: r.rdata,
397                ttl: r.ttl,
398                is_protected: None,
399                record_hash: None,
400            })
401            .collect();
402        self.put_records(&zone_id, &name, &rtype, items).await
403    }
404
405    pub(crate) async fn update(
406        &self,
407        name: impl IntoFqdn<'_>,
408        record: DnsRecord,
409        ttl: u32,
410        origin: impl IntoFqdn<'_>,
411    ) -> Result<()> {
412        let (rtype, rdata) = Self::record_to_rdata(&record)?;
413        let name = name.into_name().to_string();
414        let origin = origin.into_name().to_string();
415        let zone_id = self.resolve_zone(&origin).await?;
416
417        let items = vec![OciRecord {
418            domain: name.clone(),
419            rtype: rtype.clone(),
420            rdata,
421            ttl,
422            is_protected: None,
423            record_hash: None,
424        }];
425        self.put_records(&zone_id, &name, &rtype, items).await
426    }
427
428    pub(crate) async fn delete(
429        &self,
430        name: impl IntoFqdn<'_>,
431        origin: impl IntoFqdn<'_>,
432        record_type: DnsRecordType,
433    ) -> Result<()> {
434        if matches!(record_type, DnsRecordType::TLSA) {
435            return Err(Error::Api(
436                "TLSA records are not supported by Oracle Cloud DNS".into(),
437            ));
438        }
439        let name = name.into_name().to_string();
440        let origin = origin.into_name().to_string();
441        let zone_id = self.resolve_zone(&origin).await?;
442        let rtype = record_type.as_str();
443
444        let url = self.records_url(&zone_id, &name, rtype);
445        let (status, body) = self.send_signed(Method::DELETE, &url, None).await?;
446        if status.as_u16() == 404 {
447            return Err(Error::NotFound);
448        }
449        if !status.is_success() {
450            return Err(map_error(status, &body));
451        }
452        Ok(())
453    }
454}
455
456fn format_target(value: &str) -> String {
457    format!("{}.", value.trim_end_matches('.'))
458}
459
460fn urlencode(value: &str) -> String {
461    serde_urlencoded::to_string([("v", value)])
462        .ok()
463        .and_then(|s| s.strip_prefix("v=").map(str::to_string))
464        .unwrap_or_else(|| value.to_string())
465}
466
467fn map_error(status: reqwest::StatusCode, body: &str) -> Error {
468    match status.as_u16() {
469        400 => Error::BadRequest,
470        401 | 403 => Error::Unauthorized,
471        404 => Error::NotFound,
472        _ => Error::Api(format!("Oracle Cloud DNS error {}: {}", status, body)),
473    }
474}