use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use aho_corasick::{AhoCorasick, MatchKind};
use bytes::Bytes;
use futures_core::Stream;
use crate::restore::RestoreError;
use crate::restore_core::process_safe_region;
use crate::types::{Entry, SessionKey};
pub struct RestoreStream<S> {
inner: S,
ac: Option<AhoCorasick>,
entries: Vec<Entry>,
session_key: SessionKey,
buffer: Vec<u8>,
max_hold: usize,
eof: bool,
pending: VecDeque<Result<Bytes, RestoreError>>,
done: bool,
}
#[must_use = "the stream must be polled to process data"]
pub fn restore_stream<S>(
inner: S,
entries: Vec<Entry>,
session_key: SessionKey,
) -> Result<RestoreStream<S>, RestoreError>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
let ac = if entries.is_empty() {
None
} else {
for (idx, e) in entries.iter().enumerate() {
if e.fake.is_empty() {
return Err(RestoreError::Build {
msg: format!("empty fake in entry at index {idx}"),
});
}
}
let fakes: Vec<&[u8]> = entries.iter().map(|e| e.fake.as_slice()).collect();
Some(
AhoCorasick::builder()
.match_kind(MatchKind::LeftmostFirst)
.build(&fakes)
.map_err(RestoreError::from)?,
)
};
let max_hold = entries.iter().map(|e| e.fake.len()).max().unwrap_or(0);
Ok(RestoreStream {
inner,
ac,
entries,
session_key,
buffer: Vec::new(),
max_hold,
eof: false,
pending: VecDeque::new(),
done: false,
})
}
impl<S> RestoreStream<S> {
fn process_buffer(&mut self) {
let Some(ac) = &self.ac else {
if !self.buffer.is_empty() {
self.pending
.push_back(Ok(Bytes::copy_from_slice(&self.buffer)));
self.buffer.clear();
}
return;
};
let pending = &mut self.pending;
let result = process_safe_region(
&mut self.buffer,
ac,
&self.entries,
&self.session_key,
self.eof,
self.max_hold,
&mut |bytes| {
pending.push_back(Ok(Bytes::copy_from_slice(bytes)));
Ok::<(), RestoreError>(())
},
&mut |entry_idx| RestoreError::AeadTagFailure {
entry_index: entry_idx,
},
);
if let Err(e) = result {
pending.push_back(Err(e));
}
}
}
impl<S> Stream for RestoreStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
type Item = Result<Bytes, RestoreError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.done {
return Poll::Ready(None);
}
loop {
if let Some(item) = this.pending.pop_front() {
if item.is_err() {
this.done = true;
}
return Poll::Ready(Some(item));
}
this.process_buffer();
if !this.pending.is_empty() {
continue;
}
if this.eof {
this.done = true;
return Poll::Ready(None);
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
this.buffer.extend_from_slice(&chunk);
}
Poll::Ready(Some(Err(e))) => {
this.done = true;
return Poll::Ready(Some(Err(RestoreError::Io(e))));
}
Poll::Ready(None) => {
this.eof = true;
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::patterns;
use crate::swap::swap;
use futures::StreamExt;
const ANTHROPIC_SECRET: &[u8] = b"sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAA";
async fn collect_stream<S>(stream: RestoreStream<S>) -> Result<Vec<u8>, RestoreError>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
let mut out = Vec::new();
futures::pin_mut!(stream);
while let Some(item) = stream.next().await {
out.extend_from_slice(&item?);
}
Ok(out)
}
fn chunked_stream(
data: Vec<u8>,
chunk_size: usize,
) -> impl Stream<Item = Result<Bytes, io::Error>> + Unpin {
let chunks: Vec<Result<Bytes, io::Error>> = data
.chunks(chunk_size)
.map(|c| Ok(Bytes::copy_from_slice(c)))
.collect();
futures::stream::iter(chunks)
}
fn single_chunk_stream(data: Vec<u8>) -> impl Stream<Item = Result<Bytes, io::Error>> + Unpin {
chunked_stream(data, usize::MAX)
}
struct PendingOnceStream {
state: u8,
chunks: std::collections::VecDeque<Bytes>,
}
impl PendingOnceStream {
fn new(data: Vec<u8>, chunk_size: usize) -> Self {
let chunks = data
.chunks(chunk_size)
.map(Bytes::copy_from_slice)
.collect();
Self { state: 0, chunks }
}
}
impl Stream for PendingOnceStream {
type Item = Result<Bytes, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.state {
0 => {
self.state = 1;
cx.waker().wake_by_ref();
Poll::Pending
}
_ => {
if let Some(chunk) = self.chunks.pop_front() {
Poll::Ready(Some(Ok(chunk)))
} else {
Poll::Ready(None)
}
}
}
}
}
impl Unpin for PendingOnceStream {}
#[test]
fn test_async_basic_roundtrip() {
let payload = [b"Authorization: ".as_slice(), ANTHROPIC_SECRET].concat();
let sr = swap(&payload, &[patterns::anthropic()]).unwrap();
let inner = single_chunk_stream(sr.payload);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(result, payload);
}
#[test]
fn test_async_chunk_boundary() {
let payload = [b"ctx: ".as_slice(), ANTHROPIC_SECRET, b" end"].concat();
let sr = swap(&payload, &[patterns::anthropic()]).unwrap();
let inner = chunked_stream(sr.payload, 1);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, payload,
"INV-4: single-byte chunks must restore correctly"
);
}
#[test]
fn test_async_no_fake_in_stream() {
let sr = swap(b"unrelated", &[patterns::anthropic()]).unwrap();
let response = b"response with no fakes in it";
let inner = single_chunk_stream(response.to_vec());
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, response,
"INV-7: passthrough must be byte-identical"
);
}
#[test]
fn test_async_empty_entries_passthrough() {
let sr = swap(b"payload", &[]).unwrap();
assert!(sr.entries.is_empty());
let response = b"any response bytes";
let inner = single_chunk_stream(response.to_vec());
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(result, response);
}
#[test]
fn test_async_tampered_aead_tag_returns_err() {
let payload = [b"Authorization: ".as_slice(), ANTHROPIC_SECRET].concat();
let mut sr = swap(&payload, &[patterns::anthropic()]).unwrap();
let last = sr.entries[0].ciphertext.len() - 1;
sr.entries[0].ciphertext[last] ^= 0xFF;
let swapped = sr.payload.clone();
let inner = single_chunk_stream(swapped);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(async {
let mut out: Vec<u8> = Vec::new();
futures::pin_mut!(stream);
loop {
match stream.next().await {
Some(Ok(chunk)) => out.extend_from_slice(&chunk),
Some(Err(e)) => return Err((e, out)),
None => return Ok(out),
}
}
});
let (err, out) = result.unwrap_err();
assert!(
matches!(err, RestoreError::AeadTagFailure { .. }),
"INV-6: must yield AeadTagFailure"
);
assert!(
!out.windows(ANTHROPIC_SECRET.len())
.any(|w| w == ANTHROPIC_SECRET),
"INV-6: no secret bytes emitted before tag failure"
);
}
#[test]
fn test_async_multiple_fakes_in_stream() {
let payload = [ANTHROPIC_SECRET, b" and again: ", ANTHROPIC_SECRET].concat();
let sr = swap(&payload, &[patterns::anthropic()]).unwrap();
assert_eq!(sr.entries.len(), 1);
let inner = single_chunk_stream(sr.payload);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(result, payload, "both occurrences must be restored");
}
#[test]
fn test_async_exact_matching_only() {
let real_key = b"sk-ant-api03-BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB-BBBBBB";
let sr = swap(b"unrelated payload", &[patterns::anthropic()]).unwrap();
let inner = single_chunk_stream(real_key.to_vec());
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, real_key,
"INV-19: real key must pass through unchanged"
);
}
#[test]
fn test_async_empty_stream() {
let sr = swap(b"payload", &[patterns::anthropic()]).unwrap();
let inner = futures::stream::empty::<Result<Bytes, io::Error>>();
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert!(result.is_empty(), "empty stream must yield no bytes");
}
#[test]
fn test_async_fake_exactly_at_chunk_boundary() {
let payload = [b"prefix ".as_slice(), ANTHROPIC_SECRET, b" suffix"].concat();
let sr = swap(&payload, &[patterns::anthropic()]).unwrap();
let fake_len = sr.entries[0].fake.len();
let split_at = 7 + fake_len / 2; let inner = chunked_stream(sr.payload.clone(), split_at.max(1));
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, payload,
"fake split exactly at chunk boundary must restore"
);
}
#[test]
fn test_empty_fake_rejected_stream() {
use crate::types::{Entry, SessionKey};
let bad_entry = Entry {
fake: vec![],
ciphertext: vec![0u8; 32],
nonce: vec![0u8; 24],
};
let session_key = SessionKey::from_bytes([0u8; 32]);
let inner = futures::stream::empty::<Result<Bytes, io::Error>>();
let result = restore_stream(inner, vec![bad_entry], session_key);
assert!(
matches!(result, Err(RestoreError::Build { .. })),
"empty fake must return Err(Build)"
);
}
#[test]
fn test_async_two_distinct_fakes() {
let openai_key: Vec<u8> = {
let mut k = b"sk-proj-".to_vec();
k.extend(std::iter::repeat_n(b'A', 58));
k.extend_from_slice(b"T3BlbkFJ");
k.extend(std::iter::repeat_n(b'B', 58));
k
};
let payload = [
b"anthropic: ".as_slice(),
ANTHROPIC_SECRET,
b" openai: ",
openai_key.as_slice(),
]
.concat();
let sr = swap(
&payload,
&[patterns::anthropic(), patterns::openai_project()],
)
.unwrap();
assert_eq!(sr.entries.len(), 2, "must detect two distinct secrets");
let inner = single_chunk_stream(sr.payload);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, payload,
"both distinct fakes must restore correctly"
);
}
#[test]
fn test_async_pending_then_ready() {
let payload = [b"prefix ".as_slice(), ANTHROPIC_SECRET, b" suffix"].concat();
let sr = swap(&payload, &[patterns::anthropic()]).unwrap();
let inner = PendingOnceStream::new(sr.payload.clone(), 32);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, payload,
"Pending → Ready path must restore correctly"
);
}
#[test]
fn test_async_registered_secret_restoration() {
let secret = b"my-custom-tier2-api-token-that-is-long-enough-for-registration-abcd1234";
let pattern = crate::register(secret).expect("register failed");
let payload = [b"Bearer ".as_slice(), secret, b" end"].concat();
let sr = swap(&payload, &[pattern]).unwrap();
assert_eq!(
sr.entries.len(),
1,
"Registered secret swap must produce one entry"
);
let inner = single_chunk_stream(sr.payload);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, payload,
"registered secret fake must restore correctly"
);
}
#[test]
fn test_async_empty_entries_multichunk() {
let response = b"hello world this is a multi-chunk passthrough test";
let sr = swap(b"unrelated", &[]).unwrap();
assert!(sr.entries.is_empty());
let inner = chunked_stream(response.to_vec(), 4);
let stream = restore_stream(inner, sr.entries, sr.session_key).unwrap();
let result = futures::executor::block_on(collect_stream(stream)).unwrap();
assert_eq!(
result, response,
"multi-chunk passthrough with no entries must be byte-identical"
);
}
}