use std::collections::HashMap;
use std::sync::Arc;
use tracing::info;
use crate::classifier::BodyClassifier;
use crate::config::ScatterProxyConfig;
use crate::error::ScatterProxyError;
use crate::metrics::PoolMetrics;
use crate::task::TaskHandle;
use crate::ScatterProxy;
pub struct ScatterProxyRouter {
routes: HashMap<String, ScatterProxy>,
}
impl ScatterProxyRouter {
pub async fn new(
routes: impl IntoIterator<Item = (impl Into<String>, ScatterProxyConfig)>,
classifier: impl BodyClassifier,
) -> Result<Self, ScatterProxyError> {
let classifier: Arc<dyn BodyClassifier> = Arc::new(classifier);
let mut map = HashMap::new();
for (host, mut config) in routes {
let host = host.into();
if config.name.is_none() {
config.name = Some(host.clone());
}
let sp = ScatterProxy::new_arc(config, Arc::clone(&classifier)).await?;
map.insert(host, sp);
}
info!(hosts = map.len(), "ScatterProxyRouter initialised");
Ok(Self { routes: map })
}
pub async fn submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
let sp = self.route(&request)?;
Ok(sp.submit(request).await)
}
pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
let sp = self.route(&request)?;
sp.try_submit(request)
}
pub fn metrics_for(&self, host: &str) -> Option<PoolMetrics> {
self.routes.get(host).map(|sp| sp.metrics())
}
pub fn all_metrics(&self) -> HashMap<String, PoolMetrics> {
self.routes
.iter()
.map(|(host, sp)| (host.clone(), sp.metrics()))
.collect()
}
pub fn hosts(&self) -> Vec<&str> {
self.routes.keys().map(String::as_str).collect()
}
pub async fn shutdown(self) {
for (_, sp) in self.routes {
sp.shutdown().await;
}
}
fn route(&self, request: &reqwest::Request) -> Result<&ScatterProxy, ScatterProxyError> {
let host = request.url().host_str().unwrap_or("").to_string();
self.routes
.get(&host)
.ok_or(ScatterProxyError::UnknownHost(host))
}
}