1use http::header::ACCEPT;
4use reqwest::{Client, Proxy};
5
6use crate::into_url::IntoUrl;
7use crate::OhttpKeys;
8
9pub async fn fetch_ohttp_keys(
18 ohttp_relay: impl IntoUrl,
19 payjoin_directory: impl IntoUrl,
20) -> Result<OhttpKeys, Error> {
21 let ohttp_keys_url = payjoin_directory.into_url()?.join("/.well-known/ohttp-gateway")?;
22 let proxy = Proxy::all(ohttp_relay.into_url()?.as_str())?;
23 let client = Client::builder().proxy(proxy).build()?;
24 let res = client.get(ohttp_keys_url).header(ACCEPT, "application/ohttp-keys").send().await?;
25 parse_ohttp_keys_response(res).await
26}
27
28#[cfg(feature = "_danger-local-https")]
39pub async fn fetch_ohttp_keys_with_cert(
40 ohttp_relay: impl IntoUrl,
41 payjoin_directory: impl IntoUrl,
42 cert_der: Vec<u8>,
43) -> Result<OhttpKeys, Error> {
44 let ohttp_keys_url = payjoin_directory.into_url()?.join("/.well-known/ohttp-gateway")?;
45 let proxy = Proxy::all(ohttp_relay.into_url()?.as_str())?;
46 let client = Client::builder()
47 .use_rustls_tls()
48 .add_root_certificate(reqwest::tls::Certificate::from_der(&cert_der)?)
49 .proxy(proxy)
50 .build()?;
51 let res = client.get(ohttp_keys_url).header(ACCEPT, "application/ohttp-keys").send().await?;
52 parse_ohttp_keys_response(res).await
53}
54
55async fn parse_ohttp_keys_response(res: reqwest::Response) -> Result<OhttpKeys, Error> {
56 if !res.status().is_success() {
57 return Err(Error::UnexpectedStatusCode(res.status()));
58 }
59
60 let body = res.bytes().await?.to_vec();
61 OhttpKeys::decode(&body).map_err(|e| {
62 Error::Internal(InternalError(InternalErrorInner::InvalidOhttpKeys(e.to_string())))
63 })
64}
65
66#[derive(Debug)]
67#[non_exhaustive]
68pub enum Error {
69 UnexpectedStatusCode(http::StatusCode),
71 #[doc(hidden)]
73 Internal(InternalError),
74}
75
76#[derive(Debug)]
77pub struct InternalError(InternalErrorInner);
78
79#[derive(Debug)]
80enum InternalErrorInner {
81 ParseUrl(crate::into_url::Error),
82 Reqwest(reqwest::Error),
83 Io(std::io::Error),
84 #[cfg(feature = "_danger-local-https")]
85 Rustls(rustls::Error),
86 InvalidOhttpKeys(String),
87}
88
89impl From<url::ParseError> for Error {
90 fn from(value: url::ParseError) -> Self {
91 Self::Internal(InternalError(InternalErrorInner::ParseUrl(value.into())))
92 }
93}
94
95macro_rules! impl_from_error {
96 ($from:ty, $to:ident) => {
97 impl From<$from> for Error {
98 fn from(value: $from) -> Self {
99 Self::Internal(InternalError(InternalErrorInner::$to(value)))
100 }
101 }
102 };
103}
104
105impl_from_error!(crate::into_url::Error, ParseUrl);
106impl_from_error!(reqwest::Error, Reqwest);
107impl_from_error!(std::io::Error, Io);
108#[cfg(feature = "_danger-local-https")]
109impl_from_error!(rustls::Error, Rustls);
110
111impl std::fmt::Display for Error {
112 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
113 match self {
114 Self::UnexpectedStatusCode(code) => {
115 write!(f, "Unexpected status code from payjoin directory: {code}")
116 }
117 Self::Internal(InternalError(e)) => e.fmt(f),
118 }
119 }
120}
121
122impl std::fmt::Display for InternalErrorInner {
123 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
124 use InternalErrorInner::*;
125
126 match &self {
127 Reqwest(e) => e.fmt(f),
128 ParseUrl(e) => e.fmt(f),
129 Io(e) => e.fmt(f),
130 InvalidOhttpKeys(e) => {
131 write!(f, "Invalid ohttp keys returned from payjoin directory: {e}")
132 }
133 #[cfg(feature = "_danger-local-https")]
134 Rustls(e) => e.fmt(f),
135 }
136 }
137}
138
139impl std::error::Error for Error {
140 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
141 match self {
142 Self::Internal(InternalError(e)) => e.source(),
143 Self::UnexpectedStatusCode(_) => None,
144 }
145 }
146}
147
148impl std::error::Error for InternalErrorInner {
149 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
150 use InternalErrorInner::*;
151
152 match self {
153 Reqwest(e) => Some(e),
154 ParseUrl(e) => Some(e),
155 Io(e) => Some(e),
156 InvalidOhttpKeys(_) => None,
157 #[cfg(feature = "_danger-local-https")]
158 Rustls(e) => Some(e),
159 }
160 }
161}
162
163impl From<InternalError> for Error {
164 fn from(value: InternalError) -> Self { Self::Internal(value) }
165}
166
167impl From<InternalErrorInner> for Error {
168 fn from(value: InternalErrorInner) -> Self { Self::Internal(InternalError(value)) }
169}
170
171#[cfg(test)]
172mod tests {
173 use std::str::FromStr;
174
175 use http::StatusCode;
176 use reqwest::Response;
177
178 use super::*;
179
180 fn mock_response(status: StatusCode, body: Vec<u8>) -> Response {
181 Response::from(http::response::Response::builder().status(status).body(body).unwrap())
182 }
183
184 #[tokio::test]
185 async fn test_parse_success_response() {
186 let valid_keys =
187 OhttpKeys::from_str("OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC")
188 .expect("valid keys")
189 .encode()
190 .expect("encodevalid keys");
191
192 let response = mock_response(StatusCode::OK, valid_keys);
193 assert!(parse_ohttp_keys_response(response).await.is_ok(), "expected valid keys response");
194 }
195
196 #[tokio::test]
197 async fn test_parse_error_status_codes() {
198 let error_codes = [
199 StatusCode::BAD_REQUEST,
200 StatusCode::NOT_FOUND,
201 StatusCode::INTERNAL_SERVER_ERROR,
202 StatusCode::SERVICE_UNAVAILABLE,
203 ];
204
205 for status in error_codes {
206 let response = mock_response(status, vec![]);
207 match parse_ohttp_keys_response(response).await {
208 Err(Error::UnexpectedStatusCode(code)) => assert_eq!(code, status),
209 result => panic!(
210 "Expected UnexpectedStatusCode error for status code: {status}, got: {result:?}"
211 ),
212 }
213 }
214 }
215
216 #[tokio::test]
217 async fn test_parse_invalid_keys() {
218 let invalid_keys = vec![1, 2, 3, 4];
220
221 let response = mock_response(StatusCode::OK, invalid_keys);
222
223 assert!(
224 matches!(
225 parse_ohttp_keys_response(response).await,
226 Err(Error::Internal(InternalError(InternalErrorInner::InvalidOhttpKeys(_))))
227 ),
228 "expected InvalidOhttpKeys error"
229 );
230 }
231}