1use core::time::Duration;
12use std::fmt;
13
14use bytes::Bytes;
15use ts_capabilityversion::CapabilityVersion;
16use ts_control_serde::{TokenRequest, TokenResponse};
17use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt, StatusCode};
18use url::Url;
19
20use crate::tokio::connect::ConnectionError;
21
22const LOAD_BALANCER_HEADER_KEY: &str = "Ts-Lb";
23
24const ID_TOKEN_TIMEOUT: Duration = Duration::from_secs(30);
29
30#[derive(Debug, Clone, Copy, Eq, PartialEq)]
36pub enum IdTokenInternalErrorKind {
37 Url,
39 SerDe,
41 Http,
43 Utf8,
45}
46
47impl fmt::Display for IdTokenInternalErrorKind {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 IdTokenInternalErrorKind::Url => write!(f, "URL parsing error"),
51 IdTokenInternalErrorKind::SerDe => write!(f, "serialization/deserialization error"),
52 IdTokenInternalErrorKind::Http => write!(f, "unsuccessful HTTP request"),
53 IdTokenInternalErrorKind::Utf8 => write!(f, "invalid UTF8"),
54 }
55 }
56}
57
58#[derive(Debug, thiserror::Error, Clone, Eq, PartialEq)]
60pub enum IdTokenError {
61 #[error("network error requesting id token")]
63 NetworkError,
64 #[error("error requesting id token: {0}")]
66 Internal(IdTokenInternalErrorKind),
67}
68
69impl From<url::ParseError> for IdTokenError {
70 fn from(error: url::ParseError) -> Self {
71 tracing::error!(%error, "bad URL building id-token request");
72 IdTokenError::Internal(IdTokenInternalErrorKind::Url)
73 }
74}
75
76impl From<serde_json::Error> for IdTokenError {
77 fn from(error: serde_json::Error) -> Self {
78 tracing::error!(%error, "serde error in id-token request");
79 IdTokenError::Internal(IdTokenInternalErrorKind::SerDe)
80 }
81}
82
83impl From<core::str::Utf8Error> for IdTokenError {
84 fn from(error: core::str::Utf8Error) -> Self {
85 tracing::error!(%error, "invalid utf8 in id-token response");
86 IdTokenError::Internal(IdTokenInternalErrorKind::Utf8)
87 }
88}
89
90impl From<ts_http_util::Error> for IdTokenError {
91 fn from(error: ts_http_util::Error) -> Self {
92 tracing::error!(%error, "http error in id-token request");
93 if crate::http_error_is_recoverable(error) {
94 IdTokenError::NetworkError
95 } else {
96 IdTokenError::Internal(IdTokenInternalErrorKind::Http)
97 }
98 }
99}
100
101impl From<ConnectionError> for IdTokenError {
104 fn from(error: ConnectionError) -> Self {
105 use crate::tokio::connect::InternalErrorKind as Conn;
106 match error {
107 ConnectionError::NetworkError => IdTokenError::NetworkError,
108 ConnectionError::Internal(k) => IdTokenError::Internal(match k {
109 Conn::Url => IdTokenInternalErrorKind::Url,
110 Conn::SerDe => IdTokenInternalErrorKind::SerDe,
111 Conn::Http
113 | Conn::MessageFormat
114 | Conn::Io
115 | Conn::ChallengeLength
116 | Conn::NoiseHandshake => IdTokenInternalErrorKind::Http,
117 }),
118 }
119 }
120}
121
122pub async fn fetch_id_token(
130 config: &crate::Config,
131 node_keystate: &ts_keys::NodeState,
132 audience: &str,
133) -> Result<String, IdTokenError> {
134 let control_url = &config.server_url;
135 let rpc = async {
136 let http2_conn = crate::tokio::connect(
137 control_url,
138 &node_keystate.machine_keys,
139 config.allow_http_key_fetch,
140 )
141 .await?;
142 fetch_id_token_with(control_url, node_keystate, audience, &http2_conn).await
143 };
144
145 match tokio::time::timeout(ID_TOKEN_TIMEOUT, rpc).await {
146 Ok(result) => result,
147 Err(_elapsed) => {
148 tracing::error!(timeout = ?ID_TOKEN_TIMEOUT, "id-token request timed out");
149 Err(IdTokenError::NetworkError)
150 }
151 }
152}
153
154pub(crate) async fn fetch_id_token_with(
159 control_url: &Url,
160 node_keystate: &ts_keys::NodeState,
161 audience: &str,
162 http2_conn: &Http2<BytesBody>,
163) -> Result<String, IdTokenError> {
164 let node_public_key = node_keystate.node_keys.public;
165
166 let req = TokenRequest {
167 cap_version: CapabilityVersion::CURRENT,
168 node_key: node_public_key,
169 audience: audience.to_string(),
170 };
171
172 let body = serde_json::to_string(&req)?;
173 let url = control_url.join("machine/id-token")?;
174
175 tracing::debug!(url = %url.as_str(), "requesting id token from control");
176
177 let response = http2_conn
178 .post(
179 &url,
180 [(
181 LOAD_BALANCER_HEADER_KEY.parse().unwrap(),
182 node_public_key.to_string().parse().unwrap(),
183 )],
184 Bytes::from(body).into(),
185 )
186 .await?;
187
188 let status = response.status();
189 let body = response
190 .collect_bytes_limited(crate::MAX_CONTROL_RESPONSE)
191 .await?;
192 parse_token_response(status, &body)
193}
194
195fn parse_token_response(status: StatusCode, body: &[u8]) -> Result<String, IdTokenError> {
201 if !status.is_success() {
202 let mut truncated = body.to_vec();
203 truncated.truncate(512);
204 let preview = core::str::from_utf8(&truncated).unwrap_or("<invalid utf8>");
205 tracing::error!(body = %preview, %status, "id-token request failed");
206 return Err(IdTokenError::Internal(IdTokenInternalErrorKind::Http));
207 }
208
209 let body = core::str::from_utf8(body)?;
210 let resp: TokenResponse = serde_json::from_str(body)?;
211
212 Ok(resp.id_token)
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::tokio::connect::{ConnectionError, InternalErrorKind as ConnKind};
219
220 #[test]
223 fn connection_error_network_maps_to_network() {
224 assert_eq!(
225 IdTokenError::from(ConnectionError::NetworkError),
226 IdTokenError::NetworkError
227 );
228 }
229
230 #[test]
231 fn connection_error_internal_kinds_map_correctly() {
232 use IdTokenInternalErrorKind as Id;
233 let cases = [
234 (ConnKind::Url, Id::Url),
235 (ConnKind::SerDe, Id::SerDe),
236 (ConnKind::Http, Id::Http),
237 (ConnKind::MessageFormat, Id::Http),
238 (ConnKind::Io, Id::Http),
239 (ConnKind::ChallengeLength, Id::Http),
240 (ConnKind::NoiseHandshake, Id::Http),
241 ];
242 for (conn, expected) in cases {
243 assert_eq!(
244 IdTokenError::from(ConnectionError::Internal(conn)),
245 IdTokenError::Internal(expected),
246 "ConnectionError::Internal({conn:?}) should map to Internal({expected:?})"
247 );
248 }
249 }
250
251 #[test]
252 fn serde_error_maps_to_internal_serde() {
253 let err = serde_json::from_str::<TokenResponse>("not json").unwrap_err();
254 assert_eq!(
255 IdTokenError::from(err),
256 IdTokenError::Internal(IdTokenInternalErrorKind::SerDe)
257 );
258 }
259
260 #[test]
261 fn url_parse_error_maps_to_internal_url() {
262 let err = Url::parse("not a url").unwrap_err();
263 assert_eq!(
264 IdTokenError::from(err),
265 IdTokenError::Internal(IdTokenInternalErrorKind::Url)
266 );
267 }
268
269 #[test]
270 fn utf8_error_maps_to_internal_utf8() {
271 let bytes = vec![0xffu8, 0xfe];
274 let err = core::str::from_utf8(&bytes).unwrap_err();
275 assert_eq!(
276 IdTokenError::from(err),
277 IdTokenError::Internal(IdTokenInternalErrorKind::Utf8)
278 );
279 }
280
281 #[test]
282 fn http_util_error_non_recoverable_maps_to_internal_http() {
283 let err = ts_http_util::Error::InvalidResponse;
285 assert_eq!(
286 IdTokenError::from(err),
287 IdTokenError::Internal(IdTokenInternalErrorKind::Http)
288 );
289 }
290
291 #[test]
292 fn http_util_error_recoverable_maps_to_network() {
293 let err = ts_http_util::Error::Io;
295 assert_eq!(IdTokenError::from(err), IdTokenError::NetworkError);
296 }
297
298 #[test]
301 fn parse_token_response_ok() {
302 let body = br#"{"id_token":"abc.def.ghi"}"#;
303 let token = parse_token_response(StatusCode::OK, body).unwrap();
304 assert_eq!(token, "abc.def.ghi");
305 }
306
307 #[test]
308 fn parse_token_response_non_success_is_http() {
309 let err =
310 parse_token_response(StatusCode::INTERNAL_SERVER_ERROR, b"upstream boom").unwrap_err();
311 assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::Http));
312 }
313
314 #[test]
315 fn parse_token_response_invalid_json_is_serde() {
316 let err = parse_token_response(StatusCode::OK, b"{not json").unwrap_err();
317 assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::SerDe));
318 }
319
320 #[test]
321 fn parse_token_response_invalid_utf8_is_utf8() {
322 let err = parse_token_response(StatusCode::OK, &[0xff, 0xfe, 0xfd]).unwrap_err();
323 assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::Utf8));
324 }
325
326 #[test]
327 fn parse_token_response_missing_id_token_errors() {
328 let err = parse_token_response(StatusCode::OK, b"{}").unwrap_err();
329 assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::SerDe));
331 }
332}