use crate::oob::provider::{OobError, OobProviderTrait};
use std::time::Duration;
use wafrift_types::oob::{OobCanary, OobConfig, OobConfirmation};
pub struct OobOracle {
provider: Box<dyn OobProviderTrait>,
config: OobConfig,
}
impl OobOracle {
pub fn new(provider: Box<dyn OobProviderTrait>, config: OobConfig) -> Self {
Self { provider, config }
}
pub async fn confirm(
&self,
_payload: &str,
_payload_type: &str,
) -> Result<OobConfirmation, OobError> {
let canary = self.provider.register().await?;
let deadline = std::time::Instant::now()
+ Duration::from_secs(self.config.timeout_secs);
let interval = Duration::from_secs(self.config.poll_interval_secs.max(1));
loop {
match self.provider.poll(&canary).await {
Ok(interactions) if !interactions.is_empty() => {
return Ok(OobConfirmation::Confirmed);
}
Ok(_) => {} Err(_) => return Ok(OobConfirmation::Error),
}
if std::time::Instant::now() >= deadline {
return Ok(OobConfirmation::Timeout);
}
tokio::time::sleep(interval).await;
}
}
pub async fn confirm_background(
&self,
) -> Result<
(
OobCanary,
tokio::sync::mpsc::Receiver<OobConfirmation>,
),
OobError,
> {
let canary = self.provider.register().await?;
let (tx, rx) = tokio::sync::mpsc::channel(1);
let canary_clone = canary.clone();
let timeout = self.config.timeout_secs;
let interval_secs = self.config.poll_interval_secs.max(1);
let interactions = poll_until(
self.provider.as_ref(),
&canary_clone,
timeout,
interval_secs,
)
.await;
let _ = tx.send(interactions).await;
Ok((canary, rx))
}
}
async fn poll_until(
provider: &dyn OobProviderTrait,
canary: &OobCanary,
timeout_secs: u64,
interval_secs: u64,
) -> OobConfirmation {
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_secs);
let interval = Duration::from_secs(interval_secs.max(1));
loop {
match provider.poll(canary).await {
Ok(ints) if !ints.is_empty() => return OobConfirmation::Confirmed,
Ok(_) => {}
Err(_) => return OobConfirmation::Error,
}
if std::time::Instant::now() >= deadline {
return OobConfirmation::Timeout;
}
tokio::time::sleep(interval).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use uuid::Uuid;
use wafrift_types::oob::{OobInteraction, OobProvider};
#[derive(Debug)]
struct FakeProvider {
polls: AtomicUsize,
confirm_after: usize,
}
#[async_trait]
impl OobProviderTrait for FakeProvider {
async fn register(&self) -> Result<OobCanary, OobError> {
Ok(OobCanary {
id: Uuid::nil(),
expected_dns: "abc.fake.oast".into(),
expected_http_path: "/abc".into(),
created_at: None,
})
}
async fn poll(
&self,
_canary: &OobCanary,
) -> Result<Vec<OobInteraction>, OobError> {
let n = self.polls.fetch_add(1, Ordering::Relaxed);
if n >= self.confirm_after {
Ok(vec![OobInteraction::DnsQuery {
query: "abc.fake.oast".into(),
source_ip: "203.0.113.10".into(),
}])
} else {
Ok(Vec::new())
}
}
}
fn fast_config() -> OobConfig {
OobConfig {
provider: OobProvider::Interactsh {
server: "test".into(),
},
poll_interval_secs: 1,
timeout_secs: 5,
}
}
#[tokio::test]
async fn confirm_returns_confirmed_on_first_interaction() {
let provider = Box::new(FakeProvider {
polls: AtomicUsize::new(0),
confirm_after: 0, });
let oracle = OobOracle::new(provider, fast_config());
let result = oracle.confirm("' OR 1=1--", "Sql").await.unwrap();
assert_eq!(result, OobConfirmation::Confirmed);
}
#[tokio::test]
async fn confirm_times_out_when_no_interaction() {
let provider = Box::new(FakeProvider {
polls: AtomicUsize::new(0),
confirm_after: 100, });
let oracle = OobOracle::new(
provider,
OobConfig {
provider: OobProvider::Interactsh {
server: "test".into(),
},
poll_interval_secs: 1,
timeout_secs: 2, },
);
let result = oracle.confirm("benign", "Sql").await.unwrap();
assert_eq!(result, OobConfirmation::Timeout);
}
#[tokio::test]
async fn confirm_background_returns_canary_and_outcome() {
let provider = Box::new(FakeProvider {
polls: AtomicUsize::new(0),
confirm_after: 0,
});
let oracle = OobOracle::new(provider, fast_config());
let (canary, mut rx) = oracle.confirm_background().await.unwrap();
assert_eq!(canary.expected_dns, "abc.fake.oast");
let outcome = rx.recv().await.unwrap();
assert_eq!(outcome, OobConfirmation::Confirmed);
}
#[tokio::test]
async fn poll_counter_advances() {
let provider_arc = Arc::new(FakeProvider {
polls: AtomicUsize::new(0),
confirm_after: 2,
});
struct ArcProvider(Arc<FakeProvider>);
#[async_trait]
impl OobProviderTrait for ArcProvider {
async fn register(&self) -> Result<OobCanary, OobError> {
self.0.register().await
}
async fn poll(
&self,
c: &OobCanary,
) -> Result<Vec<OobInteraction>, OobError> {
self.0.poll(c).await
}
}
impl std::fmt::Debug for ArcProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
let oracle = OobOracle::new(
Box::new(ArcProvider(provider_arc.clone())),
fast_config(),
);
let result = oracle.confirm("x", "Sql").await.unwrap();
assert_eq!(result, OobConfirmation::Confirmed);
assert!(provider_arc.polls.load(Ordering::Relaxed) >= 3);
}
}