azeventhubs/amqp/
amqp_cbs_link.rs

1use fe2o3_amqp::link::DetachError;
2use futures_util::StreamExt;
3use tokio::task::JoinError;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration as StdDuration;
7
8use fe2o3_amqp_cbs::{client::CbsClient, AsyncCbsTokenProvider};
9use time::OffsetDateTime;
10use tokio::sync::mpsc;
11use tokio::sync::oneshot;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14
15use crate::util::sharable::Sharable;
16use crate::util::time::{DelayQueue, Key};
17
18use super::error::AmqpCbsEventLoopStopped;
19use super::{cbs_token_provider::CbsTokenProvider, error::CbsAuthError};
20
21const DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION: StdDuration = StdDuration::from_secs(30 * 60);
22const CBS_LINK_COMMAND_QUEUE_SIZE: usize = 128;
23
24// This is a monotonically incrementing identifier that is assigned when a new link is created.
25type LinkIdentifier = u32;
26
27pub(crate) enum Command {
28    NewAuthorizationRefresher {
29        auth: AuthorizationRefresher,
30        result_sender: oneshot::Sender<Result<(), CbsAuthError>>,
31    },
32    RemoveAuthorizationRefresher(LinkIdentifier),
33}
34
35pub(crate) enum Refresher {
36    /// This is a placeholder that is only used to avoid spinning the runtime when the
37    /// delay queue is exhausted.
38    Placeholder,
39    Authorization(AuthorizationRefresher),
40}
41
42pub(crate) struct AuthorizationRefresher {
43    link_identifier: LinkIdentifier,
44    endpoint: String,
45    resource: String,
46    required_claims: Vec<String>,
47}
48
49pub(crate) struct AmqpCbsLinkHandle {
50    command_sender: mpsc::Sender<Command>,
51    stop_sender: CancellationToken,
52    join_handle: JoinHandle<Result<(), DetachError>>,
53}
54
55impl std::fmt::Debug for AmqpCbsLinkHandle {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("AmqpCbsLinkHandle").finish()
58    }
59}
60
61impl AmqpCbsLinkHandle {
62    pub(crate) fn command_sender(&self) -> &mpsc::Sender<Command> {
63        &self.command_sender
64    }
65
66    pub(crate) async fn request_refreshable_authorization(
67        &mut self,
68        link_identifier: u32,
69        endpoint: String,
70        resource: String,
71        required_claims: Vec<String>,
72    ) -> Result<Result<(), CbsAuthError>, AmqpCbsEventLoopStopped> {
73        let auth = AuthorizationRefresher {
74            link_identifier,
75            endpoint,
76            resource,
77            required_claims,
78        };
79        let (result_sender, result) = oneshot::channel();
80        let command = Command::NewAuthorizationRefresher {
81            auth,
82            result_sender,
83        };
84        self.command_sender
85            .send(command)
86            .await
87            .map_err(|_| AmqpCbsEventLoopStopped {})?;
88
89        result.await.map_err(|_| AmqpCbsEventLoopStopped {})
90    }
91
92    pub(crate) fn stop(&self) {
93        self.stop_sender.cancel();
94    }
95
96    pub(crate) fn join_handle_mut(&mut self) -> &mut JoinHandle<Result<(), DetachError>> {
97        &mut self.join_handle
98    }
99}
100
101impl Sharable<AmqpCbsLinkHandle> {
102    pub(crate) async fn request_refreshable_authorization(
103        &mut self,
104        link_identifier: u32,
105        endpoint: String,
106        resource: String,
107        required_claims: Vec<String>,
108    ) -> Result<Result<(), CbsAuthError>, AmqpCbsEventLoopStopped> {
109        let result = match self {
110            Self::Owned(link) => {
111                link.request_refreshable_authorization(
112                    link_identifier,
113                    endpoint,
114                    resource,
115                    required_claims,
116                )
117                .await
118            }
119            Self::Shared(link) => {
120                link.write()
121                    .await
122                    .request_refreshable_authorization(
123                        link_identifier,
124                        endpoint,
125                        resource,
126                        required_claims,
127                    )
128                    .await
129            }
130            Self::None => unreachable!(),
131        };
132
133        match result {
134            Ok(Ok(_)) => Ok(Ok(())),
135            Ok(Err(err)) => 
136            {
137                log::error!("CBS authorization refresh failed: {}", err);
138                Ok(Err(err))
139            },
140            Err(err) => {
141                log::error!("CBS authorization refresh failed: {}", err);
142                Err(err)
143            },
144        }
145    }
146
147    pub(crate) async fn command_sender(&self) -> mpsc::Sender<Command> {
148        match self {
149            Self::Owned(link) => link.command_sender().clone(),
150            Self::Shared(link) => link.read().await.command_sender().clone(),
151            Self::None => unreachable!(),
152        }
153    }
154
155    /// Stop regardless of ownership
156    pub(crate) async fn stop(&self) {
157        match self {
158            Self::Owned(link) => link.stop(),
159            Self::Shared(link) => link.write().await.stop(),
160            Self::None => unreachable!(),
161        }
162    }
163
164    pub(crate) async fn stop_if_owned(&self) {
165        match self {
166            Self::Owned(link) => link.stop(),
167            Self::Shared(link) => {
168                if Arc::strong_count(link) == 1 {
169                    link.write().await.stop();
170                }
171            },
172            Self::None => unreachable!(),
173        }
174    }
175
176    /// Join regardless of ownership
177    pub(crate) async fn join(&mut self) -> Result<Result<(), DetachError>, JoinError> {
178        match self {
179            Self::Owned(link) => link.join_handle_mut().await,
180            Self::Shared(link) => {
181                let mut link = link.write().await;
182                link.join_handle_mut().await
183            }
184            Self::None => unreachable!(),
185        }
186    }
187
188    pub(crate) async fn join_if_owned(&mut self) -> Result<Result<(), DetachError>, JoinError> {
189        match self {
190            Self::Owned(link) => link.join_handle_mut().await,
191            Self::Shared(link) => match Arc::strong_count(link) {
192                1 => link.write().await.join_handle_mut().await,
193                _ => Ok(Ok(())),
194            },
195            Self::None => unreachable!(),
196        }
197    }
198}
199
200pub(crate) struct AmqpCbsLink {
201    pub stop: CancellationToken,
202    pub commands: mpsc::Receiver<Command>,
203    pub active_link_identifiers: HashMap<LinkIdentifier, Key>,
204    pub delay_queue: DelayQueue<Refresher>,
205    pub cbs_token_provider: CbsTokenProvider,
206    pub cbs_client: CbsClient,
207}
208
209impl AmqpCbsLink {
210    pub(crate) fn new(
211        cbs_token_provider: CbsTokenProvider,
212        cbs_client: CbsClient,
213        commands: mpsc::Receiver<Command>,
214        stop: CancellationToken,
215    ) -> Self {
216        let mut delay_queue = DelayQueue::new();
217        delay_queue.insert(
218            Refresher::Placeholder,
219            DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION,
220        );
221
222        AmqpCbsLink {
223            stop,
224            commands,
225            active_link_identifiers: HashMap::new(),
226            delay_queue,
227            cbs_token_provider,
228            cbs_client,
229        }
230    }
231
232    cfg_not_wasm32! {
233        pub(crate) fn spawn(
234            cbs_token_provider: CbsTokenProvider,
235            cbs_client: CbsClient,
236        ) -> AmqpCbsLinkHandle {
237            let (command_sender, commands) = mpsc::channel(CBS_LINK_COMMAND_QUEUE_SIZE);
238            let stop_sender = CancellationToken::new();
239            let stop = stop_sender.child_token();
240            let amqp_cbs_link = AmqpCbsLink::new(cbs_token_provider, cbs_client, commands, stop);
241
242            let join_handle = tokio::spawn(amqp_cbs_link.event_loop());
243            AmqpCbsLinkHandle {
244                command_sender,
245                stop_sender,
246                join_handle,
247            }
248        }
249    }
250
251    cfg_wasm32! {
252        pub(crate) fn spawn_local(
253            cbs_token_provider: CbsTokenProvider,
254            cbs_client: CbsClient,
255        ) -> AmqpCbsLinkHandle {
256            let (command_sender, commands) = mpsc::channel(CBS_LINK_COMMAND_QUEUE_SIZE);
257            let stop_sender = CancellationToken::new();
258            let stop = stop_sender.child_token();
259            let amqp_cbs_link = AmqpCbsLink::new(cbs_token_provider, cbs_client, commands, stop);
260
261            let join_handle = tokio::task::spawn_local(amqp_cbs_link.event_loop());
262            AmqpCbsLinkHandle {
263                command_sender,
264                stop_sender,
265                join_handle,
266            }
267        }
268    }
269
270    async fn request_authorization_using_cbs(
271        &mut self,
272        endpoint: impl AsRef<str>,
273        resource: impl AsRef<str>,
274        required_claims: impl IntoIterator<Item = impl AsRef<str>>,
275    ) -> Result<Option<crate::util::time::Instant>, CbsAuthError> {
276        log::debug!("Requesting CBS authorization.");
277
278        let resource = resource.as_ref();
279        let token = self
280            .cbs_token_provider
281            .get_token_async(endpoint, resource, required_claims)
282            .await?;
283
284        // find the smallest timeout
285        let expires_at_utc = token.expires_at_utc().clone().map(OffsetDateTime::from);
286
287        // TODO: Is there any way to convert directly from OffsetDateTime/Timestamp to StdInstant?
288        let expires_at_instant = expires_at_utc.map(|expires_at| {
289            let now_instant = crate::util::time::Instant::now();
290            let now = crate::util::time::now_utc(); // TODO: is there any way to convert instant to datetime?
291            let timespan = expires_at - now;
292            now_instant + timespan.unsigned_abs()
293        });
294
295        // TODO: There are some custom application properties in the dotnet sdk.
296        // Maybe we should have a custom type that supports this?
297        self.cbs_client.put_token(resource, token).await?;
298
299        Ok(expires_at_instant)
300    }
301
302    async fn handle_command(&mut self, command: Command) {
303        match command {
304            Command::NewAuthorizationRefresher {
305                auth,
306                result_sender,
307            } => {
308                // First request authorization once, and then schedule a refresh.
309                let result = self
310                    .request_authorization_using_cbs(
311                        &auth.endpoint,
312                        &auth.resource,
313                        &auth.required_claims,
314                    )
315                    .await;
316                match result {
317                    Ok(expires_at) => {
318                        if let Some(expires_at) = expires_at {
319                            if expires_at > crate::util::time::Instant::now() {
320                                let link_identifier = auth.link_identifier;
321                                let key = self
322                                    .delay_queue
323                                    .insert_at(Refresher::Authorization(auth), expires_at);
324                                self.active_link_identifiers.insert(link_identifier, key);
325                            }
326                        }
327                        let _ = result_sender.send(Ok(()));
328                    }
329                    Err(err) => {
330                        let _ = result_sender.send(Err(err));
331                    }
332                }
333            }
334            Command::RemoveAuthorizationRefresher(link_identifier) => {
335                let key = self.active_link_identifiers.remove(&link_identifier);
336                if let Some(key) = key {
337                    self.delay_queue.try_remove(&key);
338                }
339            }
340        }
341    }
342
343    async fn handle_refresher(&mut self, refresher: Refresher) {
344        match refresher {
345            Refresher::Placeholder => {
346                let _key = self.delay_queue.insert(
347                    Refresher::Placeholder,
348                    DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION,
349                );
350            }
351            Refresher::Authorization(auth) => {
352                let link_identifier = auth.link_identifier;
353                let result = self
354                    .request_authorization_using_cbs(
355                        &auth.endpoint,
356                        &auth.resource,
357                        &auth.required_claims,
358                    )
359                    .await;
360                match result {
361                    Ok(expires_at) => {
362                        if let Some(expires_at) = expires_at {
363                            if expires_at > crate::util::time::Instant::now() {
364                                let key = self
365                                    .delay_queue
366                                    .insert_at(Refresher::Authorization(auth), expires_at);
367                                self.active_link_identifiers.insert(link_identifier, key);
368                            }
369                        }
370                    }
371                    Err(err) => {
372                        // TODO: log error
373                        log::error!("CBS authorization refresh failed: {}", err);
374                    }
375                }
376            }
377        }
378    }
379
380    pub(crate) async fn event_loop(mut self) -> Result<(), DetachError> {
381        loop {
382            tokio::select! {
383                _stop_cbs_link = self.stop.cancelled() => {
384                    return self.cbs_client.close().await
385                },
386                command = self.commands.recv() => {
387                    if let Some(command) = command {
388                        self.handle_command(command).await;
389                    } else {
390                        // All senders including the one held by AmqpConnectionScope have been dropped, so we should stop.
391                        return self.cbs_client.close().await
392                    }
393                },
394                refresher = self.delay_queue.next() => {
395                    // A `None` is returned if the queue is exhausted. New refresher may still be
396                    // added in the future.
397                    if let Some(refresher) = refresher {
398                        self.handle_refresher(refresher.into_inner()).await;
399                    } else {
400                        // The delay queue is exhausted. We need to add a placeholder to avoid
401                        // spinning the runtime.
402                        let _key = self.delay_queue.insert(Refresher::Placeholder, DELAY_QUEUE_PLACEHOLDER_REFRESH_DURATION);
403                    }
404                }
405            }
406        }
407    }
408}