use std::cell::RefCell;
use std::path::Path;
use std::str::FromStr;
use magnus::prelude::*;
use magnus::{Error, RArray, RHash, Ruby, function, method};
use lindera::mode::Mode;
use lindera::segmenter::Segmenter;
use lindera::tokenizer::{Tokenizer, TokenizerBuilder};
use crate::dictionary::{RbDictionary, RbUserDictionary};
use crate::error::to_magnus_error;
use crate::token::RbToken;
use crate::util::rb_hash_to_json;
#[magnus::wrap(class = "Lindera::TokenizerBuilder", free_immediately, size)]
pub struct RbTokenizerBuilder {
inner: RefCell<TokenizerBuilder>,
}
impl RbTokenizerBuilder {
fn new() -> Result<Self, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let inner = TokenizerBuilder::new().map_err(|err| {
to_magnus_error(&ruby, format!("Failed to create TokenizerBuilder: {err}"))
})?;
Ok(Self {
inner: RefCell::new(inner),
})
}
fn from_file(file_path: String) -> Result<Self, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let inner = TokenizerBuilder::from_file(Path::new(&file_path)).map_err(|err| {
to_magnus_error(&ruby, format!("Failed to load config from file: {err}"))
})?;
Ok(Self {
inner: RefCell::new(inner),
})
}
fn set_mode(&self, mode: String) -> Result<(), Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let m = Mode::from_str(&mode)
.map_err(|err| to_magnus_error(&ruby, format!("Failed to create mode: {err}")))?;
self.inner.borrow_mut().set_segmenter_mode(&m);
Ok(())
}
fn set_dictionary(&self, path: String) {
self.inner.borrow_mut().set_segmenter_dictionary(&path);
}
fn set_user_dictionary(&self, uri: String) {
self.inner.borrow_mut().set_segmenter_user_dictionary(&uri);
}
fn set_keep_whitespace(&self, keep_whitespace: bool) {
self.inner
.borrow_mut()
.set_segmenter_keep_whitespace(keep_whitespace);
}
fn append_character_filter(&self, kind: String, args: Option<RHash>) -> Result<(), Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let filter_args = if let Some(hash) = args {
rb_hash_to_json(&ruby, hash)?
} else {
serde_json::Value::Object(serde_json::Map::new())
};
self.inner
.borrow_mut()
.append_character_filter(&kind, &filter_args);
Ok(())
}
fn append_token_filter(&self, kind: String, args: Option<RHash>) -> Result<(), Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let filter_args = if let Some(hash) = args {
rb_hash_to_json(&ruby, hash)?
} else {
serde_json::Value::Object(serde_json::Map::new())
};
self.inner
.borrow_mut()
.append_token_filter(&kind, &filter_args);
Ok(())
}
fn build(&self) -> Result<RbTokenizer, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let tokenizer =
self.inner.borrow().build().map_err(|err| {
to_magnus_error(&ruby, format!("Failed to build tokenizer: {err}"))
})?;
Ok(RbTokenizer { inner: tokenizer })
}
}
#[magnus::wrap(class = "Lindera::Tokenizer", free_immediately, size)]
pub struct RbTokenizer {
inner: Tokenizer,
}
fn tokenizer_new(
dictionary: &RbDictionary,
mode: Option<String>,
user_dictionary: Option<&RbUserDictionary>,
) -> Result<RbTokenizer, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let mode_str = mode.as_deref().unwrap_or("normal");
let m = Mode::from_str(mode_str)
.map_err(|err| to_magnus_error(&ruby, format!("Failed to create mode: {err}")))?;
let dict = dictionary.inner.clone();
let user_dict = user_dictionary.map(|d| d.inner.clone());
let segmenter = Segmenter::new(m, dict, user_dict);
let tokenizer = Tokenizer::new(segmenter);
Ok(RbTokenizer { inner: tokenizer })
}
impl RbTokenizer {
fn tokenize(&self, text: String) -> Result<RArray, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let tokens = self
.inner
.tokenize(&text)
.map_err(|err| to_magnus_error(&ruby, format!("Failed to tokenize text: {err}")))?;
let rb_tokens: Vec<RbToken> = tokens.into_iter().map(RbToken::from_token).collect();
let arr = ruby.ary_new_capa(rb_tokens.len());
for token in rb_tokens {
arr.push(ruby.into_value(token))?;
}
Ok(arr)
}
fn tokenize_nbest(
&self,
text: String,
n: usize,
unique: Option<bool>,
cost_threshold: Option<i64>,
) -> Result<RArray, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let results = self
.inner
.tokenize_nbest(&text, n, unique.unwrap_or(false), cost_threshold)
.map_err(|err| {
to_magnus_error(&ruby, format!("Failed to tokenize_nbest text: {err}"))
})?;
let rb_results = ruby.ary_new_capa(results.len());
for (tokens, cost) in results {
let rb_tokens: Vec<RbToken> = tokens.into_iter().map(RbToken::from_token).collect();
let token_arr = ruby.ary_new_capa(rb_tokens.len());
for token in rb_tokens {
token_arr.push(ruby.into_value(token))?;
}
let pair = ruby.ary_new_capa(2);
pair.push(token_arr)?;
pair.push(cost)?;
rb_results.push(pair)?;
}
Ok(rb_results)
}
}
pub fn define(ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
let builder_class = module.define_class("TokenizerBuilder", ruby.class_object())?;
builder_class.define_singleton_method("new", function!(RbTokenizerBuilder::new, 0))?;
builder_class
.define_singleton_method("from_file", function!(RbTokenizerBuilder::from_file, 1))?;
builder_class.define_method("set_mode", method!(RbTokenizerBuilder::set_mode, 1))?;
builder_class.define_method(
"set_dictionary",
method!(RbTokenizerBuilder::set_dictionary, 1),
)?;
builder_class.define_method(
"set_user_dictionary",
method!(RbTokenizerBuilder::set_user_dictionary, 1),
)?;
builder_class.define_method(
"set_keep_whitespace",
method!(RbTokenizerBuilder::set_keep_whitespace, 1),
)?;
builder_class.define_method(
"append_character_filter",
method!(RbTokenizerBuilder::append_character_filter, 2),
)?;
builder_class.define_method(
"append_token_filter",
method!(RbTokenizerBuilder::append_token_filter, 2),
)?;
builder_class.define_method("build", method!(RbTokenizerBuilder::build, 0))?;
let tokenizer_class = module.define_class("Tokenizer", ruby.class_object())?;
tokenizer_class.define_singleton_method("new", function!(tokenizer_new, 3))?;
tokenizer_class.define_method("tokenize", method!(RbTokenizer::tokenize, 1))?;
tokenizer_class.define_method("tokenize_nbest", method!(RbTokenizer::tokenize_nbest, 4))?;
Ok(())
}