use std::io::Read;
use std::path::Path;
use std::time::Duration;
use keyhog_core::{Chunk, ChunkMetadata, Source, SourceError};
use reqwest::blocking::Client;
mod auth;
mod listing;
use auth::AwsSigV4Config;
use listing::{encode_s3_key_path, parse_s3_listing};
const DEFAULT_S3_HOST_SUFFIX: &str = "s3.amazonaws.com";
const S3_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_OBJECTS: usize = 100_000;
const MAX_S3_OBJECT_BYTES: u64 = 10 * 1024 * 1024;
pub struct S3Source {
bucket: String,
prefix: Option<String>,
endpoint: Option<String>,
max_objects: usize,
}
impl S3Source {
pub fn new(bucket: impl Into<String>) -> Self {
Self {
bucket: bucket.into(),
prefix: None,
endpoint: None,
max_objects: DEFAULT_MAX_OBJECTS,
}
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = Some(prefix.into());
self
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
pub fn with_max_objects(mut self, max_objects: usize) -> Self {
self.max_objects = max_objects;
self
}
}
impl Source for S3Source {
fn name(&self) -> &str {
"s3"
}
fn chunks(&self) -> Box<dyn Iterator<Item = Result<Chunk, SourceError>> + '_> {
match collect_s3_chunks(
&self.bucket,
self.prefix.as_deref(),
self.endpoint.as_deref(),
self.max_objects,
) {
Ok(chunks) => Box::new(chunks.into_iter().map(Ok)),
Err(error) => Box::new(std::iter::once(Err(error))),
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn collect_s3_chunks(
bucket: &str,
prefix: Option<&str>,
endpoint: Option<&str>,
max_objects: usize,
) -> Result<Vec<Chunk>, SourceError> {
let bucket = validate_bucket_name(bucket)?;
let client = Client::builder()
.timeout(S3_REQUEST_TIMEOUT)
.build()
.map_err(|e| SourceError::Other(format!("failed to build S3 client: {e}")))?;
let base_url = build_base_url(&bucket, endpoint)?;
let aws_auth = AwsSigV4Config::from_env(&base_url);
let mut continuation_token = None::<String>;
let mut chunks = Vec::new();
let mut listed_objects = 0usize;
loop {
if listed_objects >= max_objects {
break;
}
let mut request = client.get(&base_url).query(&[("list-type", "2")]);
if let Some(prefix) = prefix {
request = request.query(&[("prefix", prefix)]);
}
if let Some(token) = continuation_token.as_deref() {
request = request.query(&[("continuation-token", token)]);
}
if let Some(auth) = aws_auth.as_ref() {
request = auth.sign(request, &base_url)?;
}
let response = request
.send()
.map_err(|e| SourceError::Other(format!("failed to list S3 objects: {e}")))?;
if !response.status().is_success() {
return Err(SourceError::Other(format!(
"failed to list S3 objects: bucket request returned {}",
response.status()
)));
}
let body = response
.text()
.map_err(|e| SourceError::Other(format!("failed to read S3 listing: {e}")))?;
let listing = parse_s3_listing(&body)?;
let remaining = max_objects.saturating_sub(listed_objects);
let reached_limit = listing.contents.len() > remaining;
for object in listing.contents.into_iter().take(remaining) {
listed_objects += 1;
if object.size == 0 || !is_probably_text(&object.key) {
continue;
}
if let Some(chunk) = fetch_object_chunk(
&client,
&base_url,
&bucket,
&object.key,
object.size,
aws_auth.as_ref(),
)? {
chunks.push(chunk);
}
}
if reached_limit || !listing.is_truncated {
break;
}
continuation_token = listing.next_continuation_token;
if continuation_token.is_none() {
break;
}
}
Ok(chunks)
}
fn fetch_object_chunk(
client: &Client,
base_url: &str,
bucket: &str,
key: &str,
object_size: u64,
aws_auth: Option<&AwsSigV4Config>,
) -> Result<Option<Chunk>, SourceError> {
if object_size > MAX_S3_OBJECT_BYTES {
tracing::debug!(
"failed to read S3 object: {}/{} exceeds {} byte limit with {} bytes",
bucket,
key,
MAX_S3_OBJECT_BYTES,
object_size
);
return Ok(None);
}
let encoded_key = encode_s3_key_path(key);
let url = format!("{}/{}", base_url.trim_end_matches('/'), encoded_key);
let request = client.get(&url);
let request = if let Some(auth) = aws_auth {
auth.sign(request, &url)?
} else {
request
};
let response = request
.send()
.map_err(|e| SourceError::Other(format!("failed to download S3 object: {key}: {e}")))?;
if !response.status().is_success() {
return Ok(None);
}
if let Some(content_length) = response.content_length()
&& content_length > MAX_S3_OBJECT_BYTES
{
tracing::debug!(
"failed to read S3 object: {}/{} content-length {} exceeds {} byte limit",
bucket,
key,
content_length,
MAX_S3_OBJECT_BYTES
);
return Ok(None);
}
if let Some(ct) = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
{
let ct_lower = ct.to_ascii_lowercase();
if ct_lower.starts_with("image/")
|| ct_lower.starts_with("audio/")
|| ct_lower.starts_with("video/")
|| ct_lower == "application/octet-stream"
|| ct_lower == "application/zip"
|| ct_lower == "application/gzip"
{
tracing::debug!("skipping S3 object {key}: binary content-type {ct}");
return Ok(None);
}
}
let mut body = Vec::new();
let mut reader = response.take(MAX_S3_OBJECT_BYTES + 1);
std::io::Read::read_to_end(&mut reader, &mut body)
.map_err(|e| SourceError::Other(format!("failed to read S3 object body: {key}: {e}")))?;
if body.len() as u64 > MAX_S3_OBJECT_BYTES {
tracing::debug!(
"failed to read S3 object: {}/{} downloaded size exceeds {} byte limit",
bucket,
key,
MAX_S3_OBJECT_BYTES
);
return Ok(None);
}
let object_text = match String::from_utf8(body) {
Ok(text) => text,
Err(_) => return Ok(None),
};
Ok(Some(Chunk {
data: object_text,
metadata: ChunkMetadata {
source_type: "s3".into(),
path: Some(format!("{bucket}/{key}")),
commit: None,
author: None,
date: None,
},
}))
}
fn build_base_url(bucket: &str, endpoint: Option<&str>) -> Result<String, SourceError> {
match endpoint {
Some(endpoint) => {
let endpoint = validate_endpoint(endpoint)?;
Ok(format!(
"{}/{}",
endpoint.trim_end_matches('/'),
urlencoding::encode(bucket)
))
}
None => Ok(format!("https://{bucket}.{DEFAULT_S3_HOST_SUFFIX}")),
}
}
fn validate_bucket_name(bucket: &str) -> Result<String, SourceError> {
let bucket = bucket.trim();
if bucket.len() < 3 || bucket.len() > 63 {
return Err(SourceError::Other("invalid S3 bucket name length".into()));
}
if bucket.starts_with('.')
|| bucket.ends_with('.')
|| bucket.starts_with('-')
|| bucket.ends_with('-')
|| bucket.contains("..")
|| bucket.contains('/')
|| bucket.chars().any(char::is_control)
{
return Err(SourceError::Other(format!("invalid S3 bucket '{bucket}'")));
}
if !bucket
.chars()
.all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || matches!(ch, '.' | '-'))
{
return Err(SourceError::Other(format!("invalid S3 bucket '{bucket}'")));
}
Ok(bucket.to_string())
}
fn validate_endpoint(endpoint: &str) -> Result<String, SourceError> {
let endpoint = endpoint.trim();
let parsed = reqwest::Url::parse(endpoint)
.map_err(|e| SourceError::Other(format!("invalid S3 endpoint: {e}")))?;
if !matches!(parsed.scheme(), "http" | "https")
|| parsed.host_str().is_none()
|| !parsed.username().is_empty()
|| parsed.password().is_some()
|| parsed.query().is_some()
|| parsed.fragment().is_some()
{
return Err(SourceError::Other("invalid S3 endpoint".into()));
}
Ok(parsed.to_string().trim_end_matches('/').to_string())
}
fn is_probably_text(key: &str) -> bool {
let ext = Path::new(key)
.extension()
.and_then(|value| value.to_str())
.map(|value| value.to_ascii_lowercase());
!matches!(
ext.as_deref(),
Some(
"png"
| "jpg"
| "jpeg"
| "gif"
| "webp"
| "zip"
| "gz"
| "tgz"
| "tar"
| "7z"
| "pdf"
| "woff"
| "woff2"
| "mp3"
| "mp4"
| "mov"
| "dll"
| "so"
| "dylib"
)
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn s3_source_defaults_to_max_objects_limit() {
let source = S3Source::new("bucket");
assert_eq!(source.max_objects, DEFAULT_MAX_OBJECTS);
}
#[test]
fn s3_source_allows_custom_max_objects_limit() {
let source = S3Source::new("bucket").with_max_objects(42);
assert_eq!(source.max_objects, 42);
}
#[test]
fn default_base_url_uses_virtual_host_style() {
assert_eq!(
build_base_url("acme-secrets", None).unwrap(),
"https://acme-secrets.s3.amazonaws.com"
);
}
#[test]
fn custom_endpoint_uses_path_style() {
assert_eq!(
build_base_url("acme-secrets", Some("https://minio.internal")).unwrap(),
"https://minio.internal/acme-secrets"
);
}
#[test]
fn rejects_invalid_custom_endpoint() {
assert!(build_base_url("acme-secrets", Some("https://user:pass@minio.internal")).is_err());
assert!(build_base_url("acme-secrets", Some("ftp://minio.internal")).is_err());
}
#[test]
fn rejects_invalid_bucket_names() {
assert!(validate_bucket_name("../escape").is_err());
assert!(validate_bucket_name("UPPERCASE").is_err());
assert!(validate_bucket_name("ok-bucket").is_ok());
}
#[test]
fn s3_key_encoding_preserves_path_separators() {
assert_eq!(
encode_s3_key_path("folder/my file.txt"),
"folder/my%20file.txt"
);
}
#[test]
fn oversized_s3_objects_are_skipped_before_download() {
let client = reqwest::blocking::Client::builder().build().unwrap();
let fetched_chunk = fetch_object_chunk(
&client,
"https://example.invalid/bucket",
"bucket",
"huge.txt",
MAX_S3_OBJECT_BYTES + 1,
None,
)
.unwrap();
assert!(fetched_chunk.is_none());
}
#[test]
fn rejects_s3_xml_with_doctype() {
let err = parse_s3_listing(
r#"<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE ListBucketResult [
<!ENTITY xxe SYSTEM "file:///etc/passwd">
]>
<ListBucketResult></ListBucketResult>"#,
)
.unwrap_err();
assert!(err.to_string().contains("DTD/entity"));
}
#[test]
fn rejects_s3_xml_with_entity_declaration_marker() {
let err = parse_s3_listing(
r#"<?xml version="1.0" encoding="UTF-8"?>
<ListBucketResult>
<!ENTITY xxe "boom">
</ListBucketResult>"#,
)
.unwrap_err();
assert!(err.to_string().contains("DTD/entity"));
}
}