payjoin/core/
io.rs

1//! IO-related types and functions. Specifically, fetching OHTTP keys from a payjoin directory.
2
3use http::header::ACCEPT;
4use reqwest::{Client, Proxy};
5
6use crate::into_url::IntoUrl;
7use crate::OhttpKeys;
8
9/// Fetch the ohttp keys from the specified payjoin directory via proxy.
10///
11/// * `ohttp_relay`: The http CONNECT method proxy to request the ohttp keys from a payjoin
12///   directory.  Proxying requests for ohttp keys ensures a client IP address is never revealed to
13///   the payjoin directory.
14///
15/// * `payjoin_directory`: The payjoin directory from which to fetch the ohttp keys.  This
16///   directory stores and forwards payjoin client payloads.
17pub async fn fetch_ohttp_keys(
18    ohttp_relay: impl IntoUrl,
19    payjoin_directory: impl IntoUrl,
20) -> Result<OhttpKeys, Error> {
21    let ohttp_keys_url = payjoin_directory.into_url()?.join("/.well-known/ohttp-gateway")?;
22    let proxy = Proxy::all(ohttp_relay.into_url()?.as_str())?;
23    let client = Client::builder().proxy(proxy).build()?;
24    let res = client.get(ohttp_keys_url).header(ACCEPT, "application/ohttp-keys").send().await?;
25    parse_ohttp_keys_response(res).await
26}
27
28/// Fetch the ohttp keys from the specified payjoin directory via proxy.
29///
30/// * `ohttp_relay`: The http CONNECT method proxy to request the ohttp keys from a payjoin
31///   directory.  Proxying requests for ohttp keys ensures a client IP address is never revealed to
32///   the payjoin directory.
33///
34/// * `payjoin_directory`: The payjoin directory from which to fetch the ohttp keys.  This
35///   directory stores and forwards payjoin client payloads.
36///
37/// * `cert_der`: The DER-encoded certificate to use for local HTTPS connections.
38#[cfg(feature = "_danger-local-https")]
39pub async fn fetch_ohttp_keys_with_cert(
40    ohttp_relay: impl IntoUrl,
41    payjoin_directory: impl IntoUrl,
42    cert_der: Vec<u8>,
43) -> Result<OhttpKeys, Error> {
44    let ohttp_keys_url = payjoin_directory.into_url()?.join("/.well-known/ohttp-gateway")?;
45    let proxy = Proxy::all(ohttp_relay.into_url()?.as_str())?;
46    let client = Client::builder()
47        .use_rustls_tls()
48        .add_root_certificate(reqwest::tls::Certificate::from_der(&cert_der)?)
49        .proxy(proxy)
50        .build()?;
51    let res = client.get(ohttp_keys_url).header(ACCEPT, "application/ohttp-keys").send().await?;
52    parse_ohttp_keys_response(res).await
53}
54
55async fn parse_ohttp_keys_response(res: reqwest::Response) -> Result<OhttpKeys, Error> {
56    if !res.status().is_success() {
57        return Err(Error::UnexpectedStatusCode(res.status()));
58    }
59
60    let body = res.bytes().await?.to_vec();
61    OhttpKeys::decode(&body).map_err(|e| {
62        Error::Internal(InternalError(InternalErrorInner::InvalidOhttpKeys(e.to_string())))
63    })
64}
65
66#[derive(Debug)]
67#[non_exhaustive]
68pub enum Error {
69    /// When the payjoin directory returns an unexpected status code
70    UnexpectedStatusCode(http::StatusCode),
71    /// Internal errors that should not be pattern matched by users
72    #[doc(hidden)]
73    Internal(InternalError),
74}
75
76#[derive(Debug)]
77pub struct InternalError(InternalErrorInner);
78
79#[derive(Debug)]
80enum InternalErrorInner {
81    ParseUrl(crate::into_url::Error),
82    Reqwest(reqwest::Error),
83    Io(std::io::Error),
84    #[cfg(feature = "_danger-local-https")]
85    Rustls(rustls::Error),
86    InvalidOhttpKeys(String),
87}
88
89impl From<url::ParseError> for Error {
90    fn from(value: url::ParseError) -> Self {
91        Self::Internal(InternalError(InternalErrorInner::ParseUrl(value.into())))
92    }
93}
94
95macro_rules! impl_from_error {
96    ($from:ty, $to:ident) => {
97        impl From<$from> for Error {
98            fn from(value: $from) -> Self {
99                Self::Internal(InternalError(InternalErrorInner::$to(value)))
100            }
101        }
102    };
103}
104
105impl_from_error!(crate::into_url::Error, ParseUrl);
106impl_from_error!(reqwest::Error, Reqwest);
107impl_from_error!(std::io::Error, Io);
108#[cfg(feature = "_danger-local-https")]
109impl_from_error!(rustls::Error, Rustls);
110
111impl std::fmt::Display for Error {
112    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
113        match self {
114            Self::UnexpectedStatusCode(code) => {
115                write!(f, "Unexpected status code from payjoin directory: {code}")
116            }
117            Self::Internal(InternalError(e)) => e.fmt(f),
118        }
119    }
120}
121
122impl std::fmt::Display for InternalErrorInner {
123    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
124        use InternalErrorInner::*;
125
126        match &self {
127            Reqwest(e) => e.fmt(f),
128            ParseUrl(e) => e.fmt(f),
129            Io(e) => e.fmt(f),
130            InvalidOhttpKeys(e) => {
131                write!(f, "Invalid ohttp keys returned from payjoin directory: {e}")
132            }
133            #[cfg(feature = "_danger-local-https")]
134            Rustls(e) => e.fmt(f),
135        }
136    }
137}
138
139impl std::error::Error for Error {
140    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
141        match self {
142            Self::Internal(InternalError(e)) => e.source(),
143            Self::UnexpectedStatusCode(_) => None,
144        }
145    }
146}
147
148impl std::error::Error for InternalErrorInner {
149    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
150        use InternalErrorInner::*;
151
152        match self {
153            Reqwest(e) => Some(e),
154            ParseUrl(e) => Some(e),
155            Io(e) => Some(e),
156            InvalidOhttpKeys(_) => None,
157            #[cfg(feature = "_danger-local-https")]
158            Rustls(e) => Some(e),
159        }
160    }
161}
162
163impl From<InternalError> for Error {
164    fn from(value: InternalError) -> Self { Self::Internal(value) }
165}
166
167impl From<InternalErrorInner> for Error {
168    fn from(value: InternalErrorInner) -> Self { Self::Internal(InternalError(value)) }
169}
170
171#[cfg(test)]
172mod tests {
173    use std::str::FromStr;
174
175    use http::StatusCode;
176    use reqwest::Response;
177
178    use super::*;
179
180    fn mock_response(status: StatusCode, body: Vec<u8>) -> Response {
181        Response::from(http::response::Response::builder().status(status).body(body).unwrap())
182    }
183
184    #[tokio::test]
185    async fn test_parse_success_response() {
186        let valid_keys =
187            OhttpKeys::from_str("OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC")
188                .expect("valid keys")
189                .encode()
190                .expect("encodevalid keys");
191
192        let response = mock_response(StatusCode::OK, valid_keys);
193        assert!(parse_ohttp_keys_response(response).await.is_ok(), "expected valid keys response");
194    }
195
196    #[tokio::test]
197    async fn test_parse_error_status_codes() {
198        let error_codes = [
199            StatusCode::BAD_REQUEST,
200            StatusCode::NOT_FOUND,
201            StatusCode::INTERNAL_SERVER_ERROR,
202            StatusCode::SERVICE_UNAVAILABLE,
203        ];
204
205        for status in error_codes {
206            let response = mock_response(status, vec![]);
207            match parse_ohttp_keys_response(response).await {
208                Err(Error::UnexpectedStatusCode(code)) => assert_eq!(code, status),
209                result => panic!(
210                    "Expected UnexpectedStatusCode error for status code: {status}, got: {result:?}"
211                ),
212            }
213        }
214    }
215
216    #[tokio::test]
217    async fn test_parse_invalid_keys() {
218        // Invalid OHTTP keys (not properly encoded)
219        let invalid_keys = vec![1, 2, 3, 4];
220
221        let response = mock_response(StatusCode::OK, invalid_keys);
222
223        assert!(
224            matches!(
225                parse_ohttp_keys_response(response).await,
226                Err(Error::Internal(InternalError(InternalErrorInner::InvalidOhttpKeys(_))))
227            ),
228            "expected InvalidOhttpKeys error"
229        );
230    }
231}