1use crate::{
2 chat::{Author, Content, Message, ReasoningEffort, Role, SystemContent, TextContent},
3 tiktoken::{CoreBPE, Rank},
4};
5use anyhow::Context as _;
6use std::{
7 collections::{HashMap, HashSet},
8 sync::Arc,
9};
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
12pub struct ParsedHeader {
13 author: Author,
14 recipient: Option<String>,
15 channel: Option<String>,
16 content_type: Option<String>,
17}
18
19#[derive(thiserror::Error, Debug)]
20pub(crate) enum RenderFormattingTokenError {
21 #[error("tried to render unmapped formatting token {0}")]
22 UnmappedToken(FormattingToken),
23
24 #[error(
25 "Expected encoding of formatting token {token} to be a single token, but got {encoding:?}"
26 )]
27 InvalidEncoding {
28 token: FormattingToken,
29 encoding: Vec<Rank>,
30 },
31}
32
33#[allow(dead_code)]
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
35pub(crate) enum FormattingToken {
36 Start,
37 Message,
38 EndMessage,
39 EndMessageDoneSampling,
40 EndMessageAssistantToTool,
41 Refusal,
42 ConstrainedFormat,
43 Channel,
44 BeginUntrusted,
45 EndUntrusted,
46 MetaSep,
47 MetaEnd,
48}
49
50impl FormattingToken {
51 fn as_str(&self) -> &str {
52 match self {
53 FormattingToken::Start => "<|start|>",
54 FormattingToken::Message => "<|message|>",
55 FormattingToken::EndMessage => "<|end|>",
56 FormattingToken::EndMessageDoneSampling => "<|return|>",
57 FormattingToken::EndMessageAssistantToTool => "<|call|>",
58 FormattingToken::Refusal => "<|refusal|>",
59 FormattingToken::ConstrainedFormat => "<|constrain|>",
60 FormattingToken::Channel => "<|channel|>",
61 FormattingToken::BeginUntrusted => "<|untrusted|>",
62 FormattingToken::EndUntrusted => "<|end_untrusted|>",
63 FormattingToken::MetaSep => "<|channel|>",
64 FormattingToken::MetaEnd => "<|meta_end|>",
65 }
66 }
67}
68
69impl std::fmt::Display for FormattingToken {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{}", self.as_str())
72 }
73}
74
75#[allow(dead_code)]
76#[derive(Clone)]
77pub struct HarmonyEncoding {
78 pub(crate) name: String,
79 pub(crate) n_ctx: usize,
80 pub(crate) max_message_tokens: usize,
81 pub(crate) max_action_length: usize,
82 pub(crate) tokenizer_name: String,
83 pub(crate) tokenizer: Arc<CoreBPE>,
84 pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
85 pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
86 pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
87}
88
89impl std::fmt::Debug for HarmonyEncoding {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 f.debug_struct("HarmonyEncoding")
92 .field("name", &self.name)
93 .field("tokenizer_name", &self.tokenizer_name)
94 .field("n_ctx", &self.n_ctx)
95 .field("max_message_tokens", &self.max_message_tokens)
96 .field("max_action_length", &self.max_action_length)
97 .finish()
98 }
99}
100
101impl std::fmt::Display for HarmonyEncoding {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 write!(f, "Renderer({})", self.name)
104 }
105}
106
107impl HarmonyEncoding {
108 pub fn name(&self) -> &str {
109 &self.name
110 }
111
112 pub fn tokenizer_name(&self) -> &str {
113 &self.tokenizer_name
114 }
115
116 pub fn max_message_tokens(&self) -> usize {
117 self.max_message_tokens
118 }
119
120 pub fn tokenizer(&self) -> &CoreBPE {
121 &self.tokenizer
122 }
123
124 pub fn stop_tokens(&self) -> anyhow::Result<HashSet<Rank>> {
125 self.stop_formatting_tokens
126 .iter()
127 .copied()
128 .map(|t| match self.render_formatting_token(t) {
129 Ok(t) => Ok(t),
130 Err(RenderFormattingTokenError::UnmappedToken(_)) => Err(anyhow::anyhow!(
131 "token {t} was specified as a stop token, but is not mapped"
132 )),
133 Err(e) => Err(anyhow::anyhow!(e).context("could not render stop token")),
134 })
135 .collect()
136 }
137
138 pub fn stop_tokens_for_assistant_actions(&self) -> anyhow::Result<HashSet<Rank>> {
139 self.stop_formatting_tokens_for_assistant_actions
140 .iter()
141 .copied()
142 .map(|t| match self.render_formatting_token(t) {
143 Ok(t) => Ok(t),
144 Err(RenderFormattingTokenError::UnmappedToken(_)) => Err(anyhow::anyhow!(
145 "token {t} was specified as a stop token, but is not mapped"
146 )),
147 Err(e) => Err(anyhow::anyhow!(e).context("could not render stop token")),
148 })
149 .collect()
150 }
151
152 pub fn render_conversation_into<'a, I, B>(
153 &self,
154 conversation: I,
155 into: &mut B,
156 config: Option<&RenderConversationConfig>,
157 ) -> anyhow::Result<()>
158 where
159 I: IntoIterator<Item = &'a Message>,
160 B: Extend<Rank>,
161 {
162 let messages: Vec<_> = conversation.into_iter().collect();
163 let has_function_tools = messages.iter().any(|msg| {
164 msg.content.iter().any(|c| {
165 if let Content::DeveloperContent(dev) = c {
166 if let Some(tools) = &dev.tools {
167 if let Some(ns) = tools.get("functions") {
168 !ns.tools.is_empty()
169 } else {
170 false
171 }
172 } else {
173 false
174 }
175 } else {
176 false
177 }
178 })
179 });
180 let render_options = RenderOptions {
181 conversation_has_function_tools: has_function_tools,
182 };
183 let last_assistant_is_final = messages
184 .iter()
185 .rev()
186 .find_map(|msg| {
187 (msg.author.role == Role::Assistant)
188 .then(|| msg.channel.as_deref() == Some("final"))
189 })
190 .unwrap_or(false);
191
192 let should_drop_analysis =
193 config.is_some_and(|c| c.auto_drop_analysis && last_assistant_is_final);
194
195 let first_final_idx = messages
196 .iter()
197 .position(|msg| msg.channel.as_deref() == Some("final"));
198
199 let result = messages
200 .iter()
201 .enumerate()
202 .filter(|(idx, msg)| {
203 !(should_drop_analysis
204 && first_final_idx.is_some_and(|first| *idx < first)
205 && msg.channel.as_deref() == Some("analysis"))
206 })
207 .try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
208 result?;
209 Ok(())
210 }
211
212 pub fn render_conversation_for_completion_into<'a, I, B>(
213 &self,
214 conversation: I,
215 next_turn_role: Role,
216 into: &mut B,
217 config: Option<&RenderConversationConfig>,
218 ) -> anyhow::Result<()>
219 where
220 I: IntoIterator<Item = &'a Message>,
221 B: Extend<Rank>,
222 {
223 let _config = config.unwrap_or(&RenderConversationConfig::default());
224 self.render_conversation_into(conversation, into, config)?;
225 self.render_formatting_token_into(FormattingToken::Start, into)?;
226 self.render_text_into(next_turn_role.as_str(), into)?;
227 Ok(())
228 }
229
230 pub fn render_conversation_for_completion<'a, I>(
231 &self,
232 conversation: I,
233 next_turn_role: Role,
234 config: Option<&RenderConversationConfig>,
235 ) -> anyhow::Result<Vec<Rank>>
236 where
237 I: IntoIterator<Item = &'a Message>,
238 {
239 let mut into = vec![];
240 self.render_conversation_for_completion_into(
241 conversation,
242 next_turn_role,
243 &mut into,
244 config,
245 )?;
246 Ok(into)
247 }
248
249 pub fn render_conversation_for_training<'a, I>(
250 &self,
251 conversation: I,
252 config: Option<&RenderConversationConfig>,
253 ) -> anyhow::Result<Vec<Rank>>
254 where
255 I: IntoIterator<Item = &'a Message>,
256 {
257 let messages: Vec<&Message> = conversation.into_iter().collect();
258 let mut out = vec![];
259 self.render_conversation_into(messages.iter().copied(), &mut out, config)?;
260 if let Some(last) = messages.last() {
261 if last.author.role == Role::Assistant && last.channel.as_deref() == Some("final") {
262 if let Some(last_token) = out.last_mut() {
263 *last_token =
264 self.render_formatting_token(FormattingToken::EndMessageDoneSampling)?;
265 }
266 }
267 }
268 Ok(out)
269 }
270
271 pub fn render_conversation<'a, I>(
272 &self,
273 conversation: I,
274 config: Option<&RenderConversationConfig>,
275 ) -> anyhow::Result<Vec<Rank>>
276 where
277 I: IntoIterator<Item = &'a Message>,
278 {
279 let mut out = vec![];
280 self.render_conversation_into(conversation, &mut out, config)?;
281 Ok(out)
282 }
283
284 pub fn render(
285 &self,
286 message: &Message,
287 render_options: Option<&RenderOptions>,
288 ) -> anyhow::Result<Vec<Rank>> {
289 let mut out = vec![];
290 Render::<Message>::render(self, message, &mut out, render_options)?;
291 Ok(out)
292 }
293
294 pub fn render_into<B>(
295 &self,
296 message: &Message,
297 into: &mut B,
298 render_options: Option<&RenderOptions>,
299 ) -> anyhow::Result<()>
300 where
301 B: Extend<Rank>,
302 {
303 Render::<Message>::render(self, message, into, render_options)
304 }
305
306 fn mapped_format_token(&self, t: FormattingToken) -> Option<&str> {
307 self.format_token_mapping.get(&t).map(|s| s.as_str())
308 }
309
310 fn render_formatting_token(
311 &self,
312 t: FormattingToken,
313 ) -> Result<Rank, RenderFormattingTokenError> {
314 let mapped = self
315 .mapped_format_token(t)
316 .ok_or(RenderFormattingTokenError::UnmappedToken(t))?;
317 let encoded = self.tokenizer.encode_with_special_tokens(mapped);
318 if encoded.len() != 1 {
319 return Err(RenderFormattingTokenError::InvalidEncoding {
320 token: t,
321 encoding: encoded,
322 });
323 }
324 Ok(encoded[0])
325 }
326
327 fn render_formatting_token_into<B>(
328 &self,
329 t: FormattingToken,
330 into: &mut B,
331 ) -> anyhow::Result<()>
332 where
333 B: Extend<Rank>,
334 {
335 let r = self.render_formatting_token(t)?;
336 into.extend(std::iter::once(r));
337 Ok(())
338 }
339
340 fn render_text_into<T, B>(&self, text: T, into: &mut B) -> anyhow::Result<()>
341 where
342 T: AsRef<str>,
343 B: Extend<Rank>,
344 {
345 into.extend(self.tokenizer.encode_ordinary(text.as_ref()));
346 Ok(())
347 }
348
349 pub fn parse_messages_from_completion_tokens<I>(
350 &self,
351 tokens: I,
352 role: Option<Role>,
353 ) -> anyhow::Result<Vec<Message>>
354 where
355 I: IntoIterator<Item = Rank>,
356 {
357 let mut parser = StreamableParser::new(self.clone(), role)?;
358 for token in tokens {
359 parser.process(token)?;
360 }
361 parser.process_eos()?;
362 Ok(parser.into_messages())
363 }
364
365 fn template_tools_section(
366 tools: &std::collections::BTreeMap<String, crate::chat::ToolNamespaceConfig>,
367 ) -> String {
368 let mut tool_sections = Vec::<String>::new();
369 tool_sections.push("# Tools".to_string());
370 for ns_config in tools.values() {
371 let mut tool_section_content = Vec::<String>::new();
372 tool_section_content.push(format!("## {}\n", ns_config.name));
373 if let Some(desc) = &ns_config.description {
374 for line in desc.lines() {
375 if !ns_config.tools.is_empty() {
376 tool_section_content.push(format!("// {line}"));
377 } else {
378 tool_section_content.push(line.to_string());
379 }
380 }
381 }
382 if !ns_config.tools.is_empty() {
383 tool_section_content.push(format!("namespace {} {{\n", ns_config.name));
384 for tool in &ns_config.tools {
385 for line in tool.description.lines() {
386 tool_section_content.push(format!("// {line}"));
387 }
388 if let Some(params) = &tool.parameters {
389 let param_type = Self::json_schema_to_typescript(params, "");
390 tool_section_content.push(format!(
391 "type {} = (_: {}) => any;\n",
392 tool.name, param_type
393 ));
394 } else {
395 tool_section_content.push(format!("type {} = () => any;\n", tool.name));
396 }
397 }
398 tool_section_content.push(format!("}} // namespace {}", ns_config.name));
399 }
400 tool_sections.push(tool_section_content.join("\n"));
401 }
402 tool_sections.join("\n\n")
403 }
404
405 fn json_schema_to_typescript(schema: &serde_json::Value, indent: &str) -> String {
406 match schema.get("type").and_then(|v| v.as_str()) {
408 Some("object") => {
409 let mut out = String::new();
410 out.push_str("{\n");
411 if let Some(props) = schema.get("properties") {
412 if let Some(props_map) = props.as_object() {
413 let mut required = std::collections::HashSet::new();
414 if let Some(req) = schema.get("required") {
415 if let Some(req_arr) = req.as_array() {
416 for r in req_arr {
417 if let Some(s) = r.as_str() {
418 required.insert(s);
419 }
420 }
421 }
422 }
423 for (key, val) in props_map {
424 out.push_str(&format!(
425 "{}{}{}: ",
426 indent,
427 key,
428 if required.contains(key.as_str()) {
429 ""
430 } else {
431 "?"
432 }
433 ));
434 let type_str = Self::json_schema_to_typescript(val, &format!("{indent} "));
435 out.push_str(&type_str);
436 out.push_str(",\n");
437 }
438 }
439 }
440 out.push_str(&format!("{indent}}}"));
441 out
442 }
443 Some("string") => "string".to_string(),
444 Some("number") | Some("integer") => "number".to_string(),
445 Some("boolean") => "boolean".to_string(),
446 Some("array") => {
447 if let Some(items) = schema.get("items") {
448 format!("{}[]", Self::json_schema_to_typescript(items, indent))
449 } else {
450 "Array<any>".to_string()
451 }
452 }
453 _ => "any".to_string(),
454 }
455 }
456}
457
458#[derive(Clone, Copy, Debug, Default)]
459pub struct RenderOptions {
460 pub conversation_has_function_tools: bool,
461}
462
463trait Render<T: ?Sized> {
464 fn render<B>(
465 &self,
466 item: &T,
467 into: &mut B,
468 render_options: Option<&RenderOptions>,
469 ) -> anyhow::Result<()>
470 where
471 B: Extend<Rank>;
472}
473
474impl Render<Message> for HarmonyEncoding {
475 fn render<B>(
476 &self,
477 message: &Message,
478 into: &mut B,
479 render_options: Option<&RenderOptions>,
480 ) -> anyhow::Result<()>
481 where
482 B: Extend<Rank>,
483 {
484 self.render_formatting_token_into(FormattingToken::Start, into)?;
485
486 if matches!(message.author.role, Role::Tool) {
487 if let Some(name) = &message.author.name {
488 self.render_text_into(name, into)?;
489 } else {
490 anyhow::bail!("Tools should have a name!");
491 }
492 } else {
493 self.render_text_into(message.author.role.as_str(), into)?;
494 if let Some(name) = &message.author.name {
495 self.render_text_into(format!(":{name}"), into)?;
496 }
497 };
498
499 if let Some(recipient) = &message.recipient {
500 if recipient != "all" {
501 self.render_text_into(format!(" to={recipient}"), into)?;
502 }
503 }
504
505 if let Some(channel) = &message.channel {
506 self.render_formatting_token_into(FormattingToken::Channel, into)?;
507 self.render_text_into(channel, into)?;
508 }
509
510 if let Some(content_type) = &message.content_type {
511 if let Some(constrain_marker) =
512 self.mapped_format_token(FormattingToken::ConstrainedFormat)
513 {
514 if let Some(rest) = content_type.strip_prefix(constrain_marker) {
515 self.render_text_into(" ", into)?;
516 self.render_formatting_token_into(FormattingToken::ConstrainedFormat, into)?;
517 if !rest.is_empty() {
518 self.render_text_into(rest, into)?;
519 }
520 } else {
521 self.render_text_into(format!(" {content_type}"), into)?;
522 }
523 } else {
524 self.render_text_into(format!(" {content_type}"), into)?;
525 }
526 }
527
528 self.render_formatting_token_into(FormattingToken::Message, into)?;
529 for content in message.content.iter() {
530 if let crate::chat::Content::SystemContent(_) = content {
531 anyhow::ensure!(
532 message.author.role == crate::chat::Role::System,
533 "SystemContent may only appear in system messages, found in {:?}",
534 message.author.role
535 );
536 }
537 if let crate::chat::Content::DeveloperContent(_) = content {
538 anyhow::ensure!(
539 message.author.role == crate::chat::Role::Developer,
540 "DeveloperContent may only appear in developer messages, found in {:?}",
541 message.author.role
542 );
543 }
544 Render::<Content>::render(self, content, into, render_options)?;
545 }
546
547 if message.author.role == crate::chat::Role::Assistant && message.recipient.is_some() {
548 self.render_formatting_token_into(FormattingToken::EndMessageAssistantToTool, into)?;
549 } else {
550 self.render_formatting_token_into(FormattingToken::EndMessage, into)?;
551 }
552 Ok(())
553 }
554}
555
556impl Render<Content> for HarmonyEncoding {
557 fn render<B>(
558 &self,
559 content: &Content,
560 into: &mut B,
561 render_options: Option<&RenderOptions>,
562 ) -> anyhow::Result<()>
563 where
564 B: Extend<Rank>,
565 {
566 match content {
567 Content::Text(text) => Render::<TextContent>::render(self, text, into, render_options),
568 Content::SystemContent(sys) => {
569 Render::<SystemContent>::render(self, sys, into, render_options)
570 }
571 Content::DeveloperContent(dev) => {
572 Render::<crate::chat::DeveloperContent>::render(self, dev, into, render_options)
573 }
574 }
575 }
576}
577
578impl Render<TextContent> for HarmonyEncoding {
579 fn render<B>(
580 &self,
581 text: &TextContent,
582 into: &mut B,
583 _render_options: Option<&RenderOptions>,
584 ) -> anyhow::Result<()>
585 where
586 B: Extend<Rank>,
587 {
588 self.render_text_into(&text.text, into)
589 }
590}
591
592impl Render<SystemContent> for HarmonyEncoding {
593 fn render<B>(
594 &self,
595 sys: &SystemContent,
596 into: &mut B,
597 render_options: Option<&RenderOptions>,
598 ) -> anyhow::Result<()>
599 where
600 B: Extend<Rank>,
601 {
602 let mut sections = Vec::<String>::new();
603
604 let mut top_section = Vec::<String>::new();
605 if let Some(model_id) = &sys.model_identity {
606 top_section.push(model_id.clone());
607 }
608 if let Some(knowledge_cutoff) = &sys.knowledge_cutoff {
609 top_section.push(format!("Knowledge cutoff: {knowledge_cutoff}"));
610 }
611 if let Some(conversation_start_date) = &sys.conversation_start_date {
612 top_section.push(format!("Current date: {conversation_start_date}"));
613 }
614 if !top_section.is_empty() {
615 sections.push(top_section.join("\n"));
616 }
617
618 let mut instructions_and_reasoning = Vec::<String>::new();
619 if let Some(effort) = sys.reasoning_effort {
620 let effort_str = match effort {
621 ReasoningEffort::Low => "low",
622 ReasoningEffort::Medium => "medium",
623 ReasoningEffort::High => "high",
624 };
625 instructions_and_reasoning.push(format!("Reasoning: {effort_str}"));
626 }
627 if !instructions_and_reasoning.is_empty() {
628 sections.push(instructions_and_reasoning.join("\n"));
629 }
630
631 if let Some(tools) = &sys.tools {
632 if !tools.is_empty() {
633 sections.push(Self::template_tools_section(tools));
634 }
635 }
636
637 if let Some(channel_config) = &sys.channel_config {
638 if !channel_config.valid_channels.is_empty() {
639 let channels_str = channel_config.valid_channels.join(", ");
640 let mut channels_header = format!("# Valid channels: {channels_str}.");
641 if channel_config.channel_required {
642 channels_header.push_str(" Channel must be included for every message.");
643 }
644 if render_options.is_some_and(|o| o.conversation_has_function_tools) {
645 channels_header.push('\n');
646 channels_header.push_str(
647 "Calls to these tools must go to the commentary channel: 'functions'.",
648 );
649 }
650 sections.push(channels_header);
651 }
652 }
653 let formatted = sections.join("\n\n");
654 self.render_text_into(&formatted, into)?;
655 Ok(())
656 }
657}
658
659impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
660 fn render<B>(
661 &self,
662 dev: &crate::chat::DeveloperContent,
663 into: &mut B,
664 _render_options: Option<&RenderOptions>,
665 ) -> anyhow::Result<()>
666 where
667 B: Extend<Rank>,
668 {
669 let mut sections = Vec::<String>::new();
670
671 if let Some(instr) = &dev.instructions {
672 sections.push("# Instructions".to_string());
673 sections.push(instr.clone());
674 }
675
676 if let Some(tools) = &dev.tools {
677 if !tools.is_empty() {
678 sections.push(Self::template_tools_section(tools));
679 }
680 }
681 let formatted = sections.join("\n\n");
682 self.render_text_into(&formatted, into)?;
683 Ok(())
684 }
685}
686
687pub struct StreamableParser {
688 encoding: HarmonyEncoding,
689 next_role: Option<Role>,
690 tokens: Vec<Rank>,
691 messages: Vec<Message>,
692 state: StreamState,
693 stop_tokens: HashSet<Rank>,
694 last_content_delta: Option<String>,
695 undecoded_tokens: Vec<Rank>,
696}
697
698#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
699pub enum StreamState {
700 ExpectStart,
701 Header {
702 header_tokens: Vec<Rank>,
703 },
704 Content {
705 header: ParsedHeader,
706 content_tokens: Vec<Rank>,
707 },
708}
709
710impl StreamableParser {
711 pub fn new(encoding: HarmonyEncoding, role: Option<Role>) -> anyhow::Result<Self> {
712 let stop_tokens = encoding.stop_tokens()?;
713 let (state, next_role) = match role {
714 Some(role) => (
715 StreamState::Header {
716 header_tokens: Vec::new(),
717 },
718 Some(role),
719 ),
720 None => (StreamState::ExpectStart, None),
721 };
722 Ok(Self {
723 encoding,
724 next_role,
725 tokens: Vec::new(),
726 messages: Vec::new(),
727 state,
728 stop_tokens,
729 last_content_delta: None,
730 undecoded_tokens: Vec::new(),
731 })
732 }
733
734 fn process_next(&mut self, token: Option<Rank>) -> anyhow::Result<&mut Self> {
735 if let Some(token) = token {
736 self.tokens.push(token);
737 }
738 let next_role_clone = self.next_role.clone();
739 match &mut self.state {
740 StreamState::ExpectStart => {
741 let start = self
742 .encoding
743 .render_formatting_token(FormattingToken::Start)?;
744 match token {
745 Some(token) if token == start => {
746 self.state = StreamState::Header {
747 header_tokens: Vec::new(),
748 };
749 }
750 Some(token) => {
751 anyhow::bail!(
752 "Unexpected token {} while expecting start token {}",
753 token,
754 start
755 );
756 }
757 None => {
758 }
760 }
761 }
762 StreamState::Header { header_tokens } => {
763 let msg_tok = self
764 .encoding
765 .render_formatting_token(FormattingToken::Message)?;
766 match token {
767 Some(token) if token == msg_tok => {
768 let header_tokens_cloned = header_tokens.clone();
769 let next_role_cloned = next_role_clone;
770 self.state = StreamState::ExpectStart;
771 let header =
772 self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?;
773 self.next_role = None;
774 self.state = StreamState::Content {
775 header,
776 content_tokens: Vec::new(),
777 };
778 }
779 Some(token) => {
780 header_tokens.push(token);
781 }
782 None => {
783 anyhow::bail!(
784 "Unexpected EOS while waiting for message header to complete"
785 );
786 }
787 }
788 }
789 StreamState::Content {
790 header,
791 content_tokens,
792 } => {
793 let is_eos = if let Some(token) = token {
794 if self.stop_tokens.contains(&token) {
795 true
796 } else {
797 self.undecoded_tokens.push(token);
798 match self
799 .encoding
800 .tokenizer()
801 .decode_utf8(&self.undecoded_tokens)
802 {
803 Ok(decoded) => {
804 content_tokens.extend(self.undecoded_tokens.iter().copied());
805 self.last_content_delta = Some(decoded);
806 self.undecoded_tokens.clear();
807 }
808 Err(_) => {
809 self.last_content_delta = None;
810 }
811 }
812 false
813 }
814 } else {
815 true
816 };
817 if is_eos {
818 let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
819 let message = Message {
820 author: header.author.clone(),
821 recipient: header.recipient.clone(),
822 channel: header.channel.clone(),
823 content_type: header.content_type.clone(),
824 content: vec![Content::Text(TextContent { text })],
825 };
826 self.messages.push(message);
827 self.state = StreamState::ExpectStart;
828 self.last_content_delta = None;
829 self.undecoded_tokens.clear();
830 }
831 }
832 }
833 Ok(self)
834 }
835
836 pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> {
837 self.process_next(Some(token))
838 }
839
840 pub fn process_eos(&mut self) -> anyhow::Result<&mut Self> {
841 self.process_next(None)?;
842 Ok(self)
843 }
844
845 fn parse_header_from_tokens(
846 &self,
847 header_tokens: &[Rank],
848 role: Option<Role>,
849 ) -> anyhow::Result<ParsedHeader> {
850 let mut header_string = self
851 .encoding
852 .tokenizer()
853 .decode_utf8(header_tokens)
854 .context("could not decode header")?;
855
856 let mut channel: Option<String> = None;
857 if let Some(channel_marker) = self.encoding.mapped_format_token(FormattingToken::Channel) {
858 if let Some(idx) = header_string.find(channel_marker) {
859 let after_marker = &header_string[idx + channel_marker.len()..];
860 let channel_end = after_marker
861 .find(|c: char| c.is_whitespace() || c == '<')
862 .unwrap_or(after_marker.len());
863 let channel_value = &after_marker[..channel_end];
864 if channel_value.is_empty() {
865 anyhow::bail!("channel marker present but no channel value found in header");
866 }
867 channel = Some(channel_value.to_string());
868
869 let mut new_header = String::new();
870 new_header.push_str(&header_string[..idx]);
871 new_header.push_str(&after_marker[channel_end..]);
872 header_string = new_header;
873 }
874 }
875
876 header_string = header_string.trim().to_string();
877
878 if let Some(constrain_marker) = self
879 .encoding
880 .mapped_format_token(FormattingToken::ConstrainedFormat)
881 {
882 if header_string.contains(constrain_marker) {
883 header_string = header_string
884 .replace(constrain_marker, &format!(" {constrain_marker}"))
885 .trim()
886 .to_string();
887 }
888 }
889
890 let mut parts: Vec<&str> = header_string.split_ascii_whitespace().collect();
891
892 let mut role_str_opt: Option<String> = None;
893 let role = match role {
894 Some(r) => r,
895 None => {
896 let role_str = parts
897 .first()
898 .context("message header did not contain a role")?;
899 role_str_opt = Some((*role_str).to_string());
900 let parsed_role = Role::try_from(*role_str);
901 match parsed_role {
902 Ok(r) => r,
903 Err(_) => {
904 if parts.len() > 1 || (parts.len() == 1 && parts[0].starts_with("to=")) {
905 parts.remove(0);
906 Role::Tool
907 } else {
908 return Err(anyhow::anyhow!("Unknown role: {}", role_str));
909 }
910 }
911 }
912 }
913 };
914
915 if let Some(&first) = parts.first() {
916 if first == role.as_str() {
917 parts.remove(0);
918 }
919 }
920
921 let mut recipient: Option<String> = None;
922 let mut content_type: Option<String> = None;
923
924 if !parts.is_empty() {
925 let num_parts = parts.len();
926 let last_part = parts.pop().unwrap();
927
928 if let Some(stripped) = last_part.strip_prefix("to=") {
929 recipient = Some(stripped.to_string());
930 } else if num_parts == 1 {
931 recipient = Some(last_part.to_string());
932 } else {
933 content_type = Some(last_part.to_string());
934 if let Some(raw_recipient) = parts.pop() {
935 recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
936 Some(stripped.to_string())
937 } else {
938 Some(raw_recipient.to_string())
939 };
940 }
941 }
942 }
943 anyhow::ensure!(
944 parts.is_empty(),
945 "unexpected tokens remaining in message header: {:?}",
946 parts
947 );
948
949 let author = if role == Role::Tool {
950 let name = role_str_opt;
951 Author { role, name }
952 } else {
953 Author { role, name: None }
954 };
955 Ok(ParsedHeader {
956 author,
957 recipient,
958 channel,
959 content_type,
960 })
961 }
962
963 pub fn current_content(&self) -> anyhow::Result<String> {
964 match &self.state {
965 StreamState::Content { content_tokens, .. } => self
966 .encoding
967 .tokenizer()
968 .decode_utf8(content_tokens)
969 .map_err(|e| anyhow::anyhow!(e)),
970 _ => Ok(String::new()),
971 }
972 }
973
974 pub fn current_role(&self) -> Option<Role> {
975 match &self.state {
976 StreamState::Content { header, .. } => Some(header.author.role.clone()),
977 _ => self.next_role.clone(),
978 }
979 }
980
981 pub fn current_content_type(&self) -> Option<String> {
982 match &self.state {
983 StreamState::Content { header, .. } => header.content_type.clone(),
984 _ => None,
985 }
986 }
987
988 pub fn last_content_delta(&self) -> anyhow::Result<Option<String>> {
989 Ok(self.last_content_delta.clone())
990 }
991
992 pub fn into_messages(self) -> Vec<Message> {
993 self.messages
994 }
995
996 pub fn messages(&self) -> &[Message] {
997 &self.messages
998 }
999
1000 pub fn tokens(&self) -> &[Rank] {
1001 &self.tokens
1002 }
1003
1004 pub fn current_recipient(&self) -> Option<String> {
1005 match &self.state {
1006 StreamState::Content { header, .. } => header.recipient.clone(),
1007 _ => None,
1008 }
1009 }
1010
1011 pub fn current_channel(&self) -> Option<String> {
1012 match &self.state {
1013 StreamState::Content { header, .. } => header.channel.clone(),
1014 _ => None,
1015 }
1016 }
1017}
1018
1019#[derive(Clone, Debug)]
1020pub struct RenderConversationConfig {
1021 pub auto_drop_analysis: bool,
1022}
1023
1024impl Default for RenderConversationConfig {
1025 fn default() -> Self {
1026 Self {
1027 auto_drop_analysis: true,
1028 }
1029 }
1030}