1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use crate::{
7 encoding::{FormattingToken, HarmonyEncoding},
8 tiktoken_ext,
9};
10
11#[derive(Clone, Copy, PartialEq, Eq, Hash)]
12pub enum HarmonyEncodingName {
13 HarmonyGptOss,
14}
15
16impl std::fmt::Display for HarmonyEncodingName {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(
19 f,
20 "{}",
21 match self {
22 HarmonyEncodingName::HarmonyGptOss => "HarmonyGptOss",
23 }
24 )
25 }
26}
27
28impl std::str::FromStr for HarmonyEncodingName {
29 type Err = anyhow::Error;
30 fn from_str(s: &str) -> Result<Self, Self::Err> {
31 match s {
32 "HarmonyGptOss" => Ok(HarmonyEncodingName::HarmonyGptOss),
33 _ => anyhow::bail!("Invalid HarmonyEncodingName: {}", s),
34 }
35 }
36}
37
38impl std::fmt::Debug for HarmonyEncodingName {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{self}")
41 }
42}
43
44#[cfg(not(target_arch = "wasm32"))]
45pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<HarmonyEncoding> {
46 match name {
47 HarmonyEncodingName::HarmonyGptOss => {
48 let n_ctx = 1_048_576; let max_action_length = 524_288; let encoding_ext = tiktoken_ext::Encoding::O200kHarmony;
51 Ok(HarmonyEncoding {
52 name: name.to_string(),
53 n_ctx,
54 tokenizer: Arc::new(encoding_ext.load()?),
55 tokenizer_name: encoding_ext.name().to_owned(),
56 max_message_tokens: n_ctx - max_action_length,
57 max_action_length,
58 format_token_mapping: make_mapping([
59 (FormattingToken::Start, "<|start|>"),
60 (FormattingToken::Message, "<|message|>"),
61 (FormattingToken::EndMessage, "<|end|>"),
62 (FormattingToken::EndMessageDoneSampling, "<|return|>"),
63 (FormattingToken::Refusal, "<|refusal|>"),
64 (FormattingToken::ConstrainedFormat, "<|constrain|>"),
65 (FormattingToken::Channel, "<|channel|>"),
66 (FormattingToken::EndMessageAssistantToTool, "<|call|>"),
67 (FormattingToken::BeginUntrusted, "<|untrusted|>"),
68 (FormattingToken::EndUntrusted, "<|end_untrusted|>"),
69 ]),
70 stop_formatting_tokens: HashSet::from([
71 FormattingToken::EndMessageDoneSampling,
72 FormattingToken::EndMessageAssistantToTool,
73 FormattingToken::EndMessage,
74 ]),
75 stop_formatting_tokens_for_assistant_actions: HashSet::from([
76 FormattingToken::EndMessageDoneSampling,
77 FormattingToken::EndMessageAssistantToTool,
78 ]),
79 })
80 }
81 }
82}
83
84#[cfg(target_arch = "wasm32")]
85pub async fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<HarmonyEncoding> {
86 match name {
87 HarmonyEncodingName::HarmonyGptOss => {
88 let n_ctx = 1_048_576; let max_action_length = 524_288; let encoding_ext = tiktoken_ext::Encoding::O200kHarmony;
91 Ok(HarmonyEncoding {
92 name: name.to_string(),
93 n_ctx,
94 tokenizer: Arc::new(encoding_ext.load().await?),
95 tokenizer_name: encoding_ext.name().to_owned(),
96 max_message_tokens: n_ctx - max_action_length,
97 max_action_length,
98 format_token_mapping: make_mapping([
99 (FormattingToken::Start, "<|start|>"),
100 (FormattingToken::Message, "<|message|>"),
101 (FormattingToken::EndMessage, "<|end|>"),
102 (FormattingToken::EndMessageDoneSampling, "<|return|>"),
103 (FormattingToken::Refusal, "<|refusal|>"),
104 (FormattingToken::ConstrainedFormat, "<|constrain|>"),
105 (FormattingToken::Channel, "<|channel|>"),
106 (FormattingToken::EndMessageAssistantToTool, "<|call|>"),
107 (FormattingToken::BeginUntrusted, "<|untrusted|>"),
108 (FormattingToken::EndUntrusted, "<|end_untrusted|>"),
109 ]),
110 stop_formatting_tokens: HashSet::from([
111 FormattingToken::EndMessageDoneSampling,
112 FormattingToken::EndMessageAssistantToTool,
113 FormattingToken::EndMessage,
114 ]),
115 stop_formatting_tokens_for_assistant_actions: HashSet::from([
116 FormattingToken::EndMessageDoneSampling,
117 FormattingToken::EndMessageAssistantToTool,
118 ]),
119 })
120 }
121 }
122}
123
124fn make_mapping<I>(iter: I) -> HashMap<FormattingToken, String>
125where
126 I: IntoIterator<Item = (FormattingToken, &'static str)>,
127{
128 iter.into_iter().map(|(k, v)| (k, v.to_string())).collect()
129}