use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use tokio::sync::{RwLock, broadcast};
use tracing::debug;
use viewpoint_cdp::CdpConnection;
use crate::error::NetworkError;
use crate::network::{Route, RouteHandlerRegistry, UrlMatcher, UrlPattern};
#[derive(Debug, Clone)]
pub enum RouteChangeNotification {
RouteAdded,
}
struct ContextRouteHandler {
pattern: Box<dyn UrlMatcher>,
handler: Arc<
dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
+ Send
+ Sync,
>,
}
impl std::fmt::Debug for ContextRouteHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContextRouteHandler")
.field("pattern", &"<pattern>")
.field("handler", &"<fn>")
.finish()
}
}
#[derive(Debug)]
pub struct ContextRouteRegistry {
handlers: RwLock<Vec<ContextRouteHandler>>,
connection: Arc<CdpConnection>,
context_id: String,
route_change_tx: broadcast::Sender<RouteChangeNotification>,
page_registries: RwLock<Vec<Weak<RouteHandlerRegistry>>>,
}
impl ContextRouteRegistry {
pub fn new(connection: Arc<CdpConnection>, context_id: String) -> Self {
let (route_change_tx, _) = broadcast::channel(16);
Self {
handlers: RwLock::new(Vec::new()),
connection,
context_id,
route_change_tx,
page_registries: RwLock::new(Vec::new()),
}
}
pub async fn register_page_registry(&self, registry: &Arc<RouteHandlerRegistry>) {
let mut registries = self.page_registries.write().await;
registries.retain(|weak| weak.strong_count() > 0);
registries.push(Arc::downgrade(registry));
}
async fn enable_fetch_on_all_pages(&self) -> Result<(), NetworkError> {
let registries = self.page_registries.read().await;
for weak in registries.iter() {
if let Some(registry) = weak.upgrade() {
registry.ensure_fetch_enabled_public().await?;
}
}
Ok(())
}
pub fn subscribe_route_changes(&self) -> broadcast::Receiver<RouteChangeNotification> {
self.route_change_tx.subscribe()
}
pub async fn route<M, H, Fut>(&self, pattern: M, handler: H) -> Result<(), NetworkError>
where
M: Into<UrlPattern>,
H: Fn(Route) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
{
let pattern = pattern.into();
debug!(context_id = %self.context_id, "Registering context route");
let handler: Arc<
dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
+ Send
+ Sync,
> = Arc::new(move |route| Box::pin(handler(route)));
let mut handlers = self.handlers.write().await;
handlers.push(ContextRouteHandler {
pattern: Box::new(pattern),
handler,
});
drop(handlers);
self.enable_fetch_on_all_pages().await?;
let _ = self
.route_change_tx
.send(RouteChangeNotification::RouteAdded);
Ok(())
}
pub async fn route_predicate<P, H, Fut>(
&self,
predicate: P,
handler: H,
) -> Result<(), NetworkError>
where
P: Fn(&str) -> bool + Send + Sync + 'static,
H: Fn(Route) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
{
struct PredicateMatcher<F>(F);
impl<F: Fn(&str) -> bool + Send + Sync> UrlMatcher for PredicateMatcher<F> {
fn matches(&self, url: &str) -> bool {
(self.0)(url)
}
}
let handler: Arc<
dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
+ Send
+ Sync,
> = Arc::new(move |route| Box::pin(handler(route)));
let mut handlers = self.handlers.write().await;
handlers.push(ContextRouteHandler {
pattern: Box::new(PredicateMatcher(predicate)),
handler,
});
drop(handlers);
self.enable_fetch_on_all_pages().await?;
let _ = self
.route_change_tx
.send(RouteChangeNotification::RouteAdded);
Ok(())
}
pub async fn unroute(&self, pattern: &str) {
let mut handlers = self.handlers.write().await;
handlers.retain(|h| !h.pattern.matches(pattern));
}
pub async fn unroute_all(&self) {
let mut handlers = self.handlers.write().await;
handlers.clear();
}
pub async fn has_routes(&self) -> bool {
let handlers = self.handlers.read().await;
!handlers.is_empty()
}
pub async fn route_count(&self) -> usize {
let handlers = self.handlers.read().await;
handlers.len()
}
#[deprecated(note = "Use set_context_routes on RouteHandlerRegistry instead")]
pub async fn apply_to_page(
&self,
page_registry: &RouteHandlerRegistry,
) -> Result<(), NetworkError> {
let handlers = self.handlers.read().await;
for handler in handlers.iter() {
let handler_clone = handler.handler.clone();
page_registry
.route("*", move |route| {
let handler = handler_clone.clone();
async move { handler(route).await }
})
.await?;
}
Ok(())
}
pub async fn find_handler(
&self,
url: &str,
) -> Option<
Arc<
dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
+ Send
+ Sync,
>,
> {
let handlers = self.handlers.read().await;
for handler in handlers.iter().rev() {
if handler.pattern.matches(url) {
return Some(handler.handler.clone());
}
}
None
}
pub async fn find_all_handlers(
&self,
url: &str,
) -> Vec<
Arc<
dyn Fn(Route) -> Pin<Box<dyn Future<Output = Result<(), NetworkError>> + Send>>
+ Send
+ Sync,
>,
> {
let handlers = self.handlers.read().await;
handlers
.iter()
.rev()
.filter(|h| h.pattern.matches(url))
.map(|h| h.handler.clone())
.collect()
}
}
#[cfg(test)]
mod tests;