use canon_core::{Chunk, CPError, Result, text};
use std::path::Path;
use uuid::Uuid;
pub trait Parser: Send + Sync {
fn parse(&self, path: &Path) -> Result<String>;
fn supported_extensions(&self) -> &[&str];
}
struct ParserRegistry {
parsers: Vec<Box<dyn Parser>>,
}
impl ParserRegistry {
fn new() -> Self {
Self {
parsers: vec![
Box::new(MarkdownParser),
Box::new(TextParser),
],
}
}
fn find_parser(&self, extension: &str) -> Option<&dyn Parser> {
for parser in &self.parsers {
if parser
.supported_extensions()
.iter()
.any(|e| e.eq_ignore_ascii_case(extension))
{
return Some(parser.as_ref());
}
}
None
}
}
pub fn parse_file(path: &Path) -> Result<String> {
let registry = ParserRegistry::new();
let extension = path
.extension()
.and_then(|e| e.to_str())
.ok_or_else(|| CPError::Parse("No file extension".into()))?;
let parser = registry
.find_parser(extension)
.ok_or_else(|| CPError::Parse(format!("No parser for extension: {}", extension)))?;
parser.parse(path)
}
struct MarkdownParser;
impl Parser for MarkdownParser {
fn parse(&self, path: &Path) -> Result<String> {
let content = std::fs::read_to_string(path)?;
let parser = pulldown_cmark::Parser::new(&content);
let mut out = String::new();
for event in parser {
match event {
pulldown_cmark::Event::Start(pulldown_cmark::Tag::Heading { level, .. }) => {
out.push('\n');
let level_str = match level {
pulldown_cmark::HeadingLevel::H1 => "# ",
pulldown_cmark::HeadingLevel::H2 => "## ",
pulldown_cmark::HeadingLevel::H3 => "### ",
pulldown_cmark::HeadingLevel::H4 => "#### ",
pulldown_cmark::HeadingLevel::H5 => "##### ",
pulldown_cmark::HeadingLevel::H6 => "###### ",
};
out.push_str(level_str);
}
pulldown_cmark::Event::End(pulldown_cmark::TagEnd::Heading(_)) => {
out.push('\n');
}
pulldown_cmark::Event::Text(t)
| pulldown_cmark::Event::Code(t) => {
out.push_str(&t);
}
pulldown_cmark::Event::SoftBreak
| pulldown_cmark::Event::HardBreak => {
out.push('\n');
}
pulldown_cmark::Event::End(pulldown_cmark::TagEnd::Paragraph) => {
out.push_str("\n\n");
}
_ => {}
}
}
Ok(text::normalize(&out))
}
fn supported_extensions(&self) -> &[&str] {
&["md", "markdown"]
}
}
struct TextParser;
impl Parser for TextParser {
fn parse(&self, path: &Path) -> Result<String> {
let content = std::fs::read_to_string(path)?;
Ok(text::normalize(&content))
}
fn supported_extensions(&self) -> &[&str] {
&["txt", "text"]
}
}
#[derive(Debug, Clone)]
pub struct ChunkConfig {
pub chunk_size: usize,
pub overlap: usize,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
chunk_size: 1000,
overlap: 200,
}
}
}
pub struct Chunker {
config: ChunkConfig,
}
impl Default for Chunker {
fn default() -> Self {
Self::new(ChunkConfig::default())
}
}
impl Chunker {
pub fn new(config: ChunkConfig) -> Self {
Self { config }
}
pub fn chunk(&self, doc_id: Uuid, text: &str) -> Result<Vec<Chunk>> {
let mut chunks = Vec::new();
let chars: Vec<char> = text.chars().collect();
let total_len = chars.len();
if total_len == 0 {
return Ok(chunks);
}
let mut offset = 0usize;
let mut seq = 0u32;
while offset < total_len {
let end = (offset + self.config.chunk_size).min(total_len);
let chunk_end = self.find_break_point(&chars, offset, end, total_len);
let chunk_text: String = chars[offset..chunk_end].iter().collect();
let chunk_text = chunk_text.trim().to_string();
if !chunk_text.is_empty() {
chunks.push(Chunk::new(doc_id, chunk_text, offset as u64, seq));
seq += 1;
}
if chunk_end >= total_len {
break;
}
offset = if chunk_end > offset + self.config.overlap {
chunk_end - self.config.overlap
} else {
chunk_end
};
}
Ok(chunks)
}
fn find_break_point(
&self,
chars: &[char],
start: usize,
target_end: usize,
total_len: usize,
) -> usize {
if target_end >= total_len {
return total_len;
}
for i in (start..target_end).rev() {
if chars[i] == '\n' && i + 1 < total_len && chars[i + 1] == '#' {
return i + 1;
}
}
for i in (start..target_end).rev() {
if chars[i] == '\n' && i + 1 < total_len && chars[i + 1] == '\n' {
return i + 2;
}
}
for i in (start..target_end).rev() {
if (chars[i] == '.' || chars[i] == '!' || chars[i] == '?')
&& i + 1 < total_len
&& chars[i + 1].is_whitespace()
{
return i + 1;
}
}
for i in (start..target_end).rev() {
if chars[i].is_whitespace() {
return i + 1;
}
}
target_end
}
}