use std::sync::Arc;
use tokio::time::timeout;
use crate::backend::{Backend, ResponseRecord};
use crate::config::ShadowConfig;
use crate::divergence::Divergence;
use crate::error::ShadowError;
use crate::log::{DivergenceEntry, DivergenceLog};
#[derive(Debug)]
pub struct ShadowOutcome {
pub primary: ResponseRecord,
pub shadow: Option<ResponseRecord>,
pub divergence: Option<Divergence>,
pub skipped_by_sampler: bool,
pub shadow_failed: Option<String>,
}
#[derive(Clone)]
pub struct Shadower {
primary: Arc<dyn Backend>,
shadow: Arc<dyn Backend>,
config: ShadowConfig,
log: Arc<DivergenceLog>,
}
impl Shadower {
pub fn new(primary: Arc<dyn Backend>, shadow: Arc<dyn Backend>, config: ShadowConfig) -> Self {
Self {
primary,
shadow,
config,
log: Arc::new(DivergenceLog::default()),
}
}
#[must_use]
pub fn with_log(mut self, log: Arc<DivergenceLog>) -> Self {
self.log = log;
self
}
pub fn divergences(&self) -> Vec<DivergenceEntry> {
self.log.snapshot()
}
pub async fn call(&self, input: &[u8]) -> Result<ShadowOutcome, ShadowError> {
let should_shadow = self.config.should_shadow(input);
if !should_shadow {
let primary = self.primary.call(input).await?;
return Ok(ShadowOutcome {
primary,
shadow: None,
divergence: None,
skipped_by_sampler: true,
shadow_failed: None,
});
}
let primary_fut = self.primary.call(input);
let shadow_fut = timeout(self.config.shadow_timeout, self.shadow.call(input));
let (primary_res, shadow_res) = tokio::join!(primary_fut, shadow_fut);
let primary = primary_res?;
let (shadow, shadow_failed) = match shadow_res {
Ok(Ok(resp)) => (Some(resp), None),
Ok(Err(err)) => (None, Some(err.to_string())),
Err(_) => (None, Some("timeout".to_string())),
};
let divergence = match &shadow {
Some(s) => Divergence::compare(&primary, s, &self.config),
None => None,
};
if let Some(d) = &divergence {
self.log.push(DivergenceEntry {
key: input.to_vec(),
divergence: d.clone(),
});
}
Ok(ShadowOutcome {
primary,
shadow,
divergence,
skipped_by_sampler: false,
shadow_failed,
})
}
pub fn divergence_count(&self) -> usize {
self.log.len()
}
}