harmony_protocol/
registry.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use crate::{
7    encoding::{FormattingToken, HarmonyEncoding},
8    tiktoken_ext,
9};
10
11#[derive(Clone, Copy, PartialEq, Eq, Hash)]
12pub enum HarmonyEncodingName {
13    HarmonyGptOss,
14}
15
16impl std::fmt::Display for HarmonyEncodingName {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(
19            f,
20            "{}",
21            match self {
22                HarmonyEncodingName::HarmonyGptOss => "HarmonyGptOss",
23            }
24        )
25    }
26}
27
28impl std::str::FromStr for HarmonyEncodingName {
29    type Err = anyhow::Error;
30    fn from_str(s: &str) -> Result<Self, Self::Err> {
31        match s {
32            "HarmonyGptOss" => Ok(HarmonyEncodingName::HarmonyGptOss),
33            _ => anyhow::bail!("Invalid HarmonyEncodingName: {}", s),
34        }
35    }
36}
37
38impl std::fmt::Debug for HarmonyEncodingName {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{self}")
41    }
42}
43
44#[cfg(not(target_arch = "wasm32"))]
45pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<HarmonyEncoding> {
46    match name {
47        HarmonyEncodingName::HarmonyGptOss => {
48            let n_ctx = 1_048_576; // 2^20
49            let max_action_length = 524_288; // 2^19
50            let encoding_ext = tiktoken_ext::Encoding::O200kHarmony;
51            Ok(HarmonyEncoding {
52                name: name.to_string(),
53                n_ctx,
54                tokenizer: Arc::new(encoding_ext.load()?),
55                tokenizer_name: encoding_ext.name().to_owned(),
56                max_message_tokens: n_ctx - max_action_length,
57                max_action_length,
58                format_token_mapping: make_mapping([
59                    (FormattingToken::Start, "<|start|>"),
60                    (FormattingToken::Message, "<|message|>"),
61                    (FormattingToken::EndMessage, "<|end|>"),
62                    (FormattingToken::EndMessageDoneSampling, "<|return|>"),
63                    (FormattingToken::Refusal, "<|refusal|>"),
64                    (FormattingToken::ConstrainedFormat, "<|constrain|>"),
65                    (FormattingToken::Channel, "<|channel|>"),
66                    (FormattingToken::EndMessageAssistantToTool, "<|call|>"),
67                    (FormattingToken::BeginUntrusted, "<|untrusted|>"),
68                    (FormattingToken::EndUntrusted, "<|end_untrusted|>"),
69                ]),
70                stop_formatting_tokens: HashSet::from([
71                    FormattingToken::EndMessageDoneSampling,
72                    FormattingToken::EndMessageAssistantToTool,
73                    FormattingToken::EndMessage,
74                ]),
75                stop_formatting_tokens_for_assistant_actions: HashSet::from([
76                    FormattingToken::EndMessageDoneSampling,
77                    FormattingToken::EndMessageAssistantToTool,
78                ]),
79            })
80        }
81    }
82}
83
84#[cfg(target_arch = "wasm32")]
85pub async fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<HarmonyEncoding> {
86    match name {
87        HarmonyEncodingName::HarmonyGptOss => {
88            let n_ctx = 1_048_576; // 2^20
89            let max_action_length = 524_288; // 2^19
90            let encoding_ext = tiktoken_ext::Encoding::O200kHarmony;
91            Ok(HarmonyEncoding {
92                name: name.to_string(),
93                n_ctx,
94                tokenizer: Arc::new(encoding_ext.load().await?),
95                tokenizer_name: encoding_ext.name().to_owned(),
96                max_message_tokens: n_ctx - max_action_length,
97                max_action_length,
98                format_token_mapping: make_mapping([
99                    (FormattingToken::Start, "<|start|>"),
100                    (FormattingToken::Message, "<|message|>"),
101                    (FormattingToken::EndMessage, "<|end|>"),
102                    (FormattingToken::EndMessageDoneSampling, "<|return|>"),
103                    (FormattingToken::Refusal, "<|refusal|>"),
104                    (FormattingToken::ConstrainedFormat, "<|constrain|>"),
105                    (FormattingToken::Channel, "<|channel|>"),
106                    (FormattingToken::EndMessageAssistantToTool, "<|call|>"),
107                    (FormattingToken::BeginUntrusted, "<|untrusted|>"),
108                    (FormattingToken::EndUntrusted, "<|end_untrusted|>"),
109                ]),
110                stop_formatting_tokens: HashSet::from([
111                    FormattingToken::EndMessageDoneSampling,
112                    FormattingToken::EndMessageAssistantToTool,
113                    FormattingToken::EndMessage,
114                ]),
115                stop_formatting_tokens_for_assistant_actions: HashSet::from([
116                    FormattingToken::EndMessageDoneSampling,
117                    FormattingToken::EndMessageAssistantToTool,
118                ]),
119            })
120        }
121    }
122}
123
124fn make_mapping<I>(iter: I) -> HashMap<FormattingToken, String>
125where
126    I: IntoIterator<Item = (FormattingToken, &'static str)>,
127{
128    iter.into_iter().map(|(k, v)| (k, v.to_string())).collect()
129}