model-rs 0.1.0

A Rust CLI tool for downloading HuggingFace models and running local LLM inference
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StreamState {
    Text,
    Lang,
    Code,
}

pub enum CodeStreamEvent<'a> {
    Start { language: &'a str },
    Chunk { language: &'a str, code: &'a str },
    End,
}

pub struct MarkdownStreamRenderer {
    state: StreamState,
    buffer: String,
    lang: String,
    code: String,
}

impl MarkdownStreamRenderer {
    pub fn new() -> Self {
        Self {
            state: StreamState::Text,
            buffer: String::new(),
            lang: String::new(),
            code: String::new(),
        }
    }

    pub fn push_with<FText, FCode>(&mut self, chunk: &str, mut on_text: FText, mut on_code: FCode)
    where
        FText: FnMut(&str),
        FCode: FnMut(CodeStreamEvent<'_>),
    {
        self.buffer.push_str(chunk);
        self.drain(&mut on_text, &mut on_code);
    }

    pub fn finish_with<FText, FCode>(&mut self, mut on_text: FText, mut on_code: FCode)
    where
        FText: FnMut(&str),
        FCode: FnMut(CodeStreamEvent<'_>),
    {
        self.drain(&mut on_text, &mut on_code);

        match self.state {
            StreamState::Text => {
                if !self.buffer.is_empty() {
                    on_text(&self.buffer);
                    self.buffer.clear();
                }
            }
            StreamState::Lang => {
                if !self.lang.is_empty() {
                    on_text(&self.lang);
                    self.lang.clear();
                }
                self.state = StreamState::Text;
            }
            StreamState::Code => {
                self.code.push_str(&self.buffer);
                self.buffer.clear();
                if !self.code.is_empty() {
                    let lang = self.lang.trim();
                    on_code(CodeStreamEvent::Start { language: lang });
                    on_code(CodeStreamEvent::Chunk {
                        language: lang,
                        code: &self.code,
                    });
                    on_code(CodeStreamEvent::End);
                }
                self.lang.clear();
                self.code.clear();
                self.state = StreamState::Text;
            }
        }
    }

    fn drain<FText, FCode>(&mut self, on_text: &mut FText, on_code: &mut FCode)
    where
        FText: FnMut(&str),
        FCode: FnMut(CodeStreamEvent<'_>),
    {
        loop {
            match self.state {
                StreamState::Text => {
                    if let Some(pos) = self.buffer.find("```") {
                        if pos > 0 {
                            on_text(&self.buffer[..pos]);
                        }
                        self.buffer.drain(..pos + 3);
                        self.lang.clear();
                        self.code.clear();
                        self.state = StreamState::Lang;
                        continue;
                    }

                    let keep = self.trailing_backticks_to_keep();
                    let emit_len = self.buffer.len().saturating_sub(keep);
                    if emit_len > 0 {
                        on_text(&self.buffer[..emit_len]);
                        self.buffer.drain(..emit_len);
                    }
                    break;
                }
                StreamState::Lang => {
                    if let Some(nl) = self.buffer.find('\n') {
                        self.lang.push_str(&self.buffer[..nl]);
                        self.buffer.drain(..nl + 1);
                        let lang = self.lang.trim();
                        on_code(CodeStreamEvent::Start { language: lang });
                        self.state = StreamState::Code;
                        continue;
                    }

                    if !self.buffer.is_empty() {
                        self.lang.push_str(&self.buffer);
                        self.buffer.clear();
                    }
                    break;
                }
                StreamState::Code => {
                    if let Some(pos) = self.buffer.find("```") {
                        if pos > 0 {
                            self.code.push_str(&self.buffer[..pos]);
                        }
                        self.buffer.drain(..pos + 3);
                        if self.buffer.starts_with('\n') {
                            self.buffer.drain(..1);
                        }
                        let lang = self.lang.trim();
                        while let Some(nl) = self.code.find('\n') {
                            let line = &self.code[..nl + 1];
                            on_code(CodeStreamEvent::Chunk {
                                language: lang,
                                code: line,
                            });
                            self.code.drain(..nl + 1);
                        }
                        if !self.code.is_empty() {
                            let rest = std::mem::take(&mut self.code);
                            on_code(CodeStreamEvent::Chunk {
                                language: lang,
                                code: &rest,
                            });
                        }
                        on_code(CodeStreamEvent::End);
                        self.lang.clear();
                        self.code.clear();
                        self.state = StreamState::Text;
                        continue;
                    }

                    let keep = self.trailing_backticks_to_keep();
                    let emit_len = self.buffer.len().saturating_sub(keep);
                    if emit_len > 0 {
                        self.code.push_str(&self.buffer[..emit_len]);
                        self.buffer.drain(..emit_len);

                        let lang = self.lang.trim();
                        while let Some(nl) = self.code.find('\n') {
                            let line = &self.code[..nl + 1];
                            on_code(CodeStreamEvent::Chunk {
                                language: lang,
                                code: line,
                            });
                            self.code.drain(..nl + 1);
                        }
                    }
                    break;
                }
            }
        }
    }

    fn trailing_backticks_to_keep(&self) -> usize {
        let bytes = self.buffer.as_bytes();
        let mut count = 0usize;
        let mut i = bytes.len();
        while count < 2 && i > 0 {
            if bytes[i - 1] == b'`' {
                count += 1;
                i -= 1;
            } else {
                break;
            }
        }
        count
    }
}

impl Default for MarkdownStreamRenderer {
    fn default() -> Self {
        Self::new()
    }
}