use std::fs;
use std::path::{Path, PathBuf};
use std::process;

fn main() {
    let current_dir = match std::env::current_dir() {
        Ok(d) => d,
        Err(e) => {
            eprintln!("无法获取当前目录: {e}");
            process::exit(1);
        }
    };

    let mut rs_files = Vec::new();
    find_rs_files(&current_dir, &mut rs_files);

    let mut modified_files: Vec<PathBuf> = Vec::new();
    let mut total_comments: u32 = 0;

    for file_path in &rs_files {
        let content = match fs::read_to_string(file_path) {
            Ok(c) => c,
            Err(e) => {
                eprintln!("无法读取文件 {}: {e}", file_path.display());
                process::exit(1);
            }
        };

        let (processed, comment_count) = process(&content);
        if comment_count > 0 {
            let cleaned = clean_output(&processed);
            if let Err(e) = fs::write(file_path, &cleaned) {
                eprintln!("无法写入文件 {}: {e}", file_path.display());
                process::exit(1);
            }
            modified_files.push(file_path.clone());
            total_comments += comment_count;
        }
    }

    let count = modified_files.len();
    if count > 0 {
        println!("处理了 {count} 个文件,删除了 {total_comments} 条注释:");
        for path in &modified_files {
            let relative = path.strip_prefix(&current_dir).unwrap_or(path);
            println!("  {}", relative.display());
        }
    } else {
        println!("处理了 0 个文件,删除了 0 条注释");
    }
}

fn find_rs_files(dir: &Path, files: &mut Vec<PathBuf>) {
    let entries = match fs::read_dir(dir) {
        Ok(e) => e,
        Err(e) => {
            eprintln!("无法读取目录 {}: {e}", dir.display());
            process::exit(1);
        }
    };

    for entry in entries {
        let entry = match entry {
            Ok(e) => e,
            Err(e) => {
                eprintln!("无法读取目录项: {e}");
                process::exit(1);
            }
        };
        let path = entry.path();
        if path.is_dir() {
            find_rs_files(&path, files);
        } else if path.extension().map_or(false, |ext| ext == "rs") {
            files.push(path);
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
    Code,
    LineComment,
    BlockComment(u32),
    String_,
    StringBackslash,
    RawString(u32),
}

fn process(source: &str) -> (String, u32) {
    let chars: Vec<(usize, char)> = source.char_indices().collect();
    let mut output = String::with_capacity(source.len());
    let mut line_buf = String::new();
    let mut state = State::Code;
    let mut comment_count: u32 = 0;
    let mut i = 0;

    while i < chars.len() {
        let (pos, ch) = chars[i];

        match state {
            State::Code => {
                if ch == '\n' {
                    flush_line(&mut output, &mut line_buf);
                    i += 1;
                } else if ch == '/' && i + 1 < chars.len() {
                    match chars[i + 1].1 {
                        '/' => {
                            state = State::LineComment;
                            comment_count += 1;
                            i += 2;
                            continue;
                        }
                        '*' => {
                            state = State::BlockComment(1);
                            comment_count += 1;
                            i += 2;
                            continue;
                        }
                        _ => {
                            line_buf.push(ch);
                            i += 1;
                        }
                    }
                } else if ch == '"' {
                    line_buf.push(ch);
                    state = State::String_;
                    i += 1;
                } else if ch == 'r' {
                    let preceded_by_ident = pos > 0 && {
                        let prev_byte = source.as_bytes()[pos - 1];
                        prev_byte.is_ascii_alphanumeric() || prev_byte == b'_'
                    };

                    if !preceded_by_ident {
                        let mut hash_count: u32 = 0;
                        let mut j = i + 1;
                        while j < chars.len() && chars[j].1 == '#' {
                            hash_count += 1;
                            j += 1;
                        }
                        if j < chars.len() && chars[j].1 == '"' {
                            line_buf.push('r');
                            for _ in 0..hash_count {
                                line_buf.push('#');
                            }
                            line_buf.push('"');
                            state = State::RawString(hash_count);
                            i += 2 + hash_count as usize;
                            continue;
                        }
                    }
                    line_buf.push(ch);
                    i += 1;
                } else {
                    line_buf.push(ch);
                    i += 1;
                }
            }

            State::LineComment => {
                if ch == '\n' {
                    flush_line(&mut output, &mut line_buf);
                    state = State::Code;
                }
                i += 1;
            }

            State::BlockComment(depth) => {
                if ch == '/' && i + 1 < chars.len() && chars[i + 1].1 == '*' {
                    state = State::BlockComment(depth + 1);
                    i += 2;
                } else if ch == '*' && i + 1 < chars.len() && chars[i + 1].1 == '/' {
                    if depth == 1 {
                        state = State::Code;
                    } else {
                        state = State::BlockComment(depth - 1);
                    }
                    i += 2;
                } else if ch == '\n' {
                    flush_line(&mut output, &mut line_buf);
                    i += 1;
                } else {
                    i += 1;
                }
            }

            State::String_ => {
                if ch == '\\' {
                    line_buf.push('\\');
                    state = State::StringBackslash;
                } else if ch == '"' {
                    line_buf.push('"');
                    state = State::Code;
                } else {
                    line_buf.push(ch);
                }
                i += 1;
            }

            State::StringBackslash => {
                line_buf.push(ch);
                state = State::String_;
                i += 1;
            }

            State::RawString(hash_count) => {
                if ch == '"' {
                    let mut found_hashes: u32 = 0;
                    let mut k = i + 1;
                    while k < chars.len() && found_hashes < hash_count && chars[k].1 == '#' {
                        found_hashes += 1;
                        k += 1;
                    }
                    if found_hashes == hash_count {
                        line_buf.push('"');
                        for _ in 0..hash_count {
                            line_buf.push('#');
                        }
                        state = State::Code;
                        i += 1 + hash_count as usize;
                        continue;
                    }
                }
                line_buf.push(ch);
                i += 1;
            }
        }
    }

    flush_line(&mut output, &mut line_buf);

    if let State::BlockComment(_) = state {
        panic!("错误:文件中有未闭合的块注释(/* 缺少匹配的 */)");
    }

    (output, comment_count)
}

fn flush_line(output: &mut String, line_buf: &mut String) {
    let trimmed = line_buf.trim_end();
    if !trimmed.is_empty() {
        output.push_str(trimmed);
        output.push('\n');
    }
    line_buf.clear();
}

fn clean_output(input: &str) -> String {
    let mut result = String::with_capacity(input.len());
    for line in input.lines() {
        let trimmed = line.trim_end();
        if !trimmed.is_empty() {
            result.push_str(trimmed);
            result.push('\n');
        }
    }
    result
}

#[cfg(test)]
mod tests {
    use super::*;

    fn run_case(input: &str, expected_output: &str, expected_count: u32) {
        let (output, count) = process(input);
        assert_eq!(
            output, expected_output,
            "output mismatch\ninput: {input:?}\nexpected: {expected_output:?}\ngot: {output:?}"
        );
        assert_eq!(count, expected_count, "comment count mismatch");
    }

    #[test]
    fn test_basic() {
        run_case(
            "fn main() {\n    // 这是注释\n    let x = 1; // 行内注释\n    /* 块注释 */\n    let y = 2; /* 行内块 */\n}\n",
            "fn main() {\n    let x = 1;\n    let y = 2;\n}\n",
            4,
        );
    }

    #[test]
    fn test_nested_block_comment() {
        run_case(
            "/* 外层 /* 内层 */ 外层继续 */\nlet x = 1;\n",
            "let x = 1;\n",
            1,
        );
    }

    #[test]
    fn test_fake_comments_in_strings() {
        run_case(
            concat!(
                "let a = \"// 不是注释\";\n",
                "let b = \"/* 也不是 */\";\n",
                "let c = r\"// 还不是\";\n",
                "let d = r#\"/* 依旧不是 */\"#;\n",
            ),
            concat!(
                "let a = \"// 不是注释\";\n",
                "let b = \"/* 也不是 */\";\n",
                "let c = r\"// 还不是\";\n",
                "let d = r#\"/* 依旧不是 */\"#;\n",
            ),
            0,
        );
    }

    #[test]
    fn test_doc_comments() {
        run_case(
            "/// 外部文档注释\n//! 内部文档注释\n/** 块文档注释 */\nfn foo() {}\n",
            "fn foo() {}\n",
            3,
        );
    }

    #[test]
    fn test_url_not_deleted() {
        run_case(
            concat!(
                "let url = \"https://example.com\";\n",
                "let path = \"//path/to/resource\";\n",
            ),
            concat!(
                "let url = \"https://example.com\";\n",
                "let path = \"//path/to/resource\";\n",
            ),
            0,
        );
    }

    #[test]
    #[should_panic(expected = "未闭合")]
    fn test_unclosed_block_comment() {
        process("/* 未闭合的注释\nlet x = 1;\n");
    }

    #[test]
    fn test_clean_output() {
        assert_eq!(clean_output("a\n\n\nb\n"), "a\nb\n");
        assert_eq!(clean_output("a\nb\n"), "a\nb\n");
        assert_eq!(clean_output("\n\n\na\n"), "a\n");
        assert_eq!(clean_output("a\n\n\n"), "a\n");
    }

    #[test]
    fn test_inline_comment_trailing_space() {
        let input = "    let x = 1; // comment\n    let y = 2;\n";
        let (output, count) = process(input);
        assert_eq!(count, 1);
        let cleaned = clean_output(&output);
        assert_eq!(cleaned, "    let x = 1;\n    let y = 2;\n");
    }

    #[test]
    fn test_multiline_raw_string() {
        let input = "let s = r#\"hello\nworld\"#;\n";
        let (output, count) = process(input);
        assert_eq!(count, 0);
        assert_eq!(output, "let s = r#\"hello\nworld\"#;\n");
    }
}