use dashmap::DashMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::RwLock;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
#[repr(transparent)]
pub struct Sym(u32);
impl Sym {
#[inline(always)]
pub const unsafe fn from_raw(index: u32) -> Self {
Sym(index)
}
#[inline(always)]
pub const fn as_raw(&self) -> u32 {
self.0
}
#[inline]
pub fn as_str(&self) -> &'static str {
resolve(*self)
}
}
impl Hash for Sym {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl fmt::Debug for Sym {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Sym({:?})", self.as_str())
}
}
impl fmt::Display for Sym {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl Serialize for Sym {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.as_str().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Sym {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(sym(&s))
}
}
impl AsRef<str> for Sym {
#[inline]
fn as_ref(&self) -> &str {
self.as_str()
}
}
mod static_strings {
use super::Sym;
pub const QUESTION: Sym = Sym(0);
pub const ANSWER: Sym = Sym(1);
pub const CONTEXT: Sym = Sym(2);
pub const THOUGHT: Sym = Sym(3);
pub const ACTION: Sym = Sym(4);
pub const OBSERVATION: Sym = Sym(5);
pub const REASONING: Sym = Sym(6);
pub const CODE: Sym = Sym(7);
pub const RESULT: Sym = Sym(8);
pub const INPUT: Sym = Sym(9);
pub const OUTPUT: Sym = Sym(10);
pub const QUERY: Sym = Sym(11);
pub const PASSAGE: Sym = Sym(12);
pub const DOCUMENT: Sym = Sym(13);
pub const SUMMARY: Sym = Sym(14);
pub const RATIONALE: Sym = Sym(15);
pub const CLAIM: Sym = Sym(16);
pub const EVIDENCE: Sym = Sym(17);
pub const LABEL: Sym = Sym(18);
pub const SCORE: Sym = Sym(19);
pub const RESPONSE: Sym = Sym(20);
pub const INSTRUCTION: Sym = Sym(21);
pub const TOOL: Sym = Sym(22);
pub const TOOL_INPUT: Sym = Sym(23);
pub const TOOL_OUTPUT: Sym = Sym(24);
pub const STEP: Sym = Sym(25);
pub const FINAL_ANSWER: Sym = Sym(26);
pub const EXPLANATION: Sym = Sym(27);
pub const HYPOTHESIS: Sym = Sym(28);
pub const CONCLUSION: Sym = Sym(29);
pub const FEEDBACK: Sym = Sym(30);
pub const ERROR: Sym = Sym(31);
pub const STATIC_COUNT: u32 = 32;
pub static STRINGS: [&str; STATIC_COUNT as usize] = [
"question", "answer", "context", "thought", "action", "observation", "reasoning", "code", "result", "input", "output", "query", "passage", "document", "summary", "rationale", "claim", "evidence", "label", "score", "response", "instruction", "tool", "tool_input", "tool_output", "step", "final_answer", "explanation", "hypothesis", "conclusion", "feedback", "error", ];
}
pub use static_strings::{
ACTION, ANSWER, CLAIM, CODE, CONCLUSION, CONTEXT, DOCUMENT, ERROR, EVIDENCE, EXPLANATION,
FEEDBACK, FINAL_ANSWER, HYPOTHESIS, INPUT, INSTRUCTION, LABEL, OBSERVATION, OUTPUT, PASSAGE,
QUERY, QUESTION, RATIONALE, REASONING, RESPONSE, RESULT, SCORE, STEP, SUMMARY, THOUGHT, TOOL,
TOOL_INPUT, TOOL_OUTPUT,
};
struct Interner {
string_to_sym: DashMap<Box<str>, Sym>,
dynamic_strings: RwLock<Vec<Box<str>>>,
next_index: AtomicU32,
}
impl Interner {
fn new() -> Self {
Self {
string_to_sym: DashMap::new(),
dynamic_strings: RwLock::new(Vec::new()),
next_index: AtomicU32::new(static_strings::STATIC_COUNT),
}
}
fn intern(&self, s: &str) -> Sym {
if let Some(sym) = static_lookup(s) {
return sym;
}
if let Some(entry) = self.string_to_sym.get(s) {
return *entry;
}
let boxed: Box<str> = s.into();
*self.string_to_sym.entry(boxed.clone()).or_insert_with(|| {
let index = self.next_index.fetch_add(1, Ordering::Relaxed);
let sym = Sym(index);
let mut strings = self
.dynamic_strings
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let dynamic_index = (index - static_strings::STATIC_COUNT) as usize;
if strings.len() <= dynamic_index {
strings.resize(dynamic_index + 1, "".into());
}
strings[dynamic_index] = boxed;
sym
})
}
fn resolve(&self, sym: Sym) -> &'static str {
let index = sym.0;
if index < static_strings::STATIC_COUNT {
return static_strings::STRINGS[index as usize];
}
let dynamic_index = (index - static_strings::STATIC_COUNT) as usize;
let strings = self
.dynamic_strings
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let s: &str = &strings[dynamic_index];
unsafe { std::mem::transmute::<&str, &'static str>(s) }
}
}
#[inline]
fn static_lookup(s: &str) -> Option<Sym> {
match s {
"question" => Some(QUESTION),
"answer" => Some(ANSWER),
"context" => Some(CONTEXT),
"thought" => Some(THOUGHT),
"action" => Some(ACTION),
"observation" => Some(OBSERVATION),
"reasoning" => Some(REASONING),
"code" => Some(CODE),
"result" => Some(RESULT),
"input" => Some(INPUT),
"output" => Some(OUTPUT),
"query" => Some(QUERY),
"passage" => Some(PASSAGE),
"document" => Some(DOCUMENT),
"summary" => Some(SUMMARY),
"rationale" => Some(RATIONALE),
"claim" => Some(CLAIM),
"evidence" => Some(EVIDENCE),
"label" => Some(LABEL),
"score" => Some(SCORE),
"response" => Some(RESPONSE),
"instruction" => Some(INSTRUCTION),
"tool" => Some(TOOL),
"tool_input" => Some(TOOL_INPUT),
"tool_output" => Some(TOOL_OUTPUT),
"step" => Some(STEP),
"final_answer" => Some(FINAL_ANSWER),
"explanation" => Some(EXPLANATION),
"hypothesis" => Some(HYPOTHESIS),
"conclusion" => Some(CONCLUSION),
"feedback" => Some(FEEDBACK),
"error" => Some(ERROR),
_ => None,
}
}
static INTERNER: std::sync::OnceLock<Interner> = std::sync::OnceLock::new();
fn get_interner() -> &'static Interner {
INTERNER.get_or_init(Interner::new)
}
#[inline]
pub fn sym(s: &str) -> Sym {
if let Some(sym) = static_lookup(s) {
return sym;
}
get_interner().intern(s)
}
#[inline]
pub fn resolve(sym: Sym) -> &'static str {
let index = sym.0;
if index < static_strings::STATIC_COUNT {
static_strings::STRINGS[index as usize]
} else {
get_interner().resolve(sym)
}
}
#[inline]
pub fn interned_count() -> u32 {
get_interner().next_index.load(Ordering::Relaxed)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_static_symbols() {
assert_eq!(sym("question"), QUESTION);
assert_eq!(sym("answer"), ANSWER);
assert_eq!(sym("context"), CONTEXT);
assert_eq!(sym("reasoning"), REASONING);
}
#[test]
fn test_static_resolve() {
assert_eq!(QUESTION.as_str(), "question");
assert_eq!(ANSWER.as_str(), "answer");
assert_eq!(REASONING.as_str(), "reasoning");
}
#[test]
fn test_dynamic_interning() {
let s1 = sym("custom_field_1");
let s2 = sym("custom_field_2");
assert_ne!(s1, s2);
assert_eq!(s1.as_str(), "custom_field_1");
assert_eq!(s2.as_str(), "custom_field_2");
}
#[test]
fn test_same_string_same_symbol() {
let s1 = sym("repeated_field");
let s2 = sym("repeated_field");
assert_eq!(s1, s2);
}
#[test]
fn test_sym_size() {
assert_eq!(std::mem::size_of::<Sym>(), 4);
}
#[test]
fn test_sym_hash() {
use std::collections::HashMap;
let mut map = HashMap::new();
map.insert(sym("key"), "value");
assert_eq!(map.get(&sym("key")), Some(&"value"));
}
#[test]
fn test_sym_ord() {
assert!(QUESTION < ANSWER);
assert!(ANSWER < CONTEXT);
}
#[test]
fn test_sym_display() {
assert_eq!(format!("{}", QUESTION), "question");
assert_eq!(format!("{}", sym("custom")), "custom");
}
#[test]
fn test_sym_debug() {
assert_eq!(format!("{:?}", QUESTION), "Sym(\"question\")");
}
#[test]
fn test_sym_serde() {
let original = sym("test_field");
let json = serde_json::to_string(&original).unwrap();
assert_eq!(json, "\"test_field\"");
let deserialized: Sym = serde_json::from_str(&json).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_interned_count() {
let initial = interned_count();
sym("count_test_1");
sym("count_test_2");
assert!(interned_count() >= initial + 2);
}
#[test]
fn test_all_static_strings() {
for (i, &s) in static_strings::STRINGS.iter().enumerate() {
let sym = Sym(i as u32);
assert_eq!(sym.as_str(), s);
assert_eq!(super::sym(s), sym);
}
}
}