use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::Mutex;
use reqwest::Client;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
use super::client::{Interaction, InteractionProtocol, InteractshClient};
use super::InteractshError;
pub fn redact_interactsh_error(e: &InteractshError) -> String {
match e {
InteractshError::Transport(req_err) => {
let kind = if req_err.is_connect() {
"connect"
} else if req_err.is_timeout() {
"timeout"
} else if req_err.is_request() {
"request"
} else if req_err.is_body() {
"body"
} else if req_err.is_decode() {
"decode"
} else if req_err.is_status() {
"status"
} else {
"transport"
};
format!("interactsh transport error: kind={kind} (url redacted)")
}
other => format!("{other}"),
}
}
#[derive(Debug, Clone)]
pub struct OobConfig {
pub server: String,
pub default_timeout: Duration,
pub max_timeout: Duration,
pub poll_interval: Duration,
pub max_observation_age: Duration,
}
impl Default for OobConfig {
fn default() -> Self {
Self {
server: "oast.fun".to_string(),
default_timeout: Duration::from_secs(30),
max_timeout: Duration::from_secs(120),
poll_interval: Duration::from_secs(2),
max_observation_age: Duration::from_secs(600),
}
}
}
#[derive(Debug, Clone)]
pub enum OobObservation {
Observed {
protocol: InteractionProtocol,
remote_address: String,
timestamp: String,
raw_payload: String,
},
NotObserved,
Disabled(String),
}
struct StoredInteraction {
interaction: Interaction,
received_at: Instant,
}
pub struct OobSession {
client: Arc<InteractshClient>,
config: OobConfig,
observations: Arc<DashMap<String, Vec<StoredInteraction>>>,
waiters: Arc<Mutex<HashMap<String, Arc<Notify>>>>,
poller_handle: Mutex<Option<JoinHandle<()>>>,
shutdown: Arc<AtomicBool>,
}
impl OobSession {
pub async fn start(
http: Client,
config: OobConfig,
) -> Result<Arc<Self>, super::InteractshError> {
let client = InteractshClient::register(http, &config.server).await?;
let client = Arc::new(client);
info!(
target: "keyhog::oob",
correlation_id = %client.correlation_id(),
server = %config.server,
"OOB verification enabled"
);
let session = Arc::new(Self {
client: Arc::clone(&client),
config: config.clone(),
observations: Arc::new(DashMap::new()),
waiters: Arc::new(Mutex::new(HashMap::new())),
poller_handle: Mutex::new(None),
shutdown: Arc::new(AtomicBool::new(false)),
});
let handle = spawn_poller(Arc::clone(&session));
*session.poller_handle.lock() = Some(handle);
Ok(session)
}
pub fn mint(&self) -> super::client::MintedUrl {
self.client.mint_url()
}
pub fn config_default_timeout(&self) -> Duration {
self.config.default_timeout
}
pub async fn wait_for(
&self,
unique_id: &str,
accepts: OobAccept,
timeout: Duration,
) -> OobObservation {
if self.shutdown.load(Ordering::Acquire) {
return OobObservation::Disabled("session shut down".into());
}
let timeout = timeout.min(self.config.max_timeout);
if let Some(obs) = self.peek_match(unique_id, accepts) {
return obs;
}
let notify = {
let mut waiters = self.waiters.lock();
waiters
.entry(unique_id.to_string())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
};
let deadline = Instant::now() + timeout;
loop {
if self.shutdown.load(Ordering::Acquire) {
self.waiters.lock().remove(unique_id);
return OobObservation::Disabled("session shut down".into());
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
self.waiters.lock().remove(unique_id);
return OobObservation::NotObserved;
}
let mut notified = std::pin::pin!(notify.notified());
notified.as_mut().enable();
if let Some(obs) = self.peek_match(unique_id, accepts) {
self.waiters.lock().remove(unique_id);
return obs;
}
let woken = tokio::time::timeout(remaining, notified.as_mut()).await;
if let Some(obs) = self.peek_match(unique_id, accepts) {
self.waiters.lock().remove(unique_id);
return obs;
}
if woken.is_err() {
self.waiters.lock().remove(unique_id);
return OobObservation::NotObserved;
}
}
}
pub async fn shutdown(self: &Arc<Self>) {
if self.shutdown.swap(true, Ordering::AcqRel) {
return;
}
self.wake_all_waiters();
let handle = self.poller_handle.lock().take();
if let Some(h) = handle {
h.abort();
let _ = h.await;
}
if let Err(e) = self.client.deregister().await {
debug!(target: "keyhog::oob", error = %e, "deregister failed (non-fatal)");
}
}
pub fn abort_poller_for_drop(&self) {
if self.shutdown.swap(true, Ordering::AcqRel) {
return;
}
self.wake_all_waiters();
if let Some(h) = self.poller_handle.lock().take() {
h.abort();
}
}
fn wake_all_waiters(&self) {
let drained: Vec<Arc<Notify>> = {
let mut waiters = self.waiters.lock();
waiters.drain().map(|(_, n)| n).collect()
};
for notify in drained {
notify.notify_waiters();
}
}
fn peek_match(&self, unique_id: &str, accepts: OobAccept) -> Option<OobObservation> {
let entries = self.observations.get(unique_id)?;
let stored = entries
.iter()
.find(|s| accepts.matches(s.interaction.protocol))?;
Some(OobObservation::Observed {
protocol: stored.interaction.protocol,
remote_address: stored.interaction.remote_address.clone(),
timestamp: stored.interaction.timestamp.clone(),
raw_payload: stored.interaction.raw_payload.clone(),
})
}
fn store_and_notify(&self, interaction: Interaction) {
let id = interaction.unique_id.clone();
let stored = StoredInteraction {
interaction,
received_at: Instant::now(),
};
self.observations
.entry(id.clone())
.or_default()
.push(stored);
if let Some(notify) = self.waiters.lock().get(&id) {
notify.notify_waiters();
}
}
fn gc(&self) {
let cutoff = Instant::now()
.checked_sub(self.config.max_observation_age)
.unwrap_or_else(Instant::now);
self.observations.retain(|_, entries| {
entries.retain(|stored| stored.received_at >= cutoff);
!entries.is_empty()
});
}
pub fn for_test(client: Arc<InteractshClient>, config: OobConfig) -> Arc<Self> {
Arc::new(Self {
client,
config,
observations: Arc::new(DashMap::new()),
waiters: Arc::new(Mutex::new(HashMap::new())),
poller_handle: Mutex::new(None),
shutdown: Arc::new(AtomicBool::new(false)),
})
}
pub fn store_and_notify_for_test(&self, interaction: super::client::Interaction) {
self.store_and_notify(interaction);
}
}
#[derive(Debug, Clone, Copy)]
pub enum OobAccept {
Dns,
Http,
Smtp,
Any,
}
impl OobAccept {
pub fn matches(self, p: InteractionProtocol) -> bool {
matches!(
(self, p),
(Self::Any, _)
| (Self::Dns, InteractionProtocol::Dns)
| (Self::Http, InteractionProtocol::Http)
| (Self::Smtp, InteractionProtocol::Smtp)
)
}
}
impl From<keyhog_core::OobProtocol> for OobAccept {
fn from(p: keyhog_core::OobProtocol) -> Self {
match p {
keyhog_core::OobProtocol::Dns => Self::Dns,
keyhog_core::OobProtocol::Http => Self::Http,
keyhog_core::OobProtocol::Smtp => Self::Smtp,
keyhog_core::OobProtocol::Any => Self::Any,
}
}
}
fn spawn_poller(session: Arc<OobSession>) -> JoinHandle<()> {
tokio::spawn(async move {
let mut consecutive_errors = 0u32;
let mut next_gc = Instant::now() + Duration::from_secs(60);
loop {
if session.shutdown.load(Ordering::Acquire) {
break;
}
match session.client.poll().await {
Ok(interactions) => {
consecutive_errors = 0;
for interaction in interactions {
session.store_and_notify(interaction);
}
}
Err(e) => {
consecutive_errors += 1;
let backoff_secs = (1u64 << consecutive_errors.min(5)).min(30);
let redacted = redact_interactsh_error(&e);
warn!(
target: "keyhog::oob",
error = %redacted,
consecutive_errors,
backoff_secs,
"interactsh poll failed; backing off"
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
}
if Instant::now() >= next_gc {
session.gc();
next_gc = Instant::now() + Duration::from_secs(60);
}
tokio::time::sleep(session.config.poll_interval).await;
}
debug!(target: "keyhog::oob", "poller exiting");
})
}