use jieba_rs::{Jieba, TokenizeMode as JiebaTokenizeMode};
use rusqlite::ffi::{self, fts5_api, fts5_tokenizer_v2};
use rusqlite::types::ToSqlOutput;
use rusqlite::{Connection, params};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::ffi::{CStr, c_char, c_int, c_void};
use std::ptr;
use std::slice;
use std::str;
use std::sync::{Arc, Mutex, OnceLock, RwLock};
pub const VULCAN_DICT_TABLE: &str = "_vulcan_dict";
const SQLITE_JIEBA_TOKENIZER_NAME: &CStr = c"jieba";
const SQLITE_FTS5_API_PTR_TYPE: &CStr = c"fts5_api_ptr";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TokenizerMode {
#[default]
None,
Jieba,
}
impl TokenizerMode {
pub fn as_str(self) -> &'static str {
match self {
Self::None => "none",
Self::Jieba => "jieba",
}
}
#[allow(dead_code)]
pub fn parse(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"" | "none" | "plain" | "default" => Some(Self::None),
"jieba" | "zh" | "zh_cn" => Some(Self::Jieba),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CustomWordEntry {
pub word: String,
pub weight: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenizeOutput {
pub tokenizer_mode: String,
pub normalized_text: String,
pub tokens: Vec<String>,
pub fts_query: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DictionaryMutationResult {
pub success: bool,
pub message: String,
pub affected_rows: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ListCustomWordsResult {
pub success: bool,
pub message: String,
pub words: Vec<CustomWordEntry>,
}
#[derive(Debug, Default)]
struct SharedDictionaryState {
custom_words: Vec<CustomWordEntry>,
}
#[derive(Debug)]
struct RegisteredTokenizerContext {
connection_handle: usize,
shared_state: Arc<RwLock<SharedDictionaryState>>,
}
#[derive(Debug)]
struct JiebaTokenizerInstance {
shared_state: Arc<RwLock<SharedDictionaryState>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct TokenSpan {
token: String,
start_byte: usize,
end_byte: usize,
}
fn shared_dictionary_registry() -> &'static Mutex<HashMap<String, Arc<RwLock<SharedDictionaryState>>>> {
static REGISTRY: OnceLock<Mutex<HashMap<String, Arc<RwLock<SharedDictionaryState>>>>> =
OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
fn registered_connection_handles() -> &'static Mutex<HashSet<usize>> {
static REGISTRY: OnceLock<Mutex<HashSet<usize>>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
}
pub fn ensure_vulcan_dict_table(connection: &Connection) -> rusqlite::Result<()> {
connection.execute_batch(&format!(
"CREATE TABLE IF NOT EXISTS {table_name} (
word TEXT PRIMARY KEY,
weight INTEGER NOT NULL DEFAULT 1,
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);",
table_name = VULCAN_DICT_TABLE
))
}
pub fn ensure_jieba_tokenizer_registered(connection: &Connection) -> rusqlite::Result<()> {
ensure_vulcan_dict_table(connection)?;
let db_key = connection_registry_key(connection)?;
let shared_state = shared_dictionary_state_for_key(&db_key);
refresh_shared_dictionary_state(connection, &shared_state)?;
let connection_handle = sqlite_connection_handle(connection);
{
let registered = registered_connection_handles()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if registered.contains(&connection_handle) {
return Ok(());
}
}
let fts_api = fetch_fts5_api(connection)?;
let registration_context = Box::new(RegisteredTokenizerContext {
connection_handle,
shared_state,
});
let registration_context_ptr = Box::into_raw(registration_context) as *mut c_void;
let tokenizer = fts5_tokenizer_v2 {
iVersion: 2,
xCreate: Some(sqlite_jieba_tokenizer_create),
xDelete: Some(sqlite_jieba_tokenizer_delete),
xTokenize: Some(sqlite_jieba_tokenizer_tokenize),
};
let create = unsafe {
(*fts_api)
.xCreateTokenizer_v2
.ok_or_else(|| rusqlite::Error::ExecuteReturnedResults)?
};
let rc = unsafe {
create(
fts_api,
SQLITE_JIEBA_TOKENIZER_NAME.as_ptr(),
registration_context_ptr,
&tokenizer as *const fts5_tokenizer_v2 as *mut fts5_tokenizer_v2,
Some(sqlite_jieba_tokenizer_registration_destroy),
)
};
if rc != ffi::SQLITE_OK {
unsafe {
drop(Box::from_raw(
registration_context_ptr as *mut RegisteredTokenizerContext,
));
}
return Err(rusqlite::Error::SqliteFailure(
ffi::Error::new(rc),
Some("register jieba tokenizer failed / 注册 jieba tokenizer 失败".to_string()),
));
}
registered_connection_handles()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.insert(connection_handle);
Ok(())
}
pub fn upsert_custom_word(
connection: &Connection,
word: &str,
weight: usize,
) -> rusqlite::Result<DictionaryMutationResult> {
ensure_jieba_tokenizer_registered(connection)?;
let trimmed = word.trim();
if trimmed.is_empty() {
return Ok(DictionaryMutationResult {
success: false,
message: "custom word must not be empty / 自定义词不能为空".to_string(),
affected_rows: 0,
});
}
let affected_rows = connection.execute(
&format!(
"INSERT INTO {table_name} (word, weight, enabled, created_at, updated_at)
VALUES (?1, ?2, 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'), strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
ON CONFLICT(word) DO UPDATE SET
weight = excluded.weight,
enabled = 1,
updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')",
table_name = VULCAN_DICT_TABLE
),
params![trimmed, weight as i64],
)?;
refresh_registered_dictionary(connection)?;
Ok(DictionaryMutationResult {
success: true,
message: "custom word upserted / 自定义词已写入".to_string(),
affected_rows: affected_rows as u64,
})
}
pub fn remove_custom_word(
connection: &Connection,
word: &str,
) -> rusqlite::Result<DictionaryMutationResult> {
ensure_jieba_tokenizer_registered(connection)?;
let trimmed = word.trim();
if trimmed.is_empty() {
return Ok(DictionaryMutationResult {
success: false,
message: "custom word must not be empty / 自定义词不能为空".to_string(),
affected_rows: 0,
});
}
let affected_rows = connection.execute(
&format!("DELETE FROM {table_name} WHERE word = ?1", table_name = VULCAN_DICT_TABLE),
params![trimmed],
)?;
refresh_registered_dictionary(connection)?;
Ok(DictionaryMutationResult {
success: true,
message: if affected_rows > 0 {
"custom word removed / 自定义词已删除".to_string()
} else {
"custom word not found / 自定义词不存在".to_string()
},
affected_rows: affected_rows as u64,
})
}
pub fn load_custom_words(connection: &Connection) -> rusqlite::Result<Vec<CustomWordEntry>> {
ensure_vulcan_dict_table(connection)?;
let mut statement = connection.prepare(&format!(
"SELECT word, weight
FROM {table_name}
WHERE enabled = 1
ORDER BY word ASC",
table_name = VULCAN_DICT_TABLE
))?;
let mut rows = statement.query([])?;
let mut entries = Vec::new();
while let Some(row) = rows.next()? {
entries.push(CustomWordEntry {
word: row.get::<_, String>(0)?,
weight: row.get::<_, i64>(1)?.max(1) as usize,
});
}
Ok(entries)
}
pub fn list_custom_words(connection: &Connection) -> rusqlite::Result<ListCustomWordsResult> {
let words = load_custom_words(connection)?;
Ok(ListCustomWordsResult {
success: true,
message: format!(
"listed {} custom words / 已列出 {} 个自定义词",
words.len(),
words.len()
),
words,
})
}
pub fn tokenize_text(
connection: Option<&Connection>,
mode: TokenizerMode,
text: &str,
search_mode: bool,
) -> rusqlite::Result<TokenizeOutput> {
let normalized_text = normalize_text(text);
let tokens = match mode {
TokenizerMode::None => tokenize_plain(&normalized_text),
TokenizerMode::Jieba => tokenize_with_jieba(connection, &normalized_text, search_mode)?,
};
let fts_query = build_fts_query(&tokens, search_mode);
Ok(TokenizeOutput {
tokenizer_mode: mode.as_str().to_string(),
normalized_text,
tokens,
fts_query,
})
}
fn normalize_text(text: &str) -> String {
text.split_whitespace().collect::<Vec<_>>().join(" ")
}
fn tokenize_plain(text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let split = text
.split(|ch: char| ch.is_whitespace() || ch.is_ascii_punctuation())
.filter(|part| !part.is_empty())
.map(|part| part.to_string())
.collect::<Vec<_>>();
if split.is_empty() {
vec![text.to_string()]
} else {
split
}
}
fn tokenize_with_jieba(
connection: Option<&Connection>,
text: &str,
search_mode: bool,
) -> rusqlite::Result<Vec<String>> {
if text.is_empty() {
return Ok(Vec::new());
}
let custom_words = if let Some(connection) = connection {
ensure_jieba_tokenizer_registered(connection)?;
current_custom_words(connection)?
} else {
Vec::new()
};
Ok(jieba_token_spans(text, search_mode, &custom_words)
.into_iter()
.map(|span| span.token)
.collect())
}
fn build_fts_query(tokens: &[String], search_mode: bool) -> String {
tokens
.iter()
.filter(|token| !token.is_empty())
.map(|token| format!("\"{}\"", token.replace('"', "\"\"")))
.collect::<Vec<_>>()
.join(if search_mode { " OR " } else { " " })
}
fn current_custom_words(connection: &Connection) -> rusqlite::Result<Vec<CustomWordEntry>> {
let db_key = connection_registry_key(connection)?;
let shared_state = shared_dictionary_state_for_key(&db_key);
Ok(shared_state
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.custom_words
.clone())
}
fn refresh_registered_dictionary(connection: &Connection) -> rusqlite::Result<()> {
let db_key = connection_registry_key(connection)?;
let shared_state = shared_dictionary_state_for_key(&db_key);
refresh_shared_dictionary_state(connection, &shared_state)
}
fn shared_dictionary_state_for_key(
db_key: &str,
) -> Arc<RwLock<SharedDictionaryState>> {
let mut registry = shared_dictionary_registry()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
registry
.entry(db_key.to_string())
.or_insert_with(|| Arc::new(RwLock::new(SharedDictionaryState::default())))
.clone()
}
fn refresh_shared_dictionary_state(
connection: &Connection,
shared_state: &Arc<RwLock<SharedDictionaryState>>,
) -> rusqlite::Result<()> {
let custom_words = load_custom_words(connection)?;
let mut guard = shared_state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
guard.custom_words = custom_words;
Ok(())
}
fn connection_registry_key(connection: &Connection) -> rusqlite::Result<String> {
let handle = sqlite_connection_handle(connection);
match connection.path() {
Some(path) if !path.trim().is_empty() => Ok(path.to_string()),
_ => Ok(format!(":memory:#{handle:x}")),
}
}
fn sqlite_connection_handle(connection: &Connection) -> usize {
unsafe { connection.handle() as usize }
}
fn fetch_fts5_api(connection: &Connection) -> rusqlite::Result<*mut fts5_api> {
let p_ret: *mut fts5_api = ptr::null_mut();
let ptr_arg = ToSqlOutput::Pointer((&p_ret as *const *mut fts5_api as _, SQLITE_FTS5_API_PTR_TYPE, None));
connection.query_row("SELECT fts5(?)", [ptr_arg], |_| Ok(()))?;
if p_ret.is_null() {
return Err(rusqlite::Error::SqliteFailure(
ffi::Error::new(ffi::SQLITE_ERROR),
Some("fts5() returned a null API pointer / fts5() 返回了空指针".to_string()),
));
}
Ok(p_ret)
}
fn jieba_token_spans(
text: &str,
search_mode: bool,
custom_words: &[CustomWordEntry],
) -> Vec<TokenSpan> {
if text.is_empty() {
return Vec::new();
}
let mut jieba = Jieba::new();
for entry in custom_words {
jieba.add_word(&entry.word, Some(entry.weight), None);
}
let char_to_byte = unicode_char_to_byte_offsets(text);
let mode = if search_mode {
JiebaTokenizeMode::Search
} else {
JiebaTokenizeMode::Default
};
jieba
.tokenize(text, mode, true)
.into_iter()
.filter_map(|token| {
let trimmed = token.word.trim();
if trimmed.is_empty() {
return None;
}
let start_byte = *char_to_byte.get(token.start)?;
let end_byte = *char_to_byte.get(token.end)?;
Some(TokenSpan {
token: trimmed.to_string(),
start_byte,
end_byte,
})
})
.collect()
}
fn unicode_char_to_byte_offsets(text: &str) -> Vec<usize> {
let mut offsets = text.char_indices().map(|(index, _)| index).collect::<Vec<_>>();
offsets.push(text.len());
offsets
}
unsafe extern "C" fn sqlite_jieba_tokenizer_create(
user_data: *mut c_void,
_args: *mut *const c_char,
_arg_count: c_int,
out_tokenizer: *mut *mut ffi::Fts5Tokenizer,
) -> c_int {
if user_data.is_null() || out_tokenizer.is_null() {
return ffi::SQLITE_MISUSE;
}
let context = unsafe { &*(user_data as *const RegisteredTokenizerContext) };
let tokenizer = Box::new(JiebaTokenizerInstance {
shared_state: Arc::clone(&context.shared_state),
});
unsafe {
*out_tokenizer = Box::into_raw(tokenizer) as *mut ffi::Fts5Tokenizer;
}
ffi::SQLITE_OK
}
unsafe extern "C" fn sqlite_jieba_tokenizer_delete(tokenizer: *mut ffi::Fts5Tokenizer) {
if tokenizer.is_null() {
return;
}
unsafe {
drop(Box::from_raw(tokenizer as *mut JiebaTokenizerInstance));
}
}
unsafe extern "C" fn sqlite_jieba_tokenizer_registration_destroy(user_data: *mut c_void) {
if user_data.is_null() {
return;
}
let context = unsafe { Box::from_raw(user_data as *mut RegisteredTokenizerContext) };
registered_connection_handles()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.remove(&context.connection_handle);
}
#[allow(non_snake_case)]
unsafe extern "C" fn sqlite_jieba_tokenizer_tokenize(
tokenizer: *mut ffi::Fts5Tokenizer,
token_context: *mut c_void,
flags: c_int,
text_ptr: *const c_char,
text_len: c_int,
_locale_ptr: *const c_char,
_locale_len: c_int,
token_callback: Option<
unsafe extern "C" fn(
pCtx: *mut c_void,
tflags: c_int,
pToken: *const c_char,
nToken: c_int,
iStart: c_int,
iEnd: c_int,
) -> c_int,
>,
) -> c_int {
if tokenizer.is_null() || token_context.is_null() || text_ptr.is_null() || text_len < 0 {
return ffi::SQLITE_MISUSE;
}
let Some(token_callback) = token_callback else {
return ffi::SQLITE_MISUSE;
};
let tokenizer = unsafe { &*(tokenizer as *const JiebaTokenizerInstance) };
let shared_state = tokenizer
.shared_state
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let text_bytes = unsafe { slice::from_raw_parts(text_ptr as *const u8, text_len as usize) };
let Ok(text) = str::from_utf8(text_bytes) else {
return ffi::SQLITE_ERROR;
};
let search_mode = (flags & ffi::FTS5_TOKENIZE_QUERY) != 0 || (flags & ffi::FTS5_TOKENIZE_AUX) != 0;
let spans = jieba_token_spans(text, search_mode, &shared_state.custom_words);
for span in spans {
let token_bytes = span.token.as_bytes();
let rc = unsafe {
token_callback(
token_context,
0,
token_bytes.as_ptr() as *const c_char,
token_bytes.len() as c_int,
span.start_byte as c_int,
span.end_byte as c_int,
)
};
if rc != ffi::SQLITE_OK {
return rc;
}
}
ffi::SQLITE_OK
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn custom_words_affect_jieba_tokenization() -> rusqlite::Result<()> {
let connection = Connection::open_in_memory()?;
ensure_jieba_tokenizer_registered(&connection)?;
let before =
tokenize_text(Some(&connection), TokenizerMode::Jieba, "市民田-女士急匆匆", false)?;
assert!(!before.tokens.iter().any(|token| token == "田-女士"));
let mutation = upsert_custom_word(&connection, "田-女士", 42)?;
assert!(mutation.success);
let after =
tokenize_text(Some(&connection), TokenizerMode::Jieba, "市民田-女士急匆匆", false)?;
assert!(after.tokens.iter().any(|token| token == "田-女士"));
let removed = remove_custom_word(&connection, "田-女士")?;
assert!(removed.success);
Ok(())
}
#[test]
fn sqlite_fts_jieba_tokenizer_is_registered() -> rusqlite::Result<()> {
let connection = Connection::open_in_memory()?;
ensure_jieba_tokenizer_registered(&connection)?;
upsert_custom_word(&connection, "田-女士", 42)?;
connection.execute_batch(
"CREATE VIRTUAL TABLE IF NOT EXISTS mcp_memory_fts USING fts5(
content,
tokenize='jieba'
);",
)?;
connection.execute(
"INSERT INTO mcp_memory_fts (content) VALUES (?1)",
params!["市民田-女士急匆匆"],
)?;
connection.execute_batch(
"CREATE VIRTUAL TABLE IF NOT EXISTS mcp_memory_vocab USING fts5vocab(
mcp_memory_fts,
'instance'
);",
)?;
let count: i64 = connection.query_row(
"SELECT count(*) FROM mcp_memory_vocab WHERE term = ?1",
params!["田-女士"],
|row| row.get(0),
)?;
assert_eq!(count, 1);
Ok(())
}
}