Skip to main content

ts_control/tokio/
id_token.rs

1//! Control RPC to mint an OIDC ID token for this node (workload-identity federation).
2//!
3//! Mirrors Go's `POST /machine/id-token` over the Noise (ts2021) transport: the node sends a
4//! [`TokenRequest`] (`{CapVersion, NodeKey, Audience}`) and control returns a [`TokenResponse`]
5//! carrying a signed JWT whose `aud` claim is the requested audience. The node is the token
6//! *subject*, not the authenticator — this is token issuance for presenting to a third-party relying
7//! party (e.g. AWS/GCP workload-identity federation), not a registration auth path.
8//!
9//! Requires control capability version ≥ 30 (Go: "2022-03-22: client can request id tokens").
10
11use core::time::Duration;
12use std::fmt;
13
14use bytes::Bytes;
15use ts_capabilityversion::CapabilityVersion;
16use ts_control_serde::{TokenRequest, TokenResponse};
17use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt, StatusCode};
18use url::Url;
19
20use crate::tokio::connect::ConnectionError;
21
22const LOAD_BALANCER_HEADER_KEY: &str = "Ts-Lb";
23
24/// Upper bound on a single id-token RPC (fresh Noise connect + POST + response read).
25///
26/// A hung control plane must not leave a half-open connection pinned forever; on expiry the RPC
27/// is abandoned and reported as a transient [`IdTokenError::NetworkError`].
28const ID_TOKEN_TIMEOUT: Duration = Duration::from_secs(30);
29
30/// The internal failure kinds an id-token request can surface.
31///
32/// Private to this module: `IdTokenError` owns its own internal vocabulary rather than borrowing a
33/// sibling module's (e.g. registration's). Only the generic kinds this RPC actually produces are
34/// represented.
35#[derive(Debug, Clone, Copy, Eq, PartialEq)]
36pub enum IdTokenInternalErrorKind {
37    /// Failed to build/parse a URL for the request.
38    Url,
39    /// Failed to serialize the request or deserialize the response body.
40    SerDe,
41    /// An unsuccessful (non-2xx) HTTP request, or an HTTP/transport error not classed as transient.
42    Http,
43    /// The response body was not valid UTF-8.
44    Utf8,
45}
46
47impl fmt::Display for IdTokenInternalErrorKind {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            IdTokenInternalErrorKind::Url => write!(f, "URL parsing error"),
51            IdTokenInternalErrorKind::SerDe => write!(f, "serialization/deserialization error"),
52            IdTokenInternalErrorKind::Http => write!(f, "unsuccessful HTTP request"),
53            IdTokenInternalErrorKind::Utf8 => write!(f, "invalid UTF8"),
54        }
55    }
56}
57
58/// Errors from an ID-token request.
59#[derive(Debug, thiserror::Error, Clone, Eq, PartialEq)]
60pub enum IdTokenError {
61    /// A transient network error; the request may succeed on retry.
62    #[error("network error requesting id token")]
63    NetworkError,
64    /// An internal failure (URL/serde/HTTP/UTF-8). Detail kept coarse for the public surface.
65    #[error("error requesting id token: {0}")]
66    Internal(IdTokenInternalErrorKind),
67}
68
69impl From<url::ParseError> for IdTokenError {
70    fn from(error: url::ParseError) -> Self {
71        tracing::error!(%error, "bad URL building id-token request");
72        IdTokenError::Internal(IdTokenInternalErrorKind::Url)
73    }
74}
75
76impl From<serde_json::Error> for IdTokenError {
77    fn from(error: serde_json::Error) -> Self {
78        tracing::error!(%error, "serde error in id-token request");
79        IdTokenError::Internal(IdTokenInternalErrorKind::SerDe)
80    }
81}
82
83impl From<core::str::Utf8Error> for IdTokenError {
84    fn from(error: core::str::Utf8Error) -> Self {
85        tracing::error!(%error, "invalid utf8 in id-token response");
86        IdTokenError::Internal(IdTokenInternalErrorKind::Utf8)
87    }
88}
89
90impl From<ts_http_util::Error> for IdTokenError {
91    fn from(error: ts_http_util::Error) -> Self {
92        tracing::error!(%error, "http error in id-token request");
93        if crate::http_error_is_recoverable(error) {
94            IdTokenError::NetworkError
95        } else {
96            IdTokenError::Internal(IdTokenInternalErrorKind::Http)
97        }
98    }
99}
100
101// The shared Noise `connect` surfaces a `ConnectionError`; fold it into our error. The connect
102// crate's richer `InternalErrorKind` is collapsed onto the coarser id-token kinds.
103impl From<ConnectionError> for IdTokenError {
104    fn from(error: ConnectionError) -> Self {
105        use crate::tokio::connect::InternalErrorKind as Conn;
106        match error {
107            ConnectionError::NetworkError => IdTokenError::NetworkError,
108            ConnectionError::Internal(k) => IdTokenError::Internal(match k {
109                Conn::Url => IdTokenInternalErrorKind::Url,
110                Conn::SerDe => IdTokenInternalErrorKind::SerDe,
111                // Everything else is an unsuccessful request/handshake at the Noise layer.
112                Conn::Http
113                | Conn::MessageFormat
114                | Conn::Io
115                | Conn::ChallengeLength
116                | Conn::NoiseHandshake => IdTokenInternalErrorKind::Http,
117            }),
118        }
119    }
120}
121
122/// Request an OIDC ID token for this node from control, scoped to `audience` (the `aud` claim of the
123/// returned JWT). Opens a fresh Noise channel and POSTs to `/machine/id-token`. Returns the signed
124/// JWT string on success.
125///
126/// The whole connect + POST + response read is bounded by `ID_TOKEN_TIMEOUT`: a hung control
127/// plane is abandoned and reported as [`IdTokenError::NetworkError`] rather than pinning a
128/// half-open connection.
129pub async fn fetch_id_token(
130    config: &crate::Config,
131    node_keystate: &ts_keys::NodeState,
132    audience: &str,
133) -> Result<String, IdTokenError> {
134    let control_url = &config.server_url;
135    let rpc = async {
136        let http2_conn = crate::tokio::connect(
137            control_url,
138            &node_keystate.machine_keys,
139            config.allow_http_key_fetch,
140        )
141        .await?;
142        fetch_id_token_with(control_url, node_keystate, audience, &http2_conn).await
143    };
144
145    match tokio::time::timeout(ID_TOKEN_TIMEOUT, rpc).await {
146        Ok(result) => result,
147        Err(_elapsed) => {
148            tracing::error!(timeout = ?ID_TOKEN_TIMEOUT, "id-token request timed out");
149            Err(IdTokenError::NetworkError)
150        }
151    }
152}
153
154/// Inner: send the `/machine/id-token` POST over an already-established Noise channel.
155///
156/// Split out from [`fetch_id_token`] so the response-parsing logic ([`parse_token_response`]) is
157/// unit-testable independent of the Noise connect.
158pub(crate) async fn fetch_id_token_with(
159    control_url: &Url,
160    node_keystate: &ts_keys::NodeState,
161    audience: &str,
162    http2_conn: &Http2<BytesBody>,
163) -> Result<String, IdTokenError> {
164    let node_public_key = node_keystate.node_keys.public;
165
166    let req = TokenRequest {
167        cap_version: CapabilityVersion::CURRENT,
168        node_key: node_public_key,
169        audience: audience.to_string(),
170    };
171
172    let body = serde_json::to_string(&req)?;
173    let url = control_url.join("machine/id-token")?;
174
175    tracing::debug!(url = %url.as_str(), "requesting id token from control");
176
177    let response = http2_conn
178        .post(
179            &url,
180            [(
181                LOAD_BALANCER_HEADER_KEY.parse().unwrap(),
182                node_public_key.to_string().parse().unwrap(),
183            )],
184            Bytes::from(body).into(),
185        )
186        .await?;
187
188    let status = response.status();
189    let body = response.collect_bytes().await?;
190    parse_token_response(status, &body)
191}
192
193/// Turn a `/machine/id-token` HTTP response into the signed JWT string.
194///
195/// Pure (no I/O): factored out of [`fetch_id_token_with`] so the status/body branch logic is
196/// unit-testable without a live stream. A non-2xx status is [`IdTokenInternalErrorKind::Http`]
197/// (logging a truncated body); a 2xx body must be UTF-8 JSON deserializing to a [`TokenResponse`].
198fn parse_token_response(status: StatusCode, body: &[u8]) -> Result<String, IdTokenError> {
199    if !status.is_success() {
200        let mut truncated = body.to_vec();
201        truncated.truncate(512);
202        let preview = core::str::from_utf8(&truncated).unwrap_or("<invalid utf8>");
203        tracing::error!(body = %preview, %status, "id-token request failed");
204        return Err(IdTokenError::Internal(IdTokenInternalErrorKind::Http));
205    }
206
207    let body = core::str::from_utf8(body)?;
208    let resp: TokenResponse = serde_json::from_str(body)?;
209
210    Ok(resp.id_token)
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::tokio::connect::{ConnectionError, InternalErrorKind as ConnKind};
217
218    // --- Error `From` conversions ---
219
220    #[test]
221    fn connection_error_network_maps_to_network() {
222        assert_eq!(
223            IdTokenError::from(ConnectionError::NetworkError),
224            IdTokenError::NetworkError
225        );
226    }
227
228    #[test]
229    fn connection_error_internal_kinds_map_correctly() {
230        use IdTokenInternalErrorKind as Id;
231        let cases = [
232            (ConnKind::Url, Id::Url),
233            (ConnKind::SerDe, Id::SerDe),
234            (ConnKind::Http, Id::Http),
235            (ConnKind::MessageFormat, Id::Http),
236            (ConnKind::Io, Id::Http),
237            (ConnKind::ChallengeLength, Id::Http),
238            (ConnKind::NoiseHandshake, Id::Http),
239        ];
240        for (conn, expected) in cases {
241            assert_eq!(
242                IdTokenError::from(ConnectionError::Internal(conn)),
243                IdTokenError::Internal(expected),
244                "ConnectionError::Internal({conn:?}) should map to Internal({expected:?})"
245            );
246        }
247    }
248
249    #[test]
250    fn serde_error_maps_to_internal_serde() {
251        let err = serde_json::from_str::<TokenResponse>("not json").unwrap_err();
252        assert_eq!(
253            IdTokenError::from(err),
254            IdTokenError::Internal(IdTokenInternalErrorKind::SerDe)
255        );
256    }
257
258    #[test]
259    fn url_parse_error_maps_to_internal_url() {
260        let err = Url::parse("not a url").unwrap_err();
261        assert_eq!(
262            IdTokenError::from(err),
263            IdTokenError::Internal(IdTokenInternalErrorKind::Url)
264        );
265    }
266
267    #[test]
268    fn utf8_error_maps_to_internal_utf8() {
269        // Route the invalid bytes through a runtime Vec so the `invalid_from_utf8` lint (which only
270        // fires on compile-time-known literals) doesn't flag a genuinely intentional bad input.
271        let bytes = vec![0xffu8, 0xfe];
272        let err = core::str::from_utf8(&bytes).unwrap_err();
273        assert_eq!(
274            IdTokenError::from(err),
275            IdTokenError::Internal(IdTokenInternalErrorKind::Utf8)
276        );
277    }
278
279    #[test]
280    fn http_util_error_non_recoverable_maps_to_internal_http() {
281        // A non-recoverable http error (e.g. an invalid response) folds onto Internal(Http).
282        let err = ts_http_util::Error::InvalidResponse;
283        assert_eq!(
284            IdTokenError::from(err),
285            IdTokenError::Internal(IdTokenInternalErrorKind::Http)
286        );
287    }
288
289    #[test]
290    fn http_util_error_recoverable_maps_to_network() {
291        // A recoverable http error (transient I/O) is surfaced as a transient NetworkError.
292        let err = ts_http_util::Error::Io;
293        assert_eq!(IdTokenError::from(err), IdTokenError::NetworkError);
294    }
295
296    // --- Response parse ---
297
298    #[test]
299    fn parse_token_response_ok() {
300        let body = br#"{"id_token":"abc.def.ghi"}"#;
301        let token = parse_token_response(StatusCode::OK, body).unwrap();
302        assert_eq!(token, "abc.def.ghi");
303    }
304
305    #[test]
306    fn parse_token_response_non_success_is_http() {
307        let err =
308            parse_token_response(StatusCode::INTERNAL_SERVER_ERROR, b"upstream boom").unwrap_err();
309        assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::Http));
310    }
311
312    #[test]
313    fn parse_token_response_invalid_json_is_serde() {
314        let err = parse_token_response(StatusCode::OK, b"{not json").unwrap_err();
315        assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::SerDe));
316    }
317
318    #[test]
319    fn parse_token_response_invalid_utf8_is_utf8() {
320        let err = parse_token_response(StatusCode::OK, &[0xff, 0xfe, 0xfd]).unwrap_err();
321        assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::Utf8));
322    }
323
324    #[test]
325    fn parse_token_response_missing_id_token_errors() {
326        let err = parse_token_response(StatusCode::OK, b"{}").unwrap_err();
327        // Missing required `id_token` field is a deserialization failure.
328        assert_eq!(err, IdTokenError::Internal(IdTokenInternalErrorKind::SerDe));
329    }
330}