use std::io::{Read, Write};
use aho_corasick::{AhoCorasick, MatchKind};
use crate::restore_core::process_safe_region;
use crate::types::{Entry, SessionKey};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RestoreError {
#[error("AEAD tag verification failed for entry {entry_index}")]
AeadTagFailure {
entry_index: usize,
},
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("failed to build Aho-Corasick automaton: {msg}")]
Build {
msg: String,
},
}
impl From<aho_corasick::BuildError> for RestoreError {
fn from(e: aho_corasick::BuildError) -> Self {
RestoreError::Build { msg: e.to_string() }
}
}
const CHUNK_SIZE: usize = 4096;
pub fn restore<R: Read, W: Write>(
input: &mut R,
output: &mut W,
entries: &[Entry],
session_key: &SessionKey,
) -> Result<(), RestoreError> {
if entries.is_empty() {
let mut buf = [0u8; CHUNK_SIZE];
loop {
let n = input.read(&mut buf)?;
if n == 0 {
break;
}
output.write_all(&buf[..n])?;
}
return Ok(());
}
for (idx, e) in entries.iter().enumerate() {
if e.fake.is_empty() {
return Err(RestoreError::Build {
msg: format!("empty fake in entry at index {idx}"),
});
}
}
let fakes: Vec<&[u8]> = entries.iter().map(|e| e.fake.as_slice()).collect();
let ac = AhoCorasick::builder()
.match_kind(MatchKind::LeftmostFirst)
.build(&fakes)
.map_err(RestoreError::from)?;
let max_hold: usize = entries.iter().map(|e| e.fake.len()).max().unwrap_or(0);
let mut buffer: Vec<u8> = Vec::new();
let mut chunk = vec![0u8; CHUNK_SIZE];
loop {
let n = input.read(&mut chunk)?;
let eof = n == 0;
if !eof {
buffer.extend_from_slice(&chunk[..n]);
}
process_safe_region(
&mut buffer,
&ac,
entries,
session_key,
eof,
max_hold,
&mut |bytes| output.write_all(bytes).map_err(RestoreError::Io),
&mut |entry_idx| RestoreError::AeadTagFailure {
entry_index: entry_idx,
},
)?;
if eof {
break;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::patterns;
use crate::swap::swap;
#[test]
fn test_restore_basic_roundtrip() {
let secret = b"sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAA";
let payload = [b"Authorization: ".as_slice(), secret].concat();
let swap_result = swap(&payload, &[patterns::anthropic()]).expect("swap failed");
let mut input = swap_result.payload.as_slice();
let mut output = Vec::new();
restore(
&mut input,
&mut output,
&swap_result.entries,
&swap_result.session_key,
)
.unwrap();
assert_eq!(output, payload, "restore must restore original payload");
}
#[test]
fn test_restore_chunk_boundary() {
let secret = b"sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAA";
let payload = [b"ctx: ".as_slice(), secret, b" end"].concat();
let swap_result = swap(&payload, &[patterns::anthropic()]).expect("swap failed");
struct OneByteReader<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> Read for OneByteReader<'a> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.pos >= self.data.len() {
return Ok(0);
}
buf[0] = self.data[self.pos];
self.pos += 1;
Ok(1)
}
}
let mut reader = OneByteReader {
data: &swap_result.payload,
pos: 0,
};
let mut output = Vec::new();
restore(
&mut reader,
&mut output,
&swap_result.entries,
&swap_result.session_key,
)
.unwrap();
assert_eq!(output, payload, "INV-4: chunk-boundary restoration failed");
}
#[test]
fn test_restore_no_fake_in_stream() {
let payload = b"no secrets here";
let swap_result = swap(payload, &[patterns::anthropic()]).expect("swap failed");
let response = b"response with no fakes in it";
let mut input = response.as_slice();
let mut output = Vec::new();
let result = restore(
&mut input,
&mut output,
&swap_result.entries,
&swap_result.session_key,
);
assert!(result.is_ok(), "INV-7: no error when no fake present");
assert_eq!(
output, response,
"INV-7: output byte-for-byte identical to input"
);
}
#[test]
fn test_restore_tampered_aead_tag_returns_err() {
let secret = b"sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAA";
let payload = [b"Authorization: ".as_slice(), secret].concat();
let mut swap_result = swap(&payload, &[patterns::anthropic()]).expect("swap failed");
let entry = &mut swap_result.entries[0];
let last = entry.ciphertext.len() - 1;
entry.ciphertext[last] ^= 0xFF;
let mut input = swap_result.payload.as_slice();
let mut output = Vec::new();
let result = restore(
&mut input,
&mut output,
&swap_result.entries,
&swap_result.session_key,
);
assert!(result.is_err(), "INV-6: tampered tag must return Err");
assert!(
!output.windows(secret.len()).any(|w| w == secret),
"INV-6: no secret bytes in output after tag failure"
);
}
#[test]
fn test_restore_exact_matching_only() {
let real_key_in_response = b"sk-ant-api03-BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB-BBBBBB";
let swap_result =
swap(b"unrelated payload", &[patterns::anthropic()]).expect("swap failed");
let mut input = real_key_in_response.as_slice();
let mut output = Vec::new();
restore(
&mut input,
&mut output,
&swap_result.entries,
&swap_result.session_key,
)
.unwrap();
assert_eq!(
output, real_key_in_response,
"INV-19: real key in response must pass through unchanged"
);
}
#[test]
fn test_empty_fake_rejected_sync() {
use crate::types::{Entry, SessionKey};
let bad_entry = Entry {
fake: vec![],
ciphertext: vec![0u8; 32],
nonce: vec![0u8; 24],
};
let session_key = SessionKey::from_bytes([0u8; 32]);
let mut input = b"some input".as_slice();
let mut output = Vec::new();
let result = restore(&mut input, &mut output, &[bad_entry], &session_key);
assert!(
matches!(result, Err(RestoreError::Build { .. })),
"empty fake must return Err(Build)"
);
assert!(
output.is_empty(),
"guard must fire before any bytes are written"
);
}
#[test]
fn test_restore_empty_input() {
let secret = b"sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAA";
let payload = [b"Authorization: ".as_slice(), secret].concat();
let sr = swap(&payload, &[patterns::anthropic()]).unwrap();
assert!(
!sr.entries.is_empty(),
"entries must be present for this test to be meaningful"
);
let mut input = b"".as_slice();
let mut output = Vec::new();
restore(&mut input, &mut output, &sr.entries, &sr.session_key).unwrap();
assert!(output.is_empty(), "empty input must produce empty output");
}
#[test]
fn test_restore_two_distinct_secrets() {
let anthropic_key = b"sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAA";
let openai_key: Vec<u8> = {
let mut k = b"sk-proj-".to_vec();
k.extend(std::iter::repeat_n(b'A', 58));
k.extend_from_slice(b"T3BlbkFJ");
k.extend(std::iter::repeat_n(b'B', 58));
k
};
let payload = [
b"anthropic: ".as_slice(),
anthropic_key,
b" openai: ",
openai_key.as_slice(),
]
.concat();
let sr = swap(
&payload,
&[patterns::anthropic(), patterns::openai_project()],
)
.unwrap();
assert_eq!(sr.entries.len(), 2, "must detect two distinct secrets");
let mut input = sr.payload.as_slice();
let mut output = Vec::new();
restore(&mut input, &mut output, &sr.entries, &sr.session_key).unwrap();
assert_eq!(output, payload, "both distinct secrets must be restored");
}
#[test]
fn test_restore_registered_roundtrip() {
let secret = b"my-custom-tier2-api-token-that-is-long-enough-for-registration-abcd1234";
let pattern = crate::register(secret).expect("register failed");
let payload = [b"Bearer ".as_slice(), secret, b" end"].concat();
let sr = swap(&payload, &[pattern]).unwrap();
assert_eq!(
sr.entries.len(),
1,
"registered swap must produce one entry"
);
let mut input = sr.payload.as_slice();
let mut output = Vec::new();
restore(&mut input, &mut output, &sr.entries, &sr.session_key).unwrap();
assert_eq!(
output, payload,
"registered secret must be restored correctly"
);
}
}