use crate::utils::get_language_tag;
use anyhow::{Context, Result};
use content_inspector::{ContentType, inspect};
use ignore::DirEntry;
use log::debug;
use memmap2::MmapOptions;
use std::fs::File as StdFile;
use std::path::Path;
use std::str;
use tokio::fs::File;
use tokio::io::{AsyncWriteExt, BufWriter};
pub const OUTPUT_MAGIC_HEADER: &str = "<!-- src2md:v1 -->\n";
pub const OUTPUT_MAGIC_BYTES: &[u8] = b"<!-- src2md:v1 -->";
pub struct MarkdownWriter<W: AsyncWriteExt + Unpin> {
writer: BufWriter<W>,
header_written: bool,
}
impl MarkdownWriter<tokio::fs::File> {
pub fn new(writer: BufWriter<File>) -> Self {
Self {
writer,
header_written: false,
}
}
async fn ensure_header_written(&mut self) -> Result<()> {
if !self.header_written {
self.writer
.write_all(OUTPUT_MAGIC_HEADER.as_bytes())
.await
.context("Failed to write output header")?;
self.writer
.write_all(b"\n")
.await
.context("Failed to write newline after header")?;
self.header_written = true;
}
Ok(())
}
pub async fn write_entry(&mut self, entry: &DirEntry, project_root: &Path) -> Result<()> {
self.ensure_header_written().await?;
let path = entry.path();
let rel_path = path.strip_prefix(project_root).unwrap_or(path);
debug!("Processing: {}", rel_path.display());
self.writer
.write_all(format!("## {}\n\n", rel_path.display()).as_bytes())
.await
.with_context(|| format!("Failed to write heading for {}", rel_path.display()))?;
let file = StdFile::open(path)
.with_context(|| format!("Failed to open file: {}", path.display()))?;
let mmap = unsafe {
MmapOptions::new()
.map(&file)
.with_context(|| format!("Failed to mmap file: {}", path.display()))?
};
let sample_size = std::cmp::min(8192, mmap.len());
let content_type = inspect(&mmap[..sample_size]);
if content_type == ContentType::BINARY {
self.writer
.write_all(b"(binary file omitted)\n\n")
.await
.with_context(|| {
format!("Failed to write binary marker for {}", rel_path.display())
})?;
return Ok(());
}
let lang = get_language_tag(path);
let content: String = match str::from_utf8(&mmap) {
Ok(s) => s.to_string(),
Err(_) => std::fs::read_to_string(path)
.with_context(|| format!("Fallback read failed for {}", path.display()))?,
};
let text = content.as_str();
let fence = calculate_fence(text);
self.writer
.write_all(format!("{}{}\n", fence, lang).as_bytes())
.await
.with_context(|| format!("Failed to write opening fence for {}", rel_path.display()))?;
self.writer
.write_all(text.as_bytes())
.await
.with_context(|| format!("Failed to write content for {}", rel_path.display()))?;
self.writer
.write_all(format!("\n{}\n\n", fence).as_bytes())
.await
.with_context(|| format!("Failed to write closing fence for {}", rel_path.display()))?;
debug!("Completed: {}", rel_path.display());
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.writer.flush().await.context("Failed to flush output")
}
}
fn calculate_fence(content: &str) -> String {
let max_backtick_run = content
.lines()
.filter_map(|line| {
let trimmed = line.trim_start();
if trimmed.starts_with('`') {
Some(trimmed.chars().take_while(|&c| c == '`').count())
} else {
None
}
})
.max()
.unwrap_or(0);
let fence_len = max_backtick_run.max(2) + 1;
"`".repeat(fence_len)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_fence_no_backticks() {
let content = "fn main() {\n println!(\"Hello\");\n}";
assert_eq!(calculate_fence(content), "```");
}
#[test]
fn test_calculate_fence_with_triple_backticks() {
let content = "# Example\n\n```rust\nfn main() {}\n```";
assert_eq!(calculate_fence(content), "````");
}
#[test]
fn test_calculate_fence_with_quad_backticks() {
let content = "````\nsome code\n````";
assert_eq!(calculate_fence(content), "`````");
}
#[test]
fn test_calculate_fence_indented_backticks() {
let content = " ```rust\n fn main() {}\n ```";
assert_eq!(calculate_fence(content), "````");
}
#[test]
fn test_calculate_fence_single_backticks() {
let content = "`inline code` at start of line\nmore text";
assert_eq!(calculate_fence(content), "```");
}
#[test]
fn test_calculate_fence_double_backticks() {
let content = "``double backtick`` code";
assert_eq!(calculate_fence(content), "```");
}
#[test]
fn test_magic_header_format() {
assert!(OUTPUT_MAGIC_HEADER.starts_with("<!--"));
assert!(OUTPUT_MAGIC_HEADER.contains("src2md"));
assert!(OUTPUT_MAGIC_HEADER.contains("v1"));
}
}