use anyhow::{Context, Result};
use log::debug;
use memmap2::MmapOptions;
use once_cell::sync::Lazy;
use regex::Regex;
use std::fs::File as StdFile;
use std::path::Component;
use std::path::{Path, PathBuf};
use std::str;
use tokio::fs as tokio_fs;
use tokio::io::AsyncWriteExt;
static FENCE_OPEN_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?m)^(`{3,})([a-zA-Z0-9]*)$").expect("Invalid fence open regex"));
#[derive(Debug, Clone)]
struct HeaderSection {
path: String,
header_start: usize,
block_start: usize,
}
pub async fn extract_from_markdown(
md_path: &PathBuf,
extract_root: Option<&PathBuf>,
) -> Result<()> {
let file = StdFile::open(md_path)
.with_context(|| format!("Failed to open markdown file: {}", md_path.display()))?;
let mmap = unsafe {
MmapOptions::new()
.map(&file)
.with_context(|| format!("Failed to memory-map markdown file: {}", md_path.display()))?
};
let content = str::from_utf8(&mmap)
.with_context(|| format!("Markdown file is not valid UTF-8: {}", md_path.display()))?;
let headers = find_top_level_headers(content);
let mut extracted_count = 0;
for (idx, header) in headers.iter().enumerate() {
let file_path_str = header.path.as_str();
let out_path = build_output_path(file_path_str, extract_root)
.with_context(|| format!("Invalid restore path in header: {file_path_str:?}"))?;
let start = header.block_start;
let end = headers
.get(idx + 1)
.map(|next| next.header_start)
.unwrap_or(content.len());
let block = &content[start..end];
if block.contains("(binary file omitted)") {
debug!("Skipping binary file: {}", out_path.display());
continue;
}
if let Some(code) = extract_fenced_code(block) {
debug!("Extracting: {} ({} bytes)", out_path.display(), code.len());
if let Some(parent) = out_path.parent() {
tokio_fs::create_dir_all(parent)
.await
.with_context(|| format!("Failed to create directory: {}", parent.display()))?;
}
let mut file = tokio_fs::File::create(&out_path)
.await
.with_context(|| format!("Failed to create file: {}", out_path.display()))?;
file.write_all(code.as_bytes())
.await
.with_context(|| format!("Failed to write content to: {}", out_path.display()))?;
file.sync_all()
.await
.with_context(|| format!("Failed to sync file: {}", out_path.display()))?;
extracted_count += 1;
}
}
debug!("Extracted {} files", extracted_count);
Ok(())
}
fn build_output_path(file_path_str: &str, extract_root: Option<&PathBuf>) -> Result<PathBuf> {
let raw_path = Path::new(file_path_str);
let rel_path = raw_path.strip_prefix("/").unwrap_or(raw_path);
validate_relative_restore_path(rel_path)?;
Ok(if let Some(root) = extract_root {
root.join(rel_path)
} else {
rel_path.to_path_buf()
})
}
fn validate_relative_restore_path(path: &Path) -> Result<()> {
if path.as_os_str().is_empty() {
anyhow::bail!("empty path");
}
let mut depth = 0usize;
for component in path.components() {
match component {
Component::Normal(_) => depth += 1,
Component::CurDir => {}
Component::ParentDir => {
if depth == 0 {
anyhow::bail!("path traversal is not allowed: {}", path.display());
}
depth -= 1;
}
Component::RootDir | Component::Prefix(_) => {
anyhow::bail!("absolute path is not allowed: {}", path.display());
}
}
}
Ok(())
}
fn find_top_level_headers(content: &str) -> Vec<HeaderSection> {
let mut headers = Vec::new();
let mut in_fence: Option<usize> = None;
let mut offset = 0usize;
for line in content.split_inclusive('\n') {
let line_no_newline = line.strip_suffix('\n').unwrap_or(line);
let normalized = line_no_newline
.strip_suffix('\r')
.unwrap_or(line_no_newline);
if let Some(fence_len) = in_fence {
if is_closing_fence(normalized, fence_len) {
in_fence = None;
}
} else if let Some(path) = parse_header_path(normalized) {
headers.push(HeaderSection {
path: path.to_string(),
header_start: offset,
block_start: offset + line.len(),
});
} else if let Some(fence_len) = parse_fence_len(normalized) {
in_fence = Some(fence_len);
}
offset += line.len();
}
headers
}
fn parse_header_path(line: &str) -> Option<&str> {
line.strip_prefix("## ").filter(|path| !path.is_empty())
}
fn parse_fence_len(line: &str) -> Option<usize> {
let fence_len = line.chars().take_while(|&c| c == '`').count();
if fence_len < 3 {
return None;
}
if line[fence_len..].contains('`') {
return None;
}
Some(fence_len)
}
fn is_closing_fence(line: &str, fence_len: usize) -> bool {
line.len() == fence_len && line.bytes().all(|b| b == b'`')
}
fn extract_fenced_code(block: &str) -> Option<String> {
let open_match = FENCE_OPEN_RE.find(block)?;
let open_caps = FENCE_OPEN_RE.captures(open_match.as_str())?;
let fence_backticks = open_caps.get(1)?.as_str();
let fence_len = fence_backticks.len();
let content_start = open_match.end() + 1;
if content_start >= block.len() {
return None;
}
let remaining = &block[content_start..];
let close_pattern = format!(r"(?m)^`{{{fence_len}}}$");
let close_re = Regex::new(&close_pattern).ok()?;
let close_match = close_re.find(remaining)?;
let code = &remaining[..close_match.start()];
let code = code.strip_suffix('\n').unwrap_or(code);
Some(code.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_extract_fenced_code_simple() {
let block = r#"
```rust
fn main() {}
```
"#;
let code = extract_fenced_code(block).unwrap();
assert_eq!(code, "fn main() {}");
}
#[test]
fn test_extract_fenced_code_with_nested_fences() {
let block = r#"
````markdown
# Example
```rust
fn main() {}
```
````
"#;
let code = extract_fenced_code(block).unwrap();
assert!(code.contains("```rust"));
assert!(code.contains("fn main() {}"));
}
#[test]
fn test_extract_fenced_code_five_backticks() {
let block = r#"
`````markdown
Here's a nested block:
````rust
fn nested() {}
````
`````
"#;
let code = extract_fenced_code(block).unwrap();
assert!(code.contains("````rust"));
}
#[tokio::test]
async fn test_extract_simple_file() -> Result<()> {
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
fs::write(
&md_path,
r#"## src/main.rs
```rust
fn main() {
println!("Hello, world!");
}
```
## src/lib.rs
```rust
pub fn add(a: i32, b: i32) -> i32 {
a + b
}
```
"#,
)?;
extract_from_markdown(&md_path, Some(&extract_dir)).await?;
let main_content = fs::read_to_string(extract_dir.join("src/main.rs"))?;
assert!(main_content.contains("Hello, world!"));
let lib_content = fs::read_to_string(extract_dir.join("src/lib.rs"))?;
assert!(lib_content.contains("pub fn add"));
Ok(())
}
#[tokio::test]
async fn test_extract_skips_binary_marker() -> Result<()> {
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
fs::write(
&md_path,
r#"## assets/image.png
(binary file omitted)
## src/main.rs
```rust
fn main() {}
```
"#,
)?;
extract_from_markdown(&md_path, Some(&extract_dir)).await?;
assert!(!extract_dir.join("assets/image.png").exists());
assert!(extract_dir.join("src/main.rs").exists());
Ok(())
}
#[tokio::test]
async fn test_extract_with_extended_fences() -> Result<()> {
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
fs::write(
&md_path,
r#"## README.md
````markdown
# Example
```rust
fn main() {}
```
````
"#,
)?;
extract_from_markdown(&md_path, Some(&extract_dir)).await?;
let content = fs::read_to_string(extract_dir.join("README.md"))?;
assert!(content.contains("```rust"));
Ok(())
}
#[tokio::test]
async fn test_extract_markdown_file_with_internal_headers() -> Result<()> {
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
let readme_content = "# Title\n\n## Internal Heading\n\nSome text.";
fs::write(
&md_path,
format!(
"## docs/README.md\n\n```markdown\n{readme_content}\n```\n\n## src/main.rs\n\n```rust\nfn main() {{}}\n```\n"
),
)?;
extract_from_markdown(&md_path, Some(&extract_dir)).await?;
let content = fs::read_to_string(extract_dir.join("docs/README.md"))?;
assert_eq!(content, readme_content);
assert!(extract_dir.join("src/main.rs").exists());
Ok(())
}
#[tokio::test]
async fn test_extract_rejects_path_traversal() -> Result<()> {
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
let escaped_path = temp_dir.path().join("escaped.txt");
fs::write(
&md_path,
r#"## ../escaped.txt
```text
malicious
```
"#,
)?;
let err = extract_from_markdown(&md_path, Some(&extract_dir))
.await
.expect_err("path traversal must fail");
assert!(err.to_string().contains("Invalid restore path"));
assert!(!escaped_path.exists());
Ok(())
}
#[tokio::test]
async fn test_extract_preserves_leading_whitespace_in_filename() -> Result<()> {
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
fs::write(
&md_path,
r#"## leading.txt
```text
content
```
"#,
)?;
extract_from_markdown(&md_path, Some(&extract_dir)).await?;
let restored = extract_dir.join(" leading.txt");
assert!(restored.exists());
assert_eq!(fs::read_to_string(restored)?, "content");
Ok(())
}
#[tokio::test]
async fn test_extract_with_magic_header() -> Result<()> {
use crate::writer::OUTPUT_MAGIC_HEADER;
let temp_dir = tempdir()?;
let md_path = temp_dir.path().join("test.md");
let extract_dir = temp_dir.path().join("extracted");
let content = format!(
r#"{}
## src/main.rs
```rust
fn main() {{
println!("Extracted!");
}}
```
"#,
OUTPUT_MAGIC_HEADER
);
fs::write(&md_path, content)?;
extract_from_markdown(&md_path, Some(&extract_dir)).await?;
let main_content = fs::read_to_string(extract_dir.join("src/main.rs"))?;
assert!(main_content.contains("Extracted!"));
Ok(())
}
}