harmony_protocol/
encoding.rs

1use crate::{
2    chat::{Author, Content, Message, ReasoningEffort, Role, SystemContent, TextContent},
3    tiktoken::{CoreBPE, Rank},
4};
5use anyhow::Context as _;
6use std::{
7    collections::{HashMap, HashSet},
8    sync::Arc,
9};
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
12pub struct ParsedHeader {
13    author: Author,
14    recipient: Option<String>,
15    channel: Option<String>,
16    content_type: Option<String>,
17}
18
19#[derive(thiserror::Error, Debug)]
20pub(crate) enum RenderFormattingTokenError {
21    #[error("tried to render unmapped formatting token {0}")]
22    UnmappedToken(FormattingToken),
23
24    #[error(
25        "Expected encoding of formatting token {token} to be a single token, but got {encoding:?}"
26    )]
27    InvalidEncoding {
28        token: FormattingToken,
29        encoding: Vec<Rank>,
30    },
31}
32
33#[allow(dead_code)]
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
35pub(crate) enum FormattingToken {
36    Start,
37    Message,
38    EndMessage,
39    EndMessageDoneSampling,
40    EndMessageAssistantToTool,
41    Refusal,
42    ConstrainedFormat,
43    Channel,
44    BeginUntrusted,
45    EndUntrusted,
46    MetaSep,
47    MetaEnd,
48}
49
50impl FormattingToken {
51    fn as_str(&self) -> &str {
52        match self {
53            FormattingToken::Start => "<|start|>",
54            FormattingToken::Message => "<|message|>",
55            FormattingToken::EndMessage => "<|end|>",
56            FormattingToken::EndMessageDoneSampling => "<|return|>",
57            FormattingToken::EndMessageAssistantToTool => "<|call|>",
58            FormattingToken::Refusal => "<|refusal|>",
59            FormattingToken::ConstrainedFormat => "<|constrain|>",
60            FormattingToken::Channel => "<|channel|>",
61            FormattingToken::BeginUntrusted => "<|untrusted|>",
62            FormattingToken::EndUntrusted => "<|end_untrusted|>",
63            FormattingToken::MetaSep => "<|channel|>",
64            FormattingToken::MetaEnd => "<|meta_end|>",
65        }
66    }
67}
68
69impl std::fmt::Display for FormattingToken {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "{}", self.as_str())
72    }
73}
74
75#[allow(dead_code)]
76#[derive(Clone)]
77pub struct HarmonyEncoding {
78    pub(crate) name: String,
79    pub(crate) n_ctx: usize,
80    pub(crate) max_message_tokens: usize,
81    pub(crate) max_action_length: usize,
82    pub(crate) tokenizer_name: String,
83    pub(crate) tokenizer: Arc<CoreBPE>,
84    pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
85    pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
86    pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
87}
88
89impl std::fmt::Debug for HarmonyEncoding {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.debug_struct("HarmonyEncoding")
92            .field("name", &self.name)
93            .field("tokenizer_name", &self.tokenizer_name)
94            .field("n_ctx", &self.n_ctx)
95            .field("max_message_tokens", &self.max_message_tokens)
96            .field("max_action_length", &self.max_action_length)
97            .finish()
98    }
99}
100
101impl std::fmt::Display for HarmonyEncoding {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "Renderer({})", self.name)
104    }
105}
106
107impl HarmonyEncoding {
108    pub fn name(&self) -> &str {
109        &self.name
110    }
111
112    pub fn tokenizer_name(&self) -> &str {
113        &self.tokenizer_name
114    }
115
116    pub fn max_message_tokens(&self) -> usize {
117        self.max_message_tokens
118    }
119
120    pub fn tokenizer(&self) -> &CoreBPE {
121        &self.tokenizer
122    }
123
124    pub fn stop_tokens(&self) -> anyhow::Result<HashSet<Rank>> {
125        self.stop_formatting_tokens
126            .iter()
127            .copied()
128            .map(|t| match self.render_formatting_token(t) {
129                Ok(t) => Ok(t),
130                Err(RenderFormattingTokenError::UnmappedToken(_)) => Err(anyhow::anyhow!(
131                    "token {t} was specified as a stop token, but is not mapped"
132                )),
133                Err(e) => Err(anyhow::anyhow!(e).context("could not render stop token")),
134            })
135            .collect()
136    }
137
138    pub fn stop_tokens_for_assistant_actions(&self) -> anyhow::Result<HashSet<Rank>> {
139        self.stop_formatting_tokens_for_assistant_actions
140            .iter()
141            .copied()
142            .map(|t| match self.render_formatting_token(t) {
143                Ok(t) => Ok(t),
144                Err(RenderFormattingTokenError::UnmappedToken(_)) => Err(anyhow::anyhow!(
145                    "token {t} was specified as a stop token, but is not mapped"
146                )),
147                Err(e) => Err(anyhow::anyhow!(e).context("could not render stop token")),
148            })
149            .collect()
150    }
151
152    pub fn render_conversation_into<'a, I, B>(
153        &self,
154        conversation: I,
155        into: &mut B,
156        config: Option<&RenderConversationConfig>,
157    ) -> anyhow::Result<()>
158    where
159        I: IntoIterator<Item = &'a Message>,
160        B: Extend<Rank>,
161    {
162        let messages: Vec<_> = conversation.into_iter().collect();
163        let has_function_tools = messages.iter().any(|msg| {
164            msg.content.iter().any(|c| {
165                if let Content::DeveloperContent(dev) = c {
166                    if let Some(tools) = &dev.tools {
167                        if let Some(ns) = tools.get("functions") {
168                            !ns.tools.is_empty()
169                        } else {
170                            false
171                        }
172                    } else {
173                        false
174                    }
175                } else {
176                    false
177                }
178            })
179        });
180        let render_options = RenderOptions {
181            conversation_has_function_tools: has_function_tools,
182        };
183        let last_assistant_is_final = messages
184            .iter()
185            .rev()
186            .find_map(|msg| {
187                (msg.author.role == Role::Assistant)
188                    .then(|| msg.channel.as_deref() == Some("final"))
189            })
190            .unwrap_or(false);
191
192        let should_drop_analysis =
193            config.is_some_and(|c| c.auto_drop_analysis && last_assistant_is_final);
194
195        let first_final_idx = messages
196            .iter()
197            .position(|msg| msg.channel.as_deref() == Some("final"));
198
199        let result = messages
200            .iter()
201            .enumerate()
202            .filter(|(idx, msg)| {
203                !(should_drop_analysis
204                    && first_final_idx.is_some_and(|first| *idx < first)
205                    && msg.channel.as_deref() == Some("analysis"))
206            })
207            .try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
208        result?;
209        Ok(())
210    }
211
212    pub fn render_conversation_for_completion_into<'a, I, B>(
213        &self,
214        conversation: I,
215        next_turn_role: Role,
216        into: &mut B,
217        config: Option<&RenderConversationConfig>,
218    ) -> anyhow::Result<()>
219    where
220        I: IntoIterator<Item = &'a Message>,
221        B: Extend<Rank>,
222    {
223        let _config = config.unwrap_or(&RenderConversationConfig::default());
224        self.render_conversation_into(conversation, into, config)?;
225        self.render_formatting_token_into(FormattingToken::Start, into)?;
226        self.render_text_into(next_turn_role.as_str(), into)?;
227        Ok(())
228    }
229
230    pub fn render_conversation_for_completion<'a, I>(
231        &self,
232        conversation: I,
233        next_turn_role: Role,
234        config: Option<&RenderConversationConfig>,
235    ) -> anyhow::Result<Vec<Rank>>
236    where
237        I: IntoIterator<Item = &'a Message>,
238    {
239        let mut into = vec![];
240        self.render_conversation_for_completion_into(
241            conversation,
242            next_turn_role,
243            &mut into,
244            config,
245        )?;
246        Ok(into)
247    }
248
249    pub fn render_conversation_for_training<'a, I>(
250        &self,
251        conversation: I,
252        config: Option<&RenderConversationConfig>,
253    ) -> anyhow::Result<Vec<Rank>>
254    where
255        I: IntoIterator<Item = &'a Message>,
256    {
257        let messages: Vec<&Message> = conversation.into_iter().collect();
258        let mut out = vec![];
259        self.render_conversation_into(messages.iter().copied(), &mut out, config)?;
260        if let Some(last) = messages.last() {
261            if last.author.role == Role::Assistant && last.channel.as_deref() == Some("final") {
262                if let Some(last_token) = out.last_mut() {
263                    *last_token =
264                        self.render_formatting_token(FormattingToken::EndMessageDoneSampling)?;
265                }
266            }
267        }
268        Ok(out)
269    }
270
271    pub fn render_conversation<'a, I>(
272        &self,
273        conversation: I,
274        config: Option<&RenderConversationConfig>,
275    ) -> anyhow::Result<Vec<Rank>>
276    where
277        I: IntoIterator<Item = &'a Message>,
278    {
279        let mut out = vec![];
280        self.render_conversation_into(conversation, &mut out, config)?;
281        Ok(out)
282    }
283
284    pub fn render(
285        &self,
286        message: &Message,
287        render_options: Option<&RenderOptions>,
288    ) -> anyhow::Result<Vec<Rank>> {
289        let mut out = vec![];
290        Render::<Message>::render(self, message, &mut out, render_options)?;
291        Ok(out)
292    }
293
294    pub fn render_into<B>(
295        &self,
296        message: &Message,
297        into: &mut B,
298        render_options: Option<&RenderOptions>,
299    ) -> anyhow::Result<()>
300    where
301        B: Extend<Rank>,
302    {
303        Render::<Message>::render(self, message, into, render_options)
304    }
305
306    fn mapped_format_token(&self, t: FormattingToken) -> Option<&str> {
307        self.format_token_mapping.get(&t).map(|s| s.as_str())
308    }
309
310    fn render_formatting_token(
311        &self,
312        t: FormattingToken,
313    ) -> Result<Rank, RenderFormattingTokenError> {
314        let mapped = self
315            .mapped_format_token(t)
316            .ok_or(RenderFormattingTokenError::UnmappedToken(t))?;
317        let encoded = self.tokenizer.encode_with_special_tokens(mapped);
318        if encoded.len() != 1 {
319            return Err(RenderFormattingTokenError::InvalidEncoding {
320                token: t,
321                encoding: encoded,
322            });
323        }
324        Ok(encoded[0])
325    }
326
327    fn render_formatting_token_into<B>(
328        &self,
329        t: FormattingToken,
330        into: &mut B,
331    ) -> anyhow::Result<()>
332    where
333        B: Extend<Rank>,
334    {
335        let r = self.render_formatting_token(t)?;
336        into.extend(std::iter::once(r));
337        Ok(())
338    }
339
340    fn render_text_into<T, B>(&self, text: T, into: &mut B) -> anyhow::Result<()>
341    where
342        T: AsRef<str>,
343        B: Extend<Rank>,
344    {
345        into.extend(self.tokenizer.encode_ordinary(text.as_ref()));
346        Ok(())
347    }
348
349    pub fn parse_messages_from_completion_tokens<I>(
350        &self,
351        tokens: I,
352        role: Option<Role>,
353    ) -> anyhow::Result<Vec<Message>>
354    where
355        I: IntoIterator<Item = Rank>,
356    {
357        let mut parser = StreamableParser::new(self.clone(), role)?;
358        for token in tokens {
359            parser.process(token)?;
360        }
361        parser.process_eos()?;
362        Ok(parser.into_messages())
363    }
364
365    fn template_tools_section(
366        tools: &std::collections::BTreeMap<String, crate::chat::ToolNamespaceConfig>,
367    ) -> String {
368        let mut tool_sections = Vec::<String>::new();
369        tool_sections.push("# Tools".to_string());
370        for ns_config in tools.values() {
371            let mut tool_section_content = Vec::<String>::new();
372            tool_section_content.push(format!("## {}\n", ns_config.name));
373            if let Some(desc) = &ns_config.description {
374                for line in desc.lines() {
375                    if !ns_config.tools.is_empty() {
376                        tool_section_content.push(format!("// {line}"));
377                    } else {
378                        tool_section_content.push(line.to_string());
379                    }
380                }
381            }
382            if !ns_config.tools.is_empty() {
383                tool_section_content.push(format!("namespace {} {{\n", ns_config.name));
384                for tool in &ns_config.tools {
385                    for line in tool.description.lines() {
386                        tool_section_content.push(format!("// {line}"));
387                    }
388                    if let Some(params) = &tool.parameters {
389                        let param_type = Self::json_schema_to_typescript(params, "");
390                        tool_section_content.push(format!(
391                            "type {} = (_: {}) => any;\n",
392                            tool.name, param_type
393                        ));
394                    } else {
395                        tool_section_content.push(format!("type {} = () => any;\n", tool.name));
396                    }
397                }
398                tool_section_content.push(format!("}} // namespace {}", ns_config.name));
399            }
400            tool_sections.push(tool_section_content.join("\n"));
401        }
402        tool_sections.join("\n\n")
403    }
404
405    fn json_schema_to_typescript(schema: &serde_json::Value, indent: &str) -> String {
406        // Simple implementation for basic schema conversion
407        match schema.get("type").and_then(|v| v.as_str()) {
408            Some("object") => {
409                let mut out = String::new();
410                out.push_str("{\n");
411                if let Some(props) = schema.get("properties") {
412                    if let Some(props_map) = props.as_object() {
413                        let mut required = std::collections::HashSet::new();
414                        if let Some(req) = schema.get("required") {
415                            if let Some(req_arr) = req.as_array() {
416                                for r in req_arr {
417                                    if let Some(s) = r.as_str() {
418                                        required.insert(s);
419                                    }
420                                }
421                            }
422                        }
423                        for (key, val) in props_map {
424                            out.push_str(&format!(
425                                "{}{}{}: ",
426                                indent,
427                                key,
428                                if required.contains(key.as_str()) {
429                                    ""
430                                } else {
431                                    "?"
432                                }
433                            ));
434                            let type_str = Self::json_schema_to_typescript(val, &format!("{indent}    "));
435                            out.push_str(&type_str);
436                            out.push_str(",\n");
437                        }
438                    }
439                }
440                out.push_str(&format!("{indent}}}"));
441                out
442            }
443            Some("string") => "string".to_string(),
444            Some("number") | Some("integer") => "number".to_string(),
445            Some("boolean") => "boolean".to_string(),
446            Some("array") => {
447                if let Some(items) = schema.get("items") {
448                    format!("{}[]", Self::json_schema_to_typescript(items, indent))
449                } else {
450                    "Array<any>".to_string()
451                }
452            }
453            _ => "any".to_string(),
454        }
455    }
456}
457
458#[derive(Clone, Copy, Debug, Default)]
459pub struct RenderOptions {
460    pub conversation_has_function_tools: bool,
461}
462
463trait Render<T: ?Sized> {
464    fn render<B>(
465        &self,
466        item: &T,
467        into: &mut B,
468        render_options: Option<&RenderOptions>,
469    ) -> anyhow::Result<()>
470    where
471        B: Extend<Rank>;
472}
473
474impl Render<Message> for HarmonyEncoding {
475    fn render<B>(
476        &self,
477        message: &Message,
478        into: &mut B,
479        render_options: Option<&RenderOptions>,
480    ) -> anyhow::Result<()>
481    where
482        B: Extend<Rank>,
483    {
484        self.render_formatting_token_into(FormattingToken::Start, into)?;
485
486        if matches!(message.author.role, Role::Tool) {
487            if let Some(name) = &message.author.name {
488                self.render_text_into(name, into)?;
489            } else {
490                anyhow::bail!("Tools should have a name!");
491            }
492        } else {
493            self.render_text_into(message.author.role.as_str(), into)?;
494            if let Some(name) = &message.author.name {
495                self.render_text_into(format!(":{name}"), into)?;
496            }
497        };
498
499        if let Some(recipient) = &message.recipient {
500            if recipient != "all" {
501                self.render_text_into(format!(" to={recipient}"), into)?;
502            }
503        }
504
505        if let Some(channel) = &message.channel {
506            self.render_formatting_token_into(FormattingToken::Channel, into)?;
507            self.render_text_into(channel, into)?;
508        }
509
510        if let Some(content_type) = &message.content_type {
511            if let Some(constrain_marker) =
512                self.mapped_format_token(FormattingToken::ConstrainedFormat)
513            {
514                if let Some(rest) = content_type.strip_prefix(constrain_marker) {
515                    self.render_text_into(" ", into)?;
516                    self.render_formatting_token_into(FormattingToken::ConstrainedFormat, into)?;
517                    if !rest.is_empty() {
518                        self.render_text_into(rest, into)?;
519                    }
520                } else {
521                    self.render_text_into(format!(" {content_type}"), into)?;
522                }
523            } else {
524                self.render_text_into(format!(" {content_type}"), into)?;
525            }
526        }
527
528        self.render_formatting_token_into(FormattingToken::Message, into)?;
529        for content in message.content.iter() {
530            if let crate::chat::Content::SystemContent(_) = content {
531                anyhow::ensure!(
532                    message.author.role == crate::chat::Role::System,
533                    "SystemContent may only appear in system messages, found in {:?}",
534                    message.author.role
535                );
536            }
537            if let crate::chat::Content::DeveloperContent(_) = content {
538                anyhow::ensure!(
539                    message.author.role == crate::chat::Role::Developer,
540                    "DeveloperContent may only appear in developer messages, found in {:?}",
541                    message.author.role
542                );
543            }
544            Render::<Content>::render(self, content, into, render_options)?;
545        }
546
547        if message.author.role == crate::chat::Role::Assistant && message.recipient.is_some() {
548            self.render_formatting_token_into(FormattingToken::EndMessageAssistantToTool, into)?;
549        } else {
550            self.render_formatting_token_into(FormattingToken::EndMessage, into)?;
551        }
552        Ok(())
553    }
554}
555
556impl Render<Content> for HarmonyEncoding {
557    fn render<B>(
558        &self,
559        content: &Content,
560        into: &mut B,
561        render_options: Option<&RenderOptions>,
562    ) -> anyhow::Result<()>
563    where
564        B: Extend<Rank>,
565    {
566        match content {
567            Content::Text(text) => Render::<TextContent>::render(self, text, into, render_options),
568            Content::SystemContent(sys) => {
569                Render::<SystemContent>::render(self, sys, into, render_options)
570            }
571            Content::DeveloperContent(dev) => {
572                Render::<crate::chat::DeveloperContent>::render(self, dev, into, render_options)
573            }
574        }
575    }
576}
577
578impl Render<TextContent> for HarmonyEncoding {
579    fn render<B>(
580        &self,
581        text: &TextContent,
582        into: &mut B,
583        _render_options: Option<&RenderOptions>,
584    ) -> anyhow::Result<()>
585    where
586        B: Extend<Rank>,
587    {
588        self.render_text_into(&text.text, into)
589    }
590}
591
592impl Render<SystemContent> for HarmonyEncoding {
593    fn render<B>(
594        &self,
595        sys: &SystemContent,
596        into: &mut B,
597        render_options: Option<&RenderOptions>,
598    ) -> anyhow::Result<()>
599    where
600        B: Extend<Rank>,
601    {
602        let mut sections = Vec::<String>::new();
603
604        let mut top_section = Vec::<String>::new();
605        if let Some(model_id) = &sys.model_identity {
606            top_section.push(model_id.clone());
607        }
608        if let Some(knowledge_cutoff) = &sys.knowledge_cutoff {
609            top_section.push(format!("Knowledge cutoff: {knowledge_cutoff}"));
610        }
611        if let Some(conversation_start_date) = &sys.conversation_start_date {
612            top_section.push(format!("Current date: {conversation_start_date}"));
613        }
614        if !top_section.is_empty() {
615            sections.push(top_section.join("\n"));
616        }
617
618        let mut instructions_and_reasoning = Vec::<String>::new();
619        if let Some(effort) = sys.reasoning_effort {
620            let effort_str = match effort {
621                ReasoningEffort::Low => "low",
622                ReasoningEffort::Medium => "medium",
623                ReasoningEffort::High => "high",
624            };
625            instructions_and_reasoning.push(format!("Reasoning: {effort_str}"));
626        }
627        if !instructions_and_reasoning.is_empty() {
628            sections.push(instructions_and_reasoning.join("\n"));
629        }
630
631        if let Some(tools) = &sys.tools {
632            if !tools.is_empty() {
633                sections.push(Self::template_tools_section(tools));
634            }
635        }
636
637        if let Some(channel_config) = &sys.channel_config {
638            if !channel_config.valid_channels.is_empty() {
639                let channels_str = channel_config.valid_channels.join(", ");
640                let mut channels_header = format!("# Valid channels: {channels_str}.");
641                if channel_config.channel_required {
642                    channels_header.push_str(" Channel must be included for every message.");
643                }
644                if render_options.is_some_and(|o| o.conversation_has_function_tools) {
645                    channels_header.push('\n');
646                    channels_header.push_str(
647                        "Calls to these tools must go to the commentary channel: 'functions'.",
648                    );
649                }
650                sections.push(channels_header);
651            }
652        }
653        let formatted = sections.join("\n\n");
654        self.render_text_into(&formatted, into)?;
655        Ok(())
656    }
657}
658
659impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
660    fn render<B>(
661        &self,
662        dev: &crate::chat::DeveloperContent,
663        into: &mut B,
664        _render_options: Option<&RenderOptions>,
665    ) -> anyhow::Result<()>
666    where
667        B: Extend<Rank>,
668    {
669        let mut sections = Vec::<String>::new();
670
671        if let Some(instr) = &dev.instructions {
672            sections.push("# Instructions".to_string());
673            sections.push(instr.clone());
674        }
675
676        if let Some(tools) = &dev.tools {
677            if !tools.is_empty() {
678                sections.push(Self::template_tools_section(tools));
679            }
680        }
681        let formatted = sections.join("\n\n");
682        self.render_text_into(&formatted, into)?;
683        Ok(())
684    }
685}
686
687pub struct StreamableParser {
688    encoding: HarmonyEncoding,
689    next_role: Option<Role>,
690    tokens: Vec<Rank>,
691    messages: Vec<Message>,
692    state: StreamState,
693    stop_tokens: HashSet<Rank>,
694    last_content_delta: Option<String>,
695    undecoded_tokens: Vec<Rank>,
696}
697
698#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
699pub enum StreamState {
700    ExpectStart,
701    Header {
702        header_tokens: Vec<Rank>,
703    },
704    Content {
705        header: ParsedHeader,
706        content_tokens: Vec<Rank>,
707    },
708}
709
710impl StreamableParser {
711    pub fn new(encoding: HarmonyEncoding, role: Option<Role>) -> anyhow::Result<Self> {
712        let stop_tokens = encoding.stop_tokens()?;
713        let (state, next_role) = match role {
714            Some(role) => (
715                StreamState::Header {
716                    header_tokens: Vec::new(),
717                },
718                Some(role),
719            ),
720            None => (StreamState::ExpectStart, None),
721        };
722        Ok(Self {
723            encoding,
724            next_role,
725            tokens: Vec::new(),
726            messages: Vec::new(),
727            state,
728            stop_tokens,
729            last_content_delta: None,
730            undecoded_tokens: Vec::new(),
731        })
732    }
733
734    fn process_next(&mut self, token: Option<Rank>) -> anyhow::Result<&mut Self> {
735        if let Some(token) = token {
736            self.tokens.push(token);
737        }
738        let next_role_clone = self.next_role.clone();
739        match &mut self.state {
740            StreamState::ExpectStart => {
741                let start = self
742                    .encoding
743                    .render_formatting_token(FormattingToken::Start)?;
744                match token {
745                    Some(token) if token == start => {
746                        self.state = StreamState::Header {
747                            header_tokens: Vec::new(),
748                        };
749                    }
750                    Some(token) => {
751                        anyhow::bail!(
752                            "Unexpected token {} while expecting start token {}",
753                            token,
754                            start
755                        );
756                    }
757                    None => {
758                        // EOS while waiting for start token is fine
759                    }
760                }
761            }
762            StreamState::Header { header_tokens } => {
763                let msg_tok = self
764                    .encoding
765                    .render_formatting_token(FormattingToken::Message)?;
766                match token {
767                    Some(token) if token == msg_tok => {
768                        let header_tokens_cloned = header_tokens.clone();
769                        let next_role_cloned = next_role_clone;
770                        self.state = StreamState::ExpectStart;
771                        let header =
772                            self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?;
773                        self.next_role = None;
774                        self.state = StreamState::Content {
775                            header,
776                            content_tokens: Vec::new(),
777                        };
778                    }
779                    Some(token) => {
780                        header_tokens.push(token);
781                    }
782                    None => {
783                        anyhow::bail!(
784                            "Unexpected EOS while waiting for message header to complete"
785                        );
786                    }
787                }
788            }
789            StreamState::Content {
790                header,
791                content_tokens,
792            } => {
793                let is_eos = if let Some(token) = token {
794                    if self.stop_tokens.contains(&token) {
795                        true
796                    } else {
797                        self.undecoded_tokens.push(token);
798                        match self
799                            .encoding
800                            .tokenizer()
801                            .decode_utf8(&self.undecoded_tokens)
802                        {
803                            Ok(decoded) => {
804                                content_tokens.extend(self.undecoded_tokens.iter().copied());
805                                self.last_content_delta = Some(decoded);
806                                self.undecoded_tokens.clear();
807                            }
808                            Err(_) => {
809                                self.last_content_delta = None;
810                            }
811                        }
812                        false
813                    }
814                } else {
815                    true
816                };
817                if is_eos {
818                    let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
819                    let message = Message {
820                        author: header.author.clone(),
821                        recipient: header.recipient.clone(),
822                        channel: header.channel.clone(),
823                        content_type: header.content_type.clone(),
824                        content: vec![Content::Text(TextContent { text })],
825                    };
826                    self.messages.push(message);
827                    self.state = StreamState::ExpectStart;
828                    self.last_content_delta = None;
829                    self.undecoded_tokens.clear();
830                }
831            }
832        }
833        Ok(self)
834    }
835
836    pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> {
837        self.process_next(Some(token))
838    }
839
840    pub fn process_eos(&mut self) -> anyhow::Result<&mut Self> {
841        self.process_next(None)?;
842        Ok(self)
843    }
844
845    fn parse_header_from_tokens(
846        &self,
847        header_tokens: &[Rank],
848        role: Option<Role>,
849    ) -> anyhow::Result<ParsedHeader> {
850        let mut header_string = self
851            .encoding
852            .tokenizer()
853            .decode_utf8(header_tokens)
854            .context("could not decode header")?;
855
856        let mut channel: Option<String> = None;
857        if let Some(channel_marker) = self.encoding.mapped_format_token(FormattingToken::Channel) {
858            if let Some(idx) = header_string.find(channel_marker) {
859                let after_marker = &header_string[idx + channel_marker.len()..];
860                let channel_end = after_marker
861                    .find(|c: char| c.is_whitespace() || c == '<')
862                    .unwrap_or(after_marker.len());
863                let channel_value = &after_marker[..channel_end];
864                if channel_value.is_empty() {
865                    anyhow::bail!("channel marker present but no channel value found in header");
866                }
867                channel = Some(channel_value.to_string());
868
869                let mut new_header = String::new();
870                new_header.push_str(&header_string[..idx]);
871                new_header.push_str(&after_marker[channel_end..]);
872                header_string = new_header;
873            }
874        }
875
876        header_string = header_string.trim().to_string();
877
878        if let Some(constrain_marker) = self
879            .encoding
880            .mapped_format_token(FormattingToken::ConstrainedFormat)
881        {
882            if header_string.contains(constrain_marker) {
883                header_string = header_string
884                    .replace(constrain_marker, &format!(" {constrain_marker}"))
885                    .trim()
886                    .to_string();
887            }
888        }
889
890        let mut parts: Vec<&str> = header_string.split_ascii_whitespace().collect();
891
892        let mut role_str_opt: Option<String> = None;
893        let role = match role {
894            Some(r) => r,
895            None => {
896                let role_str = parts
897                    .first()
898                    .context("message header did not contain a role")?;
899                role_str_opt = Some((*role_str).to_string());
900                let parsed_role = Role::try_from(*role_str);
901                match parsed_role {
902                    Ok(r) => r,
903                    Err(_) => {
904                        if parts.len() > 1 || (parts.len() == 1 && parts[0].starts_with("to=")) {
905                            parts.remove(0);
906                            Role::Tool
907                        } else {
908                            return Err(anyhow::anyhow!("Unknown role: {}", role_str));
909                        }
910                    }
911                }
912            }
913        };
914
915        if let Some(&first) = parts.first() {
916            if first == role.as_str() {
917                parts.remove(0);
918            }
919        }
920
921        let mut recipient: Option<String> = None;
922        let mut content_type: Option<String> = None;
923
924        if !parts.is_empty() {
925            let num_parts = parts.len();
926            let last_part = parts.pop().unwrap();
927
928            if let Some(stripped) = last_part.strip_prefix("to=") {
929                recipient = Some(stripped.to_string());
930            } else if num_parts == 1 {
931                recipient = Some(last_part.to_string());
932            } else {
933                content_type = Some(last_part.to_string());
934                if let Some(raw_recipient) = parts.pop() {
935                    recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
936                        Some(stripped.to_string())
937                    } else {
938                        Some(raw_recipient.to_string())
939                    };
940                }
941            }
942        }
943        anyhow::ensure!(
944            parts.is_empty(),
945            "unexpected tokens remaining in message header: {:?}",
946            parts
947        );
948
949        let author = if role == Role::Tool {
950            let name = role_str_opt;
951            Author { role, name }
952        } else {
953            Author { role, name: None }
954        };
955        Ok(ParsedHeader {
956            author,
957            recipient,
958            channel,
959            content_type,
960        })
961    }
962
963    pub fn current_content(&self) -> anyhow::Result<String> {
964        match &self.state {
965            StreamState::Content { content_tokens, .. } => self
966                .encoding
967                .tokenizer()
968                .decode_utf8(content_tokens)
969                .map_err(|e| anyhow::anyhow!(e)),
970            _ => Ok(String::new()),
971        }
972    }
973
974    pub fn current_role(&self) -> Option<Role> {
975        match &self.state {
976            StreamState::Content { header, .. } => Some(header.author.role.clone()),
977            _ => self.next_role.clone(),
978        }
979    }
980
981    pub fn current_content_type(&self) -> Option<String> {
982        match &self.state {
983            StreamState::Content { header, .. } => header.content_type.clone(),
984            _ => None,
985        }
986    }
987
988    pub fn last_content_delta(&self) -> anyhow::Result<Option<String>> {
989        Ok(self.last_content_delta.clone())
990    }
991
992    pub fn into_messages(self) -> Vec<Message> {
993        self.messages
994    }
995
996    pub fn messages(&self) -> &[Message] {
997        &self.messages
998    }
999
1000    pub fn tokens(&self) -> &[Rank] {
1001        &self.tokens
1002    }
1003
1004    pub fn current_recipient(&self) -> Option<String> {
1005        match &self.state {
1006            StreamState::Content { header, .. } => header.recipient.clone(),
1007            _ => None,
1008        }
1009    }
1010
1011    pub fn current_channel(&self) -> Option<String> {
1012        match &self.state {
1013            StreamState::Content { header, .. } => header.channel.clone(),
1014            _ => None,
1015        }
1016    }
1017}
1018
1019#[derive(Clone, Debug)]
1020pub struct RenderConversationConfig {
1021    pub auto_drop_analysis: bool,
1022}
1023
1024impl Default for RenderConversationConfig {
1025    fn default() -> Self {
1026        Self {
1027            auto_drop_analysis: true,
1028        }
1029    }
1030}