use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::debug;
pub type ContextBindingCallback = Arc<
dyn Fn(
Vec<serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct ContextBinding {
pub name: String,
pub callback: ContextBindingCallback,
}
#[derive(Default)]
pub struct ContextBindingRegistry {
bindings: RwLock<HashMap<String, ContextBinding>>,
}
impl ContextBindingRegistry {
pub fn new() -> Self {
Self {
bindings: RwLock::new(HashMap::new()),
}
}
pub async fn expose_function<F, Fut>(&self, name: &str, callback: F)
where
F: Fn(Vec<serde_json::Value>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, String>> + Send + 'static,
{
debug!("Registering context-level binding: {}", name);
let boxed_callback: ContextBindingCallback = Arc::new(move |args| Box::pin(callback(args)));
let binding = ContextBinding {
name: name.to_string(),
callback: boxed_callback,
};
let mut bindings = self.bindings.write().await;
bindings.insert(name.to_string(), binding);
}
pub async fn remove_function(&self, name: &str) -> bool {
debug!("Removing context-level binding: {}", name);
let mut bindings = self.bindings.write().await;
bindings.remove(name).is_some()
}
pub async fn get_all(&self) -> Vec<ContextBinding> {
let bindings = self.bindings.read().await;
bindings.values().cloned().collect()
}
pub async fn has(&self, name: &str) -> bool {
let bindings = self.bindings.read().await;
bindings.contains_key(name)
}
}
use super::BrowserContext;
impl BrowserContext {
pub async fn expose_function<F, Fut>(&self, name: &str, callback: F)
where
F: Fn(Vec<serde_json::Value>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, String>> + Send + 'static,
{
self.binding_registry.expose_function(name, callback).await;
}
pub async fn remove_exposed_function(&self, name: &str) -> bool {
self.binding_registry.remove_function(name).await
}
}
#[cfg(test)]
mod tests;