1use fe2o3_amqp::link::DetachError;
2use futures_util::StreamExt;
3use tokio::task::JoinError;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration as StdDuration;
7
8use fe2o3_amqp_cbs::{client::CbsClient, AsyncCbsTokenProvider};
9use time::OffsetDateTime;
10use tokio::sync::mpsc;
11use tokio::sync::oneshot;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14
15use crate::util::sharable::Sharable;
16use crate::util::time::{DelayQueue, Key};
17
18use super::error::AmqpCbsEventLoopStopped;
19use super::{cbs_token_provider::CbsTokenProvider, error::CbsAuthError};
20
21const DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION: StdDuration = StdDuration::from_secs(30 * 60);
22const CBS_LINK_COMMAND_QUEUE_SIZE: usize = 128;
23
24type LinkIdentifier = u32;
26
27pub(crate) enum Command {
28 NewAuthorizationRefresher {
29 auth: AuthorizationRefresher,
30 result_sender: oneshot::Sender<Result<(), CbsAuthError>>,
31 },
32 RemoveAuthorizationRefresher(LinkIdentifier),
33}
34
35pub(crate) enum Refresher {
36 Placeholder,
39 Authorization(AuthorizationRefresher),
40}
41
42pub(crate) struct AuthorizationRefresher {
43 link_identifier: LinkIdentifier,
44 endpoint: String,
45 resource: String,
46 required_claims: Vec<String>,
47}
48
49pub(crate) struct AmqpCbsLinkHandle {
50 command_sender: mpsc::Sender<Command>,
51 stop_sender: CancellationToken,
52 join_handle: JoinHandle<Result<(), DetachError>>,
53}
54
55impl std::fmt::Debug for AmqpCbsLinkHandle {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("AmqpCbsLinkHandle").finish()
58 }
59}
60
61impl AmqpCbsLinkHandle {
62 pub(crate) fn command_sender(&self) -> &mpsc::Sender<Command> {
63 &self.command_sender
64 }
65
66 pub(crate) async fn request_refreshable_authorization(
67 &mut self,
68 link_identifier: u32,
69 endpoint: String,
70 resource: String,
71 required_claims: Vec<String>,
72 ) -> Result<Result<(), CbsAuthError>, AmqpCbsEventLoopStopped> {
73 let auth = AuthorizationRefresher {
74 link_identifier,
75 endpoint,
76 resource,
77 required_claims,
78 };
79 let (result_sender, result) = oneshot::channel();
80 let command = Command::NewAuthorizationRefresher {
81 auth,
82 result_sender,
83 };
84 self.command_sender
85 .send(command)
86 .await
87 .map_err(|_| AmqpCbsEventLoopStopped {})?;
88
89 result.await.map_err(|_| AmqpCbsEventLoopStopped {})
90 }
91
92 pub(crate) fn stop(&self) {
93 self.stop_sender.cancel();
94 }
95
96 pub(crate) fn join_handle_mut(&mut self) -> &mut JoinHandle<Result<(), DetachError>> {
97 &mut self.join_handle
98 }
99}
100
101impl Sharable<AmqpCbsLinkHandle> {
102 pub(crate) async fn request_refreshable_authorization(
103 &mut self,
104 link_identifier: u32,
105 endpoint: String,
106 resource: String,
107 required_claims: Vec<String>,
108 ) -> Result<Result<(), CbsAuthError>, AmqpCbsEventLoopStopped> {
109 let result = match self {
110 Self::Owned(link) => {
111 link.request_refreshable_authorization(
112 link_identifier,
113 endpoint,
114 resource,
115 required_claims,
116 )
117 .await
118 }
119 Self::Shared(link) => {
120 link.write()
121 .await
122 .request_refreshable_authorization(
123 link_identifier,
124 endpoint,
125 resource,
126 required_claims,
127 )
128 .await
129 }
130 Self::None => unreachable!(),
131 };
132
133 match result {
134 Ok(Ok(_)) => Ok(Ok(())),
135 Ok(Err(err)) =>
136 {
137 log::error!("CBS authorization refresh failed: {}", err);
138 Ok(Err(err))
139 },
140 Err(err) => {
141 log::error!("CBS authorization refresh failed: {}", err);
142 Err(err)
143 },
144 }
145 }
146
147 pub(crate) async fn command_sender(&self) -> mpsc::Sender<Command> {
148 match self {
149 Self::Owned(link) => link.command_sender().clone(),
150 Self::Shared(link) => link.read().await.command_sender().clone(),
151 Self::None => unreachable!(),
152 }
153 }
154
155 pub(crate) async fn stop(&self) {
157 match self {
158 Self::Owned(link) => link.stop(),
159 Self::Shared(link) => link.write().await.stop(),
160 Self::None => unreachable!(),
161 }
162 }
163
164 pub(crate) async fn stop_if_owned(&self) {
165 match self {
166 Self::Owned(link) => link.stop(),
167 Self::Shared(link) => {
168 if Arc::strong_count(link) == 1 {
169 link.write().await.stop();
170 }
171 },
172 Self::None => unreachable!(),
173 }
174 }
175
176 pub(crate) async fn join(&mut self) -> Result<Result<(), DetachError>, JoinError> {
178 match self {
179 Self::Owned(link) => link.join_handle_mut().await,
180 Self::Shared(link) => {
181 let mut link = link.write().await;
182 link.join_handle_mut().await
183 }
184 Self::None => unreachable!(),
185 }
186 }
187
188 pub(crate) async fn join_if_owned(&mut self) -> Result<Result<(), DetachError>, JoinError> {
189 match self {
190 Self::Owned(link) => link.join_handle_mut().await,
191 Self::Shared(link) => match Arc::strong_count(link) {
192 1 => link.write().await.join_handle_mut().await,
193 _ => Ok(Ok(())),
194 },
195 Self::None => unreachable!(),
196 }
197 }
198}
199
200pub(crate) struct AmqpCbsLink {
201 pub stop: CancellationToken,
202 pub commands: mpsc::Receiver<Command>,
203 pub active_link_identifiers: HashMap<LinkIdentifier, Key>,
204 pub delay_queue: DelayQueue<Refresher>,
205 pub cbs_token_provider: CbsTokenProvider,
206 pub cbs_client: CbsClient,
207}
208
209impl AmqpCbsLink {
210 pub(crate) fn new(
211 cbs_token_provider: CbsTokenProvider,
212 cbs_client: CbsClient,
213 commands: mpsc::Receiver<Command>,
214 stop: CancellationToken,
215 ) -> Self {
216 let mut delay_queue = DelayQueue::new();
217 delay_queue.insert(
218 Refresher::Placeholder,
219 DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION,
220 );
221
222 AmqpCbsLink {
223 stop,
224 commands,
225 active_link_identifiers: HashMap::new(),
226 delay_queue,
227 cbs_token_provider,
228 cbs_client,
229 }
230 }
231
232 cfg_not_wasm32! {
233 pub(crate) fn spawn(
234 cbs_token_provider: CbsTokenProvider,
235 cbs_client: CbsClient,
236 ) -> AmqpCbsLinkHandle {
237 let (command_sender, commands) = mpsc::channel(CBS_LINK_COMMAND_QUEUE_SIZE);
238 let stop_sender = CancellationToken::new();
239 let stop = stop_sender.child_token();
240 let amqp_cbs_link = AmqpCbsLink::new(cbs_token_provider, cbs_client, commands, stop);
241
242 let join_handle = tokio::spawn(amqp_cbs_link.event_loop());
243 AmqpCbsLinkHandle {
244 command_sender,
245 stop_sender,
246 join_handle,
247 }
248 }
249 }
250
251 cfg_wasm32! {
252 pub(crate) fn spawn_local(
253 cbs_token_provider: CbsTokenProvider,
254 cbs_client: CbsClient,
255 ) -> AmqpCbsLinkHandle {
256 let (command_sender, commands) = mpsc::channel(CBS_LINK_COMMAND_QUEUE_SIZE);
257 let stop_sender = CancellationToken::new();
258 let stop = stop_sender.child_token();
259 let amqp_cbs_link = AmqpCbsLink::new(cbs_token_provider, cbs_client, commands, stop);
260
261 let join_handle = tokio::task::spawn_local(amqp_cbs_link.event_loop());
262 AmqpCbsLinkHandle {
263 command_sender,
264 stop_sender,
265 join_handle,
266 }
267 }
268 }
269
270 async fn request_authorization_using_cbs(
271 &mut self,
272 endpoint: impl AsRef<str>,
273 resource: impl AsRef<str>,
274 required_claims: impl IntoIterator<Item = impl AsRef<str>>,
275 ) -> Result<Option<crate::util::time::Instant>, CbsAuthError> {
276 log::debug!("Requesting CBS authorization.");
277
278 let resource = resource.as_ref();
279 let token = self
280 .cbs_token_provider
281 .get_token_async(endpoint, resource, required_claims)
282 .await?;
283
284 let expires_at_utc = token.expires_at_utc().clone().map(OffsetDateTime::from);
286
287 let expires_at_instant = expires_at_utc.map(|expires_at| {
289 let now_instant = crate::util::time::Instant::now();
290 let now = crate::util::time::now_utc(); let timespan = expires_at - now;
292 now_instant + timespan.unsigned_abs()
293 });
294
295 self.cbs_client.put_token(resource, token).await?;
298
299 Ok(expires_at_instant)
300 }
301
302 async fn handle_command(&mut self, command: Command) {
303 match command {
304 Command::NewAuthorizationRefresher {
305 auth,
306 result_sender,
307 } => {
308 let result = self
310 .request_authorization_using_cbs(
311 &auth.endpoint,
312 &auth.resource,
313 &auth.required_claims,
314 )
315 .await;
316 match result {
317 Ok(expires_at) => {
318 if let Some(expires_at) = expires_at {
319 if expires_at > crate::util::time::Instant::now() {
320 let link_identifier = auth.link_identifier;
321 let key = self
322 .delay_queue
323 .insert_at(Refresher::Authorization(auth), expires_at);
324 self.active_link_identifiers.insert(link_identifier, key);
325 }
326 }
327 let _ = result_sender.send(Ok(()));
328 }
329 Err(err) => {
330 let _ = result_sender.send(Err(err));
331 }
332 }
333 }
334 Command::RemoveAuthorizationRefresher(link_identifier) => {
335 let key = self.active_link_identifiers.remove(&link_identifier);
336 if let Some(key) = key {
337 self.delay_queue.try_remove(&key);
338 }
339 }
340 }
341 }
342
343 async fn handle_refresher(&mut self, refresher: Refresher) {
344 match refresher {
345 Refresher::Placeholder => {
346 let _key = self.delay_queue.insert(
347 Refresher::Placeholder,
348 DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION,
349 );
350 }
351 Refresher::Authorization(auth) => {
352 let link_identifier = auth.link_identifier;
353 let result = self
354 .request_authorization_using_cbs(
355 &auth.endpoint,
356 &auth.resource,
357 &auth.required_claims,
358 )
359 .await;
360 match result {
361 Ok(expires_at) => {
362 if let Some(expires_at) = expires_at {
363 if expires_at > crate::util::time::Instant::now() {
364 let key = self
365 .delay_queue
366 .insert_at(Refresher::Authorization(auth), expires_at);
367 self.active_link_identifiers.insert(link_identifier, key);
368 }
369 }
370 }
371 Err(err) => {
372 log::error!("CBS authorization refresh failed: {}", err);
374 }
375 }
376 }
377 }
378 }
379
380 pub(crate) async fn event_loop(mut self) -> Result<(), DetachError> {
381 loop {
382 tokio::select! {
383 _stop_cbs_link = self.stop.cancelled() => {
384 return self.cbs_client.close().await
385 },
386 command = self.commands.recv() => {
387 if let Some(command) = command {
388 self.handle_command(command).await;
389 } else {
390 return self.cbs_client.close().await
392 }
393 },
394 refresher = self.delay_queue.next() => {
395 if let Some(refresher) = refresher {
398 self.handle_refresher(refresher.into_inner()).await;
399 } else {
400 let _key = self.delay_queue.insert(Refresher::Placeholder, DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION);
403 }
404 }
405 }
406 }
407 }
408}